Search code examples
pythonmachine-learningclassificationshapxgbclassifier

How to plot SHAP summary plots for all classes in multiclass classification


I am using XGBoost with SHAP to analyze feature importance in a multiclass classification problem and need help plotting the SHAP summary plots for all classes at once. Currently, I can only generate plots one class at a time.

SHAP version: 0.45.0
Python version: 3.10.12

Here is my code:

import xgboost as xgb
import shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score

# Generate synthetic data
X, y = make_classification(n_samples=500, n_features=20, n_informative=4, n_classes=6, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Train a XGBoost model for multiclass classification
model = xgb.XGBClassifier(objective="multi:softprob", random_state=42)
model.fit(X_train, y_train)

I then tried to plot the shape values:

# Create a SHAP TreeExplainer
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)

# Attempt to plot summary for all classes
shap.summary_plot(shap_values, X_test, plot_type="bar")

I got this interaction plot instead:

enter image description here

I remedied the problem with help from this post:

shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar")

which gives a normal bar plot for class 0:

enter image description here

I can then do the same with classes 1, 2, 3, etc.

The question is, how can you make a summary plot for all the classes? I.e., a single plot showing the contribution of a feature to each class?


Solution

  • The issue is that explainer.shap_values(X_test) will return a 3D DataFrame of shape (rows, features, classes) and to show a bar plot summary_plot(shap_values) requires shap_values to be a list of (rows, features) where the list is: length = number of classes.

    For my own purposes, I used the following function which converts your shap_values into the format that you need:

    def shap_values_to_list(shap_values, model):
        shap_as_list=[]
        for i in range(len(model.classes_)):
            shap_as_list.append(shap_values[:,:,i])
        return shap_as_list
    

    Then you can do:

    shap_as_list = shap_values_to_list(shap_values, model)
    shap.summary_plot(shap_as_list, X_test, plot_type="bar")
    

    You can always add feature_names and class_names to the summary_plot if you need. With my own example I went from having the same kind of interaction plot that you did to the following:

    Example of shap.summary_plot output using shap_values converted to a list of shap_values