Search code examples
pythonregressionseabornlinear-regressionregplot

How to plot linear regression with Seaborn based on a prediction of a target variable?


I'm learning the very basics of data science and started with regression analysis. So I decided building a linear regression model to examine the linear relationship between two variables (chemical_1 and chemical_2) from this dataset.

I made chemical_1 the predictor (independent variable) and chemical_2 the target (dependent variable). Then used scipy.stats.linregress to calculate a regression line.

from scipy import stats

X = df['chemical_1']
Y = df['chemical_2']

slope, intercept, r_value, p_value, slope_std_error = stats.linregress(X,Y)
predict_y = slope * X + intercept

I figured out how to plot the regression line with matplotlib.

plt.plot(X, Y, 'o')
plt.plot(X, predict_y)
plt.show()

However I want to plot regression with Seaborn. The only option I have discovered for now is the following:

sns.set(color_codes=True)
sns.set(rc={'figure.figsize':(7, 7)})
sns.regplot(x=X, y=Y);

Is there a way to provide Seaborn with the regression line predict_y = slope * X + intercept in order to build a regression plot?

UPD: When using the following solution, proposed by RPyStats the Y-axis gets the chemical_1 name although it should be chemical_2.

fig, ax = plt.subplots()
sns.set(color_codes=True)
sns.set(rc={'figure.figsize':(8, 8)})
ax = sns.regplot(x=X, y=Y, line_kws={'label':'$y=%3.7s*x+%3.7s$'%(slope, intercept)});
ax.legend()
sns.regplot(x=X, y=Y, fit_reg=False, ax=ax);
sns.regplot(x=X, y=predict_y,scatter=False, ax=ax);

enter image description here


Solution

  • Using subplots and setting the axes will allow you to overlay your predicted Y values. Does this answer your question?

    print(predict_y.name)
    predict_y = predict_y.rename('chemical_2')
    fig, ax = plt.subplots()
    sns.set(color_codes=True)
    sns.set(rc={'figure.figsize':(7, 7)})
    sns.regplot(x=X, y=Y, fit_reg=False, ax=ax,scatter_kws={"color": "green"});
    sns.regplot(x=X, y=predict_y,scatter=False, ax=ax, scatter_kws={"color": "green"});