0

I have made a pairplot using the Iris dataset using the code below, but the regression lines show for each of the three iris flowers, rather than for the dataset as a whole (which is what I am looking for). I want each segment of pairplot to color each data point according to which category of Iris it came from, but I want the regression line to be for the whole sample rather than having three separate regression lines within each segment. Is this possible?

(df is my DataFrame created from a downloaded .csv file for the dataset.)

iris_pairplot = sns.pairplot(df, hue = "variety", palette="Dark2", height=3, aspect=1, corner=True, kind="reg")
iris_pairplot.fig.suptitle("Pairplot of traits for full Iris sample", fontsize = "xx-large")
plt.tight_layout()
plt.savefig('iris_pairplot.png')

Iris pairplot

I have tried to use .regplot() but that seems to be for individual scatterplots as opposed to a pairplot?

Trenton McKinney
  • 56,955
  • 33
  • 144
  • 158
jmce23
  • 1

1 Answers1

2

Creation of a pairplot with the regression line for all data (not split by hue) is not automatically possible within pairplot. You will need to create the pairplot first without the kind='reg' option, which will plot the plot without the lines.

Then, you can take each of the individual subplots within this excluding the diagonal subplots using map_offdiag(), which will give you each of the subplots. Note: I get iris with the column as species, not variety... you may need to rename the column. You can plot the regression line there for each of the subplots. Hope this is what you are looking for...

df = sns.load_dataset('iris')
## Note, removed kind=reg"
iris_pairplot = sns.pairplot(df, hue = "species", palette="Dark2", height=3, aspect=1, corner=True)#, kind="reg")
iris_pairplot.fig.suptitle("Pairplot of traits for full Iris sample", fontsize = "xx-large")

## Define function to plot a single regression line
def regline(x, y, **kwargs):
    sns.regplot(data=kwargs['data'], x=x.name, y=y.name, scatter=False, color=kwargs['color'])

## Call the function for each non-diagonal subplot within pairplot
iris_pairplot.map_offdiag(regline, color='red', data=df)

plt.tight_layout()
plt.show()

enter image description here

Redox
  • 9,321
  • 5
  • 9
  • 26