Search code examples
pythonpandasmatplotlibseaborn

Group axis labels for seaborn box plots


I want a grouped axis label for box-plots for example a bit like this bar chart where the x axis is hierarchical: enter image description here

I am struggling to work with groupby objects to extract the values for the box plot.

I have found this heatmap example which references this stacked bar answer from @Stein but I can't get it to work for my box plots (I know I don't want the 'sum' of the groups but can't figure out how to get the values I want grouped correctly). In my real data the group sizes will be different, not all the same as in the example data. I don't want to use seaborn's 'hue' as a solution as I want all the boxes the same color.

This is the closest I have got, thanks:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import groupby

def test_table():
    data_table = pd.DataFrame({'Room':['Room A']*24 + ['Room B']*24,
                               'Shelf':(['Shelf 1']*12 + ['Shelf 2']*12)*2,
                               'Staple':['Milk','Water','Sugar','Honey','Wheat','Corn']*8,
                               'Quantity':np.random.randint(1, 20, 48),
                               })
    return data_table

def add_line(ax, xpos, ypos):
    line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                      transform=ax.transAxes, color='black')
    line.set_clip_on(False)
    ax.add_line(line)

def label_len(my_index,level):
    labels = my_index.get_level_values(level)
    return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
    
def label_group_bar_table(ax, df):
    ypos = -.1
    scale = 1./df.index.size
    for level in range(df.index.nlevels)[::-1]:
        pos = 0
        for label, rpos in label_len(df.index,level):
            lxpos = (pos + .5 * rpos)*scale
            ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
            add_line(ax, pos*scale, ypos)
            pos += rpos
        add_line(ax, pos*scale , ypos)
        ypos -= .1

df = test_table().groupby(['Room','Shelf','Staple']).sum()

fig = plt.figure()
fig = plt.figure(figsize = (15, 10))
ax = fig.add_subplot(111)

sns.boxplot(x=df.Quantity, y=df.Quantity,data=df)

#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')

label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()

Which gives:

box plots with multi level axis labels


Solution

  • You can use:

    import numpy as np
    import seaborn as sns
    import matplotlib.pyplot as plt
    from itertools import groupby
    
    def test_table():
        data_table = pd.DataFrame({'Room':['Room A']*24 + ['Room B']*24,
                                   'Shelf':(['Shelf 1']*12 + ['Shelf 2']*12)*2,
                                   'Staple':['Milk','Water','Sugar','Honey','Wheat','Corn']*8,
                                   'Quantity':np.random.randint(1, 20, 48),
                                   })
        return data_table
    
    def add_line(ax, xpos, ypos):
        line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                          transform=ax.transAxes, color='black')
        line.set_clip_on(False)
        ax.add_line(line)
    
    def label_len(my_index,level):
        labels = my_index.get_level_values(level)
        return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
    
    # HERE: Replace all df.index occurrences with levels     
    def label_group_bar_table(ax, levels):
        ypos = -.1
        scale = 1./levels.size
        for level in range(levels.nlevels)[::-1]:
            pos = 0
            for label, rpos in label_len(levels, level):
                lxpos = (pos + .5 * rpos)*scale
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
                add_line(ax, pos*scale, ypos)
                pos += rpos
            add_line(ax, pos*scale , ypos)
            ypos -= .1
    
    # HERE: Don't aggregate your data
    df = test_table()
    
    # HERE: Create a unique group identifier
    df['ID'] = df.groupby(['Room', 'Shelf', 'Staple']).ngroup()
    
    # HERE: Create an ordered MultiIndex for label_group_bar_table
    levels = df.drop_duplicates('ID').sort_values('ID')[['Room', 'Shelf', 'Staple']]
    levels = pd.MultiIndex.from_frame(levels)
    
    fig = plt.figure(figsize = (15, 10))
    ax = fig.add_subplot(111)
    
    # HERE: Set 'ID' as x-axis
    sns.boxplot(x='ID', y='Quantity', data=df)
    
    #Below 3 lines remove default labels
    labels = ['' for item in ax.get_xticklabels()]
    ax.set_xticklabels(labels)
    ax.set_xlabel('')
    ax.set_xticks([])  # HERE: Remove xticks
    
    label_group_bar_table(ax, levels)
    fig.subplots_adjust(bottom=.1*levels.nlevels)
    plt.show()
    

    To get:

    enter image description here