Search code examples
pythonpandasmatplotlibseabornfacet-grid

How to add a horizontal mean line and annotation to each facet


I have a simple FacetGrid of 2 row and 1 column with lineplots denoting different categories for the facets - image below.

# lineplot for each Category over the last three years
g = sns.FacetGrid(df, row="Category", sharey=False, sharex=False, height=2.5, aspect = 3)
g = g.map(plt.plot, 'Date', 'Count')

enter image description here

How do I add a reference line and annotation showing the mean Count for each facet?

Sample Data

  • Read the sample dataframe with
import pandas as pd

data = {'Category': ['Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1', 'Group 1',
                     'Group 1', 'Group 1', 'Group 1', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2',
                     'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2', 'Group 2'],
        'Date': ['2017-01-31', '2017-02-28', '2017-03-31', '2017-04-30', '2017-05-31', '2017-06-30', '2017-07-31', '2017-08-31', '2017-09-30', '2017-10-31', '2017-11-30', '2017-12-31', '2018-01-31', '2018-02-28', '2018-03-31', '2018-04-30', '2018-05-31', '2018-06-30', '2018-07-31', '2018-08-31', '2018-09-30', '2018-10-31', '2018-11-30', '2018-12-31',
                 '2019-01-31', '2019-02-28', '2019-03-31', '2019-04-30', '2019-05-31', '2019-06-30', '2019-07-31', '2019-08-31', '2019-09-30', '2017-01-31', '2017-02-28', '2017-03-31', '2017-04-30', '2017-05-31', '2017-06-30', '2017-07-31', '2017-08-31', '2017-09-30', '2017-10-31', '2017-11-30', '2017-12-31', '2018-01-31', '2018-02-28', '2018-03-31',
                 '2018-04-30', '2018-05-31', '2018-06-30', '2018-07-31', '2018-08-31', '2018-09-30', '2018-10-31', '2018-11-30', '2018-12-31', '2019-01-31', '2019-02-28', '2019-03-31', '2019-04-30', '2019-05-31', '2019-06-30', '2019-07-31', '2019-08-31', '2019-09-30'],
        'Count': [226, 235, 236, 221, 187, 218, 225, 221, 248, 224, 204, 224, 218, 241, 196, 246, 256, 217, 229, 230, 222, 215, 226, 227, 232, 233, 224, 214, 243, 214, 235, 218, 208, 208, 254, 223, 227, 245, 222, 226, 235, 225, 226, 258, 234, 257, 224, 228, 222, 227, 256, 217, 243, 230, 250, 197, 232, 248, 232, 259, 259, 229, 228, 234, 218, 231]}

df = pd.DataFrame(data)
df.Date = pd.to_datetime(df.Date)

df.head()

  Category       Date  Count
0  Group 1 2017-01-31    226
1  Group 1 2017-02-28    235
2  Group 1 2017-03-31    236
3  Group 1 2017-04-30    221
4  Group 1 2017-05-31    187

Solution

  • g = sns.relplot(data=df, kind='line', x='Date', y='Count', row='Category', height=2.5, aspect=3, facet_kws={'sharey': True, 'sharex': False})
    g.fig.tight_layout()
    
    # draw lines:
    for m, ax in zip(df.groupby('Category').Count.mean(), g.axes.ravel()):
        ax.hlines(m, *ax.get_xlim())
        ax.annotate(f'Mean: {m:0.0f}', xy=(ax.get_xlim()[1], m))
    

    enter image description here

    • This also works for other figure-level plots like sns.catplot.
    g = sns.catplot(data=df, kind='bar', x='Date', y='Count', row='Category', height=2.5, aspect=3)
    g.set_xticklabels(rotation=90)
    
    # draw lines:
    for m, ax in zip(df.groupby('Category').Count.mean(), g.axes.ravel()):
        ax.hlines(m, *ax.get_xlim())
        ax.annotate(f'Mean: {m:0.0f}', xy=(ax.get_xlim()[1], m))
    

    enter image description here


    You can manually draw the horizontal line on each of the axes:

    zip(list1, list2) is similar to [(list1[0], list2[0]), (list1[1], list2[1]),...]. In this code it means m is the mean, ax is the axis in the facets. ravel() turns n-dimension np.array into 1D array so you can zip. ax.hlines(y_val, x_min, x_max) draw a horizontal line at y_val from x_min to x_max. Here the two x values are provided by *ax.get_xlim().

    g = sns.FacetGrid(df, row="Category", sharey=False, sharex=False, height=2.5, aspect = 3)
    g = g.map(plt.plot, 'Date', 'Count')
    
    # draw lines:
    for m,ax in zip(df.groupby('Category').Count.mean(), g.axes.ravel()):
        ax.hlines(m,*ax.get_xlim())
    

    Output:

    enter image description here