Search code examples
plotlyseabornplotly-pythonconfidence-intervaltrendline

how to add confidence interval fillcontour using plotly express?


I am using plotly express to add trendline, now how to plot confidence interval like in seaborn regplot(),


df = px.data.tips()
fig = px.scatter(df, x="total_bill", y="tip", trendline="ols")
fig.show()

enter image description here


Solution

  • Plotly doesn't have confidence bands built in. However, since you want something like seaborn regplot, you can directly use regplot, extract the array representing the upper and lower bands, and then create the visualization in plotly.

    The only thing is that in order to extract the confidence bands from regplot, you'll need to specify a binwidth as explained in this answer. For the tips dataset, I used a binwidth of 5 to ensure that within each bin, there are enough points for a confidence band to be created (if you use a binwidth of 1, no confidence band will be calculated for sparser regions of the data)

    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import plotly.express as px
    import plotly.graph_objects as go
    
    df = px.data.tips()
    fig = px.scatter(df, x="total_bill", y="tip", trendline="ols")
    
    ## use binning so that we can access the confidence intervals
    binwidth = 5
    x_max, x_min = max(df["total_bill"]), min(df["total_bill"])
    x_bins = np.arange(x_min, x_max, binwidth)
    sns.regplot(x="total_bill", y="tip", x_bins=x_bins, data=df, x_ci=95, fit_reg=None)
    
    ax = plt.gca()
    x = [line.get_xdata().min() for line in ax.lines]
    y_lower = [line.get_ydata().min() for line in ax.lines]
    y_upper = [line.get_ydata().max() for line in ax.lines]
    
    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=y_upper+y_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(0,100,80,0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
    ))
    
    fig.show()
    

    enter image description here