Search code examples
pythonpandasseaborn

How to apply different kinds of error regions to some subplots of a facet grid


I have a Seaborn facet grid with multiple 4 rows and 2 cols. I want to map lineplots onto the grid from a melted dataframe (with 4 variables and 2 cats) and apply error bars to the top three rows but not the last (it is based on boolean data and error bars are not appropriate).

Using an example dataset:

import seaborn as sns

fmri = sns.load_dataset('fmri')
gg = sns.FacetGrid(data=fmri, col="region", row="event")
gg.map(sns.lineplot, "timepoint", "signal", errorbar="sd")

basic facet plot

Is there any way to remove (or not plot in the first place) the errors bars on the bottom row?


Solution

  • Removing the error region

    You could remove the error regions afterward by looping through the axes:

    import seaborn as sns
    
    fmri = sns.load_dataset('fmri')
    gg = sns.FacetGrid(data=fmri, col="region", row="event")
    gg.map(sns.lineplot, "timepoint", "signal", errorbar="sd")
    
    for (row, col), ax in gg.axes_dict.items():
        if row == 'cue':
            ax.collections[0].remove()
    

    mapping seaborn lineplot with and without error region

    Custom plot function

    Alternatively, you could map a custom function to the facet grid. Inside that custom function, a test can check whether the provided y-values for that subplot are similar to booleans. As seaborn converts all values to a same type (floats), testing the number of unique y values would be a way to determine that all the values are similar to booleans.

    The custom function gets 3 parameters:

    • a pandas series (column of the dataframe, reduced to the subplot) for the x values
    • similar for the corresponding y values
    • a color
    import seaborn as sns
    
    def custom_lineplot(x, y, color):
        errorbar = 'sd' if len(y.unique() > 2) else None
        sns.lineplot(x=x, y=y, color=color, errorbar=errorbar)
    
    fmri = sns.load_dataset('fmri')
    gg = sns.FacetGrid(data=fmri, col="region", row="event")
    gg.map(custom_lineplot, "timepoint", "signal")