0

I am working on the iris data set from sklearn. As you may know the iris dataset has 3 classes ['setosa', 'versicolor', 'virginica']. I have made a scatter plot for this dataset. The details are as follows

from sklearn.datasets import load_iris
iris=load_iris()
Y_train=iris.target
X_train=iris.data
class_labels=iris.target_names
plt.scatter(X_train[:,0], X_train[:,1], c=Y_train)
plt.xlabel('attr1')
plt.ylabel('attr2')
plt.show()

Saccter plot:

I have got the scatter plot where you can see yellow, green and purple dots. I want to know which colour dot belongs to which class ('setosa', 'versicolor', 'virginica'). I would like to display legends so that I know which colour represents which class

1 Answers1

1

In this case, you can create a custom legend by looping through the labels and using the same colormap and norm as the one for the scatter plot. By default, the 'viridis' colormap is used, and a norm that maps the minimum color value to zero and the maximum to one.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

iris = load_iris()
Y_train = iris.target
X_train = iris.data
class_labels = iris.target_names
cmap = plt.get_cmap('viridis')
norm = plt.Normalize(Y_train.min(), Y_train.max())
plt.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap='viridis', norm=norm)
handles = [plt.Line2D([0, 0], [0, 0], color=cmap(norm(i)), marker='o', linestyle='', label=label)
           for i, label in enumerate(class_labels)]
plt.legend(handles=handles, title='Species')
plt.show()

scatter plot with legend

You could also use seaborn, although currently setting the legend labels isn't straightforward.

import seaborn as sns

sns.set()
ax = sns.scatterplot(x=X_train[:, 0], y=X_train[:, 1], hue=Y_train, palette='viridis')
ax.legend(ax.legend_.legendHandles, class_labels, title='Species')
JohanC
  • 71,591
  • 8
  • 33
  • 66