Search code examples
pythonmatplotlibseaborn

How to prevent pyplot.errorbar from shifting x-axis of seaborn barplot


I want to plot data using Seaborn barplot; I only have the mean and standard deviation. I use pyplot.errorbar to add error bars to my plot, however, it shifts my x axis slightly (see red star below in plot). How do I prevent this from happening?

Plots: enter image description here

Code to reproduce:

import seaborn as sn
import matplotlib.pyplot as plt 

### loading example data ###
health = sns.load_dataset('healthexp')

health_summary = health.groupby(['Country']).Life_Expectancy.agg({'mean','std'}).reset_index()


### barplot without errorbars ###
p = sn.barplot(health_summary, x = 'Country', y = 'mean', errorbar=None)

plt.show()


### barplot with errorbars ###
p = sn.barplot(health_summary, x = 'Country', y = 'mean', errorbar=None)

p.errorbar(x=health_summary['Country'], y=health_summary['mean'], yerr=health_summary['std'], fmt="none", c="k")

plt.show()

Solution

  • You need to save the xlim before calling errorbar and then restore its values.

    You can make it handy with a context manager:

    from contextlib import contextmanager
    
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    
    @contextmanager
    def fixed_xlim(ax):
        xlims = ax.get_xlim()
        try:
            yield
        finally:
            ax.set_xlim(xlims)
    
    
    health = sns.load_dataset("healthexp")
    health_summary = (
        health.groupby(["Country"]).Life_Expectancy.agg({"mean", "std"}).reset_index()
    )
    
    fig, (ax1, ax2) = plt.subplots(nrows=2)
    p = sns.barplot(health_summary, x="Country", y="mean", errorbar=None, ax=ax1)
    p = sns.barplot(health_summary, x="Country", y="mean", errorbar=None, ax=ax2)
    
    with fixed_xlim(ax2):
        p.errorbar(
            x=health_summary["Country"],
            y=health_summary["mean"],
            yerr=health_summary["std"],
            fmt="none",
            c="k",
        )
    
    plt.tight_layout()
    plt.show()
    

    result