2

The following code:

import pandas as pd
import numpy as np

data_dict = {'Best fit': [395.0, 401.0, 358.0, 443.0, 357.0, 378.0, 356.0, 356.0, 403.0, 380.0, 397.0, 406.0, 409.0, 414.0, 350.0, 433.0, 345.0, 376.0, 374.0, 379.0, 9.0, 13.0, 10.0, 13.0, 16.0, 12.0, 6.0, 11.0, 20.0, 10.0, 12.0, 11.0, 15.0, 11.0, 11.0, 11.0, 15.0, 10.0, 8.0, 18.0, 864.0, 803.0, 849.0, 858.0, 815.0, 856.0, 927.0, 878.0, 834.0, 837.0, 811.0, 857.0, 848.0, 869.0, 861.0, 820.0, 887.0, 842.0, 834.0, np.nan], 'MDP': [332, 321, 304, 377, 304, 313, 289, 314, 341, 321, 348, 334, 361, 348, 292, 362, 285, 316, 291, 318, 3, 6, 5, 5, 4, 5, 4, 3, 8, 6, 4, 0, 8, 1, 4, 0, 9, 5, 3, 8, 770, 770, 819, 751, 822, 842, 758, 825, 886, 830, 774, 839, 779, 821, 812, 850, 822, 786, 874, 831], 'Q-Learning': [358, 329, 309, 381, 302, 319, 296, 315, 343, 318, 338, 336, 360, 357, 299, 363, 287, 337, 301, 334, 3, 6, 5, 5, 4, 5, 4, 3, 8, 6, 4, 0, 8, 1, 4, 0, 9, 5, 3, 8, 771, 833, 757, 837, 831, 784, 806, 890, 843, 775, 838, 776, 824, 830, 834, 827, 791, 868, 816, 806], 'parametrized_factor': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2]}
data = pd.DataFrame(data_dict)

# figure size
plt.figure(figsize=(12, 8))

# melt the dataframe into a long form
dfm = data.melt(id_vars='parametrized_factor')

# plot
ax = sns.boxplot(data=dfm, x='variable', y='value', hue='parametrized_factor', linewidth=0.7, palette="Set3")

ax.yaxis.grid(True) # Hide the horizontal gridlines
ax.xaxis.grid(True) # Show the vertical gridlines
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# ADDED: Remove labels.
ax.set_ylabel('Rejection ratio')    
ax.set_xlabel('')

plt.show()

Plots the following:

enter image description here

Is there a way to connect, for example, the 'Best Fit', 'MDP and 'Q-Learning' for every legend category?

In other words, how to connect same color boxplots by a line connecting its mean values?

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
krm76
  • 357
  • 3
  • 13

2 Answers2

2
  • Calculate the mean for each group, and then add them to the existing ax with a seaborn.lineplot
  • Set dodge=False in the seaborn.boxplot
  • Remember that the line in the boxplot is the median, not the mean.
    • Add the means to boxplot with showmeans=True, and then remove marker='o' from the lineplot, if desired.
  • As pointed out JohanC's answer:
    • sns.pointplot(data=dfm, x='variable', y='value', hue='parametrized_factor', ax=ax) can be used without the need for calculating dfm_mean, however there isn't a legend=False parameter, which then requires manually managing the legend.
    • Also, I think it's more straightforward to use dodge=False than to calculate the offsets.
    • Either answer is viable, depending on your requirements.
# calculate the mean for each group and convert to long format with melt
dfm_mean = data.groupby('parametrized_factor', as_index=False).mean().melt(id_vars='parametrized_factor')

# plot
# figure size
plt.figure(figsize=(12, 8))

# create the boxplot but set dodge to false, so all plots are on the same x-axis line
ax = sns.boxplot(data=dfm, x='variable', y='value', hue='parametrized_factor', linewidth=0.7, palette="Set3", dodge=False)

# plot a line plot with markers for the means
sns.lineplot(data=dfm_mean, x='variable', y='value', hue='parametrized_factor', marker='o', ax=ax, legend=False)

# set the legend outside
ax.legend(title='Factor', bbox_to_anchor=(1.05, 1), loc='upper left')

enter image description here

  • If dodge isn't False the result is:

enter image description here

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
2

You can create a point plot and recalculate the dodge width. For a box plot, there are 3 boxes equally spread over the default distance of 0.8. For a pointplot, the lines are put at the limits of the width, so a scaling is needed to make them fit the boxplots. See this github issue for extra information.

Note that you don't need to calculate the means, as that is the default estimator for pointplot. Error bars for the mean can be suppressed with ci=None.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

data_dict = {'Best fit': [395.0, 401.0, 358.0, 443.0, 357.0, 378.0, 356.0, 356.0, 403.0, 380.0, 397.0, 406.0, 409.0, 414.0, 350.0, 433.0, 345.0, 376.0, 374.0, 379.0, 9.0, 13.0, 10.0, 13.0, 16.0, 12.0, 6.0, 11.0, 20.0, 10.0, 12.0, 11.0, 15.0, 11.0, 11.0, 11.0, 15.0, 10.0, 8.0, 18.0, 864.0, 803.0, 849.0, 858.0, 815.0, 856.0, 927.0, 878.0, 834.0, 837.0, 811.0, 857.0, 848.0, 869.0, 861.0, 820.0, 887.0, 842.0, 834.0, np.nan], 'MDP': [332, 321, 304, 377, 304, 313, 289, 314, 341, 321, 348, 334, 361, 348, 292, 362, 285, 316, 291, 318, 3, 6, 5, 5, 4, 5, 4, 3, 8, 6, 4, 0, 8, 1, 4, 0, 9, 5, 3, 8, 770, 770, 819, 751, 822, 842, 758, 825, 886, 830, 774, 839, 779, 821, 812, 850, 822, 786, 874, 831], 'Q-Learning': [358, 329, 309, 381, 302, 319, 296, 315, 343, 318, 338, 336, 360, 357, 299, 363, 287, 337, 301, 334, 3, 6, 5, 5, 4, 5, 4, 3, 8, 6, 4, 0, 8, 1, 4, 0, 9, 5, 3, 8, 771, 833, 757, 837, 831, 784, 806, 890, 843, 775, 838, 776, 824, 830, 834, 827, 791, 868, 816, 806], 'parametrized_factor': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2]}
data = pd.DataFrame(data_dict)

sns.set_style('darkgrid')
plt.figure(figsize=(12, 8))

dfm = data.melt(id_vars='parametrized_factor')

ax = sns.boxplot(data=dfm, x='variable', y='value', hue='parametrized_factor', linewidth=0.7, palette="Set3")
sns.pointplot(data=dfm, x='variable', y='value', hue='parametrized_factor', ci=None,
              dodge=.8 - .8 / 3, scale=0.3, color='black', marker='D')

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles[:4], labels=labels[:3] + ["means"], title="parametrized factor",
          bbox_to_anchor=(1.02, 1.02), loc='upper left')

ax.set_ylabel('Rejection ratio')
ax.set_xlabel('')
plt.tight_layout()
plt.show()

pointplot to connect means

JohanC
  • 71,591
  • 8
  • 33
  • 66