I tried finding the OpenCV method for mean shift, but nothing came up. I am looking for a way to find clusters in an image and replace them by their mean value using python OpenCV. Any leads would be appreciated.
For example:
Input:
Output:
I tried finding the OpenCV method for mean shift, but nothing came up. I am looking for a way to find clusters in an image and replace them by their mean value using python OpenCV. Any leads would be appreciated.
For example:
Input:
Output:
Here is a result from sklearn:
Notice that the image is smoothed first to reduce noise. Also, this is not quite the algorithm from image segmentation paper, because the image and the kernels are flattened.
Here is the code:
import numpy as np
import cv2 as cv
from sklearn.cluster import MeanShift, estimate_bandwidth
img = cv.imread(your_image)
# filter to reduce noise
img = cv.medianBlur(img, 3)
# flatten the image
flat_image = img.reshape((-1,3))
flat_image = np.float32(flat_image)
# meanshift
bandwidth = estimate_bandwidth(flat_image, quantile=.06, n_samples=3000)
ms = MeanShift(bandwidth, max_iter=800, bin_seeding=True)
ms.fit(flat_image)
labeled=ms.labels_
# get number of segments
segments = np.unique(labeled)
print('Number of segments: ', segments.shape[0])
# get the average color of each segment
total = np.zeros((segments.shape[0], 3), dtype=float)
count = np.zeros(total.shape, dtype=float)
for i, label in enumerate(labeled):
total[label] = total[label] + flat_image[i]
count[label] += 1
avg = total/count
avg = np.uint8(avg)
# cast the labeled image into the corresponding average color
res = avg[labeled]
result = res.reshape((img.shape))
# show the result
cv.imshow('result',result)
cv.waitKey(0)
cv.destroyAllWindows()
For people who think it's straightforward to find mean shift in cv2, I have some information to share with you.
The meanShift() function given by @fmw42 is not what you want. I believe that many people (included me) have been googled and found it. The meanShift() function of OpenCV is used for object tracking.
What we want might be another function named pyrMeanShiftFiltering().
I haven't try it. Just FYI.