# Copyright 2021 CR-Suite Development Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The implementation in this file is based on the example provided by
Sabrina J. Mielke
in https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA#scrollTo=LJyoi46rIJr7
"""
from typing import NamedTuple
from jax import lax, jit, vmap, random
import jax.numpy as jnp
from jax.numpy.linalg import norm
[docs]class KMeansState(NamedTuple):
"""The state for K-means algorithm
"""
centroids: jnp.ndarray
"""Current set of centroids"""
assignment: jnp.ndarray
"""Current assignment of points to centroids"""
distortion : float
""" Current mean distance"""
prev_distortion : float
""" Previous mean distance"""
iterations: int
"""The number of iterations it took to complete"""
[docs]class KMeansSolution(NamedTuple):
"""The solution for K-means algorithm
"""
centroids: jnp.ndarray
"""Current set of centroids"""
assignment: jnp.ndarray
"""Current assignment of points to centroids"""
distortion : float
""" Current mean distance"""
key: jnp.ndarray
""" The PRNG key seed for the k-means run with least distortion"""
iterations: int
"""The number of iterations it took to complete"""
[docs]def find_nearest(point, centroids):
"""Returns the index of the nearest centroid for a specific point
Args:
point (jax.numpy.ndarray) : A specific point
centroids (jax.numpy.ndarray) : An array of centroids
Returns:
(int) : The index of the nearest centroid
"""
return jnp.argmin(vmap(norm)(centroids - point))
find_nearest_jit = jit(find_nearest)
[docs]def find_assignment(points, centroids):
"""Finds the assignment of each point to a specific centroid
Args:
points (jax.numpy.ndarray) : Each row of the points matrix is a point.
centroids (jax.numpy.ndarray) : An array of centroids
Returns:
(jax.numpy.ndarray, jax.numpy.ndarray): A tuple consisting of
#. An assignment array of each point to a cluster
#. Distance of each point from corresponding cluster centroid
"""
assignment = vmap(lambda point: find_nearest(point, centroids))(points)
errors = centroids[assignment, :] - points
distances = vmap(norm)(errors)
return assignment, distances
find_assignment_jit = jit(find_assignment)
def assignment_counts(assignment, k):
"""Returns the number of points in each cluster based on the current assignment
If a cluster has no points, we return 1.
"""
return ((assignment[jnp.newaxis, :] == jnp.arange(k)[:, jnp.newaxis])
.sum(axis=1, keepdims=True)
.clip(min=1))
[docs]def find_new_centroids(assignment, points, k):
"""Finds new centroids based on current assignment
Args:
assignment (jax.numpy.ndarray) : current assignment of each point to a specific cluster
points (jax.numpy.ndarray) : Each row of the points matrix is a point.
k (int): The number of clusters
"""
counts = assignment_counts(assignment, k)
new_centroids = jnp.sum(
jnp.where(
# axes: (data points, clusters, data dimension)
assignment[:, jnp.newaxis, jnp.newaxis] \
== jnp.arange(k)[jnp.newaxis, :, jnp.newaxis],
points[:, jnp.newaxis, :],
0.,
),
axis=0,
) / counts
return new_centroids
find_new_centroids_jit = jit(find_new_centroids, static_argnums=(2,))
[docs]def kmeans_with_seed(key, points, k, thresh=1e-5, max_iters=100):
"""Runs the k-means algorithm for a specific random initialization
Args:
key: a PRNG key used as the random key for choosing initial centroids
points (jax.numpy.ndarray): Each row of the points matrix is a point.
k (int): The number of clusters
thresh (float): Convergence threshold on change in distortion
max_iters (int): Maximum number of iterations for k-means algorithm
Returns:
(KMeansState): A named tuple consisting of:
centroids for each cluster, assignment of each point to a cluster,
current distorition, previous distortion, number of iterations
for convergence.
"""
# number of points
n = points.shape[0]
def init():
# select k points as initial centroids randomly
indices = random.permutation(key, jnp.arange(n))[:k]
# the initial centroids
centroids = points[indices, :]
# assign all points to centroids and compute distances
assignment, distances = find_assignment(points, centroids)
distortion = jnp.mean(distances)
# algorithm state
return KMeansState(centroids=centroids, assignment=assignment,
distortion=distortion, prev_distortion=jnp.inf, iterations=0)
def body(state):
# update centroids
centroids = find_new_centroids(state.assignment, points, k)
# update assignment
assignment, distances = find_assignment(points, centroids)
# mean distance
distortion = jnp.mean(distances)
# algorithm state
return KMeansState(centroids=centroids, assignment=assignment,
distortion=distortion, prev_distortion=state.distortion,
iterations=state.iterations+1)
def cond(state):
# check if the mean distance has updated enough
gap = state.prev_distortion - state.distortion
# print(state.prev_distortion, state.distortion, gap, thresh, gap > thresh)
return jnp.logical_and(gap > thresh, state.iterations < max_iters)
# state = init()
# while cond(state):
# state = body(state)
state = lax.while_loop(cond, body, init())
return state
kmeans_with_seed_jit = jit(kmeans_with_seed, static_argnums=(2,3))
[docs]def kmeans(key, points, k, iter=20, thresh=1e-5, max_iters=100):
r"""Clusters points using k-means algorithm
Args:
key: a PRNG key used as the random key
points (jax.numpy.ndarray): Each row of the points matrix is a point.
From the statistical point of view, each row is an observation vector
and each column is a feature.
k (int): The number of clusters
iter (int): The number of times k-means will be restarted with
different seeds. The result with least amount of distortion is returned.
thresh (float): Convergence threshold on change in distortion
max_iters (int): Maximum number of iterations for each replicate of k-means algorithm
Returns:
(KMeansSolution): A named tuple consisting of:
* centroids : centroid for each cluster
* assignment: assignment of each point to a cluster
* distortion: distortion after current assignment
* key: The PRNG key seed for the k-means run with the least distortion
* iterations: number of iterations taken in convergence
Let the k centroids be :math:`m_1, m_2, \dots, m_k`.
Let the n points be :math:`x_1, x_2, \dots, x_n`.
Let the assignment of i-th point to j-th cluster be given by
:math:`a_1, a_2, \dots, a_n` where :math:`1 \leq a_i = j \leq k`.
Then the distance of i-th point from its centroid is given by:
.. math::
d_i = \| x_i - m_{a_i} \|_2
The distortion is given by the mean of all the distances.
"""
# keys for each restart of kmeans algorithm
keys = random.split(key, iter)
# individual run of k-means algorithm
kmeans_core = lambda key: kmeans_with_seed(key, points, k,
thresh=thresh, max_iters=max_iters)
# Run all restarts of kmeans using vmap
results = vmap(kmeans_core, 0, 0)(keys)
# Find the run with the least distortion
i = jnp.argmin(results.distortion)
return KMeansSolution(centroids=results.centroids[i], assignment=results.assignment[i],
distortion=results.distortion[i], key=keys[i],
iterations=results.iterations[i])
kmeans_jit = jit(kmeans, static_argnums=(2,3,4,5))