Search code examples
pythonpandasmatplotlibseabornboxplot

Grouping boxplots in seaborn when input is a DataFrame


I intend to plot multiple columns in a pandas dataframe, all grouped by another column using groupby inside seaborn.boxplot. There is a nice answer here, for a similar problem in matplotlib matplotlib: Group boxplots but given the fact that seaborn.boxplot comes with groupby option I thought it could be much easier to do this in seaborn.

Here we go with a reproducible example that fails:

import seaborn as sns
import pandas as pd
df = pd.DataFrame([[2, 4, 5, 6, 1], [4, 5, 6, 7, 2], [5, 4, 5, 5, 1],
                   [10, 4, 7, 8, 2], [9, 3, 4, 6, 2], [3, 3, 4, 4, 1]],
                  columns=['a1', 'a2', 'a3', 'a4', 'b'])

# display(df)
   a1  a2  a3  a4  b
0   2   4   5   6  1
1   4   5   6   7  2
2   5   4   5   5  1
3  10   4   7   8  2
4   9   3   4   6  2
5   3   3   4   4  1

#Plotting by seaborn
sns.boxplot(df[['a1','a2', 'a3', 'a4']], groupby=df.b)

What I get is something that completely ignores groupby option:

Failed groupby

Whereas if I do this with one column it works thanks to another SO question Seaborn groupby pandas Series :

sns.boxplot(df.a1, groupby=df.b)

seaborn that does not fail

So I would like to get all my columns in one plot (all columns come in a similar scale).

EDIT:

The above SO question was edited and now includes a 'not clean' answer to this problem, but it would be nice if someone has a better idea for this problem.


Solution

  • You can directly use sns.boxplot, an axes-level function, or sns.catplot with kind='box', a figure-level function. See Figure-level vs. axes-level functions for further details

    sns.catplot has the col and row variable, which are used to create subplots / facets with a different variable.

    The default palette is determined by the type of variable, continuous (numeric) or categorical, passed to hue.

    As explained by @mwaskom, you have to melt the sample dataframe into its "long-form" where each column is a variable and each row is an observation.

    Tested in python 3.12.0, pandas 2.1.2, matplotlib 3.8.1, seaborn 0.13.0

    df_long = pd.melt(df, "b", var_name="a", value_name="c")
    
    # display(df_long.head())
       b   a   c
    0  1  a1   2
    1  2  a1   4
    2  1  a1   5
    3  2  a1  10
    4  2  a1   9
    

    sns.boxplot

    fig, ax = plt.subplots(figsize=(5, 5))
    sns.boxplot(x="a", hue="b", y="c", data=df_long, ax=ax)
    ax.spines[['top', 'right']].set_visible(False)
    sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
    

    sns.catplot

    Create the same plot as sns.boxplot with fewer lines of code.

    g = sns.catplot(kind='box', data=df_long, x='a', y='c', hue='b', height=5, aspect=1)
    

    Resulting Plot

    enter image description here