1

Is there a way to group labels when using plotnine's facet_grid function?

For example, take the following DataFrame (I made the example up such that you can reproduce my problem. In real my DataFrame is much bigger but the structure is almost identical):

df_test = pd.DataFrame(data= {'rate_1': [0.0,0.1,0.2,0.3,0.0,0.1,0.2,0.3,0.3,0.2,0.1,0.0],
                              'rate_2': [0.0,0.1,0.2,0.3,0.0,0.1,0.2,0.3,0.3,0.2,0.1,0.0],
                              'rate_3':[0.0,0.1,0.2,0.3,0.0,0.1,0.2,0.3,0.3,0.2,0.1,0.0],
                              'rate_4': [0.0,0.1,0.2,0.3,0.1,0.2,0.3,0.0,0.2,0.3,0.0,0.1],
                              'samples': [50000, 100000, 50000, 100000,50000, 100000, 50000, 100000,50000, 100000, 50000, 100000],
                              'model': ['model 1', 'model 2', 'model 3', 'model 4','model 1', 'model 2', 'model 3', 'model 4','model 1', 'model 2', 'model 3', 'model 4'],
                              'metric': [0.5,0.4,0.3,0.2,0.5,0.4,0.3,0.2,1,0.3,0.4,0.3]})

I created the following function to plot the absolute metric depending on the rates, sample size and model:

# Plotting
my_plot = ggplot(data=df_test, mapping=aes('rate_1', 'rate_2', fill='metric'))
my_plot += geom_tile()
my_plot += theme(axis_text_x=element_text(rotation=90), figure_size=(16, 8), strip_background_x=element_text(width=1.))
my_plot += scale_x_continuous(breaks=[.0, .1, .2, .3, .4, .5])
my_plot += scale_y_continuous(breaks=[.0, .1, .2, .3, .4, .5])
my_plot += facet_grid('rate_3 + samples  ~ model + rate_4', labeller=label_value)
my_plot += scale_fill_gradient2(midpoint=0, low='blue', mid="white", high="red")
my_plot

The resulting plot looks like:

enter image description here

Is there a way to group model names (model 1, model 2, model 3, model 4) and rate_4 when labeling the facet_grid? I'm looking for a result such that the 'columns' are named like the following (I used Excel for the illustration):

enter image description here Thanks!

Jannik
  • 965
  • 2
  • 12
  • 21

1 Answers1

1

Thanks to this hint, I was able to "group" the labels by manipulating the matplotlib text and background objects:

# Plotting
my_plot = ggplot(data=df_test, mapping=aes('rate_1', 'rate_2', fill='metric'))
my_plot += geom_tile()
my_plot += theme(axis_text_x=element_text(rotation=90), figure_size=(16, 8), strip_background_x=element_text(width=1.))
my_plot += scale_x_continuous(breaks=[.0, .1, .2, .3, .4, .5])
my_plot += scale_y_continuous(breaks=[.0, .1, .2, .3, .4, .5])
my_plot += facet_grid('rate_3 + samples  ~ model + rate_4', labeller=label_value)
my_plot += scale_fill_gradient2(midpoint=0, low='blue', mid="white", high="red")

fig = my_plot.draw()
fig._themeable['strip_text_x'][0].set_text('\n 0.0')
fig._themeable['strip_text_x'][1].set_text('model 1 \n 0.10')
fig._themeable['strip_text_x'][2].set_text('\n 0.20')
fig._themeable['strip_background_x'][0].set_width(3.0)

fig._themeable['strip_text_x'][3].set_text('\n 0.00')
fig._themeable['strip_text_x'][4].set_text('model 2 \n 0.10')
fig._themeable['strip_text_x'][5].set_text('\n 0.20')
fig._themeable['strip_background_x'][3].set_width(3.0)

fig._themeable['strip_text_x'][6].set_text('\n 0.00')
fig._themeable['strip_text_x'][7].set_text('model 3 \n 0.10')
fig._themeable['strip_text_x'][8].set_text('\n 0.20')
fig._themeable['strip_background_x'][6].set_width(3.0)

fig._themeable['strip_text_x'][9].set_text('\n 0.00')
fig._themeable['strip_text_x'][10].set_text('model 4 \n 0.10')
fig._themeable['strip_text_x'][11].set_text('\n 0.20')
fig._themeable['strip_background_x'][9].set_width(3.0)

plt.show()

The result looks like:

Bildschirmfoto 2020-08-21 um 07 50 11

Jannik
  • 965
  • 2
  • 12
  • 21