2

I tried finding the OpenCV method for mean shift, but nothing came up. I am looking for a way to find clusters in an image and replace them by their mean value using python OpenCV. Any leads would be appreciated.

For example:

Input:

enter image description here

Output:

enter image description here

Hissaan Ali
  • 2,229
  • 4
  • 25
  • 51
  • If you are not limited to using opencv and mean-shift, you could follow [this example of normalized cuts](https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_ncut.html#sphx-glr-auto-examples-segmentation-plot-ncut-py), I also recommend performing the segmentation in the LAB space. There are other approaches [here](https://scikit-image.org/docs/dev/api/skimage.segmentation.html?highlight=segmentation#module-skimage.segmentation). – JoOkuma Jun 25 '20 at 13:21
  • What have you tried? See https://docs.opencv.org/4.1.1/dc/d6b/group__video__track.html#ga432a563c94eaf179533ff1e83dbb65ea – fmw42 Jun 25 '20 at 16:50
  • You probably want mean_shift from Python Wand, which is based upon ImageMagick. See https://imagemagick.org/discourse-server/viewtopic.php?f=4&t=25504 – fmw42 Jun 25 '20 at 17:00

2 Answers2

5

Here is a result from sklearn:

enter image description here

Notice that the image is smoothed first to reduce noise. Also, this is not quite the algorithm from image segmentation paper, because the image and the kernels are flattened.

Here is the code:

import numpy as np
import cv2 as cv
from sklearn.cluster import MeanShift, estimate_bandwidth


img = cv.imread(your_image)

# filter to reduce noise
img = cv.medianBlur(img, 3)

# flatten the image
flat_image = img.reshape((-1,3))
flat_image = np.float32(flat_image)

# meanshift
bandwidth = estimate_bandwidth(flat_image, quantile=.06, n_samples=3000)
ms = MeanShift(bandwidth, max_iter=800, bin_seeding=True)
ms.fit(flat_image)
labeled=ms.labels_


# get number of segments
segments = np.unique(labeled)
print('Number of segments: ', segments.shape[0])

# get the average color of each segment
total = np.zeros((segments.shape[0], 3), dtype=float)
count = np.zeros(total.shape, dtype=float)
for i, label in enumerate(labeled):
    total[label] = total[label] + flat_image[i]
    count[label] += 1
avg = total/count
avg = np.uint8(avg)

# cast the labeled image into the corresponding average color
res = avg[labeled]
result = res.reshape((img.shape))

# show the result
cv.imshow('result',result)
cv.waitKey(0)
cv.destroyAllWindows()

N. Osil
  • 494
  • 7
  • 13
1

For people who think it's straightforward to find mean shift in cv2, I have some information to share with you.

The meanShift() function given by @fmw42 is not what you want. I believe that many people (included me) have been googled and found it. The meanShift() function of OpenCV is used for object tracking.

What we want might be another function named pyrMeanShiftFiltering().

I haven't try it. Just FYI.

hiankun
  • 31
  • 3