Search code examples
rtreerpartcart-analysis

Adding informations to tree - Rpart


I want to add some information to my tree. Let's say for instance I have a database like this :

library(rpart)
library(rpart.plot)
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
                 var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))

I can run a tree :

mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8))

The result looks like this : enter image description here

And it's ok for me, but let's imagine I want to know the average exposure for each leaf.

I know i can add some informations to prp, for instance the weight of each leaf with a function :

node.fun1 <- function(x, labs, digits, varlen)
{
  paste("Weight \n",x$frame$wt)
}

prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8),node.fun = node.fun1)

enter image description here

But it works only if it's calculated in frame, the results of the rpart function.

My question :

How can I add custom informations to the plot, like the average exposure, or any other function that calculates custom indicators and add it to the table frame ?


Solution

  • This is really nice, I didn't know this was an option.

    All the work seems to be getting the subset of the original data used on each node. This is easy for terminal nodes, but I didn't find a straight-forward way of identifying rows of data that were used in every node, not just the leaves. If someone knows an easier way, I would love to hear it.

    library('rpart.plot')
    set.seed(1)
    mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
                     var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
    mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
    pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
    
    rpart.plot(pfit)
    

    enter image description here

    Define your new function which takes x, the result of fitting rpart (I didn't look into the other arguments, but the vignette should be helpful).

    For every line of x$frame we need to get the data used to calculate summary statistics. Unfortunately, x$where only tells us the terminal node in which each observation lies. So for each node number, we use subset.rpart to get the underlying data, and do whatever you want with it

    f <- function(x, labs, digits, varlen) {
      nodes <- as.integer(rownames(x$frame))
      z <- sapply(nodes, function(y) {
        data <- subset.rpart(x, y)
        c(mean = mean(data$expo), nrow(data), nrow(data) / length(x$where) * 100)
      })
      sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ])
    }
    
    prp(pfit, type=1, extra=100, fallen.leaves=FALSE,
        shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8),
        node.fun = f)
    

    enter image description here

    The work was done by subset.rpart which takes a node number and returns the subset of data used on the node.

    subset.rpart <- function(tree, node = 1L) {
      ## returns subset of tree$call$data used on any node
      data <- eval(tree$call$data, parent.frame(1L))
      wh <- sapply(as.integer(rownames(tree$frame)), parent)
      wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
      data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
    }
    
    parent <- function(x) {
      ## returns vector of parent nodes
      if (x[1] != 1)
        c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
    }
    

    Tests

    ## tests
    dim(subset.rpart(pfit, 1)) == dim(mydb)
    # [1] TRUE TRUE
    
    ## terminal nodes
    nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ]))
    sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb)
    # [1] TRUE