Search code examples
rrandom-forestmulticlass-classificationshap

Approximated SHAP values for multi-classification problem using randomForest


I would like to use the fastshap package to obtain SHAP values plots for every category of my outcome in a multi-classification problem using a random forest classifier. I could only found chunks of the code around, but no explanation on how to procede from the beginning in obtaining the SHAP values in this case. Here is the code I have so far (my y has 5 classes, here I am trying to obtain SHAP values for class 3):

library(randomForest)
library(fastshap)

set.seed(42) 
sample <- sample.int(n = nrow(ITA), size = floor(.75*nrow(ITA)), replace=F)
train <- ITA [sample,]
test <- ITA [-sample,]

set.seed(42)
rftrain <-randomForest(y ~ ., data=train, ntree=500, importance = TRUE) 

p_function_3<- function(object, newdata) 
  caret::predict.train(object, 
                       newdata = newdata, 
                       type = "prob")[,3]

shap_values_G <- fastshap::explain(rftrain, 
                                   X = train, 
                                   pred_wrapper = p_function_3, 
                                   nsim = 50,
                                   newdata=train[which(y==3),])

Now, I took the code largely from an example I found online, and I tried to adapt it (I am not an expert R user), but it does not work.. Can you please help me in correcting it? Thanks!


Solution

  • Here is a working example (with a different dataset), but I think the logic is the same.

    library(randomForest)
    library(fastshap)
    
    set.seed(42) 
    
    ix <- sample(nrow(iris), 0.75 * nrow(iris))
    train <- iris[ix, ]
    test <- iris[-ix, ]
    
    xvars <- c("Sepal.Width", "Sepal.Length")
    yvar <- "Species"
    fit <- randomForest(reformulate(xvars, yvar), data = train, ntree = 500) 
    
    pred_3 <- function(model, newdata) {
      predict(model, newdata = newdata, type = "prob")[, "virginica"]
    }
    
    shap_values_3 <- fastshap::explain(
      fit, 
      X = train,             # Reference data
      feature_names = xvars,
      pred_wrapper = pred_3, 
      nsim = 50,
      newdata = train[train$Species == "virginica", ] # For these rows, you will calculate explanations
    )
    
    head(shap_values_3)
    
    # Sepal.Width Sepal.Length
    # <dbl>        <dbl>
    # 1      0.101        0.381 
    # 2      0.159       -0.0109
    # 3      0.0736      -0.0285
    # 4      0.0564       0.161 
    # 5      0.0649       0.594 
    # 6      0.232        0.0305