Search code examples
pythonmatplotlibplotshap

How to transform nd array to plot feature importances using matplotlib


I am using shap to determine feature importances for an mlp_classifier.

(Note - I am using dummy data and model because my normal data and model are proprietary).

import shap, pandas as pd, numpy as np
from sklearn.neural_network import MLPClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split

iris = datasets.load_iris()
X = iris.data
y = iris.target

data = pd.DataFrame(data= np.c_[iris['data'], iris['target']],
                     columns= iris['feature_names'] + ['target'])
label = data['target']
data.drop('target', axis=1, inplace=True)
X_train, X_test, y_train, y_test = train_test_split(data, label,random_state=np.random.randint(1,10), test_size=0.3)

mlp = MLPClassifier(max_iter=150).fit(X_train, y_train)                                            
mlp.score(X_test, y_test)

explainer = shap.KernelExplainer(mlp.predict_proba, shap.kmeans(X_train, 5))
shap_values = explainer.shap_values(X_test)

# First plot
shap.summary_plot(shap_values[1], feature_names = X_test.columns, plot_type='bar')

# Second, error, empty plot
import matplotlib.pyplot as plt; plt.rcdefaults()
y_pos = np.arange(len(X_test.columns))
plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
plt.xticks(y_pos, X_test.columns)
plt.ylabel('SHAP Importance')
plt.title('MLP Feature Importances')

plt.show()

Using shap.summary_plot by following this guide, I get a plot like the following:

enter image description here

In my actual dataset, I have about 10,000 features. What I am looking to do, is take the top n features from that list, and plot them using matplotlib, by following this guide. However, I get an error and a blank graph:

enter image description here

Full Traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-42-ba03083152ca> in <module>
      1 import matplotlib.pyplot as plt; plt.rcdefaults()
      2 y_pos = np.arange(len(X_test.columns))
----> 3 plt.bar(y_pos, shap_values[1], align='center', alpha=0.5)
      4 plt.xticks(y_pos, X_test.columns)
      5 plt.ylabel('SHAP Importance')

c:\python367-64\lib\site-packages\matplotlib\pyplot.py in bar(x, height, width, bottom, align, data, **kwargs)
   2407     return gca().bar(
   2408         x, height, width=width, bottom=bottom, align=align,
-> 2409         **({"data": data} if data is not None else {}), **kwargs)
   2410 
   2411 

c:\python367-64\lib\site-packages\matplotlib\__init__.py in inner(ax, data, *args, **kwargs)
   1563     def inner(ax, *args, data=None, **kwargs):
   1564         if data is None:
-> 1565             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1566 
   1567         bound = new_sig.bind(ax, *args, **kwargs)

c:\python367-64\lib\site-packages\matplotlib\axes\_axes.py in bar(self, x, height, width, bottom, align, **kwargs)
   2393                 edgecolor=e,
   2394                 linewidth=lw,
-> 2395                 label='_nolegend_',
   2396                 )
   2397             r.update(kwargs)

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, xy, width, height, angle, **kwargs)
    725         """
    726 
--> 727         Patch.__init__(self, **kwargs)
    728 
    729         self._x0 = xy[0]

c:\python367-64\lib\site-packages\matplotlib\patches.py in __init__(self, edgecolor, facecolor, color, linewidth, linestyle, antialiased, hatch, fill, capstyle, joinstyle, **kwargs)
     87         self.set_fill(fill)
     88         self.set_linestyle(linestyle)
---> 89         self.set_linewidth(linewidth)
     90         self.set_antialiased(antialiased)
     91         self.set_hatch(hatch)

c:\python367-64\lib\site-packages\matplotlib\patches.py in set_linewidth(self, w)
    393                 w = mpl.rcParams['axes.linewidth']
    394 
--> 395         self._linewidth = float(w)
    396         # scale the dash pattern by the linewidth
    397         offset, ls = self._us_dashes

TypeError: only size-1 arrays can be converted to Python scalars

I am having trouble mapping the shap_value in shap_values[1] to a specific feature in X_test.columns. How can I do this to properly take the top n and plot in matplotlib?


Solution

  • The shap_values[1] are feature importances. So you can turn it into a series:

    # features importances sorted 
    fi = (pd.Series(shap_values[1].mean(0), index=X_test.columns)
            .abs()
            .sort_values(ascending=False)
         )
    
    # extract 5 top values and plot
    fi.head(5).plot.barh()
    

    Output:

    enter image description here