Search code examples
pythonmatplotlibplotshapbeeswarm

Is there a way to customize the feature order in a SHAP beeswarm plot?


I'm wondering if there's a way to change the order the features in a SHAP beeswarm plot are displayed in. The docs describe "transforms" like using shap_values.abs or shap_values.abs.mean(0) to change how the ordering is calculated, but what I actually want is to put in a list of features or indices and have it order by that.

From the docs:

shap.plots.beeswarm(shap_values, order=shap_values.abs)

This is the resulting plot


Solution

  • This is the default implementation of ordering:

    import xgboost
    import shap
    
    X, y = shap.datasets.adult()
    model = xgboost.XGBClassifier().fit(X, y)
    
    explainer = shap.Explainer(model, X)
    shap_values = explainer(X)
    
    shap.plots.beeswarm(shap_values, max_display=12, order=shap.Explanation.abs.mean(0))
    

    enter image description here

    Then, if you want define ordering of output columns manually:

    order = [
        "Country",
        "Workclass",
        "Education-Num",
        "Marital Status",
        "Occupation",
        "Relationship",
        "Race",
        "Sex",
        "Capital Gain",
        "Capital Loss",
        "Hours per week",
        "Age",
    ]
    col2num = {col: i for i, col in enumerate(X.columns)}
    
    order = list(map(col2num.get, order))
    
    shap.plots.beeswarm(shap_values, max_display=12, show=False, color_bar=False, order=order)
    plt.colorbar()
    plt.show()
    

    enter image description here