1

Following an example on SO to plot annotated heatmaps, I am running into an issue with legends. The poster used a workaround where the legends are created from invisible bar plots. This works for only two axes, so if I include anymore, they begin to run into each other. Here is a copy of that code with my own modifications (COMMENTS IN ALL CAPS):

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from functools import reduce

#load data
networks = sns.load_dataset("brain_networks", index_col=0, header=[0, 1, 2])

# network colors
network_labels = networks.columns.get_level_values("network")
network_pal = sns.cubehelix_palette(network_labels.unique().size, light=.9, dark=.1, reverse=True, start=1, rot=-2)
network_lut = dict(zip(map(str, network_labels.unique()), network_pal))

# create network index
network_colors = pd.Series(network_labels, index=networks.columns).map(network_lut)


# node network colors
node_labels = networks.columns.get_level_values("node")
node_pal = sns.cubehelix_palette(node_labels.unique().size)
node_lut = dict(zip(map(str, node_labels.unique()), node_pal))

# create network index
node_colors = pd.Series(node_labels, index=networks.columns).map(node_lut)

#df of row and col maps
network_node_colors = pd.DataFrame(network_colors).join(pd.DataFrame(node_colors))

#THIS STUFF IS NEW
# HEMI COLORS
hemi_labels = network.columns.get_level_values('hemi')
hemi_pal = sns.hls_palette(hemi_labels.unique().size)
hemi_lut = dict(zip(map(str, hemi_labels.unique()), hemi_pal))

# CREATE HEMI INDEX
hemi_colors = pd.Series(hemi_labels, index=networks.columns).map(hemi_lut)

# CREATE HEMI INDEX
network_node_hemi_colors = network_node_colors.join(pd.DataFrame(hemi_colors))


g = sns.clustermap(networks.corr(),
# Turn off the clustering
row_cluster=True, col_cluster=True,
# Add colored class labels using data frame created from node and network colors
row_colors = network_node_colors,
col_colors = network_node_hemi_colors,
# Make the plot look better when many rows/cols
linewidths=0,
xticklabels=False, yticklabels=False,
center=0, cmap="vlag")

# network legend
for label in network_labels.unique():
    g.ax_col_dendrogram.bar(0, 0, color=network_lut[str(label)], label=label, linewidth=0)

l1 = g.ax_col_dendrogram.legend(title='Network', loc="upper right", ncol=1, bbox_to_anchor=(1.2,0.55), bbox_transform=plt.gcf().transFigure)

# node legend
for label in node_labels.unique():
    g.ax_row_dendrogram.bar(0, 0, color=node_lut[label], label=label, linewidth=0)

l2 = g.ax_row_dendrogram.legend(title='Node', loc='upper right', ncol=1, bbox_to_anchor=(6.47, 1))

# HEMI LEGEND
for label in hemi_labels.unique():
    g.ax_row_dendrogram.bar(0, 0, color=hemi_lut[label], label=label, linewidth=0)

l3 = g.ax_row_dendrogram.legend(title='hemi', loc='upper right', ncol=1, bbox_to_anchor=(6.47, 1))

#output graph and legends
plt.savefig('heatmap_annotated.svg', format='svg')

Here is what I get: enter image description here

Bonus points for getting the legends to appear on the output SVG!

Thomas Matthew
  • 2,826
  • 4
  • 34
  • 58
  • There is no need for an invisible bar plot. You can simply create proxy artists for the legend handles. I warmly recommend the [documentation](https://matplotlib.org/tutorials/intermediate/legend_guide.html#creating-artists-specifically-for-adding-to-the-legend-aka-proxy-artists), which is quite good on that topic. – Paul Brodersen Feb 12 '20 at 10:38

0 Answers0