Search code examples
pythonpandasmatplotlibswarmplot

Related to multiple swamplots inside a figure Pandas


This question is related to group multiple plot in one figure python, "individual 28 plots". This is my code:

for column in df.columns[1:]:
    sns.set()
    fig, ax = plt.subplots(nrows=3, ncols=3) # tried 9 plots in one figure
    sns.set(style="whitegrid")
    sns.swarmplot(x='GF', y=column, data=df,order=["WT", 'Eulomin'])  # Choose column
    sns.despine(offset=10, trim=True) #?
    plt.savefig('{}.png'.format(column), bbox_inches='tight')  #  filename 
plt.show()

I have more than 100 columns and it saves every file individually and just prints empty plots beside the normal one . How do I save 9 plots in one figure, till it reachs the moment he'll have 5 left (which will have to be in one figure either)?


Solution

  • Instead of iterating through columns, iterate through multiples of 9 with range to index the data frame by column number while placing each swarmplot into the ax array you define:

    from itertools import product
    ...
    sns.set(style="whitegrid")
    
    for i in range(1, 100, 9):                         # ITERATE WITH STEPS
        col = i
        fig, ax = plt.subplots(nrows=3, ncols=3, figsize = (12,6)) 
    
        # TRAVERSE 3 X 3 MATRIX
        for r, c in product(range(3), range(3)):
            if col in range(len(df.columns)):         # CHECK IF COLUMN EXISTS
                # USE ax ARGUMENT WITH MATRIX INDEX
                sns.swarmplot(x='GF', y=df[df.columns[col]], data=df, ax=ax[r,c],
                              order=["WT", 'Eulomin'])
                sns.despine(offset=10, trim=True)
                col += 1
    
        plt.tight_layout()
        plt.savefig('SwarmPlots_{0}-{1}.png'.format(i,i+8), bbox_inches='tight')
    

    To demonstrate with random, seeded data of 100 columns by 500 rows for reproducibility:

    Data

    import numpy as np
    import pandas as pd
    
    np.random.seed(362020)
    cols = ['Col'+str(i) for i in range(1,100)]
    df = (pd.DataFrame([np.random.randn(99) for n in range(500)])
            .assign(GF = np.random.choice(['r', 'python', 'julia'], 500))
            .set_axis(cols + ['GF'], axis='columns', inplace = False)
            .reindex(['GF'] + cols, axis='columns')
         )          
    
    df.shape
    # (500, 100)
    

    Plot

    import matplotlib.pyplot as plt
    import seaborn as sns
    from itertools import product
    
    sns.set(style="whitegrid")
    
    for i in range(1, 100, 9):
        col = i
        fig, ax = plt.subplots(nrows=3, ncols=3, figsize = (12,6)) 
    
        for r, c in product(range(3), range(3)):
            if col in range(len(df.columns)):
                sns.swarmplot(x='GF', y=df[df.columns[col]], data=df, ax=ax[r,c])
                col += 1
    
        plt.tight_layout()
        plt.savefig('SwarmPlots_{0}-{1}.png'.format(i,i+8), bbox_inches='tight')
    
    plt.show()
    plt.clf()
    plt.close()
    

    Output (first plot)

    Plot Output