Search code examples
pythonpandasfor-loopmatplotlibstatsmodels

What is the most efficient way to create multiple subplots with `statsmodels.api.qqplot()`?


Currently this code block generates 8x1 subplots. I want the code block to generate 2x4 subplots in the most efficient way. Please help!

from sklearn.preprocessing import LabelEncoder
import statsmodels.api as sm
import pandas as pd

# load the dataset
url = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/wine.csv'
df = pd.read_csv(url, header=None)
normal_stat = df.loc[:,:7]

# generate subplots
for i in range(len(normal_stat.columns)):
    plt.figure(figsize=(20,10))
    ax=plt.subplot(2,4,1+i)
    sm.qqplot(df[normal_stat.columns[i]], line='s', ax=ax)
    ax.set_title(str(normal_stat.columns[i]) + ' QQ Plot')
    plt.show()

Solution

  • Edit:

    To plot only the first 7 qq-plots:

    fig, axes = plt.subplots(ncols=4, nrows=2, sharex=True, figsize=(4*3, 2*3))
    for k, ax in zip(df.columns, np.ravel(axes)):
        if k >= 7:
            ax.set_visible(False)
        else:
            sm.qqplot(df[k], line='s', ax=ax)
            ax.set_title(f'{k} QQ Plot')
    plt.tight_layout()
    

    enter image description here

    Note: we could have limited the zip iterator to the first 7 columns by saying for k, ax in zip(df.columns[:7], np.ravel(axes)), but it would leave the 8th subplot shown (as an empty box). The method above explicitly hides the plots that we don't want.

    Original answer

    You can do this:

    import numpy as np
    
    fig, axes = plt.subplots(ncols=4, nrows=2, sharex=True, figsize=(4*3, 2*3))
    for k, ax in zip(df.columns, np.ravel(axes)):
        sm.qqplot(df[k], line='s', ax=ax)
        ax.set_title(f'{k} QQ Plot')
    plt.tight_layout()
    

    resulting figure of 2x4 qq-plots

    Explanation

    The ncols=4, nrows=2 arguments mandate 8 subplots, arranged in two rows. But the resulting axes is a nested length-2 list of length-4 lists. So we use np.ravel to flatten that. Then we use zip() to iterate concurrently on the columns of df and these axes. The iterator will stop when the shortest of columns and axes iterator is reached, which in this case is the 8 axes (so we don't even have to explicitly slice the df columns).

    Finally, the plt.tight_layout() is used to space the subplots in a way that the labels are more readable and the subplots more nicely separated.