Search code examples
rrandom-forestparty

cforest prints empty tree


I'm trying to use cforest function(R, party package).

This's what I do to construct forest:

library("party")
set.seed(42)
readingSkills.cf <- cforest(score ~ ., data = readingSkills, 
                         control = cforest_unbiased(mtry = 2, ntree = 50))

Then I want to print the first tree and I do

party:::prettytree(readingSkills.cf@ensemble[[1]],names(readingSkills.cf@data@get("input")))

The result look like this

     1) shoeSize <= 28.29018; criterion = 1, statistic = 89.711
       2) age <= 6; criterion = 1, statistic = 48.324
    3) age <= 5; criterion = 0.997, statistic = 8.917
      4)*  weights = 0 
    3) age > 5
      5)*  weights = 0 
  2) age > 6
    6) age <= 7; criterion = 1, statistic = 13.387
      7) shoeSize <= 26.66743; criterion = 0.214, statistic = 0.073
        8)*  weights = 0 
      7) shoeSize > 26.66743
        9)*  weights = 0 
    6) age > 7
      10)*  weights = 0 
1) shoeSize > 28.29018
  11) age <= 9; criterion = 1, statistic = 36.836
    12) nativeSpeaker == {}; criterion = 0.998, statistic = 9.347
      13)*  weights = 0 
    12) nativeSpeaker == {}
      14)*  weights = 0 
  11) age > 9
    15) nativeSpeaker == {}; criterion = 1, statistic = 19.124
      16) age <= 10; criterion = 1, statistic = 18.441
        17)*  weights = 0 
      16) age > 10
        18)*  weights = 0 
    15) nativeSpeaker == {}
      19)*  weights = 0 

Why is it empty(weights in each node is equal to zero)?


Solution

  • Short answer: the case weights weights in each node are NULL, i.e. not stored. The prettytree function outputs weights = 0, since sum(NULL) equals 0 in R.


    Consider the following ctree example:

    library("party")
    x <- ctree(Species ~ ., data=iris)
    plot(x, type="simple")
    

    ctree plot

    For the resulting object x (class BinaryTree) the case weights are stored in each node:

    R> sum(x@tree$left$weights)
    [1] 50
    R> sum(x@tree$right$weights)
    [1] 100
    R> sum(x@tree$right$left$weights)
    [1] 54
    R> sum(x@tree$right$right$weights)
    [1] 46
    

    Now lets take a closer look at cforest:

    y <- cforest(Species ~ ., data=iris, control=cforest_control(mtry=2))
    tr <- party:::prettytree(y@ensemble[[1]], names(y@data@get("input")))
    plot(new("BinaryTree", tree=tr, data=y@data, responses=y@responses))
    

    cforest tree

    The case weights are not stored in the tree ensemble, which can be seen by the following:

    fixInNamespace("print.TerminalNode", "party")
    

    change the print method to

    function (x, n = 1, ...)·                                                     
    {                                                                             
        print(names(x))                                                           
        print(x$weights)                                                          
        cat(paste(paste(rep(" ", n - 1), collapse = ""), x$nodeID,·               
            ")* ", sep = "", collapse = ""), "weights =", sum(x$weights),·        
            "\n")                                                                 
    } 
    

    Now we can observe that weights is NULL in every node:

    R> tr
    1) Petal.Width <= 0.4; criterion = 10.641, statistic = 10.641
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
      2)*  weights = 0 
    1) Petal.Width > 0.4
      3) Petal.Width <= 1.6; criterion = 8.629, statistic = 8.629
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        4)*  weights = 0 
      3) Petal.Width > 1.6
     [1] "nodeID"     "weights"    "criterion"  "terminal"   "psplit"    
     [6] "ssplits"    "prediction" "left"       "right"      NA          
    NULL
        5)*  weights = 0 
    

    Update this is a hack to display the sums of the case weights:

    update_tree <- function(x) {
      if(!x$terminal) {
        x$left <- update_tree(x$left)
        x$right <- update_tree(x$right)
      } else {
        x$weights <- x[[9]]
        x$weights_ <- x[[9]]
      }
      x
    }
    tr_weights <- update_tree(tr)
    plot(new("BinaryTree", tree=tr_weights, data=y@data, responses=y@responses))
    

    cforest tree with case weights