Search code examples
rparty

Further split an existing tree with partykit


For a given partynode or respective party object, I want to add a new split.

library(partykit)

sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)

node <- partynode(1L, split = sp_o, kids = list(
  partynode(2L, split = sp_h, kids = list(
    partynode(3L),
    partynode(4L))),
  partynode(5L)))

print(node)

##[1] root
##|   [2] V1 in (-Inf,1]
##|   |   [3] V3 <= 75 *
##|   |   [4] V3 > 75 *
##|   [5] V1 in (1,2] *

Now I want to add a new split to terminal node, e.g. node 3. I found the following way to do this:

# New split
sp_w <- partysplit(4L, index = 1:2)

# Access terminal node 3 and add the new split
node$kids[[1]]$kids[[1]] <- partynode(3L, split = sp_w, kids = list(
  partynode(4L),
  partynode(5L)))

node <- as.partynode(node)

print(node)

##[1] root
##|   [2] V1 in (-Inf,1]
##|   |   [3] V3 <= 75
##|   |   |   [4] V4 <= 1 *
##|   |   |   [5] V4 > 1 *
##|   |   [6] V3 > 75 *
##|   [7] V1 in (1,2] *

Is there way to do this that avoids having to call $kids[[i]] repeatedly? Assume that I know the vector of indices to access the respective node (e.g., here: c(1,1)). I want to do this for partynodes of arbitrary complexity, simply using a vector of indices to refer to the position of the desired terminal node (e.g., c(1,1,2)), or just the terminal node id, if possible.


Solution

  • You can do this by converting between list objects and partynode objects:

    • Convert original partynode to list

      li <- as.list(node)
      
    • Node id to replace and maximum node id

      i <- 3L
      n <- length(li)
      
    • Set up new partynode and convert to list

      node2 <- partynode(i, split = sp_w, kids = list(partynode(n + 1L), partynode(n + 2L)))
      li2 <- as.list(node2)
      
    • Replace node id with new root node and append the rest

      li[[i]] <- li2[[1L]]
      li <- c(li, li2[-1L])
      
    • Convert back to partynode

      node <- as.partynode(li)
      print(node)
      ## [1] root
      ## |   [2] V1 in (-Inf,1]
      ## |   |   [3] V3 <= 75
      ## |   |   |   [4] V4 <= 1 *
      ## |   |   |   [5] V4 > 1 *
      ## |   |   [6] V3 > 75 *
      ## |   [7] V1 in (1,2] *