Search code examples
pythonpython-3.xmatplotlibmachine-learninglinear-regression

Why visualization of targets and prediction is not showing accurately for Linear Regression Model?


Using Multiple linear regression model to estimate medical charges for smokers. I have used on the 'age', 'bmi', 'children' feature to estimate "charges".Here is my code below :

import pandas as pd
import numpy as np
import plotly.express as px
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error as rmse

Read the data from the github repo

smoker_df = pd.read_csv('https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv')

Create inputs and targets

inputs  = smoker_df[['age', 'bmi', 'children']]
targets = smoker_df['charges']

Create and train the model

model6 = LinearRegression().fit(inputs, targets)

Generate predictions

predictions = model6.predict(inputs)

Compute loss to evalute the model

loss = rmse(targets, predictions)
print('Loss:', loss)

Visualization of Prediction and Targets :

fig, ax = plt.subplots(figsize=(7, 3.5))

ax.plot(predictions, targets, color='k', label='Regression model')
ax.set_ylabel('predictions', fontsize=14)
ax.set_xlabel('targets', fontsize=14)
ax.legend(facecolor='white', fontsize=11)

It is not a good visualization . How do i improve it so that i could get some insight and how to visualize more than 3 features as inputs with 1 feature as targets. Output

Data Source


Solution

  • You can use a scatter to visualize your prediction vs observed:

    fig, ax = plt.subplots(figsize=(7, 3.5))
    
    ax.scatter(predictions, targets)
    ax.set_xlabel('prediction', fontsize=14)
    ax.set_ylabel('charges', fontsize=14)
    ax.legend(facecolor='white', fontsize=11)
    

    enter image description here

    You can see that some of your prediction is off, this is because you did not include other variables:

    import seaborn as sns
    sns.scatterplot(data=smoker_df,x = "age", y = "charges",hue="smoker")
    

    enter image description here

    And you can check out how your other features correlate with your target:

    fig, ax = plt.subplots(1,3,figsize=(15, 5))
    
    for i,x in enumerate(inputs.columns):
        ax[i].scatter(inputs[[x]], targets, label=x)
        ax[i].set_xlabel(x, fontsize=14)
        ax[i].set_ylabel('charges', fontsize=14)
        ax[i].legend(facecolor='white', fontsize=11)
    
    plt.tight_layout()
    

    enter image description here