Search code examples
pythonpandasseaborndisplothistplot

How can I wrap subplot columns


I've been struggling with visualizing subplots column wrapping in Seaborn histogram plots (kdeplot, histplot). Tried various things including fig, ax & enumerate(zip(df.columns, ax.flatten()).

Here's the dataset

 for col in df.columns:
  plt.figure(figsize = (3,3))
  sns.histplot(df, x = col, kde = True, bins = 40, hue = 'Dataset', fill = True)
  plt.show();

How can the plots be done with other seaborn plots or plots with facet wrap functionality?


Solution

  • import pandas as pd
    import seaborn as sns
    
    # load the dataset downloaded from https://www.kaggle.com/uciml/indian-liver-patient-records
    df = pd.read_csv('d:/data/kaggle/indian_liver_patient.csv')
    
    # convert the data to a long form
    dfm = df.melt(id_vars=['Gender', 'Dataset'])
    
    # plot the data for each gender
    for gender, data in dfm.groupby('Gender'):
        
        g = sns.displot(kind='hist', data=data, x='value', col='variable', hue='Dataset',
                        hue_order=[1, 2], common_norm=False, common_bins=False,
                        multiple='dodge', kde=True, col_wrap=3, height=2.5, aspect=2,
                        facet_kws={'sharey': False, 'sharex': False}, palette='tab10')
        
        fig = g.fig
        
        fig.suptitle(f'Gender: {gender}', y=1.02)
    
        fig.savefig(f'hist_{gender}.png', bbox_inches='tight')
    
    • The only problem with this option is common_bins=False means the bins of the two hue groups don't align. However, setting it to True causes sharex=False to be ignored, so all of the x-axis limits will be 0 - 2000, as can be seen in this plot.

    enter image description here

    enter image description here


    • The plot generated by the following code has too many columns
      • col_wrap can't be used if row is also in use.
    g = sns.displot(kind='hist', data=dfm, x='value', row='Dataset', col='variable', hue='Gender',
                    common_norm=False, common_bins=False, multiple='dodge', kde=True,
                    facet_kws={'sharey': False, 'sharex': False})
    
    g.fig.savefig('hist.png')
    
    • The following plot does not separate the data by 'Gender'.
    g = sns.displot(kind='hist', data=dfm, x='value', col='variable', col_wrap=3,
                    hue='Dataset', common_norm=False, common_bins=False,
                    multiple='dodge', kde=True, height=2.5, aspect=2,
                    facet_kws={'sharey': False, 'sharex': False}, palette='tab10')
    

    • The following option correctly allows common_bins=True to be used.
    import seaborn as sns
    import numpy as np
    import pandas as pd
    
    # load the dataset
    df = pd.read_csv('d:/data/kaggle/indian_liver_patient.csv')
    
    # convert the data to a long form
    dfm = df.melt(id_vars=['Gender', 'Dataset'])
    
    # iterate through the data for each gender
    for gen, data in dfm.groupby('Gender'):
        
        # create the figure and axes
        fig, axes = plt.subplots(3, 3, figsize=(11, 5), sharex=False, sharey=False, tight_layout=True)
        
        # flatten the array of axes
        axes = axes.flatten()
        
        # iterate through each axes and variable category
        for ax, (var, sel) in zip(axes, data.groupby('variable')):
            
            sns.histplot(data=sel, x='value', hue='Dataset', hue_order=[1, 2], kde=True, ax=ax,
                         common_norm=False, common_bins=True, multiple='dodge', palette='tab10')
            
            ax.set(xlabel='', title=var.replace('_', ' ').title())
            ax.spines[['top', 'right']].set_visible(False)
        
        # remove all the legends except for Aspartate Aminotrnsferase, which will be move to used for the figure
        for ax in np.append(axes[:5], axes[6:]):
            ax.get_legend().remove()
            
        sns.move_legend(axes[5], bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
            
        fig.suptitle(f'Gender: {gen}', y=1.02)
        
        fig.savefig(f'hist_{gen}.png', bbox_inches='tight')
    

    enter image description here

    enter image description here


    • Some columns in df have significant outliers. Removing them will improve the histogram visualization.
    from scipy.stats import zscore
    from typing import Literal
    
    
    def remove_outliers(data: pd.DataFrame, method: Literal['std', 'z'] = 'std') -> pd.DataFrame:
        # remove outliers with std or zscore
        if method == 'std':
            std = data.value.std()
            low = data.value.mean() - std * 3
            high = data.value.mean() + std * 3
            data = data[data.value.between(low, high)] 
        else:
            data = data[(np.abs(zscore(data['value'])) < 3)]
        return data
    
    
    # iterate through the data for each gender
    for gen, data in dfm.groupby('Gender'):
        
        ...
        
        # iterate through each axes and variable category
        for ax, (var, sel) in zip(axes, data.groupby('variable')):
            
            # remove outliers of specified columns
            if var in df.columns[2:7]:
                sel = remove_outliers(sel)
            
            sns.histplot(data=sel, x='value', hue='Dataset', hue_order=[1, 2], kde=True, ax=ax,
                         common_norm=False, common_bins=True, multiple='dodge', palette='tab10')
    
            ....
         ....
    

    enter image description here

    enter image description here