Search code examples

export SHAP waterfall plot to dataframe

I am working on a binary classification using random forest model, neural networks in which am using SHAP to explain the model predictions. I followed the tutorial and wrote the below code to get the waterfall plot shown below

row_to_show = 20
data_for_prediction = ord_test_t.iloc[row_to_show]  # use 1 row of data here. Could use multiple rows if desired
data_for_prediction_array = data_for_prediction.values.reshape(1, -1)
explainer = shap.TreeExplainer(rf_boruta)
# Calculate Shap values
shap_values = explainer.shap_values(data_for_prediction)
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[0], shap_values[0],ord_test_t.iloc[row_to_show])

This generated the plot as shown below

enter image description here

However, I want to export this to dataframe and how can I do it?

I expect my output to be like as shown below. I want to export this for the full dataframe. Can you help me please?

enter image description here


  • Let's do a small experiment:

    from sklearn.ensemble import RandomForestClassifier
    from sklearn.datasets import load_breast_cancer
    from shap import TreeExplainer
    X, y = load_breast_cancer(return_X_y=True)
    model = RandomForestClassifier(max_depth=5, n_estimators=100).fit(X, y)
    explainer = TreeExplainer(model)

    What is explainer here? If you do dir(explainer) you'll find out it has some methods and attributes among which is:


    which is of interest to you because this is base on which SHAP values add up.


    sv = explainer.shap_values(X)

    will give a hint sv is a list consisting of 2 objects which are most probably SHAP values for 1 and 0, which must be symmetric (because what moves towards 1 moves exactly by the same amount, but with opposite sign, towards 0).


    sv1 = sv[1]

    Now you have everything to pack it to the desired format:

    df = pd.DataFrame(sv1, columns=X.columns)
    df.insert(0, 'bv', explainer.expected_value[1])

    Q: How do I know?
    A: Read docs and source code.