Search code examples
pandasseabornmulti-index

Select appropriate colors in stacked Seaborn barplot


I want to create a stacked barplot using Seaborn with this MiltiIndex DataFrame

header = pd.MultiIndex.from_product([['#'],
                                     ['TE', 'SS', 'M', 'MR']])
dat = ([[100, 20, 21, 35], [100, 12, 5, 15]])
df = pd.DataFrame(dat, index=['JC', 'TTo'], columns=header)
df = df.stack()
df = df.sort_values('#', ascending=False).sort_index(level=0, sort_remaining=False)

enter image description here

The code I'm using for the plot is:

fontP = FontProperties()
fontP.set_size('medium')
colors = {'TE': 'green', 'SS': 'blue', 'M': 'yellow', 'MR': 'red'}
kwargs = {'alpha':0.5}

plt.figure(figsize=(12, 9))
sns.barplot(x=df2.index.get_level_values(0).unique(),
            y=df2.loc[pd.IndexSlice[:, df2.index[0]], '#'],
            color=colors[df2.index[0][1]], **kwargs)

sns.barplot(x=df2.index.get_level_values(0).unique(),
                          y=df2.loc[pd.IndexSlice[:, df2.index[1]], '#'],
                          color=colors[df2.index[1][1]], **kwargs)
sns.barplot(x=df2.index.get_level_values(0).unique(),
                          y=df2.loc[pd.IndexSlice[:, df2.index[2]], '#'],
                          color=colors[df2.index[2][1]], **kwargs)
bottom_plot = sns.barplot(x=df2.index.get_level_values(0).unique(),
                          y=df2.loc[pd.IndexSlice[:, df2.index[3]], '#'],
                          color=colors[df2.index[3][1]], **kwargs)

bar1 = plt.Rectangle((0, 0), 1, 1, fc='green', edgecolor="None")
bar2 = plt.Rectangle((0, 0), 0, 0, fc='yellow', edgecolor="None")
bar3 = plt.Rectangle((0, 0), 2, 2, fc='red', edgecolor="None")
bar4 = plt.Rectangle((0, 0), 3, 3, fc='blue', edgecolor="None")
l = plt.legend([bar1, bar2, bar3, bar4], [
    "TE", "M",
    'MR', 'SS'
],
               bbox_to_anchor=(0.95, 1),
               loc='upper left',
               prop=fontP)
l.draw_frame(False)

sns.despine()
bottom_plot.set_ylabel("#")

axes = plt.gca()
axes.yaxis.grid()

And I get:

enter image description here

My problem is the order of the colors in the second bar ('TTo'), I want the colors to be automatically selected based on the level 1 index value (['TE', 'SS', 'M', 'MR']) so that they are ordered correctly. Further down the one with the highest value with its corresponding color, in front the next one with the next highest value and its color and so on, as the first bar shows ('JC).

Maybe there is a simpler way to do this in Seaborn than the one I'm using...


Solution

  • I'm not sure how to create such a plot with seaborn. Here is a way to create it with a loop through the rows and adding one matplotlib bar at each step:

    import pandas as pd
    import seaborn as sns
    from matplotlib import pyplot as plt
    
    sns.set()
    header = pd.MultiIndex.from_product([['#'],
                                         ['TE', 'SS', 'M', 'MR']])
    dat = ([[100, 20, 21, 35], [100, 12, 5, 15]])
    df = pd.DataFrame(dat, index=['JC', 'TTo'], columns=header)
    df = df.stack()
    df = df.sort_values('#', ascending=False).sort_index(level=0, sort_remaining=False)
    
    colors = {'TE': 'green', 'SS': 'blue', 'M': 'yellow', 'MR': 'red'}
    
    prev_index0 = None
    for (index0, index1), quantity in df.itertuples():
        if index0 != prev_index0:
            bottom = 0
        plt.bar(index0, quantity, fc=colors[index1], ec='none', bottom=bottom, label=index1)
        bottom += quantity
        prev_index0 = index0
    legend_handles = [plt.Rectangle((0, 0), 0, 0, color=colors[c], label=c) for c in colors]
    plt.legend(handles=legend_handles)
    plt.show()
    

    resulting plot

    To plot the bars back to front without stacking, the code can be simplified:

    colors = {'TE': 'forestgreen', 'SS': 'cornflowerblue', 'M': 'gold', 'MR': 'crimson'}
    
    for (index0, index1), quantity in df.itertuples():
        plt.bar(index0, quantity, fc=colors[index1], ec='none', label=index1)
    legend_handles = [plt.Rectangle((0, 0), 0, 0, color=colors[c], label=c, ec='black') for c in colors]
    plt.legend(handles=legend_handles, bbox_to_anchor=(1.02, 1.02), loc='upper left')
    plt.tight_layout()
    

    non-stacked bars