Search code examples
pythonmatplotlibdata-visualizationdata-sciencestatsmodels

How can I plot the results of Logit in statsmodel using matplotlib


In this data set I have two categorical response values (0 and 1) and I want to fit the Logit model using statsmodels.

X_incl_const = sm.add_constant(X)
model = sm.Logit(y, X_incl_const)
results = model.fit()
results.summary()

when I try to plot the line and points using code below:

plt.scatter(X, y)
plt.plot(X, model.predict(X))

I get the following error:

    ValueError                                Traceback (most recent call last)
    <ipython-input-16-d69741b1f0ad> in <module>
          1 plt.scatter(X, y)
    ----> 2 plt.plot(X, model.predict(X))
    
    ~\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py in predict(self, params, exog, linear)
        461             exog = self.exog
        462         if not linear:
    --> 463             return self.cdf(np.dot(exog, params))
        464         else:
        465             return np.dot(exog, params)
    
    <__array_function__ internals> in dot(*args, **kwargs)
    
    ValueError: shapes (518,2) and (518,) not aligned: 2 (dim 1) != 518 (dim 0)

how can I plot the predicted line predicted by this model?


Solution

  • Your predict function must input an array with the same number of columns (or predictors) that was used in the fit. Also you should use the fitted object result in your code, instead of model. Using an example dataset:

    from sklearn.datasets import load_breast_cancer
    import statsmodels.api as sm
    
    dat = load_breast_cancer()
    df = pd.DataFrame(dat.data,columns=dat.feature_names)
    df['target'] = dat.target
    X = df['mean radius']
    y = df['target']
    
    X_incl_const = sm.add_constant(X)
    model = sm.Logit(y, X_incl_const)
    results = model.fit()
    results.summary()
    

    Fit is all good. Now if we just do prediction, same error like you saw:

    model.predict(X)
    
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-180-2558e7096c7c> in <module>
    ----> 1 model.predict(X)
          2 
          3 
    
    ~/anaconda2/lib/python3.7/site-packages/statsmodels/discrete/discrete_model.py in predict(self, params, exog, linear)
        482             exog = self.exog
        483         if not linear:
    --> 484             return self.cdf(np.dot(exog, params))
        485         else:
        486             return np.dot(exog, params)
    
    <__array_function__ internals> in dot(*args, **kwargs)
    
    ValueError: shapes (569,2) and (569,) not aligned: 2 (dim 1) != 569 (dim 0)
    

    We add the constant intercept, and then it works:

    plt.scatter(X,results.predict(sm.add_constant(X)))
    

    enter image description here

    Or if you are plotting just the fitted values, do:

    plt.scatter(X,results.predict())