1

I am using matplotlib to create a table with 5 rows and 4 columns. I would like to show the difference in values for all the entries in each individual column using color. Ideally, I would like to use a colormap scale which is individualized for each column, meaning that the scale for that column's colormap would have the range of that column's values.

To clarify - in the second column, values are between 800-1200, but the first column values are between 120-230. When the same colormap is applied on the entire table's range, the difference between values in the first column is much less defined than they would be if the colormap range was 120-230 instead of 120-1200.

This does not seem possible with matplotlib, as the colormap applies to the entire table. What I want could also just be a terrible and confusing presentation, so if there is a better way to show what I would like please let me know!

This is what I have now:

fig, ax = plt.subplots()

rows = ['%d nodes' % x for x in (10, 30, 50, 75, 100)]
columns=['TP', 'TN', 'FP', 'FN']

conf_data = np.array(
[[ 230,  847,  784,  208],
 [ 156, 1240,  391,  282],
 [ 146, 1212,  419,  292],
 [ 130, 1148,  483,  308],
 [ 122, 1173,  458,  316]]
)  

normal = plt.Normalize(np.min(conf_data), np.max(conf_data))
fig.patch.set_visible(False)
ax.axis('off')
ax.axis('tight')
ax.table(cellText=conf_data,
         rowLabels=rows,
         colLabels=columns,
         cellColours=cm.GnBu(normal(conf_data)),
         loc='center',
         colWidths=[0.1 for x in columns])
fig.tight_layout()
plt.show()
JohanC
  • 71,591
  • 8
  • 33
  • 66
jpc
  • 129
  • 3
  • 16

2 Answers2

1

You could subtract the minimum of each column and divide by their ptp. np.ptp or peak-to-peak distance is the difference between maximum and minimum. Save this into a new array to be used for the color values.

To avoid the too dark blue for the highest values, you could multiply the result by something like 0.8. (Alternatively, you could change the text color, which needs some additional code.)

import numpy as np
from matplotlib import pyplot as plt

fig, ax = plt.subplots()

rows = ['%d nodes' % x for x in (10, 30, 50, 75, 100)]
columns = ['TP', 'TN', 'FP', 'FN']
conf_data = np.array([[230, 847, 784, 208],
                      [156, 1240, 391, 282],
                      [146, 1212, 419, 292],
                      [130, 1148, 483, 308],
                      [122, 1173, 458, 316]])
normed_data = (conf_data - conf_data.min(axis=0, keepdims=True)) / conf_data.ptp(axis=0, keepdims=True)

fig.patch.set_visible(False)
ax.axis('off')
ax.axis('tight')
table = ax.table(cellText=conf_data,
                 rowLabels=rows,
                 colLabels=columns,
                 cellColours=plt.cm.GnBu(normed_data*0.8),
                 loc='center',
                 colWidths=[0.1 for x in columns])
table.scale(2, 2) # make table a little bit larger
fig.tight_layout()
plt.show()

Below the result with two different color maps: 'GnBu' at the left with normed values between 0 and 0.8, and 'coolwarm' at the right with normed values between 0.1 and 0.9

resulting plot

PS: Another way to improve the contrast between cell text and background color, is to set each cell's alpha:

for cell in table._cells:
    table._cells[cell].set_alpha(.6)
JohanC
  • 71,591
  • 8
  • 33
  • 66
1

You can calculate the norm for each column:

import matplotlib.cm as cm
fig, ax = plt.subplots()

rows = ['%d nodes' % x for x in (10, 30, 50, 75, 100)]
columns=['TP', 'TN', 'FP', 'FN']

conf_data = np.array(
[[ 230,  847,  784,  208],
 [ 156, 1240,  391,  282],
 [ 146, 1212,  419,  292],
 [ 130, 1148,  483,  308],
 [ 122, 1173,  458,  316]]
)  

colores = np.zeros((conf_data.shape[0], conf_data.shape[1], 4))
for i in range(conf_data.shape[1]):
    col_data = conf_data[:, i]
    normal = plt.Normalize(np.min(col_data), np.max(col_data))
    colores[:, i] = cm.Reds(normal(col_data))

#fig.patch.set_visible(False)
ax.axis('off')
ax.axis('tight')
ax.table(cellText=conf_data,
         rowLabels=rows,
         colLabels=columns,
         cellColours=colores,
         loc='center',
         colWidths=[0.1 for x in columns])
fig.tight_layout()
plt.show()
jjsantoso
  • 1,586
  • 1
  • 12
  • 17