I'd like to create a linear regression model that shows a positive correlation between BMI and Disease risk (a quantitative measure of disease one year after baseline).
The dataset is the same one from the sklearn dataset--
import sklearn.datasets.load_diabetes
And this is the URL (https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt)
I've imported the whole table using read_csv(args) and called it 'data'
df = DataFrame({'BMI': data['BMI'], 'Target': data['Y']}).sort_values('BMI')
df.plot.scatter('BMI', 'Target')
model = LinearRegression(fit_intercept=True)
model.fit(data[['BMI']], data['Y'])
x_test = np.linspace(data['BMI'].min(), data['BMI'].max())
y_pred = model.predict(x_test[:, np.newaxis])
df.plot(x_test, y_pred, linestyle=":", color="red")
When I try this it gives me a big error message I don't understand, why does this happen?
I think what you want is:
import pandas as pd
from sklearn.linear_model import LinearRegression
import numpy as np
from matplotlib import pyplot as plt
[...]
df = pd.DataFrame({'BMI': data['BMI'], 'Target': data['Y']}).sort_values('BMI')
model = LinearRegression(fit_intercept=True)
model.fit(data[['BMI']], data['Y'])
x_test = np.linspace(data['BMI'].min(), data['BMI'].max())
y_pred = model.predict(x_test[:, np.newaxis])
plt.scatter(df['BMI'].values, df['Target'].values)
plt.plot(x_test, y_pred, linestyle="-", color="red")
plt.show()
The solution you had with df.plot(x, y)
is giving you the error because this plot function of the pandas dataframe only works on the dataframe it is called on. It's no general plot function like the pyplot.plot(x, y)
plot function.