Search code examples
rangeseabornyaxisexploratory-data-analysiscountplot

Setting the y-axis range for Seaborn/Matplotlib countplot to a specified range above/below the max and min data points


I am doing EDA for a machine learning project and am plotting a grid of countplots with Seaborn. In the visualization, the difference in height of the bars in each subplot is almost indiscernible as the y-axis starts at 0 and there are many observations. I would like to set the y-axis range for each plot to a range above/below its max and min data points.

As I know Seaborn is Matplotlib-based, I tried using plt.ylim(a, b) or ax.set(ylim=(a, b)) following each countplot function, but this would only change the y-axis range for the lower-right subplot (SubtitlesEnabled).

Any help as to how I can apply this to each subplot within the plot function I have defined is greatly appreciated.

The code I used to generate the plot grid is below:

# Figure dimensions
fig, axes = plt.subplots(5, 2, figsize = (18, 18))

# List of categorical features column names
categorical_features_list = list(categorical_features)

# Countplot function
def make_countplot():
    i = 0
    for n in range(0,5):
        m = 0
        sns.countplot(ax = axes[n, m], x = categorical_features_list[i], data = train_df, color = 'blue',
                      order = train_df[categorical_features_list[i]].value_counts().index);
        m += 1
        sns.countplot(ax = axes[n, m], x = categorical_features_list[i+1], data = train_df, color = 'blue', 
                      order = train_df[categorical_features_list[i+1]].value_counts().index);
        m -= 1
        i += 2
        
make_countplot()

Here is the plotgrid itself: Seaborn Countplot Grid


Solution

  • You might want to read about the difference between the plt and the ax interface and about avoiding indices in Pytho. You can call ax.set_ylim(...) to set the y limits of a specific subplot.

    Here is an example using Seaborn's titanic dataset:

    import matplotlib.pyplot as plt
    import seaborn as sns
    
    titanic = sns.load_dataset('titanic')
    
    categorical_features = ['survived', 'sex', 'sibsp', 'parch', 'class', 'who', 'deck', 'embark_town']
    
    # Figure dimensions
    fig, axes = plt.subplots((len(categorical_features) + 1) // 2, 2, figsize=(10, 15))
    
    for feature, ax in zip(categorical_features, axes.flat):
        counts = titanic[feature].value_counts()
        sns.countplot(ax=ax, x=feature, data=titanic, color='dodgerblue',
                      order=counts.index)
        min_val = counts.min()
        max_val = counts.max()
        # set the limits 10% higher than the highest and 10% lower than the lowest
        delta = (max_val - min_val) * 0.10
        ax.set_ylim(max(0, min_val - delta), max_val + delta)
        # remove the xlabel and set the feature name inside the plot (to save some space)
        ax.set_xlabel('')
        ax.text(0.5, 0.98, feature, fontsize=14, transform=ax.transAxes, ha='center', va='top')
    
    plt.tight_layout()
    plt.show()
    

    seaborn countplot with custom y limits