Search code examples
rrpart

rpart create a table that indicates if an observation belongs to a node or not


The following figure shows what I want to do:

  1. Grow a tree with rpart for some dataset
  2. Create a table with one row per observation in the original data set and one column per node in the tree, plus an id. The nodes columns should take the value 1 if the observation belongs to that node and zero otherwise.

enter image description here

This is some code that I wrote:

library(rpart)
  library(rattle)
  data <- kyphosis
  fit <- rpart(Age ~ Number + Start, data = kyphosis)
  fancyRpartPlot(fit)

  nodeNumbers <- as.numeric(rownames(fit$frame))

  paths <- path.rpart(fit, nodeNumbers)

  for(i in 1:length(nodeNumbers)){
    nodeNumber <- nodeNumbers[i]
    data[,paste0('gp', nodeNumber)] <- NA
    path <- paths[[i]]
    if(length(path) == 1) # i.e. we're at the root
      data[,paste0('gp', nodeNumber)] <- 1 else
        print('help')
  }
  data

Is there a package out there to do what I need? The only way that I can think of doing it is with some regular expression magic for the paths object. My guess/hope is that there is an easier way of doing this.


Solution

  • Is there a package out there to do what I need?

    AFAIK, no but this work in rpart version 4.1.13

    # function to get the binary matrix OP wants given the leaf index
    get_nodes <- function(object, where){
      rn <- row.names(object$frame)
      edges <- descendants(as.numeric(rn))
      o <- t(edges)[where, , drop = FALSE]
      colnames(o) <- paste0("GP", rn)
      o
    }
    environment(get_nodes) <- environment(rpart)
    
    # use function 
    nodes <- get_nodes(fit, fit$where)
    head(nodes, 9)
    #R       GP1   GP2   GP3   GP6   GP7  GP14  GP15
    #R [1,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
    #R [2,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
    #R [3,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
    #R [4,] TRUE  TRUE FALSE FALSE FALSE FALSE FALSE
    #R [5,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
    #R [6,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [7,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [8,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [9,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    
    # compare with
    head(data, 9)
    #R   Kyphosis Age Number Start
    #R 1   absent  71      3     5
    #R 2   absent 158      3    14
    #R 3  present 128      4     5
    #R 4   absent   2      5     1
    #R 5   absent   1      4    15
    #R 6   absent   1      2    16
    #R 7   absent  61      2    17
    #R 8   absent  37      3    16
    #R 9   absent 113      2    16
    

    Here is the full code which fits the model, creates a function that can get the end leaf for a new data set, and creates and uses the above function

    # do as OP
    library(rpart)
    library(rattle)
    data <- kyphosis
    fit <- rpart(Age ~ Number + Start, data = kyphosis)
    fancyRpartPlot(fit)
    

    enter image description here

    # function that gives us the leaf index
    get_where <- function(object, newdata, na.action = na.pass){
      if (is.null(attr(newdata, "terms"))) {
        Terms <- delete.response(object$terms)
        newdata <- model.frame(Terms, newdata, na.action = na.action, 
                               xlev = attr(object, "xlevels"))
        if (!is.null(cl <- attr(Terms, "dataClasses"))) 
          .checkMFClasses(cl, newdata, TRUE)
      }
      pred.rpart(object, rpart.matrix(newdata))
    }
    environment(get_where) <- environment(rpart)
    
    # check that we get the correct value
    where <- get_where(fit, data)
    stopifnot(isTRUE(all.equal(
      fit$frame$yval[where], unname(predict(fit, newdata = data)))))
    
    # function to get the binary matrix OP wants given the leaf index
    get_nodes <- function(object, where){
      rn <- row.names(object$frame)
      edges <- descendants(as.numeric(rn))
      o <- t(edges)[where, , drop = FALSE]
      colnames(o) <- paste0("GP", rn)
      o
    }
    environment(get_nodes) <- environment(rpart)
    
    # use function 
    nodes <- get_nodes(fit, where)
    head(nodes, 9)
    #R       GP1   GP2   GP3   GP6   GP7  GP14  GP15
    #R [1,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
    #R [2,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
    #R [3,] TRUE FALSE  TRUE FALSE  TRUE  TRUE FALSE
    #R [4,] TRUE  TRUE FALSE FALSE FALSE FALSE FALSE
    #R [5,] TRUE FALSE  TRUE FALSE  TRUE FALSE  TRUE
    #R [6,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [7,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [8,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    #R [9,] TRUE FALSE  TRUE  TRUE FALSE FALSE FALSE
    
    # compare with
    head(data, 9)
    #R   Kyphosis Age Number Start
    #R 1   absent  71      3     5
    #R 2   absent 158      3    14
    #R 3  present 128      4     5
    #R 4   absent   2      5     1
    #R 5   absent   1      4    15
    #R 6   absent   1      2    16
    #R 7   absent  61      2    17
    #R 8   absent  37      3    16
    #R 9   absent 113      2    16
    

    The code is from rpart:::predict.rpart and rpart::path.rpart. You can, of course, merge the get_where and get_nodes function if you want.