Search code examples
data-visualizationlinear-regressionconstantsstatsmodelstraceback

Simple logistic regression with Statsmodels: Adding an intercept and visualizing the logistic regression equation


Using Statsmodels, I am trying to generate a simple logistic regression model to predict whether a person smokes or not (Smoke) based on their height (Hgt).

I have a feeling that an intercept needs to be included into the logistic regression model but I am not sure how to implement one using the add_constant() function. Also, I am unsure why the error below is generated.

This is the dataset, Pulse.CSV: https://drive.google.com/file/d/1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98V/view?usp=sharing

The full code and output are in this PDF file: https://drive.google.com/file/d/1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJ/view?usp=sharing

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1,y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_value(self, series, key)
   4729         try:
-> 4730             return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
   4731         except KeyError as e1:
((( Truncated for brevity )))
IndexError: index out of bounds

Solution

  • Intercept is not added by default in Statsmodels regression, but if you need you can include it manually.

    import numpy as np
    import pandas as pd
    import statsmodels.api as sm
    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set()
    raw_data = pd.read_csv('Pulse.csv')
    raw_data
    x1 = raw_data['Hgt']
    y = raw_data['Smoke'] 
    
    x1 = sm.add_constant(x1)
    
    reg_log = sm.Logit(y,x1,missing='Drop')
    results_log = reg_log.fit()
    
    results_log.summary()
    
    def f(x,b0,b1):
        return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
    f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
    x_sorted = np.sort(np.array(x1))
    
    plt.scatter(x1['Hgt'],y,color='C0')
    
    plt.xlabel('Hgt', fontsize = 20)
    plt.ylabel('Smoked', fontsize = 20)
    plt.plot(x_sorted,f_sorted,color='C8')
    plt.show()
    

    This will also resolve the error as there was no intercept in your initial code.Source