Search code examples
pythonmatplotlibseabornscatter-plot

Matrix of scatterplots by month-year


My data is in a dataframe of two columns: y and x. The data refers to the past few years. Dummy data is below:

np.random.seed(167)
rng = pd.date_range('2017-04-03', periods=365*3)

df = pd.DataFrame(
    {"y": np.cumsum([np.random.uniform(-0.01, 0.01) for _ in range(365*3)]),
     "x": np.cumsum([np.random.uniform(-0.01, 0.01) for _ in range(365*3)])
    }, index=rng
)

In first attempt, I plotted a scatterplot with Seaborn using the following code:

import seaborn as sns
import matplotlib.pyplot as plt

def plot_scatter(data, title, figsize):
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_title(title)
    sns.scatterplot(data=data,
                    x=data['x'],
                    y=data['y'])

plot_scatter(data=df, title='dummy title', figsize=(10,7))  

enter image description here However, I would like to generate a 4x3 matrix including 12 scatterplots, one for each month with year as hue. I thought I could create a third column in my dataframe that tells me the year and I tried the following:

import seaborn as sns
import matplotlib.pyplot as plt

def plot_scatter(data, title, figsize):
    fig, ax = plt.subplots(figsize=figsize)
    ax.set_title(title)
    sns.scatterplot(data=data,
                    x=data['x'],
                    y=data['y'],
                    hue=data.iloc[:, 2])
df['year'] = df.index.year
plot_scatter(data=df, title='dummy title', figsize=(10,7))    

enter image description here While this allows me to see the years, it still shows all the data in the same scatterplot instead of creating multiple scatterplots, one for each month, so it's not offering the level of detail I need.

I could slice the data by month and build a for loop that plots one scatterplot per month but I actually want a matrix where all the scatterplots use similar axis scales. Does anyone know an efficient way to achieve that?


Solution

  • To create multiple subplots at once, seaborn introduces figure-level functions. The col= argument indicates which column of the dataframe should be used to identify the subplots. col_wrap= can be used to tell how many subplots go next to each other before starting an additional row.

    Note that you shouldn't create a figure, as the function creates its own new figure. It uses the height= and aspect= arguments to tell the size of the individual subplots.

    The code below uses a sns.relplot() on the months. An extra column for the months is created; it is made categorical to fix an order.

    To remove the month= in the title, you can loop through the generated axes (a recent seaborn version is needed for axes_dict). With sns.set(font_scale=...) you can change the default sizes of all texts.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    np.random.seed(167)
    dates = pd.date_range('2017-04-03', periods=365 * 3, freq='D')
    
    df = pd.DataFrame({"y": np.cumsum([np.random.uniform(-0.01, 0.01) for _ in range(365 * 3)]),
                       "x": np.cumsum([np.random.uniform(-0.01, 0.01) for _ in range(365 * 3)])
                       }, index=dates)
    
    df['year'] = df.index.year
    month_names = pd.date_range('2017-01-01', periods=12, freq='M').strftime('%B')
    df['month'] = pd.Categorical.from_codes(df.index.month - 1, month_names)
    
    sns.set(font_scale=1.7)
    g = sns.relplot(kind='scatter', data=df, x='x', y='y', hue='year', col='month', col_wrap=4, height=4, aspect=1)
    # optionally remove the `month=` in the title
    for name, ax in g.axes_dict.items():
        ax.set_title(name)
    plt.setp(g.axes, xlabel='', ylabel='')  # remove all x and y labels
    g.axes[-2].set_xlabel('x', loc='left')  # set an x label at the left of the second to last subplot
    g.axes[4].set_ylabel('y')  # set a y label to 5th subplot
    
    plt.subplots_adjust(left=0.06, bottom=0.06)  # set some more spacing at the left and bottom
    plt.show()
    

    sns.catplot using months