Search code examples
pythonscikit-learnlinear-regressionstatsmodels

How to predict data using LinearRegression using linear_model.OLS from statsmodels


I was running a linear regression using statsmodel.api and I wanted to do the same things I can with sklearn. However, I can't seem to find a way to apply my model to the test data and get the R-squared and other things.

This is the kind of thing I get using sklearn, but can't find a way to replicate using statsmodel:

# import library
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression

# Create sample
 X_R1, y_R1 = make_regression(n_samples = 100, n_features=1,n_informative=1, bias = 150.0, noise = 30, random_state=0)

# split train / test
X_train, X_test, y_train, y_test = train_test_split(X_R1, y_R1,random_state = 1)

# Roda o modelo
linreg = LinearRegression().fit(X_train, y_train)

# Apresenta as informacoes desejadas
print('linear model coeff (w): {}'.format(linreg.coef_))
print('linear model intercept (b): {:.3f}'.format(linreg.intercept_))
print('R-squared score (training): {:.3f}'.format(linreg.score(X_train, y_train)))
print('R-squared score (test): {:.3f}'.format(linreg.score(X_test, y_test)))

The output:

enter image description here

Now this is using statsmodel:

from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
from scipy import stats

X2 = sm.add_constant(X_train)
est = sm.OLS(y_train, X2)
est2 = est.fit()
print(est2.summary())

The output in the second script is more complete, so I would like to use it. But I still need to apply the model to the test data.


Solution

  • It's easy. You just need the predict method of the OLS model.

    Use this:

    from sklearn import datasets, linear_model
    from sklearn.linear_model import LinearRegression
    import statsmodels.api as sm
    from scipy import stats
    
    X2 = sm.add_constant(X_train)
    est = sm.OLS(y_train, X2).fit() # this is a OLS object
    
    X_test = sm.add_constant(X_test) # add again the constant
    y_test_predicted = est.predict(X_test) # use the predict method of the object
    

    All the available methods of the OLS object can be found here: https://www.statsmodels.org/stable/generated/statsmodels.regression.linear_model.OLS.html#statsmodels.regression.linear_model.OLS