Search code examples
pythonpandasmachine-learningjupyter-notebooksklearn-pandas

Plotting a simple linear regression model goes wrong


I'd like to create a linear regression model that shows a positive correlation between BMI and Disease risk (a quantitative measure of disease one year after baseline).

The dataset is the same one from the sklearn dataset-- import sklearn.datasets.load_diabetes

And this is the URL (https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt)

I've imported the whole table using read_csv(args) and called it 'data'

df = DataFrame({'BMI': data['BMI'], 'Target': data['Y']}).sort_values('BMI')

df.plot.scatter('BMI', 'Target')

model = LinearRegression(fit_intercept=True)
model.fit(data[['BMI']], data['Y'])

x_test = np.linspace(data['BMI'].min(), data['BMI'].max())
y_pred = model.predict(x_test[:, np.newaxis])

df.plot(x_test, y_pred, linestyle=":", color="red")

When I try this it gives me a big error message I don't understand, why does this happen?

Error Message


Solution

  • I think what you want is:

    import pandas as pd
    from sklearn.linear_model import LinearRegression
    import numpy as np
    from matplotlib import pyplot as plt
    
    [...]
    
    df = pd.DataFrame({'BMI': data['BMI'], 'Target': data['Y']}).sort_values('BMI')
    
    model = LinearRegression(fit_intercept=True)
    model.fit(data[['BMI']], data['Y'])
    
    x_test = np.linspace(data['BMI'].min(), data['BMI'].max())
    y_pred = model.predict(x_test[:, np.newaxis])
    
    plt.scatter(df['BMI'].values, df['Target'].values)
    plt.plot(x_test, y_pred, linestyle="-", color="red")
    plt.show()
    

    which gives us: enter image description here

    The solution you had with df.plot(x, y) is giving you the error because this plot function of the pandas dataframe only works on the dataframe it is called on. It's no general plot function like the pyplot.plot(x, y) plot function.