I want to add some information to my tree. Let's say for instance I have a database like this :
library(rpart)
library(rpart.plot)
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
I can run a tree :
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8))
And it's ok for me, but let's imagine I want to know the average exposure for each leaf.
I know i can add some informations to prp, for instance the weight of each leaf with a function :
node.fun1 <- function(x, labs, digits, varlen)
{
paste("Weight \n",x$frame$wt)
}
prp(pfit,type=1,extra=100,fallen.leaves=F,shadow.col="darkgray",box.col=rgb(0.8,0.9,0.8),node.fun = node.fun1)
But it works only if it's calculated in frame, the results of the rpart function.
How can I add custom informations to the plot, like the average exposure, or any other function that calculates custom indicators and add it to the table frame
?
This is really nice, I didn't know this was an option.
All the work seems to be getting the subset of the original data used on each node. This is easy for terminal nodes, but I didn't find a straight-forward way of identifying rows of data that were used in every node, not just the leaves. If someone knows an easier way, I would love to hear it.
library('rpart.plot')
set.seed(1)
mydb<-data.frame(results=rnorm(1000,0,1),expo=runif(1000),var1=sample(LETTERS[1:4],1000,replace=T),
var2=sample(LETTERS[5:6],1000,replace=T),var3=sample(LETTERS[20:25], 1000,replace=T))
mytree<-rpart(results~var1+var2+var3,data=mydb,cp=0)
pfit<- prune(mytree, cp=mytree$cptable[4,"CP"])
rpart.plot(pfit)
Define your new function which takes x
, the result of fitting rpart
(I didn't look into the other arguments, but the vignette should be helpful).
For every line of x$frame
we need to get the data used to calculate summary statistics. Unfortunately, x$where
only tells us the terminal node in which each observation lies. So for each node number, we use subset.rpart
to get the underlying data, and do whatever you want with it
f <- function(x, labs, digits, varlen) {
nodes <- as.integer(rownames(x$frame))
z <- sapply(nodes, function(y) {
data <- subset.rpart(x, y)
c(mean = mean(data$expo), nrow(data), nrow(data) / length(x$where) * 100)
})
sprintf('Mean expo: %.2f\nn=%.0f (%.0f%%)', z[1, ], z[2, ], z[3, ])
}
prp(pfit, type=1, extra=100, fallen.leaves=FALSE,
shadow.col="darkgray", box.col=rgb(0.8,0.9,0.8),
node.fun = f)
The work was done by subset.rpart
which takes a node number and returns the subset of data
used on the node.
subset.rpart <- function(tree, node = 1L) {
## returns subset of tree$call$data used on any node
data <- eval(tree$call$data, parent.frame(1L))
wh <- sapply(as.integer(rownames(tree$frame)), parent)
wh <- unique(unlist(wh[sapply(wh, function(x) node %in% x)]))
data[rownames(tree$frame)[tree$where] %in% wh[wh >= node], ]
}
parent <- function(x) {
## returns vector of parent nodes
if (x[1] != 1)
c(Recall(if (x %% 2 == 0L) x / 2 else (x - 1) / 2), x) else x
}
Tests
## tests
dim(subset.rpart(pfit, 1)) == dim(mydb)
# [1] TRUE TRUE
## terminal nodes
nodes <- as.integer(rownames(pfit$frame[pfit$frame$var %in% '<leaf>', ]))
sum(sapply(nodes, function(x) nrow(subset.rpart(pfit, x)))) == nrow(mydb)
# [1] TRUE