Search code examples
matplotlibseabornhistogramscatter-plotpairplot

Equivalent of Hist()'s Layout hyperparameter in Sns.Pairplot?


Am trying to find hist()'s figsize and layout parameter for sns.pairplot().

I have a pairplot that gives me nice scatterplots between the X's and y. However, it is oriented horizontally and there is no equivalent layout parameter to make them vertical to my knowledge. 4 plots per row would be great.

This is my current sns.pairplot():

sns.pairplot(X_train,
  x_vars = X_train.select_dtypes(exclude=['object']).columns,
  y_vars = ["SalePrice"])

enter image description here

This is what I would like it to look like: Source

num_mask = train_df.dtypes != object
num_cols = train_df.loc[:, num_mask[num_mask == True].keys()]
num_cols.hist(figsize = (30,15), layout = (4,10))
plt.show()

enter image description here


Solution

  • What you want to achieve isn't currently supported by sns.pairplot, but you can use one of the other figure-level functions (sns.displot, sns.catplot, ...). sns.lmplot creates a grid of scatter plots. For this to work, the dataframe needs to be in "long form".

    Here is a simple example. sns.lmplot has parameters to leave out the regression line (fit_reg=False), to set the height of the individual subplots (height=...), to set its aspect ratio (aspect=..., where the subplot width will be height times aspect ratio), and many more. If all y ranges are similar, you can use the default sharey=True.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # create some test data with different y-ranges
    np.random.seed(20230209)
    X_train = pd.DataFrame({"".join(np.random.choice([*'uvwxyz'], np.random.randint(3, 8))):
                                np.random.randn(100).cumsum() + np.random.randint(100, 1000) for _ in range(10)})
    X_train['SalePrice'] = np.random.randint(10000, 100000, 100)
    
    # convert the dataframe to long form
    # 'SalePrice' will get excluded automatically via `melt`
    compare_columns = X_train.select_dtypes(exclude=['object']).columns
    long_df = X_train.melt(id_vars='SalePrice', value_vars=compare_columns)
    
    # create a grid of scatter plots
    g = sns.lmplot(data=long_df, x='SalePrice', y='value', col='variable', col_wrap=4, sharey=False)
    g.set(ylabel='')
    plt.show()
    

    sns.lmplot for a grid of scatter plots

    Here is another example, with histograms of the mpg dataset:

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    mpg = sns.load_dataset('mpg')
    
    compare_columns = mpg.select_dtypes(exclude=['object']).columns
    mpg_long = mpg.melt(value_vars=compare_columns)
    g = sns.displot(data=mpg_long, kde=True, x='value', common_bins=False, col='variable', col_wrap=4, color='crimson',
                    facet_kws={'sharex': False, 'sharey': False})
    g.set(xlabel='')
    plt.show()
    

    sns.displot for a list of numeric columns