Search code examples
pythonpandasseabornhistogramfacet-grid

How to set multiple histograms in a FacetGrid


I have a dataframe with 95 columns where are Max(), Min() and Avg() values of different measures, I want to plot their histograms on a FacetGrid of 3 columns and 32 rows where 1st column is max value, 2nd is avg value and 3rd is min value, and the rows are the measure type.

I have this code right now:

fig, axes = plt.subplots(nrows=32, ncols=3, figsize=(20, 96))
columnas_numeric = df_agg_new.select_dtypes(include=['float64', 'int64']).columns
columnas_numeric = columnas_numeric.drop('season')

for i, colum in enumerate(columns):
    if colum.endswith('Max'):
        sns.histplot(
            data     = df_agg_new,
            x        = colum,
            stat     = "count",
            kde      = True,
            line_kws = {'linewidth': 2},
            alpha    = 0.3,
            ax       = axes[int(i/3)][0]
        )
        axes[int(i/3)][0].set_title(colum, fontsize = 7, fontweight = "bold")
        axes[int(i/3)][0].tick_params(labelsize = 6)
        axes[int(i/3)][0].set_xlabel("")
    elif colum.endswith('Avg'):
        sns.histplot(
            data     = df_agg_new,
            x        = colum,
            stat     = "count",
            kde      = True,
            line_kws = {'linewidth': 2},
            alpha    = 0.3,
            ax       = axes[int(i/3)][1]
        )
        axes[int(i/3)][1].set_title(colum, fontsize = 7, fontweight = "bold")
        axes[int(i/3)][1].tick_params(labelsize = 6)
        axes[int(i/3)][1].set_xlabel("")
    else:
        sns.histplot(
            data     = df_agg_new,
            x        = colum,
            stat     = "count",
            kde      = True,
            line_kws = {'linewidth': 2},
            alpha    = 0.3,
            ax       = axes[int(i/3)][2]
        )
        axes[int(i/3)][2].set_title(colum, fontsize = 7, fontweight = "bold")
        axes[int(i/3)][2].tick_params(labelsize = 6)
        axes[int(i/3)][2].set_xlabel("")
    
    
fig.tight_layout()
plt.subplots_adjust(top = 0.97)
fig.suptitle('Distribution plots', fontsize = 10, fontweight = "bold");

But don't work because some measure values go to other rows.

this is my list of columns:

Index(['Var1_phase2_Avg', 'Var1_phase2_Max',
       'Var1_phase2_Min', 'Var1_phase3_Avg',
       'Var1_phase3_Max', 'Var1_phase3_Min',
       'Var1_phase1_Avg', 'Var1_phase1_Max',
       'Var1_phase1_Min', 'Var1_phase4_Avg',
       'Var1_phase4_Max', 'Var1_phase4_Min',
       'Var2_phase2_Avg', 'Var2_phase2_Max',
       'Var2_phase3_Avg', 'Var2_phase3_Max',
       'Var2_phase1_Avg', 'Var2_phase1_Max',
       'Var2_phase4_Avg', 'Var2_phase4_Max',
       'Var3_phase2_Avg', 'Var3_phase2_Max',
       'Var3_phase2_Min', 'Var3_phase3_Avg',
       'Var3_phase3_Max', 'Var3_phase3_Min',
       'Var3_phase1_Avg', 'Var3_phase1_Max',
       'Var3_phase1_Min', 'Var3_phase4_Avg',
       'Var3_phase4_Max', 'Var3_phase4_Min', 
       'Var4_phase2_Avg', 'Var4_phase2_Max', 
       'Var4_phase3_Avg', 'Var4_phase3_Max', 
       'Var4_phase1_Avg', 'Var4_phase1_Max', 
       'Var4_phase4_Avg', 'Var4_phase4_Max', 
       'Var5_phase2_Avg', 'Var5_phase2_Max',
       'Var5_phase2_Min', 'Var5_phase3_Avg', 
       'Var5_phase3_Max', 'Var5_phase3_Min', 
       'Var5_phase1_Avg', 'Var5_phase1_Max', 
       'Var5_phase1_Min', 'Var5_phase4_Avg', 
       'Var5_phase4_Max', 'Var5_phase4_Min', 
       'Var6_phase2_Avg', 'Var6_phase2_Max', 
       'Var6_phase2_Min', 'Var6_phase3_Avg', 
       'Var6_phase3_Max', 'Var6_phase1_Avg', 
       'Var6_phase1_Max', 'Var7_phase2_Avg', 
       'Var7_phase2_Max', 'Var7_phase2_Min',
       'Var7_phase3_Avg', 'Var7_phase3_Max',
       'Var7_phase3_Min', 'Var7_phase1_Avg', 
       'Var7_phase1_Max', 'Var7_phase1_Min', 
       'Var7_phase4_Avg', 'Var7_phase4_Max', 
       'Var7_phase4_Min', 'Var8_phase2_Avg', 
       'Var8_phase2_Max', 'Var8_phase2_Min', 
       'Var8_phase3_Avg', 'Var8_phase3_Max', 
       'Var8_phase3_Min', 'Var8_phase1_Avg', 
       'Var8_phase1_Max', 'Var8_phase1_Min',
       'Var8_phase4_Avg', 'Var8_phase4_Max', 
       'Var8_phase4_Min'],
      dtype='object')

And this is how i want to get represented (with their histplot):

       'Var1_phase2_Avg', 'Var1_phase2_Max', 'Var1_phase2_Min', 
       'Var1_phase3_Avg', 'Var1_phase3_Max', 'Var1_phase3_Min',
       'Var1_phase1_Avg', 'Var1_phase1_Max', 'Var1_phase1_Min',
       'Var1_phase4_Avg', 'Var1_phase4_Max', 'Var1_phase4_Min',
       
       'Var2_phase2_Avg', 'Var2_phase2_Max',
       'Var2_phase3_Avg', 'Var2_phase3_Max',
       'Var2_phase1_Avg', 'Var2_phase1_Max',
       'Var2_phase4_Avg', 'Var2_phase4_Max',

       'Var3_phase2_Avg', 'Var3_phase2_Max', 'Var3_phase2_Min',
       'Var3_phase3_Avg', 'Var3_phase3_Max', 'Var3_phase3_Min',
       'Var3_phase1_Avg', 'Var3_phase1_Max', 'Var3_phase1_Min',
       'Var3_phase4_Avg', 'Var3_phase4_Max', 'Var3_phase4_Min', 
       
       'Var4_phase2_Avg', 'Var4_phase2_Max', 
       'Var4_phase3_Avg', 'Var4_phase3_Max', 
       'Var4_phase1_Avg', 'Var4_phase1_Max', 
       'Var4_phase4_Avg', 'Var4_phase4_Max', 

       'Var5_phase2_Avg', 'Var5_phase2_Max', 'Var5_phase2_Min', 
       'Var5_phase3_Avg', 'Var5_phase3_Max', 'Var5_phase3_Min', 
       'Var5_phase1_Avg', 'Var5_phase1_Max', 'Var5_phase1_Min',
       'Var5_phase4_Avg', 'Var5_phase4_Max', 'Var5_phase4_Min', 
       
       'Var6_phase2_Avg', 'Var6_phase2_Max', 'Var6_phase2_Min',
       'Var6_phase3_Avg', 'Var6_phase3_Max',
       'Var6_phase1_Avg', 'Var6_phase1_Max',

       'Var7_phase2_Avg', 'Var7_phase2_Max', 'Var7_phase2_Min',
       'Var7_phase3_Avg', 'Var7_phase3_Max', 'Var7_phase3_Min',
       'Var7_phase1_Avg', 'Var7_phase1_Max', 'Var7_phase1_Min',
       'Var7_phase4_Avg', 'Var7_phase4_Max', 'Var7_phase4_Min',

       'Var8_phase2_Avg', 'Var8_phase2_Max', 'Var8_phase2_Min', 
       'Var8_phase3_Avg', 'Var8_phase3_Max', 'Var8_phase3_Min',
       'Var8_phase1_Avg', 'Var8_phase1_Max', 'Var8_phase1_Min',
       'Var8_phase4_Avg', 'Var8_phase4_Max', 'Var8_phase4_Min'

Each column is for Avg, Max and Min values, and each row if for each phase of the day and later for each different measure.


Solution

  • Your use of axes[int(i/3)][0] supposes that the list of columns comes in nice groups of 3, which is not the case.

    Given how your columns are named, you could:

    • create a list of the names without the suffix (so e.g. 'Var1_phase2' for 'Var1_phase2_Avg')
    • optionally sort that list
    • use the number of names to calculate nrows=len(filas)
    • loop through the row names, and check whether a column with that name and one of the suffixes exists
    • if the column exists, draw the plot
    • if the column doesn't exist, remove the empty plot

    Optionally, you could share the x and/or y axes of the plots, which makes them easier to compare. These are parameters in plt.subplots(..., sharex=True, sharey=True. This would also skip the repeated labeling of the axes, which saves some space (but could be unwanted if you have many rows; if so you can enable them again via axes[i, j].tick_params(..., labelbottom=True)).

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    
    columnas_numeric = df_agg_new.select_dtypes(include=['float64', 'int64']).columns
    columnas_numeric = columnas_numeric.drop('season')
    # columnas_numeric = ['Var1_phase2_Avg', 'Var1_phase2_Max', 'Var1_phase2_Min', 'Var1_phase3_Avg', 'Var1_phase3_Max', 'Var1_phase3_Min', 'Var1_phase1_Avg', 'Var1_phase1_Max', 'Var1_phase1_Min', 'Var1_phase4_Avg', 'Var1_phase4_Max', 'Var1_phase4_Min', 'Var2_phase2_Avg', 'Var2_phase2_Max', 'Var2_phase3_Avg', 'Var2_phase3_Max', 'Var2_phase1_Avg', 'Var2_phase1_Max', 'Var2_phase4_Avg', 'Var2_phase4_Max', 'Var3_phase2_Avg', 'Var3_phase2_Max', 'Var3_phase2_Min', 'Var3_phase3_Avg', 'Var3_phase3_Max', 'Var3_phase3_Min', 'Var3_phase1_Avg', 'Var3_phase1_Max', 'Var3_phase1_Min', 'Var3_phase4_Avg', 'Var3_phase4_Max', 'Var3_phase4_Min', 'Var4_phase2_Avg', 'Var4_phase2_Max', 'Var4_phase3_Avg', 'Var4_phase3_Max', 'Var4_phase1_Avg', 'Var4_phase1_Max', 'Var4_phase4_Avg', 'Var4_phase4_Max', 'Var5_phase2_Avg', 'Var5_phase2_Max', 'Var5_phase2_Min', 'Var5_phase3_Avg', 'Var5_phase3_Max', 'Var5_phase3_Min', 'Var5_phase1_Avg', 'Var5_phase1_Max', 'Var5_phase1_Min', 'Var5_phase4_Avg', 'Var5_phase4_Max', 'Var5_phase4_Min', 'Var6_phase2_Avg', 'Var6_phase2_Max', 'Var6_phase2_Min', 'Var6_phase3_Avg', 'Var6_phase3_Max', 'Var6_phase1_Avg', 'Var6_phase1_Max', 'Var7_phase2_Avg', 'Var7_phase2_Max', 'Var7_phase2_Min', 'Var7_phase3_Avg', 'Var7_phase3_Max', 'Var7_phase3_Min', 'Var7_phase1_Avg', 'Var7_phase1_Max', 'Var7_phase1_Min', 'Var7_phase4_Avg', 'Var7_phase4_Max', 'Var7_phase4_Min', 'Var8_phase2_Avg', 'Var8_phase2_Max', 'Var8_phase2_Min', 'Var8_phase3_Avg', 'Var8_phase3_Max', 'Var8_phase3_Min', 'Var8_phase1_Avg', 'Var8_phase1_Max', 'Var8_phase1_Min', 'Var8_phase4_Avg', 'Var8_phase4_Max', 'Var8_phase4_Min']
    
    filas = [colum[:-4] for colum in columnas_numeric if colum.endswith('_Avg')]
    filas = sorted(filas) # optionally sort the list for the rows
    
    fig, axes = plt.subplots(nrows=len(filas), ncols=3, figsize=(20, 96))
    
    for ax_row, fila in zip(axes, filas):
        for ax, suffix in zip(ax_row, ['Max', 'Avg', 'Min']):
            colum = fila + '_' + suffix
            if not colum in columnas_numeric:
                ax.remove()  # remove empty subplot when there is no data
            else:
                sns.histplot(
                    data=df_agg_new,
                    x=colum,
                    stat="count",
                    kde=True,
                    line_kws={'linewidth': 2},
                    alpha=0.3,
                    ax=ax
                )
                ax.set_title(colum, fontsize=7, fontweight="bold")
                ax.tick_params(labelsize=6, labelbottom=True)
                ax.set_xlabel("")
    
    fig.tight_layout()
    plt.subplots_adjust(top=0.97)
    fig.suptitle('Distribution plots', fontsize=10, fontweight="bold");
    plt.show()