Search code examples
pythonpython-3.xscikit-learnlinear-regression

why is my linear regression line so short?


For a student project, I have to plot the knob position rc as a function of the corresponding wavelength in nanometres. When doing a linear regression, my line is very short. How can I change this? And how could I add error bars?

It looks like this:

It looks like this

But I want it to look more like this:

desired result

This is the data:

rc      nm
21.22   728.1
13.62   587.6
11.51   504.7
11.36   501.6
11.18   492.2
10.62   471.3
9.94    447.1
8.47    388.9

This is my code for a scatter plot:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from scipy import stats


df = pd.read_csv('datalin.csv')

X = df.drop(['rc'], axis=1)
y = df['rc']

df.head()

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

regressor = LinearRegression()
regressor.fit(X_train, y_train)

LinearRegression()
y_pred = regressor.predict(X_test)

plt.scatter(X_train, y_train,color='g')
plt.plot(X_test, y_pred,color='k')

plt.show()

I tried using sklearn.linear_model, LinearRegression.


Solution

  • The turqoise (desired) plot doesn't make a tonne of sense, since the scatter points don't add any new information - they're all on the line of regression. Either don't show them at all, or show them as the original (experimental) data.

    sklearn and scipy are not needed as dependencies.

    import matplotlib.pyplot as plt
    import numpy as np
    
    rc, wavelength = np.array((
        (21.22,  728.1),
        (13.62,  587.6),
        (11.51,  504.7),
        (11.36,  501.6),
        (11.18,  492.2),
        (10.62,  471.3),
        (9.94 ,  447.1),
        (8.47 ,  388.9),
    )).T
    
    
    (m, b), (residual,), rank, (m_sing, b_sing) = np.linalg.lstsq(
        a=np.stack((wavelength, np.ones_like(wavelength)), axis=1),
        b=rc, rcond=None,
    )
    title = f'rc ~ {m:.5f}λ - {-b:.3f}'
    print(title)
    rc_approx = m*wavelength + b
    
    fig: plt.Figure
    ax: plt.Axes
    fig, ax = plt.subplots()
    ax.scatter(wavelength, rc, c='#65dbc4')
    ax.errorbar(
        wavelength, rc_approx,
        yerr=np.abs(rc - rc_approx),
        c='#65dbc4', capsize=6,
    )
    ax.set_title(title)
    ax.set_xlabel('λ (nm)')
    ax.set_ylabel('Knob position r_c (#)')
    plt.show()
    

    fit

    But really, error bars aren't needed here since there are so few data; and the highest point is clearly an outlier:

    import matplotlib.pyplot as plt
    import numpy as np
    
    rc, wavelength = np.array((
        (21.22,  728.1),
        (13.62,  587.6),
        (11.51,  504.7),
        (11.36,  501.6),
        (11.18,  492.2),
        (10.62,  471.3),
        (9.94 ,  447.1),
        (8.47 ,  388.9),
    )).T
    
    
    (m, b), (residual,), rank, (m_sing, b_sing) = np.linalg.lstsq(
        a=np.stack((wavelength[1:], np.ones_like(wavelength[1:])), axis=1),
        b=rc[1:],
        rcond=None,
    )
    title = f'rc ~ {m:.5f}λ - {-b:.3f}'
    print(title)
    rc_approx = m*wavelength + b
    
    fig: plt.Figure
    ax: plt.Axes
    fig, ax = plt.subplots()
    ax.scatter(wavelength, rc, marker='+')
    ax.plot(wavelength, rc_approx)
    ax.set_title(title)
    ax.set_xlabel('λ (nm)')
    ax.set_ylabel('Knob position r_c (#)')
    plt.show()
    

    better fit