2

I found an excellent tutorial on drawing a heatmap for a confusion matrix, but I want to add some errors of commission and omission on the sides.

I'll try to explain using this image:

confusion matrix

This means:

  1. I need to insert a number beside each of the boxes containing 0, 6, and 9 just right of the right edge of the image, and to the left of the legend

  2. I need to insert a number above the each of boxes containing 13, 0 and 0 just above the top edge of the image, just below the title.

(so 6 numbers in total)

Is this even possible? I know nothing about the plotting functions in Python, as I'm new to the language. It just seems like a very difficult task from where I'm standing.

SonicProtein
  • 840
  • 1
  • 11
  • 28

3 Answers3

1

You could do this using ticks.

Let me present this approach with the following easy plot:

from matplotlib import pyplot as plt

ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)

for i in range(3):
    for j in range(3):
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)

plt.show()

enter image description here

I will not focus on the colors neither on the tick style, but know that you can change these very easily.

You can create an Axes object that will share ax's Y axis, with ax.twiny(). Then, you can add X ticks on this new Axes, which will appear on top of the plot:

from matplotlib import pyplot as plt

ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)

for i in range(3):
    for j in range(3):
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)

ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])

plt.show()

enter image description here

In order to display ticks for the X axis, you have to create an Axes object that shares ax's Y axis, with ax.twiny(). This might seem counter-intuitive, but if you used ax.twinx() instead, then modifying ax2's X ticks would modify ax's as well, because they're actually the same. Then, you want to set the X window of ax2, so that it has three squares. After that, you can set the ticks: one in every square, at the horizontal center, so at [0.5, 1.5, 2.5]. Finally, you can set the tick labels to display the desired value.

Then, you just do the same with the Y ticks:

from matplotlib import pyplot as plt

ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)

for i in range(3):
    for j in range(3):
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)
        ax.fill_between((i, i+1), j, j+1)

ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])

ax3 = ax.twinx()
ax3.set_ylim(ax.get_ylim())
ax3.set_yticks([0.5, 1.5, 2.5])
ax3.set_yticklabels([0, 6, 9])

plt.show()

enter image description here

Right leg
  • 16,080
  • 7
  • 48
  • 81
1

Use the following modified function. The idea is following:

  • Add two twin axes - one to the right and other to the top.
  • Set the limits of the twin axes equal to that of the original axes
  • Set the positions of the ticks on the twin axes to be the same as that of the original axes
  • Hide the tick marks and assign the tick-labels
  • Shift the title a bit upward using y=1.1

def plot_confusion_matrix(y_true, y_pred, classes, normalize=False,
                          title=None, cmap=plt.cm.Blues):
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    cm = confusion_matrix(y_true, y_pred)
    classes = classes[unique_labels(y_true, y_pred)]
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    fig, ax = plt.subplots(figsize=(6.5,6))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           ylabel='True label',
           xlabel='Predicted label')
    ax.set_title(title, y=1.1)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Adding data to the right
    ax2 = ax.twinx()
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticks(np.arange(cm.shape[0]))
    ax2.set_yticklabels(cm[:, -1])
    ax2.tick_params(axis="y", right=False)

    # Adding data to the top
    ax3 = ax.twiny()
    ax3.set_xlim(ax.get_xlim())
    ax3.set_xticks(np.arange(cm.shape[0]))
    ax3.set_xticklabels(cm[:, 0])
    ax3.tick_params(axis="x", top=False)
    ax.set_aspect('auto')


    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

enter image description here

Sheldore
  • 37,862
  • 7
  • 57
  • 71
  • I'm not sure why but I keep getting "RuntimeError: adjustable='datalim' is not allowed when both axes are shared." when I try to add your code into the method. – SonicProtein Apr 29 '19 at 21:31
  • @SonicProtein: I have included the complete function. – Sheldore Apr 29 '19 at 21:42
  • According to the traceback, in apply_aspect in matplotlib\axes\_base.py (looking at the documentation, the method does this: "Adjust the Axes for a specified data aspect ratio.") and right after the if statement "if shared_x and shared_y:" I'm not sure what all of this means, however – SonicProtein Apr 29 '19 at 21:42
  • @SonicProtein: Can you comment out `fig.tight_layout()` and then try again? – Sheldore Apr 29 '19 at 21:44
0

A rather manual approach would consist of combining the following items until the result is satisfactory:

  • using twinx and twiny to get new axes on the top and on the right: twinax = ax.twinx().twiny()
  • using twinax.set(xlim=ax.get_xlim(), ylim=ax.get_ylim()) to match their range with the range of the original axes, then...
  • using twinax.set(xticks=ax.get_xticks(), yticks=ax.get_yticks, xticklabels=('0','1','2'), yticklabels = ('0','1','2')) to set the labels on the new axes as was done in your example (these two calls can be combined if you like).
  • (If you don't want the actual ticks (only the labels) you can give them 0 length through tick_params.)
  • You can reposition the axes with set_position.
  • See this question for info on how to move the colorbar.
Koen G.
  • 738
  • 5
  • 10