Search code examples
pandasdataframeseabornscatter-plot

How can I have one category per column in the legend using seaborn and pandas?


I am working on a scatter plot using seaborn on a pandas dataframe with various categories represented as style, hues and sizes. I print the legend above the figure using several columns but, because the number cases in the different categories is not the same, the layout among the columns is not very clear (see the image). I would like the break of column to correspond to a change of category.

Here is a minimal reproductible example:

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

f, ax = plt.subplots()
df = pd.DataFrame({'col1': [1, 2, 3, 1, 2, 3, 4, 1, 2, 3],
                      'col1b': [1.1, 1.9, 3.1, 0.9, 2.1, 3.2, 4.1, 1, 1.8, 3.2],
                      'col2': [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
                      'col3': [10, 20, 15, 10, 20, 15, 30, 10, 20, 15]})

sns.scatterplot(data=df, x='col1', y='col1b', hue='col2', style='col3', ax=ax)
sns.move_legend(
    ax, "lower center",
    bbox_to_anchor=(.5, 1), ncol=2, title=None, frameon=False,)

plt.savefig('mwe.png')

enter image description here


Solution

  • As far as I know, there is no easy parameter that would force the legend columns to be aligned by column name/category.

    One of the solutions I can think of is to create separate legend columns per category and add them to the chart separately - which has been discussed e.g. here.

    Steps:

    • Get the list of handles and labels using get_legend_handles_labels()
    • Split the handles and labels per category/column to create two legends per column
    • Add legends to the chart and adjust the properties

    Given the example provided, this can be one of possible implementations:

    import seaborn as sns
    import pandas as pd
    import matplotlib.pyplot as plt
    
    sns.set()
    
    f, ax = plt.subplots()
    df = pd.DataFrame({'col1': [1, 2, 3, 1, 2, 3, 4, 1, 2, 3],
                          'col1b': [1.1, 1.9, 3.1, 0.9, 2.1, 3.2, 4.1, 1, 1.8, 3.2],
                          'col2': [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
                          'col3': [10, 20, 15, 10, 20, 15, 30, 10, 20, 15]})
    
    scatter = sns.scatterplot(data=df, x='col1', y='col1b', hue='col2', style='col3', ax=ax)
    
    # Get number of unique values and add one for list indexing used later
    col_2_values = df['col2'].nunique() + 1
    
    # Get handles and labels for both categories (split the list into categories needed)
    handles, labels = ax.get_legend_handles_labels()
    col2_handles = handles[:col_2_values]
    col2_labels = labels[:col_2_values]
    col3_handles = handles[col_2_values:]
    col3_labels = labels[col_2_values:]
    
    # Create separate legends for both category columns
    legend_col2 = ax.legend(col2_handles, col2_labels, frameon=False)
    legend_col3 = ax.legend(col3_handles, col3_labels, frameon=False)
    
    # Add the `legend_col2` back to the chart, since the second call of `legend` above removes it
    ax.add_artist(legend_col2)
    
    # Adjust positionning
    legend_col2.set_bbox_to_anchor((1, 1))
    legend_col3.set_bbox_to_anchor((1.4, 1))
    
    plt.show()
    

    Will generate: enter image description here