Search code examples
pythonmatplotlibseabornbar-chartcatplot

How to annotate grouped bars in a facetgrid with custom strings


My seaborn plot is shown below. Is there a way to add the info in the flag column (which will always be a single character or empty string) in the center (or top) of the bars? Hoping there is an answer which would not need redoing the plot as well.

This answer seems to have some pointers but I am not sure how to connect it back to the original dataframe to pull info in the flag column.

import matplotlib.pyplot as plt
import seaborn as sns

df = pd.DataFrame([
    ['C', 'G1', 'gbt',    'auc', 0.7999, "†"],
    ['C', 'G1', 'gbtv2',  'auc', 0.8199, "*"],
    ['C', 'G1', 'gbt',  'pr@2%', 0.0883, "*"],
    ['C', 'G1', 'gbt', 'pr@10%', 0.0430,  ""],
    ['C', 'G2', 'gbt',    'auc', 0.7554,  ""],
    ['C', 'G2', 'gbt',  'pr@2%', 0.0842,  ""],
    ['C', 'G2', 'gbt', 'pr@10%', 0.0572,  ""],
    ['C', 'G3', 'gbt',    'auc', 0.7442,  ""],
    ['C', 'G3', 'gbt',  'pr@2%', 0.0894,  ""],
    ['C', 'G3', 'gbt', 'pr@10%', 0.0736,  ""],
    ['E', 'G1', 'gbt',    'auc', 0.7988,  ""],
    ['E', 'G1', 'gbt',  'pr@2%', 0.0810,  ""],
    ['E', 'G1', 'gbt', 'pr@10%', 0.0354,  ""],
    ['E', 'G1', 'gbtv3','pr@10%',0.0454,  ""],
    ['E', 'G2', 'gbt',    'auc', 0.7296,  ""],
    ['E', 'G2', 'gbt',  'pr@2%', 0.1071,  ""],
    ['E', 'G2', 'gbt', 'pr@10%', 0.0528,  "†"],
    ['E', 'G3', 'gbt',    'auc', 0.6958,  ""],
    ['E', 'G3', 'gbt',  'pr@2%', 0.1007,  ""],
    ['E', 'G3', 'gbt', 'pr@10%', 0.0536,  "†"],
  ], columns=["src","grp","model","metric","val","flag"])

cat = sns.catplot(data=df, x="grp", y="val", hue="model", kind="bar", sharey=False, 
            col="metric", row="src")
plt.show()

Solution

    • The issue is, for each axes, and each container within a given axes, the corresponding data must be selected. For example:
      • The first facet has src = C and metric = auc, and the facet is comprised of 3 containers, corresponding to the unique values of 'model'.
    • The label= parameter in .bar_label expects a list with the same number of values as there are ticks on the x-axis, even if a bar doesn't exist in that space.
      • The list-comprehension labels = [...] puts the corresponding label at the correct index, and fills missing labels with ''.
    • Tested in python 3.11.2, pandas 2.0.0, matplotlib 3.7.1, seaborn 0.12.2
    import pandas as pd
    import seaborn as sns
    import numpy as np
    
    # plot the dataframe from the OP
    g = sns.catplot(data=df, x="grp", y="val", hue="model", kind="bar", sharey=False, col="metric", row="src")
    
    # get the unique values from the grp column, which corresponds to the x-axis tick labels
    grp_unique = df.grp.unique()
    
    # iterate through axes
    for ax in g.axes.flat:
        
        # get the components of the title to filter the current data
        src, metric = [s.split(' = ')[1] for s in ax.get_title().split(' | ')]
        
        # iterate through the containers of the current axes
        for c in ax.containers:
            
            # get the hue label of the current container
            model = c.get_label()
            
            # filter the corresponding data
            data = df.loc[df.src.eq(src) & df.metric.eq(metric) & df.model.eq(model)]
            
            # if the DataFrame, data, isn't empty (e.g. there are bars for the current model
            if not data.empty:
                
                # for each grp on the x-axis, get the corresponding bar height (value, nan, 0)
                # this is to show the corresponding data and labels - this can be removed
                gh = {grp: (h := v.get_height(), data.loc[data.grp.eq(grp), 'flag'].tolist()[0] if not np.isnan(h) else '') for v, grp in zip(c, grp_unique)}
    
                # custom labels from the flag column
                labels = [data.loc[data.grp.eq(grp), 'flag'].tolist()[0] if not np.isnan(v.get_height()) else '' for v, grp in zip(c, grp_unique)]
    
                # shows the different data being used - can be removed
                print(src, metric, model)
                display(data)
                print(gh)
                print(labels)
                print('\n')
                
                # add the labels
                ax.bar_label(c, labels=labels, label_type='edge')
        ax.margins(y=0.2)
    

    enter image description here

    Printed Output

    C auc gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    |  0 | C     | G1    | gbt     | auc      | 0.7999 | †      |
    |  4 | C     | G2    | gbt     | auc      | 0.7554 |        |
    |  7 | C     | G3    | gbt     | auc      | 0.7442 |        |
    {'G1': (0.7999, '†'), 'G2': (0.7554, ''), 'G3': (0.7442, '')}
    ['†', '', '']
    
    
    C auc gbtv2
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    |  1 | C     | G1    | gbtv2   | auc      | 0.8199 | *      |
    {'G1': (0.8199, '*'), 'G2': (nan, ''), 'G3': (nan, '')}
    ['*', '', '']
    
    
    C pr@2% gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    |  2 | C     | G1    | gbt     | pr@2%    | 0.0883 | *      |
    |  5 | C     | G2    | gbt     | pr@2%    | 0.0842 |        |
    |  8 | C     | G3    | gbt     | pr@2%    | 0.0894 |        |
    {'G1': (0.0883, '*'), 'G2': (0.0842, ''), 'G3': (0.0894, '')}
    ['*', '', '']
    
    
    C pr@10% gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    |  3 | C     | G1    | gbt     | pr@10%   | 0.043  |        |
    |  6 | C     | G2    | gbt     | pr@10%   | 0.0572 |        |
    |  9 | C     | G3    | gbt     | pr@10%   | 0.0736 |        |
    {'G1': (0.043, ''), 'G2': (0.0572, ''), 'G3': (0.0736, '')}
    ['', '', '']
    
    
    E auc gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    | 10 | E     | G1    | gbt     | auc      | 0.7988 |        |
    | 14 | E     | G2    | gbt     | auc      | 0.7296 |        |
    | 17 | E     | G3    | gbt     | auc      | 0.6958 |        |
    {'G1': (0.7988, ''), 'G2': (0.7296, ''), 'G3': (0.6958, '')}
    ['', '', '']
    
    
    E pr@2% gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    | 11 | E     | G1    | gbt     | pr@2%    | 0.081  |        |
    | 15 | E     | G2    | gbt     | pr@2%    | 0.1071 |        |
    | 18 | E     | G3    | gbt     | pr@2%    | 0.1007 |        |
    {'G1': (0.081, ''), 'G2': (0.1071, ''), 'G3': (0.1007, '')}
    ['', '', '']
    
    
    E pr@10% gbt
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    | 12 | E     | G1    | gbt     | pr@10%   | 0.0354 |        |
    | 16 | E     | G2    | gbt     | pr@10%   | 0.0528 | †      |
    | 19 | E     | G3    | gbt     | pr@10%   | 0.0536 | †      |
    {'G1': (0.0354, ''), 'G2': (0.0528, '†'), 'G3': (0.0536, '†')}
    ['', '†', '†']
    
    
    E pr@10% gbtv3
    |    | src   | grp   | model   | metric   |    val | flag   |
    |---:|:------|:------|:--------|:---------|-------:|:-------|
    | 13 | E     | G1    | gbtv3   | pr@10%   | 0.0454 |        |
    {'G1': (0.0454, ''), 'G2': (nan, ''), 'G3': (nan, '')}
    ['', '', '']