1

I would like to only include certain values in a seaborn heatmap legend. Specifically, I have a "nan" category that I don't want to see in the legend.

I am trying to plot ward movements for patients in a hospital as a kind of categorical heatmap, with different colours representing different wards. I have borrowed from this code heatmap-like plot, but for categorical variables in seaborn to configure my input table for the heatmap. Blank cells mean the patient was not in the hospital on those dates.

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import pandas as pd

data = {'12/3': [np.nan, 'Ward_B', np.nan],
        '13/3': [np.nan, 'Ward_B', np.nan],
        '14/3': [np.nan, 'Ward_B', 'ED'],
        '15/3': ['ED', 'Ward_A', 'Ward_C'],
        '16/3': ['ED', 'Ward_A', 'Ward_C'],
        '17/3': ['Ward_A', 'Ward_A', 'Ward_C'],
        '18/3': ['Ward_A', np.nan, 'Ward_C'],
        '19/3': ['Ward_A', np.nan, 'Ward_A'],
        '20/3': [np.nan, np.nan, 'Ward_A']}

df = pd.DataFrame (data, columns = ['12/3',
                                    '13/3',
                                    '14/3',
                                    '15/3',
                                    '16/3',
                                    '17/3',
                                    '18/3',
                                    '19/3',
                                    '20/3'])

# Create dataframe of patient IDs
patient_codes_df = pd.DataFrame(['Patient_A', 'Patient_B', 'Patient_C'])
# change heading
patient_codes_df = patient_codes_df.rename(columns={0:'Patient'})
# Merge
df2 = pd.concat([patient_codes_df, df], axis=1)
# Make Patient column the index
df3 = df2.set_index('Patient')
df3

df3 is what my input data looks like.

And this is how I'm plotting the heatmap

value_to_int = {j:i for i,j in enumerate(pd.unique(df3.values.ravel()))}
n = len(value_to_int)

cmap = sns.color_palette("Accent", n) # set colours

fig, ax = plt.subplots(1, 1, figsize = (6, 2), dpi=300)

mask = df3.isnull()
ax = sns.heatmap(df3.replace(value_to_int), cmap=cmap, mask=mask, linewidths=0.1, linecolor='#b5b5b5') 

ax.set_ylabel('')

# modify colorbar:
colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin 
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(value_to_int.keys()))  
plt.xticks(rotation=90)
plt.show()

I would like to get rid of "nan" from the legend, and also re order so it goes in a sensible order like ED, Ward_A, Ward_B, Ward_C.

Thanks for your help.

Will Hamilton
  • 357
  • 2
  • 17

2 Answers2

0

You have to drop the nans when you define value_to_int. The code you borrowed is good, but I guess the more straightforward way is to define your colors manually in a dictionary, then replace your data.frame with this while plotting:

lvls = {'ED': 0, 'Ward_A': 1, 'Ward_B': 2, 'Ward_C': 3}
cmap = sns.color_palette("Accent", len(lvls)) # set colours

fig, ax = plt.subplots(1, 1, figsize = (6, 2), dpi=300)
sns.heatmap(df3.replace(lvls),cmap=cmap,mask=df3.isnull(),linewidths=.1,
            linecolor='#b5b5b5',ax=ax)

colorbar = ax.collections[0].colorbar 
r = colorbar.vmax - colorbar.vmin
n = len(lvls)
colorbar.set_ticks([colorbar.vmin + r / n * (0.5 + i) for i in range(n)])
colorbar.set_ticklabels(list(lvls.keys()))
plt.xticks(rotation=90)
plt.show()

enter image description here

StupidWolf
  • 45,075
  • 17
  • 40
  • 72
  • Thanks, that looks more elegant than my solution of drawing my own legend from scratch. But I get an error in line 4 with your code - "NameError: name 'tcks' is not defined". I'm probably doing something stupid... – Will Hamilton May 09 '20 at 11:58
  • sorry about that.. thanks for the feedback. I was trying something and the tcks was kept. I have removed that argument, the code above should work now – StupidWolf May 09 '20 at 12:01
  • Now I get error from line 7, "AttributeError: 'NoneType' object has no attribute 'vmax'" – Will Hamilton May 09 '20 at 14:46
  • arh ok.. you did not call ```fig, ax = plt.subplots(1, 1, figsize = (6, 2), dpi=300)```, ok i include that – StupidWolf May 09 '20 at 14:55
  • Great that worked thanks! Is there a way to specify the order of the color bar? – Will Hamilton May 09 '20 at 15:06
  • thanks the purpose of the dictionary, lvls, so what is 0 is always bottom on the color, 3 is always on top.. and you have to assign the labels in the order, for example ```{ 'Ward_A': 0, 'Ward_B': 1, 'Ward_C': 2,'ED': 3}``` would have ward_A bottom ED top – StupidWolf May 09 '20 at 15:30
0

So one approach which worked was to remove the color bar and create my own legend manually:

# Define colours
cols = ["#ffff99", '#beaed4', '#fdc692', '#7fc97f', '#fd4396']

# Transform categorical variables into numbers for heatmap
value_to_int = {j:i for i,j in enumerate(pd.unique(df3.values.ravel()))}
n = len(value_to_int)
cmap = sns.color_palette(cols, n) # set colours

# Plot figure
fig, ax = plt.subplots(1, 1, figsize = (8, 3), dpi=300)
sns.set(font_scale=1.27, style='whitegrid')
mask = df3.isnull()
ax = sns.heatmap(df3.replace(value_to_int),
                 cmap=cmap, mask=mask, linewidths=0.1, linecolor='#b5b5b5',
                 cbar=False, # Remove the color bar legend
                 xticklabels=2)

# Tweaking figire
ax.set_ylabel('')
plt.xticks(rotation=90)

# Create a new legend
ED_patch = mpatches.Patch(color='#beaed4', label='ED')
A_patch = mpatches.Patch(color='#7fc97f', label='Ward A')
B_patch = mpatches.Patch(color='#fd4396', label='Ward B')
C_patch = mpatches.Patch(color='#3da1bf', label='Ward C')

ax.legend(handles=[ED_patch,A_patch,B_patch,C_patch],
         bbox_to_anchor=(1.22, 1),
         prop={'size': 12})
Will Hamilton
  • 357
  • 2
  • 17