Search code examples
pythonscipyhierarchical-clusteringshap

How to view specific rows clustering within shap and scipy?


I have a genetic dataset where I have 600 rows of genes by 11 features I use in machine learning in a regression classification to predict disease-causing genes.

I am trying to view the hierarchical clustering of rows that is performed within the shap package. I am specifically running the shap heatmap - shap.plots.heatmap(shap_values, max_display=11) - and trying to view the data for the rows/genes being clustered in this plot (example plot shown here: https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/heatmap.html).

The code I run to try and get the clustering performed within shap (within the shap.plots.heatmap() function) is:

explainer = shap.Explainer(model, X)
shap_values = explainer(X)
import scipy.cluster
D = scipy.spatial.distance.pdist(shap_values[:,:-1], 'sqeuclidean')
clustOrder = scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.complete(D))

However this gives me an error at the 4th line making the D object:

ValueError: setting an array element with a sequence.

How do I change shap_values[:,:-1] to fix this error?

For reference, my shap_values.shape gives a 600x11 dataset and type(shap_values) gives shap._explanation.Explanation, unfortunately I can't give example data in this case. I've tried understanding similar stackoverflow questions that address this error, but I can't figure out how to get them working for the shap_values data in specific.

How can I address the error to run the hierarchical clustering? Is it that shap_value isn't a 2D scalar unless the classification is binary?

I've seen examples that have run from the shap_values with a notebook (https://github.com/suinleelab/treeexplainer-study/blob/master/notebooks/mortality/NHANES%20I%20Analysis.ipynb) that runs the code I try above yet works.


Solution

  • As pointed by @user12750353 that scipy.spatial.distance.pdist takes ndarray and shap_values is type of shap._explanation.Explanation

    shap_values object as an attribute values which is the result in ndarray format that can be used with scipy.spatial.distance.pdist

    explainer = shap.Explainer(model, X)
    shap_values = explainer(X)
    import scipy.cluster
    D = scipy.spatial.distance.pdist(shap_values.values[:,:-1], 'sqeuclidean')
    clustOrder = scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.complete(D))