0

I have a problem like I have to implement shift algorithm and perform the segmentation for the image. here is vegetable image I have to use a suitable bandwidth such that the vegetables look as seprated as can. I used manually sklearn estimate_bandwidth to calculate bandwidth and i hard coded. I am not allowed to use sklearn i just can use numpy,PIL or matplotlib to implement this. here is what i tried

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Load the image
img = np.array(Image.open("peppers.jpg"))

# Convert the image to grayscale
gray_img = np.mean(img, axis=2)

# Flatten the image to a 2D array of pixel values
flat_img = gray_img.reshape((-1, 1))

# Define the distance metric
def euclidean_distance(x1, x2):
    return np.sqrt(np.sum((x1 - x2) ** 2))

# Estimate the bandwidth parameter using the median of the pairwise distances
bandwidth = 0.24570638879032147

# Perform Mean Shift clustering
centroids = []
for i, point in enumerate(flat_img):
    centroid = point
    converged = False
    while not converged:
        points_within_bandwidth = flat_img[euclidean_distance(flat_img, centroid) < bandwidth]
        new_centroid = np.mean(points_within_bandwidth, axis=0)
        if euclidean_distance(new_centroid, centroid) < 1e-5:
            converged = True
        centroid = new_centroid
    centroids.append(centroid)

# Assign each data point to a cluster based on its converged mean
labels = np.zeros_like(flat_img)
for i, centroid in enumerate(centroids):
    labels[euclidean_distance(flat_img, centroid) < bandwidth] = i

# Reshape the labels to the shape of the original image
segmented_img = labels.reshape(gray_img.shape)

# Display the segmented image
plt.imshow(segmented_img)
plt.show()

First it took a long time and does not showed the right output.

timp bill
  • 57
  • 7

0 Answers0