Search code examples
rggplot2shap

I want to colorize the points of a Shapviz dependence output acording to the variable


I have been trying to change the colors of the points on each variable, but even though i've tried using scale_x_discrete, by and changing the colors, or using scale_color_manual(), nothing works for me. When I try to use a list of colors on the sv_dependence function (color =""), i get the error:

"Error in ggplot2::geom_jitter(): ! Problem while setting up geom aesthetics. ℹ Error occurred in the 1st layer. Caused by error in check_aesthetics(): ! Aesthetics must be either length 1 or the same as the data (1000). ✖ Fix the following mappings: colour."

I spent several hours trying to change it, but i am not able. Someone to help me please?

Shap_dependence output

I am using the following libraries and my code currently is as follow, but it gets me just one color for the whole graphic...

library(kernelshap)
library(shapviz)
library(ggplot2)

#COS
names_cos=c('ART','AGRI','PAST','HDW','EUC','RES','SHL')
#dep_cos<-

cos_dep<-sv_dependence(sv, v = "cos",color_var=NULL,color="gray30")

cos_dep+
scale_x_discrete(labels = names_cos, limits = factor(1:7))+
  labs(y = "SHAP value",
    x = "LCU")+
  theme_bw()+
  labs(tag = "(b)") +
  theme(plot.tag = element_text(),
  plot.tag.position = c(0.95, 0.1))

dep_cos$layers <- c(
  list(geom_hline(yintercept = 0, color = "black")),
  dep_cos$layers
)
dep_cos

I was expecting to get one different color per X Class (one for ART, one for AGRI,...,etc)


Solution

  • Coloring in SHAP dependence plots is to visualize interaction effects. The only variable where it does not make sense to study interactions with is the variable on the x axis. This is the reason why sv_dependence() suppresses v as the color_var. Maybe, in your situation, it feels bad. But it allows quite neatly the use of multiple color variables or multiple x variables.

    Since sv_dependence() in its basic form is simply a scatter plot, you can easily make your own special version.

    library(xgboost)
    library(shapviz)
    library(ggplot2)
    
    X_pred <- data.matrix(iris[, -1])
    dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1])
    fit <- xgboost::xgb.train(data = dtrain, nrounds = 10)
    
    shap_obj <- shapviz(fit, X_pred = X_pred, X = iris)
    
    
    df <- data.frame(Species = shap_obj$X$Species, SHAP_value = shap_obj$S[, "Species"])
    ggplot(df, aes(Species, SHAP_value, color = Species)) +
      geom_jitter()
    

    enter image description here