I applied a SHAP model to my random forest multiclass classification model.
Is there a way to get:
Instead of having 8 different plots (picture1) representing the categories of my Y variable, to have a combined plot like in picture 2?
Only one global averaging feature importance plot?
Here is my code:
library(randomForest)
library(kernelshap)
library(shapviz)
RF <- randomForest(
droughts_column ~ .,
data = train_data,
ntree=100,
mtry= 17,
importance = TRUE
)
s <- kernelshap(RF, X = train_data[, -1], bg_X = bg_X, type = "prob") # if i don't set type= "prob", I run through "Predictions must be numeric".
sv <- shapviz(s)
sv_importance(sv, kind = "bar", max_display = 10)
here is my data:
dput(train_data[c(1:20), 1:31])
structure(list(min_column = structure(c(4L, 5L, 4L, 7L, 7L, 5L,
8L, 5L, 7L, 5L, 6L, 4L, 4L, 7L, 7L, 4L, 1L, 5L, 8L, 8L), levels = c("PREC3",
"PPSTV", "PPSV", "PREC6", "SM3", "SV", "TWSA3", "VPD3"), class = "factor"),
aridity_index = c(1.05540223890365, 1.42291223741047, 0.289312012765131,
2.08955832504966, 1.29389651632679, 0.327431291000845, 3.31130880723198,
0.435962678062337, 0.720398000663778, 0.317880591943831,
1.25169086405035, 0.686319667529413, 1.40184251502531, 1.35417639189355,
0.640902793146427, 1.38923272127388, 0.48987851172821, 0.559316314968663,
0.317328910956498, 0.34518109028685), mean_t2m = c(25.8957153320313,
11.9346964518229, 29.8120997111003, 27.2445922851563, 26.2609522501628,
27.6344212849935, 25.1281499226888, 19.8798919677735, 17.526274617513,
31.3262395222982, 20.2956227620443, 23.411154683431, 20.2042404174805,
23.607448832194, 17.7585464477539, 25.0856745402018, 23.8563552856446,
22.4191426595052, 22.3183787027995, 23.6592041015625), sd_t2m = c(0.748837584496715,
1.69355443757821, 1.94353385216372, 0.393817169978326, 0.437997449158287,
1.20966419549391, 0.30109030223568, 1.46221521450813, 1.20757543363649,
0.384620838283781, 0.951142946899067, 0.660012290331394,
1.05941750720462, 0.623493914695329, 1.66694863819749, 0.828680288883452,
0.9433044402992, 1.24476846679161, 1.37801589084397, 0.746930915488626
), mean_vpd = c(10.6565330028534, 3.94749279816945, 29.6649556954702,
8.2109356323878, 7.93534930547078, 17.9001136620839, 4.93418568372726,
14.3897955814997, 6.24165391921997, 31.4230984052022, 8.02857120831808,
8.20808569590251, 7.72910133997599, 8.53423659006754, 8.67752106984456,
5.9931894938151, 11.9273592233658, 16.4530976613363, 15.9718067646027,
14.2948899269104), sd_vpd = c(1.73944337077899, 0.755394420255983,
8.05402389935442, 1.10117851181646, 0.948786864039974, 3.09593279514218,
0.610038282171325, 3.45080704918886, 1.19894214225847, 3.17718368822484,
1.50526167566968, 1.25732973567913, 1.22045801441578, 1.21170835192973,
2.30159243776068, 1.483610883746, 2.40324777669803, 2.78005896417215,
4.37614864431736, 1.7897095039182), mean_prec = c(92.4068752924601,
84.150003751119, 47.7314598560333, 118.624369303385, 259.424163818359,
29.1541662439704, 222.893957773844, 33.1908336480459, 107.695629437764,
7.0000001937151, 133.085209528605, 118.405418713888, 113.004998207092,
109.223124663035, 45.3147913614909, 241.841251373291, 116.099378267924,
47.0570815404256, 32.3000017683953, 18.1831246614456), sd_prec = c(25.2230387606066,
23.3214254568287, 54.2045520974117, 54.0742759998264, 84.2605910448614,
44.5100269137947, 61.7197090730034, 27.3946301979818, 39.1184086642315,
5.5819938843692, 118.01624202547, 66.4338077887344, 48.0153506259032,
62.4349963152462, 35.9822947596737, 177.032327827943, 40.7720457797305,
41.9353214009905, 48.2970552400401, 11.0728449402471), mean_sm = c(0.280448893706004,
0.202963148554166, 0.138948903108637, 0.352660410106182,
0.278219901025295, 0.173655264079571, 0.311909670631091,
0.220840279012918, 0.26629921918114, 0.145378011589249, 0.289096682022015,
0.238765094429255, 0.293268837034702, 0.352709278464317,
0.245920985937119, 0.299343595902125, 0.243301281084617,
0.14155216080447, 0.188165209566553, 0.230367512752612),
sd_sm = c(0.0120999081770558, 0.0106934897216381, 0.0449606306723597,
0.00935333257513591, 0.0170730820627545, 0.0287503551492261,
0.00551166949866713, 0.0263233458019688, 0.0224377766941138,
0.0231998024046729, 0.0261685416345269, 0.0252500867384453,
0.0133047644235911, 0.0121125656223249, 0.020165065621216,
0.0105683334975659, 0.0268471729517476, 0.0209773792143915,
0.0484241248619342, 0.0123767393717239), mean_snr = c(12.0064274470011,
6.76889745394389, 9.40070144335429, 12.9343137741089, 12.6970313390096,
9.85343869527181, 11.4195901552836, 12.5402154922485, 9.32121702035268,
7.51378111044566, 9.33259244759878, 13.1668708324432, 10.5996281305949,
8.06930311520894, 8.90324242909749, 11.8287103176117, 12.5012144247691,
16.2886624336243, 8.91103057066599, 7.04941769440969), sd_snr = c(1.14155839883206,
0.892173051553268, 0.654157370607644, 0.827404584595355,
0.790501626018555, 0.403225009907126, 0.515157191337593,
0.378115825721952, 1.05029652548267, 0.497559161125383, 0.667521104274022,
0.719427400607346, 0.58177946280854, 0.736912015197759, 0.751189102845447,
1.54676429427932, 0.579172898464673, 0.581649429155848, 0.757825017522055,
0.66237735678821), grass.man = c(16.052395277553, 0, 0, 0.66145833317811,
1.43209873843522, 15.3672170920507, 0.00144675930237593,
70.6226110458368, 0.486689811494829, 12.669270892938, 6.96614590287209,
35.5963554183634, 4.59172460436825, 77.6109197404646, 97.3732637829251,
29.6277959677904, 31.8275464773165, 89.6716828346252, 1.30208332091568,
0.0188078705105002), grass.nat = c(69.4744475682573, 49.188658979204,
16.0431131124497, 50.8833920293391, 19.497588634491, 60.9923440615334,
0.12577160423906, 22.4525263309482, 53.1556712786355, 80.8637145360311,
91.4583339691162, 38.05237253507, 13.8398921489716, 8.84162787596396,
2.06963733335332, 17.9012338386641, 37.8492470979698, 5.67679386999876,
86.2560768127439, 82.8998835881549), slope = c(2.16054304838179,
1.97955802321434, 0.459141999483106, 4.04891811609252, 1.37888130545614,
2.56013614609837, 2.42401843070985, 5.54462247371674, 2.90989639997484,
0.852418287396432, 1.59599373936653, 2.52898863553996, 6.2659276509285,
1.76413418650627, 0.292095533311367, 1.12207695484161, 1.18567454397679,
0.569621352553369, 1.84543817639354, 3.72511589765552), elevation = c(359.97704925537,
398.722771911621, 426.586759948729, 113.829170227046, 221.654676361083,
118.164677486419, 193.076432800294, 1278.94374267578, 366.954654235841,
310.976059875488, 115.406995010376, 824.692267456063, 452.296558532715,
24.027112817764, 120.484393692017, 328.746539916992, 1225.43281127929,
187.192862854004, 1294.85421752929, 1099.986953125), irrigated_pct = c(1.25787228625548,
0, 0, 0.00231511534741727, 0.000534317362681038, 10.66306531353,
0, 5.93849432875218, 0.00258188543729491, 0, 0.132554369636341,
0, 0.0692080791329039, 34.1391748461725, 0.201742927556814,
0.0077405653046329, 0.00480382097959415, 0.00376549045024443,
0.0321618323922187, 0.0455722879171372), T_SILT = c(6.45358939186102,
23.0251486053956, 5.855814662399, 33.130658436214, 5.71982167352536,
22.0589468068892, 13.9091220850479, 36.8535665294925, 33.0751028806584,
22.6529492455417, 44.4293552812071, 10.7534293552813, 39,
36.1208276177412, 40.1746684956561, 14.0314357567444, 29.253886602652,
35.6460143270843, 14.7945816186557, 28.5586419753087), T_SAND = c(77.6197226032616,
65.1125971650663, 83.4992379210485, 31.9764136564551, 71.943072702332,
27.0651958542905, 55.8758573388203, 39.8782578875171, 43.9490169181527,
61.8762002743486, 23.377914951989, 19.1499771376313, 41,
36.0430574607529, 30.3878600823045, 60.5447721383936, 47.3446502057613,
26.8314281359549, 71.0716735253772, 41.3868312757202), T_CLAY = c(15.9266880048773,
10.3419067215364, 6.73929279073309, 34.8929279073309, 22.3371056241426,
50.8758573388203, 30.2150205761318, 22.10219478738, 22.9758802011889,
15.4708504801097, 32.1927297668039, 70.0965935070874, 20,
19.7199740893156, 29.4374714220393, 25.423792104862, 23.4014631915867,
37.5225575369608, 14.133744855967, 30.0545267489712), fraction_tree = c(8.75704383786053,
25.7484571668838, 0, 41.6792055434643, 71.8004441261293,
6.29989361806966, 98.8653546733872, 0.887232535285893, 44.7146998239333,
0.101562502017867, 1.24826386972564, 17.082176297903, 79.1825804582072,
2.98678628670882, 0.126157407784675, 49.9505215655604, 9.95341438055052,
0.0438850315468813, 2.35069438815109, 0.080439814832063),
fraction_shrubs = c(0.150779876418225, 4.97656261755356,
0.00617283955216385, 1.91213345651827, 7.23668978611617,
8.4761317869028, 0.00125385803403324, 0.0193571608397205,
0.0925925932824588, 0.0286458334497496, 0.0703125018626452,
9.25520811975072, 0, 0, 0.0070408952111995, 2.20023143870961,
20.3544558684036, 1.56240353929916, 9.99826425313981, 16.9250582066212
), Artificial_Surfaces = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0), CropLand = c(1, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0), Grassland = c(0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1),
Tree_Covered_Area = c(0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1,
1, 0, 0, 1, 1, 0, 0, 0), Shrubs_Covered_Area = c(0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0), Herbaceous_aquatic_regularly_flooded = c(0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
Mangroves = c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0), Sparse_vegetation = c(0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), BareSoil = c(0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)), class = c("data.table",
"data.frame"), row.names = c(NA, -20L), .internal.selfref = <pointer: 0x560853c37a90>)
As discussed per e-mail, summing up average absolute SHAP values on probability scale might not be the best thing.
Still, you can do it like this:
library(ranger)
library(kernelshap)
library(tidyverse)
set.seed(1)
fit <- ranger(Species ~ ., data = iris, num.trees = 100, probability = TRUE)
s <- kernelshap(fit, X = iris[, -5], bg_X = iris)
# imp <- as.data.frame(sv_importance(shapviz(sv), kind = "no"))
imp <- as.data.frame(lapply(s$S, function(x) colMeans(abs(x))))
# setosa versicolor virginica
# Sepal.Length 0.01696860 0.021523220 0.025105943
# Sepal.Width 0.00667117 0.005647546 0.003627264
# Petal.Length 0.20037883 0.196985104 0.198589750
# Petal.Width 0.22079294 0.203404371 0.191520664
imp_reshaped <- imp |>
rownames_to_column(var = "Variable") |>
pivot_longer(-Variable, names_to = "Class")
ggplot(imp_reshaped, aes(y = reorder(Variable, value), fill = Class, x = value)) +
geom_bar(position = "stack", stat = "identity") +
scale_fill_viridis_d(begin = 0.2, end = 0.8, option = "B") +
labs(y = element_blank(), x = "Average absolute SHAP values")
{shapviz >= 0.9.2} can do these plots now out-of-the box, by default a dodged bar plot.
library(ranger)
library(kernelshap)
set.seed(1)
fit <- ranger(Species ~ ., data = iris, num.trees = 100, probability = TRUE)
s <- kernelshap(fit, X = iris[, -5], bg_X = iris) |>
shapviz()
sv_importance(s, bar_type = "stack")