I am doing EDA for a machine learning project and am plotting a grid of countplots with Seaborn. In the visualization, the difference in height of the bars in each subplot is almost indiscernible as the y-axis starts at 0 and there are many observations. I would like to set the y-axis range for each plot to a range above/below its max and min data points.
As I know Seaborn is Matplotlib-based, I tried using plt.ylim(a, b) or ax.set(ylim=(a, b)) following each countplot function, but this would only change the y-axis range for the lower-right subplot (SubtitlesEnabled).
Any help as to how I can apply this to each subplot within the plot function I have defined is greatly appreciated.
The code I used to generate the plot grid is below:
# Figure dimensions
fig, axes = plt.subplots(5, 2, figsize = (18, 18))
# List of categorical features column names
categorical_features_list = list(categorical_features)
# Countplot function
def make_countplot():
i = 0
for n in range(0,5):
m = 0
sns.countplot(ax = axes[n, m], x = categorical_features_list[i], data = train_df, color = 'blue',
order = train_df[categorical_features_list[i]].value_counts().index);
m += 1
sns.countplot(ax = axes[n, m], x = categorical_features_list[i+1], data = train_df, color = 'blue',
order = train_df[categorical_features_list[i+1]].value_counts().index);
m -= 1
i += 2
make_countplot()
You might want to read about the difference between the plt
and the ax
interface and about avoiding indices in Pytho. You can call ax.set_ylim(...)
to set the y limits of a specific subplot.
Here is an example using Seaborn's titanic dataset:
import matplotlib.pyplot as plt
import seaborn as sns
titanic = sns.load_dataset('titanic')
categorical_features = ['survived', 'sex', 'sibsp', 'parch', 'class', 'who', 'deck', 'embark_town']
# Figure dimensions
fig, axes = plt.subplots((len(categorical_features) + 1) // 2, 2, figsize=(10, 15))
for feature, ax in zip(categorical_features, axes.flat):
counts = titanic[feature].value_counts()
sns.countplot(ax=ax, x=feature, data=titanic, color='dodgerblue',
order=counts.index)
min_val = counts.min()
max_val = counts.max()
# set the limits 10% higher than the highest and 10% lower than the lowest
delta = (max_val - min_val) * 0.10
ax.set_ylim(max(0, min_val - delta), max_val + delta)
# remove the xlabel and set the feature name inside the plot (to save some space)
ax.set_xlabel('')
ax.text(0.5, 0.98, feature, fontsize=14, transform=ax.transAxes, ha='center', va='top')
plt.tight_layout()
plt.show()