Is there a way to add the mean and median to Seaborn's displot
?
penguins = sns.load_dataset("penguins")
g = sns.displot(
data=penguins, x='body_mass_g',
col='species',
facet_kws=dict(sharey=False, sharex=False)
)
Based on Add mean and variability to seaborn FacetGrid distplots, I see that I can define a FacetGrid
and map a function. Can I pass a custom function to displot
?
The reason for trying to use displot
directly is that the plots are much prettier out of the box, without tweaking tick label size, axis label size, etc. and are visually consistent with other plots I am making.
def specs(x, **kwargs):
ax = sns.histplot(x=x)
ax.axvline(x.mean(), color='k', lw=2)
ax.axvline(x.median(), color='k', ls='--', lw=2)
g = sns.FacetGrid(data=penguins, col='species')
g.map(specs,'body_mass_g' )
FacetGrid
directly is not recommended. Instead, use other figure-level methods like seaborn.displot
seaborn.FacetGrid.map
works with figure-level methods.python 3.8.11
, pandas 1.3.2
, matplotlib 3.4.3
, seaborn 0.11.2
plt.
instead of ax
.
vlines
are going to ax
for the histplot
, but here, the figure is created before .map
.penguins = sns.load_dataset("penguins")
g = sns.displot(
data=penguins, x='body_mass_g',
col='species',
facet_kws=dict(sharey=False, sharex=False)
)
def specs(x, **kwargs):
plt.axvline(x.mean(), c='k', ls='-', lw=2.5)
plt.axvline(x.median(), c='orange', ls='--', lw=2.5)
g.map(specs,'body_mass_g' )
displot
.import seaborn as sns
import pandas as pd
# load the data
pen = sns.load_dataset("penguins")
# groupby to get mean and median
pen_g = pen.groupby('species').body_mass_g.agg(['mean', 'median'])
g = sns.displot(
data=pen, x='body_mass_g',
col='species',
facet_kws=dict(sharey=False, sharex=False)
)
# extract and flatten the axes from the figure
axes = g.axes.flatten()
# iterate through each axes
for ax in axes:
# extract the species name
spec = ax.get_title().split(' = ')[1]
# select the data for the species
data = pen_g.loc[spec, :]
# print data as needed or comment out
print(data)
# plot the lines
ax.axvline(x=data['mean'], c='k', ls='-', lw=2.5)
ax.axvline(x=data['median'], c='orange', ls='--', lw=2.5)