Search code examples
matplotlibsubplotshap

SHAP Waterfall diagram as a matplotlib subplot?


Hi i would like to Show side by side diagrams from SHAP Library:

  1. Waterfall Diagram :API Reference https://shap.readthedocs.io/en/latest/generated/shap.plots.waterfall.html#shap.plots.waterfall

  2. Bar plot : API reference https://shap.readthedocs.io/en/latest/generated/shap.plots.bar.html

i am using standard Matplotlib line fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))

whole code is as follows:

import matplotlib.pyplot as plt
def shap_diagrams(shapley_values, index=0):
    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5))
    ax0 = shap.plots.waterfall(shapley_values[index], show= False)
    ax0.set_title('SHAP Waterfall Plot')
    plt.gca()
    shap.plots.bar(shapley_values,ax=ax1, show= False)
    ax1.set_title('SHAP Bar Plot')
    
    plt.show()

and after calling a function shap_diagrams(shap_values,2)diagrams are overlayed. Please advise.

i have tried different assignments but where supposed to be waterfall "axes"the is an empty ax object,and waterfall diagram itself jump over to properly rendered barplotenter image description here


Solution

  • That's because you're appending axes to the ones of subplots, leading to a total of four:

    [
        <Axes: >,
        <Axes: title={'center': 'SHAP Bar Plot'}, xlabel='mean(|SHAP value|)'>,
        <Axes: >,
        <Axes: title={'center': 'SHAP Waterfall Plot'}>
    ]
    

    And since the waterfall doesn't have an ax parameter, you need to sca before calling it :

    def shap_diagrams(shapley_values, index=0):
        fig, (ax0, ax1) = plt.subplots(1, 2)
        plt.sca(ax0)
    
        shap.plots.waterfall(shapley_values[index], show=False)
        ax0.set_title("SHAP Waterfall Plot")
    
        shap.plots.bar(shapley_values, ax=ax1, show=False)
        ax1.set_title("SHAP Bar Plot", pad=25)
    
        # to horizontally separate the two axes
        plt.subplots_adjust(wspace=1)
        # because waterfall seems to update the figsize
        fig.set_size_inches(10, 3)
    
        plt.show()
    
    shap_diagrams(shap_values, 2)
    

    enter image description here

    Used input :

    import xgboost
    import shap
    
    X, y = shap.datasets.adult(n_points=2000)
    model = xgboost.XGBClassifier().fit(X, y)
    
    explainer = shap.Explainer(model, X)
    shap_values = explainer(X)