I have a subplot including a sns.scatterplot
that needs to differentiate 2 characteristics:
fig,ax= plt.subplots(figsize=(16/2.54,7/2.54),ncols=2,sharey=True)
sns.scatterplot(data=df,x='size',y='width_tot', hue='name_short',style='color',ax=ax[1])
The legend now offers the following entries:
name_short
* name12
* name45
color
* c
+ r
* g
As you can see, the names of my dataframes are more of "working-names" rather than ones I want to see in my plot. In contrast, when replacing the seaborn function of hue and style, I lose important information of my plot. So my question would be how I can overwrite the legend-entries e.g. the legend titles and the hue-entries, as I could live with the short forms of colors.
The following attempt did not bring the solution, overwriting the whole legend neither, as I am losing the symbols of the style differentiation.
legend_titles = ['Clear Name 1', 'Clear Name 2']
legend = ax[1].legend()
for i, title in enumerate(legend_titles):
legend.get_texts()[i].set_text(title)
To change the titles, you temporarily can change "working-names" to the names you want to see in your plot. To change other legend parameters, you can use sns.move_legend()
.
Here is an example:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
df = pd.DataFrame({'size': np.random.rand(20) * 20 + 1,
'width_tot': np.random.rand(20) * 10 + 2,
'name_short': np.random.choice(['name12', 'name45'], 20),
'color': np.random.choice(['c', 'r', 'g'], 20)
})
# Set the desired column names
df1 = df.rename(columns={'size': 'Size', 'width_tot': 'Total Width', 'name_short': 'Short Name', 'color': 'Color'})
# Set the full version of the short names
df1['Short Name'] = df1['Short Name'].map({'name12': 'Twelve', 'name45': 'Forty Five'})
df1['Color'] = df1['Color'].map({'c': 'Cyan', 'r': 'Red', 'g': 'Green'})
sns.set()
# sns.scatterplot(data=df, x='size', y='width_tot', hue='name_short', style='color')
sns.scatterplot(data=df1, x='Size', y='Total Width', hue='Short Name', style='Color')
plt.tight_layout()
plt.show()