Search code examples
pythonpandasmatplotlibgroup-by

Create subplot, by overlapping two dataframes of different shapes and column names, for every group/id,


I have the below two dataframes with different shapes and column names:

#Load the required libraries
import pandas as pd
import matplotlib.pyplot as plt

#Create dataset_1
data_set_1 = {'id': [1, 1, 1,1, 1, 1, 1, 1, 1,
               2, 2, 2, 2,
               3, 3, 3, 3, 3, 3, 3,3,
               4, 4, 4, 4,],
        'cycle_1': [0.0, 0.2,0.4, 0.6, 0.8, 1,1.2,1.4,1.6,
                  0.0, 0.2,0.4, 0.6,
                  0.0, 0.2,0.4, 0.6, 0.8,1.0,1.2,1.4,
                  0.0, 0.2,0.4, 0.6, ],
        'Salary_1': [6, 7, 7, 7,8,9,10,11,12,
                   3, 4, 4, 4,
                   2, 8,9,10,11,12,13,14,
                   1, 8,9,10,],
        'Children_1': ['Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'No','No', 'Yes',
                     'Yes', 'Yes', 'Yes', 'No',  
                     'Yes', 'No','Yes', 'Yes', 'No','No', 'Yes','Yes',
                     'Yes', 'Yes', 'No','Yes', ],
        'Days_1': [141, 123, 128, 66, 66, 120, 141, 52, 52,
                 141, 96, 120,120, 
                 141,  15,123, 128, 66, 120, 141, 141,
                 141, 141,123, 128, ],
        }

#Convert to dataframe_1
df_1 = pd.DataFrame(data_set_1)
print("\n df_1 = \n",df_1)



#Create dataset_2
data_set_2 = {'id': [1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1,
               2, 2, 2, 2, 2, 2, 2,
               3, 3, 3, 3, 3, 3, 3,3,
               4, 4, 4, 4, 4,4,],
        'cycle_2': [0.0, 0.2,0.4, 0.6, 0.8, 1,1.2,1.4,1.6,1.8,2.0,2.2,
                  0.0, 0.2,0.4, 0.6,0.8,1.0,1.2,
                  0.0, 0.2,0.4, 0.6, 0.8,1.0,1.2,1.4,
                  0.0, 0.2,0.4, 0.6, 0.8,1.0,],
        'Salary_2': [7, 8, 8, 8,8,9,14,21,12,19,14,20,
                   1, 6, 3, 8,4,9,8,
                   6, 4,9,10,4,12,13,6,
                   1, 4,9,10,9,4,],
        'Children_2': ['Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'No','No', 'Yes', 'Yes', 'Yes', 'No',
                     'Yes', 'Yes', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 
                     'Yes', 'No','Yes', 'Yes', 'No','No', 'Yes','Yes',
                     'Yes', 'Yes', 'No','Yes', 'Yes','Yes',],
        'Days_2': [141, 123, 128, 66, 66, 120, 141, 52,96, 120, 141, 52,
                 141, 96, 120,120, 141, 52,96,
                 141,  15,123, 128, 66, 120, 141, 141,
                 141, 141,123, 128, 66,67,],
        }

#Convert to dataframe_2
df_2 = pd.DataFrame(data_set_2)
print("\n df_2 = \n",df_2)

Now, here I wish to plot the cycle_1 vs Salary_1, and overlap it with cycle_2 vs Salary_2, for every id in different subplots.

Thus I need to use subplot function as such:

## Plot for all id's
plt_fig_verify = plt.figure(figsize=(10,8))

## id1: 
plt.subplot(4,1,1)
plt.plot(df_1.groupby(by="id").get_group(1)['cycle_1'], df_1.groupby(by="id").get_group(1)['Salary_1'], 'b',  linewidth = '1', label ='id1: Salary_1 of df_1')
plt.plot(df_2.groupby(by="id").get_group(1)['cycle_2'], df_2.groupby(by="id").get_group(1)['Salary_2'], 'r',  linewidth = '1', label ='id1: Salary_2 of df_2')
plt.xlabel('cycle')
plt.ylabel('Salary')
plt.legend()

## id2: 
plt.subplot(4,1,2)
plt.plot(df_1.groupby(by="id").get_group(2)['cycle_1'], df_1.groupby(by="id").get_group(2)['Salary_1'], 'b',  linewidth = '1', label ='id2: Salary_1 of df_1')
plt.plot(df_2.groupby(by="id").get_group(2)['cycle_2'], df_2.groupby(by="id").get_group(2)['Salary_2'], 'r',  linewidth = '1', label ='id2: Salary_2 of df_2')
plt.xlabel('cycle')
plt.ylabel('Salary')
plt.legend()

## id3: 
plt.subplot(4,1,3)
plt.plot(df_1.groupby(by="id").get_group(3)['cycle_1'], df_1.groupby(by="id").get_group(3)['Salary_1'], 'b',  linewidth = '1', label ='id3: Salary_1 of df_1')
plt.plot(df_2.groupby(by="id").get_group(3)['cycle_2'], df_2.groupby(by="id").get_group(3)['Salary_2'], 'r',  linewidth = '1', label ='id3: Salary_2 of df_2')
plt.xlabel('cycle')
plt.ylabel('Salary')
plt.legend()

## id4: 
plt.subplot(4,1,4)
plt.plot(df_1.groupby(by="id").get_group(4)['cycle_1'], df_1.groupby(by="id").get_group(4)['Salary_1'], 'b',  linewidth = '1', label ='id4: Salary_1 of df_1')
plt.plot(df_2.groupby(by="id").get_group(4)['cycle_2'], df_2.groupby(by="id").get_group(4)['Salary_2'], 'r',  linewidth = '1', label ='id4: Salary_2 of df_2')
plt.xlabel('cycle')
plt.ylabel('Salary')
plt.legend()

plt.show()

The plot looks as such:

enter image description here

However, here I need to write the codes for the subplot function four times, with different column names, i.e. for all four id's of the dataframe, and then overlap.

Is there any way out, by which we can have some iterative function and write the subplot function only once and get all four overalapped subplots.

Can somebody please let me know how to achieve this task in Python?


Solution

  • You can slightly adjust this code to make sure that the two dataframes have the same header.

    colors = {"df_1": "blue", "df_2": "red"}
    
    df = (
        pd.concat(
            [df_1, df_2.set_axis(df_1.columns, axis=1)], keys=colors)
                .rename(lambda x: x.split("_")[0], axis=1)
    )
    
    fig, axs = plt.subplots(figsize=(10, 8), nrows=len(df["id"].unique()))
    
    for (n, g), ax in zip(df.groupby("id"), axs.flatten()):
        for i, s in enumerate(df.index.levels[0], start=1):
            g.loc[s].plot(
                x="cycle", y="Salary",
                xlabel="Cycle", ylabel="Salary",
                label=f"id{n}: Salary_{i} of {s}",
                color=colors[s],
                ax=ax
            )
    
    plt.tight_layout()
    
    plt.show();
    

    Output :

    enter image description here