Search code examples
pythonmatplotlibscikit-learnlinear-regressionpolynomials

polynomial degree scatter graph points not fitting for linear regression


I am using sklearn linear and polynomial feature to fit to a data set. the code looks like below. I am plotting the points using scatter but they don't seem to align with the prediction values. not sure what i am missing. i have tried to change degree value from 1 to 20 but no effect.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures

DEGREE = 5

X = np.array([276237,276617, 276997,  277377, 277757, 278137, 278517, 278897,  279277, 279657]).reshape(-1, 1)
y = np.array([6, 8, 2, 4, 0, 1, 7, 0, 1, 4])

poly_feat = PolynomialFeatures(degree=DEGREE)
X_poly = poly_feat.fit_transform(X)

lm = LinearRegression(fit_intercept = False)
lm.fit(X_poly, y)

fig=plt.figure()
ax=fig.add_axes([0,0,1,1])
ax.scatter(X, lm.predict(X_poly), color='r')
ax.set_xlabel('Total Amount')
ax.set_ylabel('Days to mine')
ax.plot(X,y)
plt.show()

Solution

  • I guess it is because you do not have enough data. You have 5 degree polynomial but only 10 data. The model doesn't train well. I tried made up some data and found that your code works well:

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.linear_model import LinearRegression
    from sklearn.preprocessing import PolynomialFeatures
    
    BLOCK_REWARD = 380
    DEGREE = 5
    
    #X = np.array([276237,276617, 276997,  277377, 277757, 278137, 278517, 278897,  279277, 279657]).reshape(-1, 1)
    #y = np.array([6, 8, 2, 4, 0, 1, 7, 0, 1, 4])
    
    # New data
    n = 50
    X = np.linspace(-5, 5, n)
    y = X**5 - 3 * X**4 + 2 * X**3 + 4 * X**2 - X + 6 + 200*np.random.randn(n)
    X = X.reshape(-1, 1)
    
    # Everything remain unchange
    poly_feat = PolynomialFeatures(degree=DEGREE)
    X_poly = poly_feat.fit_transform(X)
    
    lm = LinearRegression(fit_intercept = False)
    lm.fit(X_poly, y)
    
    fig=plt.figure()
    ax=fig.add_axes([0,0,1,1])
    ax.scatter(X, lm.predict(X_poly), color='r')
    ax.set_xlabel('Total Amount')
    ax.set_ylabel('Days to mine')
    ax.plot(X,y)
    plt.show()