3

How can I make my mat plot lib interactive? For example, when I hover the mouse over each cell of a confusion matrix, I'd like to display the instance of that prediction.

confusion_mat_df = pd.DataFrame(confusion_mat,columns = pred_spectrum, index = actual_spectrum)

plt.figure(figsize=(7,5)) # width,height
sns.heatmap(confusion_mat_df, annot=True)
Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
yishairasowsky
  • 741
  • 1
  • 7
  • 21
  • Have a look at [mplcursors](https://mplcursors.readthedocs.io/en/stable/), the docs contain quite some examples, [this](https://mplcursors.readthedocs.io/en/stable/examples/image.html) might be helpful. – JohanC Jan 19 '20 at 16:22
  • yes, i want to use these tools, but i don't see yet how to implement them into sklearn confusion matrx – yishairasowsky Jan 20 '20 at 11:56
  • as i updated in the question, i actually convert to a df and then plot with sns. any tips? thanks! – yishairasowsky Jan 20 '20 at 13:31

1 Answers1

2

Here is an example to illustrate how to use mplcursors for an sklearn confusion matrix.

Unfortunately, mplcursors doesn't work with seaborn heatmaps. Seaborn uses a QuadMesh for the heatmap, which doesn't support the necessary coordinate picking.

In the code below I added the confidence at the center of the cell, similar to seaborn's. I also changed the colors the texts and arrows to be easier to read. You'll need to adapt the colors and sizes to your situation.

from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
import mplcursors

y_true = ["cat", "ant", "cat", "cat", "ant", "bird", "dog"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat", "dog"]
labels = ["ant", "bird", "cat", "dog"]
confusion_mat = confusion_matrix(y_true, y_pred, labels=labels)

fig, ax = plt.subplots()
heatmap = plt.imshow(confusion_mat, cmap="jet", interpolation='nearest')

for x in range(len(labels)):
    for y in range(len(labels)):
        ax.annotate(str(confusion_mat[x][y]), xy=(y, x),
                    ha='center', va='center', fontsize=18, color='white')

plt.colorbar(heatmap)
plt.xticks(range(len(labels)), labels)
plt.yticks(range(len(labels)), labels)
plt.ylabel('Predicted Values')
plt.xlabel('Actual Values')

cursor = mplcursors.cursor(heatmap, hover=True)
@cursor.connect("add")
def on_add(sel):
    i, j = sel.target.index
    sel.annotation.set_text(f'{labels[i]} - {labels[j]} : {confusion_mat[i, j]}')
    sel.annotation.set_fontsize(12)
    sel.annotation.get_bbox_patch().set(fc="papayawhip", alpha=0.9, ec='white')
    sel.annotation.arrow_patch.set_color('white')

plt.show()

plot

PS: The annotation can be multiline, for example:

sel.annotation.set_text(f'Predicted: {labels[i]}\nActual: {labels[j]}\n{confusion_mat[i, j]:5}')
JohanC
  • 71,591
  • 8
  • 33
  • 66