Search code examples
rrpart

referencing rpart terminal nodes in r


I am new to R (and rpart). I have vehicle model data (~400 models). I am using rpart to group these into a smaller number (say 5-10 groups) that have similar vehicle repair costs. I have successfully run rpart and have these groupings.

fit <- rpart(repairs ~ model, data=data, method='anova', control=rpart.control(minsplit=2,minbucket=1,cp=.0005))    

Assume each terminal node has roughly 40-80 models in it. Is there an way for me to create a formula that refers to the values in the terminal node. Assuming data$model contains all of the model names (and is the independent variable I am trying to do something like:

data$modelgroup <- data$model
data$modelgroup[data$modelgroup %in% terminal node 1] <- 'Group1'
data$modelgroup[data$modelgroup %in% terminal node 2] <- 'Group2'
and so on for the rest of the groups

Also, if there were a way to do this without having to have a line of code for each group, that would be good.

I know I can print the tree and manually copy the text from the terminal nodes and accomplish it that way, but that is very inefficient.

Thanks in advance for your assistance!

Per the request below, I added a reproducible example below.

data <- read.csv("rpart_example.csv")
data

data[,1:2]

   Model Amount
1      a      1
2      a      1
3      a      1
4      b      1
5      b      1
6      b      1
7      c      2
8      c      2
9      c      2
10     d      2
11     d      2
12     d      2
13     e      3
14     e      3
15     e      3
16     f      4
17     f      4
18     f      4

fit <- rpart(Amount ~ Model, data=data, method='anova', 
          control=rpart.control(minsplit=2,minbucket=1,cp=.0005))
print(fit)

n= 18 

node), split, n, deviance, yval
* denotes terminal node

1) root 18 20.5 2.166667  
2) Model=a,b,c,d 12  3.0 1.500000  
4) Model=a,b 6  0.0 1.000000 *
  5) Model=c,d 6  0.0 2.000000 *
  3) Model=e,f 6  1.5 3.500000  
6) Model=e 3  0.0 3.000000 *
  7) Model=f 3  0.0 4.000000 *

# create a variable modelgroup that groups models per terminal nodes from rpart     

# I can do this manually as below
# is there a way for me to automate this assignment?

data$modelgroup <- as.character(data$Model)

# per rpart output, a&b are grouped into one terminal node
data$modelgroup[data$modelgroup %in% c('a','b')] <- 'Group1'    

# per rpart output, c&d are grouped into the second terminal node
data$modelgroup[data$modelgroup %in% c('c','d')] <- 'Group2'

# per rpart, e is the third terminal node
data$modelgroup[data$modelgroup == 'e'] <- 'Group3'

# per rpart, f is the fourth terminal node
data$modelgroup[data$modelgroup == 'f'] <- 'Group4'

Solution

  • In rpart objects the information you are looking for is essentially readily stored in the $where element. It gives you the node number to which each observation is assigned:

    table(fit$where, data$modelgroup)
    ##     Group1 Group2 Group3 Group4
    ##   3      6      0      0      0
    ##   4      0      6      0      0
    ##   6      0      0      3      0
    ##   7      0      0      0      3
    

    Of course you could also switch the node IDs (3, 4, 6, 7) to a factor or character variable, e.g., factor(fit$where, levels = c(3, 4, 6, 7), labels = paste0("Group", 1:4)) or something along those line.

    If you want to do this on new data with a simple and unified interface, you can convert your rpart object to a party object in package partykit:

    library("partykit")
    fit2 <- as.party(fit)
    

    The unified methods for print(fit2) and plot(fit2) are available as well as predict(fit2, ...) with different types:

    table(predict(fit2, newdata = data, type = "node"), data$modelgroup)
    ##     Group1 Group2 Group3 Group4
    ##   3      6      0      0      0
    ##   4      0      6      0      0
    ##   6      0      0      3      0
    ##   7      0      0      0      3
    

    This returns the same result as above but could easily be applied to other newdata as well.