Search code examples
rdecision-treerpart

Getting the observations in a rpart's node (i.e.: CART)


I would like to inspect all the observations that reached some node in an rpart decision tree. For example, in the following code:

fit <- rpart(Kyphosis ~ Age + Start, data = kyphosis)
fit

n= 81 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 81 17 absent (0.79012346 0.20987654)  
   2) Start>=8.5 62  6 absent (0.90322581 0.09677419)  
     4) Start>=14.5 29  0 absent (1.00000000 0.00000000) *
     5) Start< 14.5 33  6 absent (0.81818182 0.18181818)  
      10) Age< 55 12  0 absent (1.00000000 0.00000000) *
      11) Age>=55 21  6 absent (0.71428571 0.28571429)  
        22) Age>=111 14  2 absent (0.85714286 0.14285714) *
        23) Age< 111 7  3 present (0.42857143 0.57142857) *
   3) Start< 8.5 19  8 present (0.42105263 0.57894737) *

I would like to see all the observations in node (5) (i.e.: the 33 observations for which Start>=8.5 & Start< 14.5). Obviously I could manually get to them. But I would like to have some function like (say) "get_node_date". For which I could just run get_node_date(5) - and get the relevant observations.

Any suggestions on how to go about this?


Solution

  • There seems to be no such function which enables an extraction of the observations from a specific node. I would solve it as follows: first determine which rule/s is/are used for the node you are insterested in. You can use path.rpart for it. Then you could apply the rule/s one after the other to extract the observations.

    This approach as a function:

    get_node_date <- function(tree = fit, node = 5){
      rule <- path.rpart(tree, node)
      rule_2 <- sapply(rule[[1]][-1], function(x) strsplit(x, '(?<=[><=])(?=[^><=])|(?<=[^><=])(?=[><=])', perl = TRUE))
      ind <- apply(do.call(cbind, lapply(rule_2, function(x) eval(call(x[2], kyphosis[,x[1]], as.numeric(x[3]))))), 1, all)
      kyphosis[ind,]
      }
    

    For node 5 you get:

    get_node_date()
    
     node number: 5 
       root
       Start>=8.5
       Start< 14.5
       Kyphosis Age Number Start
    2    absent 158      3    14
    10  present  59      6    12
    11  present  82      5    14
    14   absent   1      4    12
    18   absent 175      5    13
    20   absent  27      4     9
    23  present  96      3    12
    26   absent   9      5    13
    28   absent 100      3    14
    32   absent 125      2    11
    33   absent 130      5    13
    35   absent 140      5    11
    37   absent   1      3     9
    39   absent  20      6     9
    40  present  91      5    12
    42   absent  35      3    13
    46  present 139      3    10
    48   absent 131      5    13
    50   absent 177      2    14
    51   absent  68      5    10
    57   absent   2      3    13
    59   absent  51      7     9
    60   absent 102      3    13
    66   absent  17      4    10
    68   absent 159      4    13
    69   absent  18      4    11
    71   absent 158      5    14
    72   absent 127      4    12
    74   absent 206      4    10
    77  present 157      3    13
    78   absent  26      7    13
    79   absent 120      2    13
    81   absent  36      4    13