cr.sparse.cluster.vq.kmeans¶
- cr.sparse.cluster.vq.kmeans(key, points, k, iter=20, thresh=1e-05, max_iters=100)[source]¶
Clusters points using k-means algorithm
- Parameters
key – a PRNG key used as the random key
points (jax.numpy.ndarray) – Each row of the points matrix is a point. From the statistical point of view, each row is an observation vector and each column is a feature.
k (int) – The number of clusters
iter (int) – The number of times k-means will be restarted with different seeds. The result with least amount of distortion is returned.
thresh (float) – Convergence threshold on change in distortion
max_iters (int) – Maximum number of iterations for each replicate of k-means algorithm
- Returns
A named tuple consisting of:
centroids : centroid for each cluster
assignment: assignment of each point to a cluster
distortion: distortion after current assignment
key: The PRNG key seed for the k-means run with the least distortion
iterations: number of iterations taken in convergence
- Return type
Let the k centroids be \(m_1, m_2, \dots, m_k\). Let the n points be \(x_1, x_2, \dots, x_n\). Let the assignment of i-th point to j-th cluster be given by \(a_1, a_2, \dots, a_n\) where \(1 \leq a_i = j \leq k\).
Then the distance of i-th point from its centroid is given by:
(1)¶\[d_i = \| x_i - m_{a_i} \|_2\]The distortion is given by the mean of all the distances.