I am working on a scatter plot using seaborn on a pandas dataframe with various categories represented as style, hues and sizes. I print the legend above the figure using several columns but, because the number cases in the different categories is not the same, the layout among the columns is not very clear (see the image). I would like the break of column to correspond to a change of category.
Here is a minimal reproductible example:
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
f, ax = plt.subplots()
df = pd.DataFrame({'col1': [1, 2, 3, 1, 2, 3, 4, 1, 2, 3],
'col1b': [1.1, 1.9, 3.1, 0.9, 2.1, 3.2, 4.1, 1, 1.8, 3.2],
'col2': [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
'col3': [10, 20, 15, 10, 20, 15, 30, 10, 20, 15]})
sns.scatterplot(data=df, x='col1', y='col1b', hue='col2', style='col3', ax=ax)
sns.move_legend(
ax, "lower center",
bbox_to_anchor=(.5, 1), ncol=2, title=None, frameon=False,)
plt.savefig('mwe.png')
As far as I know, there is no easy parameter that would force the legend columns to be aligned by column name/category.
One of the solutions I can think of is to create separate legend columns per category and add them to the chart separately - which has been discussed e.g. here.
Steps:
Given the example provided, this can be one of possible implementations:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
sns.set()
f, ax = plt.subplots()
df = pd.DataFrame({'col1': [1, 2, 3, 1, 2, 3, 4, 1, 2, 3],
'col1b': [1.1, 1.9, 3.1, 0.9, 2.1, 3.2, 4.1, 1, 1.8, 3.2],
'col2': [1, 1, 1, 2, 2, 2, 2, 3, 3, 3],
'col3': [10, 20, 15, 10, 20, 15, 30, 10, 20, 15]})
scatter = sns.scatterplot(data=df, x='col1', y='col1b', hue='col2', style='col3', ax=ax)
# Get number of unique values and add one for list indexing used later
col_2_values = df['col2'].nunique() + 1
# Get handles and labels for both categories (split the list into categories needed)
handles, labels = ax.get_legend_handles_labels()
col2_handles = handles[:col_2_values]
col2_labels = labels[:col_2_values]
col3_handles = handles[col_2_values:]
col3_labels = labels[col_2_values:]
# Create separate legends for both category columns
legend_col2 = ax.legend(col2_handles, col2_labels, frameon=False)
legend_col3 = ax.legend(col3_handles, col3_labels, frameon=False)
# Add the `legend_col2` back to the chart, since the second call of `legend` above removes it
ax.add_artist(legend_col2)
# Adjust positionning
legend_col2.set_bbox_to_anchor((1, 1))
legend_col3.set_bbox_to_anchor((1.4, 1))
plt.show()