Search code examples
seabornlegend

How to customize the legend of Seaborn plot


I was trying to customize the legend of my seaborn plot adding also the values into the legend just after the color and the name of each x entry. In this case, it will be color, x value, y value.

After some attempts, I figured it out by myself, so I decided to share the code, to help someone else who maybe encountered the same problem, or maybe for someone who would like to improve the code. Thanks.

ax = sns.barplot(data=df, x="Orbit", y="Class", estimator="mean", hue="Orbit", errorbar=None, legend="full")

# Create custom legend labels including both x-entry and its corresponding y-value
mean_values = df.groupby("Orbit")["Class"].mean()

legend_labels = []
for orbit, mean in mean_values.items():
    legend_labels.append(f"{orbit} : {mean:.4f}")

# Add legend with custom labels
handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, legend_labels, loc="lower left", bbox_to_anchor=(1.05, 0.29), title="Orbit")

plt.show()

And this is the result:

enter image description here


Solution

  • The order of the legend and the colors are different from the bars. By default, seaborn takes the order in which the values appear in the dataframe. And groupby would take an alphabetic order.

    You can get the same order by using sns.barplot(..., order=mean_values.index, hue_order=mean_values.index) (in that case mean_values needs to be calculated up front).

    Another straightforward way to force an order on the values, is by making that column explicitly of a categorical type (df['Orbit'] = df['Orbit'].astype('category')).

    (As an aside, in Python it is highly recommended to create a list via list comprehension. That way you don't need to create an empty list in advance. It makes the code easier to maintain.)

    Here is some example code. Note that the code would be different if there'd be multiple bars per hue value.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    orbits = ['Io', 'Gm', 'By', 'Es', 'Kq', 'Jp', 'Lr', 'Ax', 'Hn', 'Ft', 'Dr', 'Cu']
    df = pd.DataFrame({'Orbit': np.random.choice(orbits, 80),
                       'Class': np.random.rand(80) ** 0.5})
    df['Orbit'] = df['Orbit'].astype('category') # forces an order
    ax = sns.barplot(data=df, x="Orbit", y="Class", estimator="mean", hue="Orbit", errorbar=None, legend='full')
    sns.despine()
    
    mean_values = df.groupby("Orbit", observed=False)["Class"].mean()
    legend_labels = [f"{orbit} : {mean:.4f}" for orbit, mean in mean_values.items()]
    # Add legend with custom labels
    handles, _ = ax.get_legend_handles_labels()
    ax.legend(handles, legend_labels, loc="lower left", bbox_to_anchor=(1.01, 0.29), title="Orbit")
    plt.tight_layout()
    plt.show()
    

    bar plot with custom legend

    By the way, having the bars both on the x-axis and in the legend, causes a plot with superflous information. You could directly add the values on top of the bars via ax.bar_label(). That way, there aren't problems of ordering, and there isn't a need to calculate the means separetely (by default ax.bar_label() works with the heights of the bars).

    ax = sns.barplot(data=df, x="Orbit", y="Class", estimator="mean", hue="Orbit", errorbar=None, legend=False)
    sns.despine()
    for bars in ax.containers:
        ax.bar_label(bars, fmt='%.4f', fontsize=8)
    

    sns.barplot with bar_label