Search code examples
rdata-miningdecision-tree

how to get all terminal nodes - weight & response prediction 'ctree' in r


Here's what I can use to list weight for all terminal nodes : but how can I add some code to get response prediction as well as weight by each terminal node ID :

say I want my output to look like this

enter image description here

-- Here below is what I have so far to get the weight

nodes(airct, unique(where(airct))) 

Thank you


Solution

  • The Binary tree is a big S4 object, so sometimes it is difficult to extract the data.

    But the plot method for BinaryTree object, has an optional panel function of the form function(node) plotting the terminal nodes. So when you plot you can get node informations.

    here I use the plot function, to extract the information and even better I used the gridExtra package to convert the terminal node to a table.

    library(party)
    library(gridExtra)
    set.seed(100)
    lls <- data.frame(N = gl(3, 50, labels = c("A", "B", "C")), 
                      a = rnorm(150) + rep(c(1, 0,150)),
                      b = runif(150))
    pond= sample(1:5,150,replace=TRUE)
    tt <- ctree(formula=N~a+b, data=lls,weights = pond)
    output.df <- data.frame()
    innerWeights <- function(node){
    
     dat <- data.frame (x=node$nodeID,
                        y=sum(node$weights),
                        z=paste(round(node$prediction,2),collapse='  '))
      grid.table(dat,
                 cols = c('ID','Weights','Prediction'),
                 h.even.alpha=1, 
                 h.odd.alpha=1,  
                 v.even.alpha=0.5, 
                 v.odd.alpha=1)
       output.df <<- rbind(output.df,dat)  # note the use of <<-
    
    }
    
    plot(tt, type='simple', terminal_panel = innerWeights)
    
    
    data
      ID Weights       Prediction
    1  4      24  0.42  0.5  0.08
    2  5      17 0.06  0.24  0.71
    3  6      24    0.08  0  0.92
    4  7     388 0.37  0.37  0.26
    

    enter image description here