Search code examples
pythonpandasnumpymatplotlibsubplot

Group by a dataframe and iterate over subplots for each group


I have a dataframe. I want to (in step 1) grouped it by column 'Main' (to create df_M1 and df_M2 and make a subplots for each created df). In (step 2), grouped each df (df_M1 and df_M2) by column 'Sub' (to create df_S1, df_S2, df_S3, df_S4). And than in (step 3) loop over columns (col1 and col2) of each Sub df to plot them on one of the plots. Actually, step 1 creates two dataframes and step 2 create four dataframe for the each created dataframe in step1. So I want to create two groups of subplots (each subplots has 4 plots (2*2)) for (df_M1 and df_M2), and each of subplots for (df_S1, df_S2, df_S3, df_S4). Each plot contains col1 and col2 line graphs. The column 'Date is X axis, and the values of col1 and col2 are on Y axis. I wrote the followings scripts, but there is a problem with defining subplots to plot line graphs. So, can anyone please let me know how I can define subplots (axes) to each Sub df and plot coulmns (col1 and col2) of each Sub df on each plot?

Create a sample dataframe

data = {'Date': [2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001],
        'Main': ['A','A','A','A','A','A','A','A','B','B','B','B','B','B','B','B'],
       'Sub' : ['A1','A1','A2','A2','A3','A3','A4','A4','B1','B1','B2','B2','B3','B3','B4','B4'],
       'col1' : [1,2,4,5,1,2,6,4,8,5,7,2,4,5,1,2],
       'col2' : [5,6,1,4,5,4,5,1,5,4,5,6,4,5,8,4]}

df = pd.DataFrame(data)
df_M = [x for _, x in df.groupby(['Main'])]
for i in df_M:
    fig, axes = plt.subplots(nrows=2, ncols=2,figsize=(12,6), 
sharex=True,linewidth=1, edgecolor='black')
    df_S = [z for _, z in i.groupby(['Sub'])]
    for j in df_S:
        for col in j.columns.values[3:5]:
            ax.plot(j[col])
            plt.show()

Solution

  • You need to address the individual axis properties within the axes object, i.e axes[0,0] plots in the axis in the first row and column of your 2x2 grid in the subplot.

    EDIT added date as x-axis

    import matplotlib.pyplot as plt
    import pandas as pd
    
    data = {
        "Date": [2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,2000,2001,],
        "Main": ["A","A","A","A","A","A","A","A","B","B","B","B","B","B","B","B",],
        "Sub": ["A1","A1","A2","A2","A3","A3","A4","A4","B1","B1","B2","B2","B3","B3","B4","B4",],
        "col1": [1, 2, 4, 5, 1, 2, 6, 4, 8, 5, 7, 2, 4, 5, 1, 2],
        "col2": [5, 6, 1, 4, 5, 4, 5, 1, 5, 4, 5, 6, 4, 5, 8, 4],
    }
    
    df = pd.DataFrame(data)
    df_M = [x for _, x in df.groupby(["Main"])]
    
    
    for groupName, groupdf in df.groupby(["Main"]):
        fig, axes = plt.subplots(
            nrows=2, ncols=2, figsize=(12, 6), sharex=True, linewidth=1, edgecolor="black"
        )
    
        for idx, (subGroupName, subGroupdf) in enumerate(groupdf.groupby(["Sub"])):
            row = 0 if idx < 2 else 1
            col = idx % 2
            for plottingCol in subGroupdf.columns.values[3:5]:
                axes[row, col].plot(subGroupdf.Date, subGroupdf[plottingCol])
    
        plt.show()