Search code examples
pythonmatplotlibxgboostshap

Subplots for shap.plots.bar() plots?


I'd like to add the shap.plots.bar (https://github.com/slundberg/shap) figure to a subplot. Something like this...

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(8,20))
for (X, y) in [(x1, y1), (x2, y2)]:
  model = xgboost.XGBRegressor().fit(X, y)
  explainer = shap.Explainer(model, check_additivity=False)
  shap_values = explainer(X, check_additivity=False)
  shap.plots.bar(shap_values, max_display=6, show=False) # ax=ax ?? 
plt.show()

However, ax is undefined for shap.plots.bar, unlike some other plotting methods such as shap.dependence_plot(..., ax=ax[0, 0], show=False). Is there a way to add many bar plots to a subplot?


Solution

  • Looking at the source code, the function does not create it's own figure. So, you can create a figure and then set the desired axis as the current axis using plt.sca.

    Here is how you'd do it using the bar plot sample code from the documentation.

    import xgboost
    import shap
    import matplotlib.pyplot as plt
    
    X, y = shap.datasets.adult()
    model = xgboost.XGBClassifier().fit(X, y)
    
    explainer = shap.Explainer(model, X)
    shap_values = explainer(X)
    
    
    fig, (ax1, ax2) = plt.subplots(2, 1)
    plt.sca(ax2)
    shap.plots.bar(shap_values)
    fig.tight_layout()
    fig.show()
    

    If you really want to have the ax argument, you'll have to edit the source code to add that option.