Search code examples
pythonscikit-learnrandom-forestshap

SHAP TreeExplainer for RandomForest multiclass: what is shap_values[i]?


I am trying to plot SHAP This is my code rnd_clf is a RandomForestClassifier:

import shap 
explainer = shap.TreeExplainer(rnd_clf) 
shap_values = explainer.shap_values(X) 
shap.summary_plot(shap_values[1], X) 

I understand that shap_values[0] is negative and shap_values[1] is positive.

But what about for multiple class RandomForestClassifier? I have the rnd_clf classifying one of:

['Gusto','Kestrel 200 SCI Older Road Bike', 'Vilano Aluminum Road Bike 21 Speed Shimano', 'Fixie'].

How do I determine which index of shap_values[i] corresponds to which class of my output?


Solution

  • How do I determine which index of shap_values[i] corresponds to which class of my output?

    shap_values[i] are SHAP values for i'th class. What is an i'th class is more a question of an encoding schema you use: LabelEncoder, pd.factorize, etc.

    You may try the following as a clue:

    from sklearn.preprocessing import LabelEncoder
    
    labels = [
        "Gusto",
        "Kestrel 200 SCI Older Road Bike",
        "Vilano Aluminum Road Bike 21 Speed Shimano",
        "Fixie",
    ]
    le = LabelEncoder()
    y = le.fit_transform(labels)
    encoding_scheme = dict(zip(y, labels))
    pprint(encoding_scheme)
    

    {0: 'Fixie',
     1: 'Gusto',
     2: 'Kestrel 200 SCI Older Road Bike',
     3: 'Vilano Aluminum Road Bike 21 Speed Shimano'}
    

    So, eg shap_values[3] for this particular case is for 'Vilano Aluminum Road Bike 21 Speed Shimano'

    To further understand how to interpret SHAP values let's prepare a synthetic dataset for multiclass classification with 100 features and 10 classes:

    from sklearn.datasets import make_classification
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import train_test_split
    from shap import TreeExplainer
    from shap import summary_plot
    
    X, y = make_classification(1000, 100, n_informative=8, n_classes=10)
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
    print(X_train.shape)
    

    (750, 100)
    

    At this point we have train dataset with 750 rows, 100 features, and 10 classes.

    Let's train RandomForestClassifier and feed it to TreeExplainer:

    clf = RandomForestClassifier(n_estimators=100, max_depth=3)
    clf.fit(X_train, y_train)
    explainer = TreeExplainer(clf)
    shap_values = np.array(explainer.shap_values(X_train))
    print(shap_values.shape)
    

    (10, 750, 100)
    

    10 : number of classes. All SHAP values are organized into 10 arrays, 1 array per class.
    750 : number of datapoints. We have local SHAP values per datapoint.
    100 : number of features. We have SHAP value per every feature.

    For example, for Class 3 you'll have:

    print(shap_values[3].shape)
    

    (750, 100)
    

    750: SHAP values for every datapoint
    100: SHAP value contributions for every feature

    Finally, you can run a sanity check to make it sure real predictions from model are the same as those predicted by shap.

    To do so, we'll (1) swap the first 2 dimensions of shap_values, (2) sum up SHAP values per class for all features, (3) add SHAP values to base values:

    shap_values_ = shap_values.transpose((1,0,2))
    
    np.allclose(
        clf.predict_proba(X_train),
        shap_values_.sum(2) + explainer.expected_value
    )
    

    True
    

    Then you may proceed to summary_plot that will show feature rankings based on SHAP values on a per class basis. For class 3 this will be:

    summary_plot(shap_values[3],X_train)
    

    Which is interpreted as follows:

    • For class 3 most influential features based on SHAP contributions are 44, 64, 17

    • For features 64 and 17 lower values tend to result in higher SHAP values (hence higher probability of the class label)

    • Features 92, 6, 53 are least influential out of 20 displayed