I've been struggling with visualizing subplots column wrapping in Seaborn histogram plots (kdeplot, histplot). Tried various things including fig, ax
& enumerate(zip(df.columns, ax.flatten())
.
Here's the dataset
for col in df.columns:
plt.figure(figsize = (3,3))
sns.histplot(df, x = col, kde = True, bins = 40, hue = 'Dataset', fill = True)
plt.show();
How can the plots be done with other seaborn plots or plots with facet wrap functionality?
seaborn.displot
with kind='hist'
can be used to create subplots / facets, where col_wrap
specifies the number of columns.
nrows
and ncols
when using axes-level plots.'Female'
and 'Male'
should be shown separately, because gender statistics are often different, so presenting them together can skew the impression of the data.'Gender'
produces the best display option.python 3.11.3
, pandas 2.0.1
, matplotlib 3.7.1
, seaborn 0.12.2
import pandas as pd
import seaborn as sns
# load the dataset downloaded from https://www.kaggle.com/uciml/indian-liver-patient-records
df = pd.read_csv('d:/data/kaggle/indian_liver_patient.csv')
# convert the data to a long form
dfm = df.melt(id_vars=['Gender', 'Dataset'])
# plot the data for each gender
for gender, data in dfm.groupby('Gender'):
g = sns.displot(kind='hist', data=data, x='value', col='variable', hue='Dataset',
hue_order=[1, 2], common_norm=False, common_bins=False,
multiple='dodge', kde=True, col_wrap=3, height=2.5, aspect=2,
facet_kws={'sharey': False, 'sharex': False}, palette='tab10')
fig = g.fig
fig.suptitle(f'Gender: {gender}', y=1.02)
fig.savefig(f'hist_{gender}.png', bbox_inches='tight')
common_bins=False
means the bins of the two hue groups don't align. However, setting it to True
causes sharex=False
to be ignored, so all of the x-axis limits will be 0 - 2000, as can be seen in this plot.col_wrap
can't be used if row
is also in use.g = sns.displot(kind='hist', data=dfm, x='value', row='Dataset', col='variable', hue='Gender',
common_norm=False, common_bins=False, multiple='dodge', kde=True,
facet_kws={'sharey': False, 'sharex': False})
g.fig.savefig('hist.png')
'Gender'
.g = sns.displot(kind='hist', data=dfm, x='value', col='variable', col_wrap=3,
hue='Dataset', common_norm=False, common_bins=False,
multiple='dodge', kde=True, height=2.5, aspect=2,
facet_kws={'sharey': False, 'sharex': False}, palette='tab10')
common_bins=True
to be used.import seaborn as sns
import numpy as np
import pandas as pd
# load the dataset
df = pd.read_csv('d:/data/kaggle/indian_liver_patient.csv')
# convert the data to a long form
dfm = df.melt(id_vars=['Gender', 'Dataset'])
# iterate through the data for each gender
for gen, data in dfm.groupby('Gender'):
# create the figure and axes
fig, axes = plt.subplots(3, 3, figsize=(11, 5), sharex=False, sharey=False, tight_layout=True)
# flatten the array of axes
axes = axes.flatten()
# iterate through each axes and variable category
for ax, (var, sel) in zip(axes, data.groupby('variable')):
sns.histplot(data=sel, x='value', hue='Dataset', hue_order=[1, 2], kde=True, ax=ax,
common_norm=False, common_bins=True, multiple='dodge', palette='tab10')
ax.set(xlabel='', title=var.replace('_', ' ').title())
ax.spines[['top', 'right']].set_visible(False)
# remove all the legends except for Aspartate Aminotrnsferase, which will be move to used for the figure
for ax in np.append(axes[:5], axes[6:]):
ax.get_legend().remove()
sns.move_legend(axes[5], bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)
fig.suptitle(f'Gender: {gen}', y=1.02)
fig.savefig(f'hist_{gen}.png', bbox_inches='tight')
df
have significant outliers. Removing them will improve the histogram visualization.from scipy.stats import zscore
from typing import Literal
def remove_outliers(data: pd.DataFrame, method: Literal['std', 'z'] = 'std') -> pd.DataFrame:
# remove outliers with std or zscore
if method == 'std':
std = data.value.std()
low = data.value.mean() - std * 3
high = data.value.mean() + std * 3
data = data[data.value.between(low, high)]
else:
data = data[(np.abs(zscore(data['value'])) < 3)]
return data
# iterate through the data for each gender
for gen, data in dfm.groupby('Gender'):
...
# iterate through each axes and variable category
for ax, (var, sel) in zip(axes, data.groupby('variable')):
# remove outliers of specified columns
if var in df.columns[2:7]:
sel = remove_outliers(sel)
sns.histplot(data=sel, x='value', hue='Dataset', hue_order=[1, 2], kde=True, ax=ax,
common_norm=False, common_bins=True, multiple='dodge', palette='tab10')
....
....