Search code examples
rr-caretimlr-ale

Interpretation of iml's Accumulated Local Effect (ALE) values in a classification task


I'm working with the ALE implementation provided by the iml package in R. This package is accompanied by the usual documentation, a vignette and even a very nice book.

I've studied all three of them trying to figure out the exact interpretation of the resulting ALE values in a classification task. I do have a high level understanding which is: Increasing ALE values when moving from one feature value to a neighbouring one mean that the probability of the model predicting a specific class has increased.

What I cannot figure out is: what is the exact ALE value in different scenarios and is it using probabilities extracted from the model or does iml extract those via type? I puzzled together three different classification models based on the iml documentation and the caret documentation:

  1. A caret based RF which does not explicitly extract class probabilities
  2. A caret based RF which does explicitly extract class probabilities
  3. An rpart based DT

I then plotted them using iml in two ways:

  1. Using type="prob" (which is what I probably should be doing for a classification model as per iml's documentation: "The classic use case is to say type="prob" for classification models")
  2. Not using type="prob" but "raw" (where possible) or "class"

This gives me six plots (code see below) only two of which are qualitatively and quantitatively identical (#3 and #6), another one which is at least qualitatively identical (#4) and another three which are only qualitatively somewhat similar (#1, #2 and #5):

ALE for Petal.Width

I've learned one is not supposed to ask multiple questions in one post, but these are so closely related that I feel it would be confusing creating a separate post for each of them.

  • Is #1 and #2 only slightly different due to the randomness in the two RF models but whether or not I extract probabilities from the caret based RF models really plays no role for iml's ALE?
  • If yes, why are #4 and #5 so different given the fact that they also differ only via randomness of the same two RF models as #1 and #2?
  • Why does #3 look similar to #4 and #6 despite the fact that one is of type "prob" (what is probably correct), one of type "prob" and one of type "class" (both of which are probably not correct)?

What is really put on the y-axis? How does it depend on whether or not I'm extracting probabilities via caret? Why does the "wrong" type not always make a difference?

library(rpart)
library(caret)
library(iml)

data(iris)
TrainData <- iris[,1:4]
TrainClasses <- iris[,5]

## Train three different models
# RF w/o extracting probabilities
cntrl_noprobs<-trainControl(method = "oob", number=5, sampling="up", search="grid", verboseIter=TRUE, savePredictions=TRUE, classProbs=FALSE)
rf_noprobs <- caret::train(TrainData, TrainClasses, method="rf", ntree=100, metric="Kappa", trControl=cntrl_noprobs)
# RF w/ extracting probabilities
cntrl_probs<-trainControl(method = "oob", number=5, sampling="up", search="grid", verboseIter=TRUE, savePredictions=TRUE, classProbs=TRUE)
rf_probs <- caret::train(TrainData, TrainClasses, method="rf", ntree=100, metric="Kappa", trControl=cntrl_probs)
# DT
dt <- rpart(Species ~ ., data = iris)

## Create ALE plots with type="prob"
mod_rf_noprobs_prob <- Predictor$new(rf_noprobs, data = iris, type = "prob")
plot(FeatureEffect$new(mod_rf_noprobs_prob, feature = "Petal.Width")) + ggtitle("caret classProbs=FALSE | iml type=prob")
mod_rf_probs_prob <- Predictor$new(rf_probs, data = iris, type = "prob")
plot(FeatureEffect$new(mod_rf_probs_prob, feature = "Petal.Width")) + ggtitle("caret classProbs=TRUE | iml type=prob")
mod_dt_prob <- Predictor$new(dt, data = iris, type = "prob")
plot(FeatureEffect$new(mod_dt_prob, feature = "Petal.Width")) + ggtitle("rpart | iml type=prob")

## Create ALE plots with type="raw" or "class"
mod_rf_noprobs_raw <- Predictor$new(rf_noprobs, data = iris, type = "raw")
plot(FeatureEffect$new(mod_rf_noprobs_raw, feature = "Petal.Width")) + ggtitle("caret classProbs=FALSE | iml type=raw")
mod_rf_probs_raw <- Predictor$new(rf_probs, data = iris, type = "raw")
plot(FeatureEffect$new(mod_rf_probs_raw, feature = "Petal.Width")) + ggtitle("caret classProbs=TRUE | iml type=raw")
mod_dt_class <- Predictor$new(dt, data = iris, type = "class")
plot(FeatureEffect$new(mod_dt_class, feature = "Petal.Width")) + ggtitle("rpart | iml type=class")

Solution

  • Before any comments, please note that analyses on a model trained on the entire dataset is not valid because it will overfit massively unless the model is bootstrapped. For small datasets like iris, you should carry out full model bootstrapping (a vignette for my ale package explains this in detail: https://cran.r-project.org/web/packages/ale/vignettes/ale-statistics.html). But I will try to interpret the results you show as if they were valid (even though they are not).

    I am not an expert in either the caret or the iml packages, so I cannot answer definitively, but I can try to answer based on my general knowledge of ALE and predictive modeling. I've spent some time examining and comparing these graphs and what I offer here is my educated guess.

    First, you can readily see that there are two different shapes among the three diagrams: 1, 2, and 5 have the same shape; and 3, 4, and 6 have the same shape.

    Examining group 3-4-6, I can clearly see what is going on. In this group, the ALE for setosa is consistently 0. When ALE values are all exactly 0, this always means that the value was not used at all by the model. (Not by ALE, but by the model: ALE only describes a model, not the data directly.) In the context of this analysis, it means that these multinomial classification models set setosa to 0 as the base reference group and that the versicolor and virginica values for these models are given relative to setosa. This is equivalent to using dummy encoding for the classes: with n classes, there will be n - 1 dummy values and one of the classes is considered the reference class.

    In contrast, group 1-2-5, has ALE values for all three classes, which indicates that it gives values for each of them. This is equivalent to creating one-hot encoding for the classes: with n classes, there will be n binary variabls, one for each class. So, in these models, each iris class has its own distinct ALE.

    This means that the two different shapes most likely represent the same relationship across all six models. The only difference is the distortion caused because the 1-2-5 group models each class individually whereas the 3-4-6 group models versicolor and virginica relative to setosa.

    So, why do 3, 4, and 6 use dummy encoding but 1, 2, and 5 use one-hot encoding? I am not sure, but here is a guess. As far as I can tell, it seems like when probabilities are requested for the randomForest objects (caret type='rf'), then probabilities will be calculated for each class. This is the case for group 1-2-5:

    • Plot 1: probabilities come from "iml type=prob".
    • Plot 2: probabilities come from both "caret classProbs=TRUE" and "iml type=prob".
    • Plot 5: probabilities come from "caret classProbs=TRUE".

    I am guessing, then, that the decision tree from rpart always uses dummy encoding, which is why 3 and 6 look identical. For the random forests, only Plot 4 has no probabilities available at all ("caret classProbs=FALSE" and "iml type=raw"), so it defaults to dummy encoding.

    There is one more loose end: the different scales of the plots. It only makes sense to look at scales within groups.

    For the 1-2-5 group, you should understand that when you ask for probabilities (whether the caret specification trainControl(classProbs=TRUE) or the iml specification Predictor$new(type = "prob")), you will get a probability between 0 and 1. This is very similar to the fact that when you ask for classes (with the iml specification Predictor$new(type = "raw")), you will get a 1 response for TRUE and 0 for FALSE. With ALE, the average class that ALE calculates will end up being an average of all the TRUE and FALSE values, that is, an average of the 1 and 0 values. These ALE average values will be very similar to the probabilities, though they are not actually the same thing. So, Plots 1 and 2 are essentially the same thing (allowing for random variation): they represent probabilities ("iml type=prob") but Plot 5 is the average of 1 and 0 raw class predictions ("iml type=raw").

    For the 3-4-6 group, you have completely different models. The plots for the rpart decision tree (3 and 6) are essentially the same thing, allowing for random variation. However, Plot 4 is a random forest; even though its shape is the same, its predictions are different, so it is not surprising that the scale is different.

    I know that there is a lot of uncertainty in my answer, but I hope it makes sense.