This is a good use case for numba, because it lets you express this as a simple double loop without a big performance penalty, which in turn allows you to avoid the excessive extra memory of using np.tile
to replicate the data across a third dimension just to do it in a vectorized manner.
Borrowing the standard vectorized numpy implementation from the other answer, I have these two implementations:
import numba
import numpy as np
def kmeans_assignment(centroids, points):
num_centroids, dim = centroids.shape
num_points, _ = points.shape
# Tile and reshape both arrays into `[num_points, num_centroids, dim]`.
centroids = np.tile(centroids, [num_points, 1]).reshape([num_points, num_centroids, dim])
points = np.tile(points, [1, num_centroids]).reshape([num_points, num_centroids, dim])
# Compute all distances (for all points and all centroids) at once and
# select the min centroid for each point.
distances = np.sum(np.square(centroids - points), axis=2)
return np.argmin(distances, axis=1)
@numba.jit
def kmeans_assignment2(centroids, points):
P, C = points.shape[0], centroids.shape[0]
distances = np.zeros((P, C), dtype=np.float32)
for p in range(P):
for c in range(C):
distances[p, c] = np.sum(np.square(centroids[c] - points[p]))
return np.argmin(distances, axis=1)
Then for some sample data, I did a few timing experiments:
In [12]: points = np.random.rand(10000, 50)
In [13]: centroids = np.random.rand(30, 50)
In [14]: %timeit kmeans_assignment(centroids, points)
196 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [15]: %timeit kmeans_assignment2(centroids, points)
127 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
I won't go as far to say that the numba version is certainly faster than the np.tile
version, but clearly it's very close while not incurring the extra memory cost of np.tile
.
In fact, I noticed for my laptop that when I make the shapes larger and use (10000, 1000) for the shape of points
and (200, 1000) for the shape of centroids
, then np.tile
generated a MemoryError
, meanwhile the numba
function runs in under 5 seconds with no memory error.
Separately, I actually noticed a slowdown when using numba.jit
on the first version (withnp.tile
), which is likely due to the extra array creation inside the jitted function combined with the fact that there's not much numba can optimize when you're already calling all vectorized functions.
And I also did not notice any significant improvement in the second version when trying to shorten the code by using broadcasting. E.g. shortening the double loop to be
for p in range(P):
distances[p, :] = np.sum(np.square(centroids - points[p, :]), axis=1)
did not really help anything (and would use more memory when repeatedly broadcasting points[p, :]
across all of centroids
).
This is one of the really nice benefits of numba. You really can write the algorithms in a very straightforward, loop-based way that comports with standard descriptions of algorithms and allows finer point of control over how the syntax unpacks into memory consumption or broadcasting... all without giving up runtime performance.