Search code examples
pythonmatplotlibplotstatisticsstatsmodels

How to plot regression results using statsmodels with single categorical (3 levels) independent variable?


I have a numerical dependent varible Y and one categorical independent variable X with 3 levels (x1, x2 and x3).

Y corresponds to the measurement of a sensor and X to three measurement conditions. Let's say I measured luminance of (Y) under 3 different conditions (X: x1, x2 and x3).

I'm using the statsmodels python library to perform a regression (how measurement conditions affect luminance)

res = smf.ols(formula='Y ~ C(X)', data=df_cont).fit()

Now I need to plot the regression results (linear fit) and the "raw" data on the same plot. The plot I have in mind is something like this mock example:

[example1

I've tried the statsmodels plot_fit and albine_plot but have not managed to make it work. I've tried to follow this question but I'm still not able to do it.

Any idea on how to accomplish this would be very welcome!


Solution

  • When you fit a linear model like you did, you are estimating a mean for every category, it is a not a slope and intercept fit through all your data points,for example:

    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import statsmodels.api as sm
    import numpy as np
    import statsmodels.formula.api as smf
    
    df = pd.DataFrame({'Y':np.random.normal(np.repeat([0,1.5,2.5],20),1,60),
                      'X':np.repeat(['x1','x2','x3'],20)})
    
    df['X'] = pd.Categorical(df['X'],categories=['x1','x2','x3'])
    
    res = smf.ols(formula= "Y ~ X",data=df).fit()
    res.summary()
    
        coef    std err t   P>|t|   [0.025  0.975]
    Intercept   -0.0418 0.233   -0.180  0.858   -0.508  0.424
    X[T.x2] 1.3507  0.329   4.102   0.000   0.691   2.010
    X[T.x3] 2.5947  0.329   7.880   0.000   1.935   3.254
    

    To plot these results, you can do:

    fig, ax = plt.subplots()
    sns.scatterplot(data=df,x = "X",y = "Y",ax=ax)
    ncat = len(res.params)
    ax.scatter(x = np.arange(ncat)+0.1,y = res.params , color = "#FE9898")
    ax.vlines(x = np.arange(ncat)+0.1,
              ymin = res.conf_int().iloc[:,0],
              ymax = res.conf_int().iloc[:,1],
             color = "#FE9898")
    

    enter image description here

    If you really have to force a line, bear in mind this does not come from the regression you've just shown:

    sns.regplot(x = df['X'].cat.codes,y = df['Y'],ax=ax,scatter=False,color="#628395")
    fig
    

    enter image description here