I am working on the iris data set from sklearn. As you may know the iris dataset has 3 classes ['setosa', 'versicolor', 'virginica']. I have made a scatter plot for this dataset. The details are as follows
from sklearn.datasets import load_iris
iris=load_iris()
Y_train=iris.target
X_train=iris.data
class_labels=iris.target_names
plt.scatter(X_train[:,0], X_train[:,1], c=Y_train)
plt.xlabel('attr1')
plt.ylabel('attr2')
plt.show()
I have got the scatter plot where you can see yellow, green and purple dots. I want to know which colour dot belongs to which class ('setosa', 'versicolor', 'virginica'). I would like to display legends so that I know which colour represents which class