Search code examples
loopsmatplotliblegendscatter-plotsubplot

Matplotlib scatterplot subplot legends overwrite one another


I have a scatterplot figure with subplots generated using a for loop. Within the figure, I am trying to create a single legend but each time a subplot and legend is rendered the legend is overwritten by the next subplot, so the figure that is generated contains a single legend pertaining only to the last subplot. I would like the legend to pertain to all subplots (i.e., it should include years 2019, 2020, 2021 and 2022). Here is my code, please let me know how I can tweak it.

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches

df = pd.read_excel(path)

spp = df.SPP.unique()

fig, axs = plt.subplots(nrows=8, ncols=4, figsize=(14, 14))

for spp_i, ax in zip(spp, axs.flat):
    df_1 = df[df['SPP'] == spp_i]
    labels = list(df_1.Year.unique())
    x = df_1['Length_mm']
    y = df_1['Weight_g']
    levels, categories = pd.factorize(df_1['Year'])
    colors = [plt.cm.tab10(i) for i in levels]
    handles = [matplotlib.patches.Patch(color=plt.cm.tab10(i), label=c) for i, c in enumerate(categories)]
    ax.scatter(x, y, c=colors)
    plt.legend(handles=handles)

plt.savefig('Test.png', bbox_inches='tight', pad_inches=0.1, dpi=600)

Here is figure, as you can see the legend in the bottom right is for the last subplot only.
enter image description here


Solution

  • Creating this type of plots is quite cumbersome with standard matplotlib. Seaborn automates a lot of the steps.

    In this case, sns.relplot(...) can be used. If you don't want all the subplots to have the same x and/or y ranges, you can add facet_kws={'sharex': False, 'sharey': False}).

    The size of the individual subplots is controlled via height=, while the width will be calculated as the height multiplied by the aspect. col_wrap= tells how many columns of subplots will be put before starting a new row.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    spp_list = ["Aeloria", "Baelun", "Caelondia", "Draeden", "Eldrida", "Faerun", "Gorandor", "Haldira", "Ilysium",
                "Jordheim", "Kaltara", "Lorlandia", "Myridia", "Nirathia", "Oakenfort"]
    df = pd.DataFrame({'SPP': np.repeat(spp_list, 100),
                       'Year': np.tile(np.repeat(np.arange(2019, 2023), 25), 15),
                       'Length_mm': np.abs(np.random.randn(1500).cumsum()) + 10,
                       'Weight_g': np.abs(np.random.randn(1500).cumsum()) + 20})
    
    g = sns.relplot(df, x='Length_mm', y='Weight_g', col='SPP', col_order=spp_list,
                    hue='Year', palette='turbo',
                    height=3, aspect=1.5, col_wrap=6,
                    facet_kws={'sharex': False, 'sharey': False})
    g.set_axis_labels(x_var='Length (mm)', y_var='Weight (g)', clear_inner=True)
    
    g.fig.tight_layout()  # nicely fit supblots with their titles, labels and ticks
    g.fig.subplots_adjust(right=0.97)  # space for the legend after fitting the subplots
    plt.show()
    

    sns.relplot example