0

Am trying to find hist()'s figsize and layout parameter for sns.pairplot().

I have a pairplot that gives me nice scatterplots between the X's and y. However, it is oriented horizontally and there is no equivalent layout parameter to make them vertical to my knowledge. 4 plots per row would be great.

This is my current sns.pairplot():

sns.pairplot(X_train,
  x_vars = X_train.select_dtypes(exclude=['object']).columns,
  y_vars = ["SalePrice"])

enter image description here

This is what I would like it to look like: Source

num_mask = train_df.dtypes != object
num_cols = train_df.loc[:, num_mask[num_mask == True].keys()]
num_cols.hist(figsize = (30,15), layout = (4,10))
plt.show()

enter image description here

Katsu
  • 8,479
  • 3
  • 15
  • 16
  • What is `hist()`? – mwaskom Feb 09 '23 at 11:44
  • 1
    @mwaskom: Probably pandas' [`DataFrame.hist`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.hist.html) is meant here. Avoiding the need for `melt` in seaborn's `FacetGrid` functions seems quite complex given the many options available, and would create more confusion than it would solve. – JohanC Feb 09 '23 at 12:54
  • yep Dataframe's hist – Katsu Feb 09 '23 at 19:18

1 Answers1

1

What you want to achieve isn't currently supported by sns.pairplot, but you can use one of the other figure-level functions (sns.displot, sns.catplot, ...). sns.lmplot creates a grid of scatter plots. For this to work, the dataframe needs to be in "long form".

Here is a simple example. sns.lmplot has parameters to leave out the regression line (fit_reg=False), to set the height of the individual subplots (height=...), to set its aspect ratio (aspect=..., where the subplot width will be height times aspect ratio), and many more. If all y ranges are similar, you can use the default sharey=True.

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# create some test data with different y-ranges
np.random.seed(20230209)
X_train = pd.DataFrame({"".join(np.random.choice([*'uvwxyz'], np.random.randint(3, 8))):
                            np.random.randn(100).cumsum() + np.random.randint(100, 1000) for _ in range(10)})
X_train['SalePrice'] = np.random.randint(10000, 100000, 100)

# convert the dataframe to long form
# 'SalePrice' will get excluded automatically via `melt`
compare_columns = X_train.select_dtypes(exclude=['object']).columns
long_df = X_train.melt(id_vars='SalePrice', value_vars=compare_columns)

# create a grid of scatter plots
g = sns.lmplot(data=long_df, x='SalePrice', y='value', col='variable', col_wrap=4, sharey=False)
g.set(ylabel='')
plt.show()

sns.lmplot for a grid of scatter plots

Here is another example, with histograms of the mpg dataset:

import matplotlib.pyplot as plt
import seaborn as sns

mpg = sns.load_dataset('mpg')

compare_columns = mpg.select_dtypes(exclude=['object']).columns
mpg_long = mpg.melt(value_vars=compare_columns)
g = sns.displot(data=mpg_long, kde=True, x='value', common_bins=False, col='variable', col_wrap=4, color='crimson',
                facet_kws={'sharex': False, 'sharey': False})
g.set(xlabel='')
plt.show()

sns.displot for a list of numeric columns

JohanC
  • 71,591
  • 8
  • 33
  • 66
  • Awesome it looks like it worked for me! One follow up, is there a way to make the font bigger? It is really small and `g.set(font_scale = 2)` does not work on my end. – Katsu Feb 09 '23 at 19:09
  • 1
    You can set the scale before creating the plot by calling `sns.set(font_scale=2)` – JohanC Feb 09 '23 at 19:22