Search code examples
pythonmachine-learningscikit-learnlinear-regressionshap

SHAP Partial Dependence Plot Misalignment with Train-Test Split in Linear Regression


I'm experiencing an issue with SHAP's partial dependence plot when using a train-test split for a linear regression model in Python. When I calculate SHAP values and plot the partial dependence for the first observation in my test set, the alignment of the data point and the baseline seems off.

Here's a simplified version of my code:

import shap
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import pandas as pd
import matplotlib.pyplot as plt
import requests

def load_data() -> pd.DataFrame:
    """
    Loads and returns the dataset from the given URL as a Pandas DataFrame.

    Returns:
        pd.DataFrame: The loaded dataset.
    """
    url = "https://archive.ics.uci.edu/static/public/165/concrete+compressive+strength.zip"

    r = requests.get(url)

    if r.ok:
        with zipfile.ZipFile(BytesIO(r.content)) as thezip:
            with thezip.open("Concrete_Data.xls") as thefile:
                return pd.read_excel(thefile, header=0)
    else:
        raise Exception("Something went wrong.")

df = load_data()

df = df.rename(
    columns={
        'Cement (component 1)(kg in a m^3 mixture)':'cement',
        'Blast Furnace Slag (component 2)(kg in a m^3 mixture)':'blast',
        'Fly Ash (component 3)(kg in a m^3 mixture)':'ash',
        'Water  (component 4)(kg in a m^3 mixture)':'water',
        'Superplasticizer (component 5)(kg in a m^3 mixture)':'superplasticizer',
        'Coarse Aggregate  (component 6)(kg in a m^3 mixture)':'coarse',
        'Fine Aggregate (component 7)(kg in a m^3 mixture)':'fine',
        'Age (day)':'age',
        'Concrete compressive strength(MPa, megapascals) ': 'strength'
    }
)
df = df.drop_duplicates()
X = df.drop(['strength'], axis=1) 
y = df['strength']

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train the model
model = LinearRegression()
model.fit(X_train, y_train)

# Initialize SHAP explainer and calculate values for the test set
explainer = shap.Explainer(model.predict, X_train)
shap_values = explainer(X_test)

# Plot partial dependence for the first test observation
idx = 0
shap.partial_dependence_plot(
    "cement", model.predict, X_test,
    model_expected_value=True, feature_expected_value=True, ice=False,
    shap_values=shap_values[idx:idx+1,:]
)

# Save the plot
plt.tight_layout()
plt.savefig('shap_dependence_plot.png', dpi=300)

However, when I generate the plot, the data point (black dot) does not align with the expected value line (blue line) for the feature of interest. It seems to be shifted along the y-axis. Here's the output plot for reference:

enter image description here

The plot seems correct when I initialize the SHAP explainer with the entire dataset X instead of just X_train:

explainer = shap.Explainer(linreg, X)
shap_values = explainer(X_test)

idx = 0
shap.partial_dependence_plot(
    "cement", model.predict, X_test,
    model_expected_value=True, feature_expected_value=True, ice=False,
    shap_values=shap_values[idx:idx+1,:]
)

Result:

enter image description here

Can someone explain why this misalignment occurs and how to correct the partial dependence plot when using a train-test split?

Any insights or suggestions would be greatly appreciated!


Solution

  • Starting from where you define shap_values, it would make sense to be consistent with what you use as a background dataset and data to be explained (hence my comment above):

    # Initialize SHAP explainer and calculate values for the test set
    explainer = shap.Explainer(model.predict, X_test)
    shap_values = explainer(X_test)
    
    # Plot partial dependence for the first test observation
    idx = 0
    shap.partial_dependence_plot(
        "cement", model.predict, X_test,
        model_expected_value=True, feature_expected_value=True, ice=False,
        shap_values=shap_values[idx:idx+1,:]
    )
    

    enter image description here

    or if you wish

    # Initialize SHAP explainer and calculate values for the train set
    explainer = shap.Explainer(model.predict, X_train)
    shap_values = explainer(X_train)
    
    # Plot partial dependence for the first train observation
    idx = 0
    shap.partial_dependence_plot(
        "cement", model.predict, X_train,
        model_expected_value=True, feature_expected_value=True, ice=False,
        shap_values=shap_values[idx:idx+1,:]
    )
    

    enter image description here