Search code examples
pythonpandasmatplotlibseabornsubplot

How to add an additional plot to multiple subplots


I want to generate pairs of lineplots where one of them is used as a benchmark.

I can generate a plot like this with the code below.

enter image description here

however, I wish I could have 6 pair plots with Arkhangelsk as the benchmark line in each plot instead. for example, one of them will be like this:

enter image description here.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

data = {'year': ['1998','1998','1998','1998','1998','1998','1998','1999','1999','1999','1999','1999','1999','1999'],'region':['Adygea','Altai Krai','Amur Oblast','Arkhangelsk','Astrakhan','Bashkortostan','Belgorod','Adygea','Altai Krai','Amur Oblast','Arkhangelsk','Astrakhan','Bashkortostan','Belgorod'],  'sales':[8.8, 19.2,21.2, 10.6,18,17.5,23, 10, 17.8, 20.5, 12.6, 19.9, 16, 21]}

df1 = pd.DataFrame(data)

plt.figure(figsize=(12, 6))

palette1 = {c:'#079b51' if c=='Astrakhan' else 'grey' for c in df1['region'].unique()}

sns.lineplot(x= 'year', y='sales', data=df1,hue='region', palette=palette1) # or sns.relplot
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

My code is reproducible.

I have tried the following, which apparently does not work. I am not sure how to loop through each of my 6 regions to compare the Astkrakhan line plot to. It should probably contain a condition like not equal/equal to 'Astrakhan' ? thank you.

fig, axs = plt.subplots(2,3, figsize=(18,15))
for i in enumerate(df1.region.unique()):
    sns.lineplot(x= 'year', y='sales', data=df1,ax = axs[i])


and this, which brings up the ValueError: Could not interpret value Adygea for parameter y.


df2 = df1.pivot(index='year', columns='region', values='sales') # converting the rows into columns
df_a = df2[['Arkhangelsk']]
df_r = df2.loc[:, ~df2.columns.isin(['Arkhangelsk'])] ## all other columns

fig, axes = plt.subplots(2, 3)
for col, ax in zip(df2.columns, axes.ravel()):
    sns.lineplot(x = "year", y = col, data = df_a, ax = ax, linestyle="--")
    sns.lineplot(x = "year", y = col, data = df_r, ax = ax)


Solution

  • # convert the year column to an int
    df.year = df.year.astype(int)
    
    # pivot the data to a wide format
    dfp = df.pivot(index='year', columns='region', values='sales')
    
    # get the columns to plot and compare
    compare = 'Arkhangelsk'
    cols = dfp.columns.tolist()
    cols.remove(compare)
    
    # set color dict
    color = {c:'#079b51' if c=='Arkhangelsk' else 'grey' for c in df['region'].unique()}
    
    # plot the data with subplots
    axes = dfp.plot(y=cols, subplots=True, layout=(2, 3), figsize=(16, 10), sharey=True, xticks=dfp.index, color=color)
    
    # flatten the array
    axes = axes.flat  # .ravel() and .flatten() also work
    
    # extract the figure object
    fig = axes[0].get_figure()
    
    # iterate through each axes
    for ax in axes:
        
        # plot the comparison column
        dfp.plot(y=compare, ax=ax, color=color)
        
        # adjust the legend if desired
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.10), frameon=False, ncol=2)
        
    fig.suptitle('My Plots', fontsize=22, y=0.95)
    plt.show()
    

    enter image description here