12

Any option to extract the distance between the nodes and the centroid in a kmeans cluster.

I have done Kmeans clustering over an text embedding data set and I want to know which are the nodes that are far away from the Centroid in each of the cluster, so that I can check the respective node's features which is making a difference.

Thanks in advance!

Arav
  • 143
  • 1
  • 1
  • 10

3 Answers3

24

KMeans.transform() returns an array of distances of each sample to the cluster center.

import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
plt.style.use('ggplot')
import seaborn as sns

# Generate some random clusters
X, y = make_blobs()
kmeans = KMeans(n_clusters=3).fit(X)

# plot the cluster centers and samples 
sns.scatterplot(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1], 
                marker='+', 
                color='black', 
                s=200);
sns.scatterplot(X[:,0], X[:,1], hue=y, 
                palette=sns.color_palette("Set1", n_colors=3));

enter image description here

transform X and take the sum of each row (axis=1) to identify samples furthest from the centers.

# squared distance to cluster center
X_dist = kmeans.transform(X)**2

# do something useful...
import pandas as pd
df = pd.DataFrame(X_dist.sum(axis=1).round(2), columns=['sqdist'])
df['label'] = y

df.head()
    sqdist  label
0   211.12  0
1   257.58  0
2   347.08  1
3   209.69  0
4   244.54  0

A visual check -- the same plot, only this time with the furthest points to each cluster center highlighted:

# for each cluster, find the furthest point
max_indices = []
for label in np.unique(kmeans.labels_):
    X_label_indices = np.where(y==label)[0]
    max_label_idx = X_label_indices[np.argmax(X_dist[y==label].sum(axis=1))]
    max_indices.append(max_label_idx)

# replot, but highlight the furthest point
sns.scatterplot(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1], 
                marker='+', 
                color='black', 
                s=200);
sns.scatterplot(X[:,0], X[:,1], hue=y, 
                palette=sns.color_palette("Set1", n_colors=3));
# highlight the furthest point in black
sns.scatterplot(X[max_indices, 0], X[max_indices, 1], color='black');

enter image description here

Kevin
  • 7,960
  • 5
  • 36
  • 57
  • 1
    Perfect Answer that was. Basically, along with finding the farthest node in a cluster, I am trying ti set a threshold for each cluster and filter out all those nodes which have a greater 'sqdist' than the threshold. So I'm taking the mean off the 'sqdist' of each cluster and using that as a threshold, does that sound sensible fro your point of view? – Arav Jan 18 '19 at 11:53
  • Difficult to say that it's sensible without knowing your downstream use case, but it seems like a reasonable method to get all points that are closest to the centriod. If you think this was the best answer to your original question feel free to [accept the answer](https://stackoverflow.com/help/someone-answers). – Kevin Jan 18 '19 at 12:09
1

If you are using Python and sklearn.

From here: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans

you can get labels_ and cluster_centers_.

Now, you determine the distance function that takes the vector of each node and its cluster center. Filter by labels_ and calculate distances for each point inside each label.

avchauzov
  • 1,007
  • 1
  • 8
  • 13
0

Kevin has a great answer above but I feel like it does not answer the question that is asked (Maybe I am reading this completely wrong). If you are trying to look at each individual cluster center and get the point in that cluster that is furthest from the center, you will need to use the cluster labels to get the distance of each point to the centroid of that cluster. The code above just finds the point in each cluster that is furthest from ALL other cluster centers (which is you can see in the picture, the points are always on the far side of the cluster away from the other 2 clusters). In order to look at the individual clusters you would need something like the following:

center_dists = np.array([X_dist[i][x] for i,x in enumerate(y)])

This will give you the distance of each point to the centroid of its cluster. Then by running almost the same code that Kevin has above, it will give you the point that is the furthest away in each cluster.

max_indices = []
for label in np.unique(kmeans.labels_):
    X_label_indices = np.where(y==label)[0]
    max_label_idx = X_label_indices[np.argmax(center_dists[y==label])]
    max_indices.append(max_label_idx)
JBarrett
  • 1
  • 1
  • Does this center_dists code comes instead of X_dist in kevin's code? – KSp Jan 05 '22 at 14:15
  • Yes. The X_dist in Kevin's code is used to get the center_dists. And then that is plugged into the same spot in the for loop to get the point in each cluster that is furthest from the cluster center. – JBarrett Jan 25 '22 at 01:56