Search code examples
rtreerandom-forestpartyensemble-learning

Variable importance for a single tree in randomForest, randomForestSRC or cforest?


I am trying to find a way in R to calculate variable importance for a single tree of a random forest or a conditional random forest.
A good starting point is the rpart:::importance command which calculates a measure of variable importance for rpart trees:

> library(rpart) 
> rp <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
> rpart:::importance(rp)
   Start      Age   Number 
8.198442 3.101801 1.521863

The randomForest::getTree command is the standard tool to extract the structure of a tree from a randomForest object, but it returns a data.frame:

library(randomForest)
rf <- randomForest(Kyphosis ~ Age + Number + Start, data = kyphosis)
tree1 <- getTree(rf, k=1, labelVar=TRUE)
str(tree1)

'data.frame':   29 obs. of  6 variables:
$ left daughter : num  2 4 6 8 10 12 0 0 14 16 ...
$ right daughter: num  3 5 7 9 11 13 0 0 15 17 ...
$ split var     : Factor w/ 3 levels "Age","Number",..: 2 3 1 2 3 3 NA NA 3 1 ...
$ split point   : num  5.5 8.5 78 3.5 14.5 7.5 0 0 3.5 75 ...
$ status        : num  1 1 1 1 1 1 -1 -1 1 1 ...
erce$ prediction    : chr  NA NA NA NA ...

A solution would be to use a as.rpart command to coerce tree1 to an rpart object. Unfortunately,I am not aware of this command in any R package.

Using the party package I found a similar problem. The varimp command works with cforest objects and not with a single tree.

library(party) 
cf <- cforest(Kyphosis ~ Age + Number + Start, data = kyphosis) 
ct <- party:::prettytree(cf@ensemble[[1]], names(cf@data@get("input"))) 
tree2 <- new("BinaryTree") 
tree2@tree <- ct 
tree2@data <- cf@data 
tree2@responses <- cf@responses 
tree2@weights <- cf@initweights
varimp(tree2)

Error in varimp(tree2) : 
   no slot of name "initweights" for this object of class "BinaryTree"

Any help is appreciated.


Solution

  • Starting from the suggestion of @Alex, I worked on the party:::varimp. This command calculates standard (mean decrease accuracy) and conditional variable importance (VI) for cforest and can be easily modified to calculate VI for each single tree of the forest.

    set.seed(12345)
    y <- cforest(score ~ ., data = readingSkills,
           control = cforest_unbiased(mtry = 2, ntree = 10))
    
    varimp_ctrees <- function (object, mincriterion = 0, conditional = FALSE,
    threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional) {
        response <- object@responses
        if (length(response@variables) == 1 && inherits(response@variables[[1]], 
            "Surv")) 
            return(varimpsurv(object, mincriterion, conditional, 
                threshold, nperm, OOB, pre1.0_0))
        input <- object@data@get("input")
        xnames <- colnames(input)
        inp <- initVariableFrame(input, trafo = NULL)
        y <- object@responses@variables[[1]]
        if (length(response@variables) != 1) 
            stop("cannot compute variable importance measure for multivariate response")
        if (conditional || pre1.0_0) {
            if (!all(complete.cases(inp@variables))) 
                stop("cannot compute variable importance measure with missing values")
        }
        CLASS <- all(response@is_nominal)
        ORDERED <- all(response@is_ordinal)
        if (CLASS) {
            error <- function(x, oob) mean((levels(y)[sapply(x, which.max)] != 
                y)[oob])
        } else {
            if (ORDERED) {
                error <- function(x, oob) mean((sapply(x, which.max) != 
                    y)[oob])
            } else {
                error <- function(x, oob) mean((unlist(x) - y)[oob]^2)
            }
        }
        w <- object@initweights
        if (max(abs(w - 1)) > sqrt(.Machine$double.eps)) 
            warning(sQuote("varimp"), " with non-unity weights might give misleading results")
        perror <- matrix(0, nrow = nperm * length(object@ensemble), 
            ncol = length(xnames))
        colnames(perror) <- xnames
        for (b in 1:length(object@ensemble)) {
            tree <- object@ensemble[[b]]
            if (OOB) {
                oob <- object@weights[[b]] == 0
            } else {
                oob <- rep(TRUE, length(y))
            }
            p <- .Call("R_predict", tree, inp, mincriterion, -1L, 
                PACKAGE = "party")
            eoob <- error(p, oob)
            for (j in unique(party:::varIDs(tree))) {
                for (per in 1:nperm) {
                    if (conditional || pre1.0_0) {
                      tmp <- inp
                      ccl <- create_cond_list(conditional, threshold, 
                        xnames[j], input)
                      if (is.null(ccl)) {
                        perm <- sample(which(oob))
                      }  else {
                        perm <- conditional_perm(ccl, xnames, input, 
                          tree, oob)
                      }
                      tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
                      p <- .Call("R_predict", tree, tmp, mincriterion, 
                        -1L, PACKAGE = "party")
                    } else {
                      p <- .Call("R_predict", tree, inp, mincriterion, 
                        as.integer(j), PACKAGE = "party")
                    }
                    perror[(per + (b - 1) * nperm), j] <- (error(p, 
                      oob) - eoob)
                }
            }
        }
        perror <- as.data.frame(perror)
        return(list(MeanDecreaseAccuracy = colMeans(perror), VIMcTrees=perror))
    }
    

    VIMcTrees is a matrix with a number of rows equal to the number of forest trees and with a column for each explanatory variable. The (i,j) element of this matrix is the VI of the j-th variable in the i-th tree.

    varimp_ctrees(y)$VIMcTrees
    
       nativeSpeaker       age  shoeSize
    1       4.853855  30.06969 52.271824
    2      15.740311  70.55825  5.409772
    3      17.022082 113.86020  0.000000
    4      22.003119  19.62134 50.634286
    5       6.070659  28.58817 47.049866
    6      16.508634 105.50321  2.302387
    7      11.487349  31.80002 46.147677
    8      19.250631  27.78282 43.589832
    9      19.669478  98.73722  0.483079
    10     11.748669  85.95768  5.812538