1

I am trying to display multiple legends within multiple subplots. Currently, I am able to create figures with subplots and display multiple legends within one plot but can't combine the two. They seem to default to the last subplot.

This is the code I found that got me the closest:

import matplotlib.pyplot as plt
import numpy as np

f, axs = plt.subplots(2)
l1, = axs[0].plot([0, 1], [1, 0], label="line1")
h1 = [l1] 
l2, = axs[0].plot([0, 1], [0, 1], "--", label="line2")
h2 = [l2] 

lab1 = [h.get_label() for h in h1]
lab2 = [h.get_label() for h in h2]

leg1 = plt.legend(h1, lab1, loc=1)
leg2 = plt.legend(h2, lab2, loc=4)

plt.gca().add_artist(leg1)
plt.show()

But can't control the location of: plt.gca().add_artist(leg1)

Hopefully there's a better way to do this.

  • An alternative I found is 'plt.figlegend()' but I'm not completely satisfied with that option. – sunnyklein1 Apr 11 '22 at 16:02
  • 1
    What are you trying to get? A legend for each subplot? Or a subplot containing all legend entries of another subplot? – Davide_sd Apr 11 '22 at 16:03
  • Two legends for each subplot. One legend is for data specific to the experiment (one experiment = one subplot) while the other legend is meant to label the average of each experiment across each subplot. – sunnyklein1 Apr 11 '22 at 16:29

1 Answers1

0

Figured it out:

The trick was to plot on two sperate axes using the .twinx(). I added the "ax1.get_shared_y_axes().join(ax1, ax1prime)" and "ax1prime.axes.get_yaxis().set_visible(False)" to make it look better.

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import host_subplot

ax1 = host_subplot(211)
ax1.plot([0, 100], [0, 100], label='Data Set 1')
ax1.legend(loc=2)

ax1prime = ax1.twinx()
ax1prime.plot([0, 100], [100, 0], label='Data Set 2')
ax1prime.legend(loc=1)

ax1.get_shared_y_axes().join(ax1, ax1prime)
ax1prime.axes.get_yaxis().set_visible(False)

ax2 = host_subplot(212)
ax2.plot([0, 100], [0, 100], label='Data Set 3')
ax2.legend(loc=2)

ax2prime = ax2.twinx()
ax2prime.plot([0, 100], [100, 0], label='Data Set 4')
ax2prime.legend(loc=1)

ax2.get_shared_y_axes().join(ax2, ax2prime)
ax2prime.axes.get_yaxis().set_visible(False)
plt.show()