0

I'm trying to use numpy to implement k-means in order to perform basic image segmentation based on pixel color. However, when I run my program and have it print the cost function and the locations of the centroids after each iteration, it seems like something is wrong. The cost function is oscillating and centroids don't converge to a local optimum.

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image

rng = np.random.default_rng()

img = image.imread('fruits_small.jpg') / 256.0
h, w = img.shape[:2]

copy = np.array(img)

plt.subplot(3, 3, 1)
plt.title('Original')
plt.imshow(img)

for plot, k in enumerate(range(5, 13)):
    centroids = rng.choice(copy.reshape((-1, 3)), size=k, replace=False)
    clusters = np.empty((h, w))

    print(centroids)

    while True:
        for y, x in np.ndindex(img.shape[:2]):
            v = copy[y, x]
            clusters[y, x] = np.argmin(np.linalg.norm(centroids - v, axis=1))

        cost = 0
        for i in range(k):
            cost += np.linalg.norm(copy[clusters == i] - centroids[i], axis=1).sum()

        print(f'cost = {cost}')

        d = 0
        for i in range(k):
           new_centroid = copy[clusters == i].mean(axis=0)
           d += np.linalg.norm(centroids[i] - new_centroid)
           centroids[i] = new_centroid

        if d == 0:
            break

        print(centroids)

    for i in range(k):
        img[clusters == i] = centroids[i]

    plt.subplot(3, 3, plot + 1)
    plt.title(f'k = {k}')
    plt.imshow(img)

plt.show()

Have I made a mistake in my implementation somewhere?

Bradley Garagan
  • 155
  • 1
  • 9
  • One of the problems with your algorithm is in the initialization. Try to think what happens when you set two or more centroids to be the same. they have to be unique to avoid this problem. – yann ziselman Jun 23 '21 at 07:38

1 Answers1

1

There is an amazing explanation about K-means convergence and oscillation you can check it out here!

I spent some time reviewing your implementation, I did not find any problems with it. Except that in some cases (different starting point or image) may cause Empty cluster error (here you can read more) which leads to the following error in my experiments

RuntimeWarning: Mean of empty slice.
  print(copy[clusters == i].mean())

as an unprincipled solution, you can double-check the nearest point for each centroid should be a member of the relative cluster!

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image
from scipy.cluster.vq import vq
import sys

rng = np.random.default_rng()

img = image.imread('images.jpg') / 256.0
h, w = img.shape[:2]

copy = np.array(img)

plt.subplot(3, 3, 1)
plt.title('Original')
plt.imshow(img)

for plot, k in enumerate(range(5, 13)):
    centroids = rng.choice(copy.reshape((-1, 3)), size=k, replace=False)
    clusters = np.empty((h, w))

    print(centroids)

    while True:
        for y, x in np.ndindex(img.shape[:2]):
            v = copy[y, x]
            clusters[y, x] = np.argmin(np.linalg.norm(centroids - v, axis=1))
        closest, _ = vq(centroids, copy.reshape(-1, 3))
        centroids = copy.reshape(-1, 3)[closest]
        

        for i in range(k):
            clusters[closest[i]//h, closest[i]%w] = i
        
        cost = 0
        for i in range(k):
            cost += np.linalg.norm(copy[clusters == i] - centroids[i], axis=1).sum()

        print(f'cost = {cost}')
        for i in range(k):
            print(copy[clusters == i].mean())
        d = 0
        for i in range(k):
           new_centroid = copy[clusters == i].mean(axis=0)
           d += np.linalg.norm(centroids[i] - new_centroid)
           centroids[i] = new_centroid

        if np.linalg.norm(d)<=0.1:
            break

        print(centroids)

    for i in range(k):
        img[clusters == i] = centroids[i]

    plt.subplot(3, 3, plot + 2)
    plt.title(f'k = {k}')
    plt.imshow(img)

plt.show()

Final point is that in RGB color space, sometimes Euclidean distance (what has been used as the algebraic norm) loses its meaning. So that a pale yellow may look more like purple than a bolder yellow. If you use the average of channels, you will get more intuitive results. Be aware your algorithm doesn't care about special dependencies!

meti
  • 1,921
  • 1
  • 8
  • 15