Search code examples
pythonmachine-learningshap

Can't display bar plot with SHAP


I'm new to SHAP and trying to use it on top of my RandomForestClassifier. Here's the code snippet after I already ran clf.fit(train_x, train_y):

explainer = shap.Explainer(clf)
shap_values = explainer(train_x.to_numpy()[0:5, :])
shap.summary_plot(shap_values, plot_type='bar')

Here's the resulting plot: enter image description here

Now, there's two problems with this. One is that it is not a bar plot even though I set the plot_type parameter. The other is that I've seemed to lost my feature names somehow (and yes they do exist on the dataframes when calling clf.fit()).

I tried replacing the last line with:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], plot_type='bar')

And that changed nothing. I also tried to replace it with the following to see if I could at least recover my feature names:

shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')

But that threw an error:

Traceback (most recent call last):
  File "sklearn_model_runs.py", line 41, in <module>
    main()
  File "sklearn_model_runs.py", line 38, in main
    shap.summary_plot(shap_values, train_x.to_numpy()[0:5, :], feature_names=list(train_x.columns.values), plot_type='bar')
  File "C:\Users\kapoo\anaconda3\envs\sci\lib\site-packages\shap\plots\_beeswarm.py", line 554, in summary_legacy
    feature_names=feature_names[sort_inds],
TypeError: only integer scalar arrays can be converted to a scalar index

I'm kind of at a loss at this point. I just tried it with 5 rows of the training set but want to use the whole thing once I get past this stumbling block. If it helps, the classifier had 5 labels and my SHAP version is 0.40.0.


Solution

  • Alright, here was the problem. Replace this:

    shap_values = explainer(train_x.to_numpy()[0:5, :])

    With this:

    shap_values = explainer.shap_values(train_x) # Use whole thing as dataframe

    Then you can use this during plotting:

    feature_names=list(train_x.columns.values)

    The documentation here should really be updated...