Search code examples
pythondataframemachine-learningclassificationshap

SHAP plotting waterfall using an index value in dataframe


I am working on a binary classification using random forest algorithm

Currently, am trying to explain the model predictions using SHAP values.

So, I referred this useful post here and tried the below.

from shap import TreeExplainer, Explanation
from shap.plots import waterfall
sv = explainer(ord_test_t)
exp = Explanation(sv.values[:,:,1], 
                  sv.base_values[:,1], 
                  data=ord_test_t.values, 
                  feature_names=ord_test_t.columns)
idx = 20
waterfall(exp[idx])

I like the above approach as it allows to display the feature values along with waterfall plot. So, I wish to use this approach

However, this doesn't help me get the waterfall for a specific row in ord_test_t (test data).

For example, let's consider that ord_test_t.Index.tolist() returns 3,5,8,9 etc...

Now, I want to plot the waterfall plot for ord_test_t.iloc[[9]] but when I pass exp[9], it just gets the 9th row but not the index named as 9.

When I try exp.iloc[[9]] it throws error as explanation object doesnt have iloc.

Can help me with this please?


Solution

  • My suggestion is as following:

    from sklearn.ensemble import RandomForestClassifier
    from sklearn.datasets import load_breast_cancer
    from shap import TreeExplainer, Explanation
    from shap.plots import waterfall
    
    import shap
    
    print(shap.__version__)
    
    X, y = load_breast_cancer(return_X_y=True, as_frame=True)
    
    idx = 9
    model = RandomForestClassifier(max_depth=5, n_estimators=100).fit(X, y)
    explainer = TreeExplainer(model)
    sv = explainer(X.loc[[idx]])    # corrected, pass the row of interest as df
    exp = Explanation(
        sv.values[:, :, 1],         # class to explain
        sv.base_values[:, 1],
        data=X.loc[[idx]].values,   # corrected, pass the row of interest as df
        feature_names=X.columns,
    )
    waterfall(exp[0])               # pretend you have only 1 data point which is 0th 
    

    0.40.0
    

    enter image description here

    Proof:

    model.predict_proba(X.loc[[idx]]) # corrected
    

    array([[0.95752656, 0.04247344]])