Search code examples
rmodeltreenodesparty

Update Estimates in Party / Partykit model with averages from unseen holdout data


I want to create a decision tree (using evtree which has a VERY LONG run time with large datasets) on a subsample of data.

I then want to take this model and update the terminal node estimates with estimates from hold out data. This is analogous to the concept of "honesty" in the GRF package where bias in model construction from sampling is countered by looking at hold out data. The end result in such a scenario would be a final model that is generally less biased, runs faster (smaller training input) and which has lower variance. Ideally I'd be able to take the new model and inference new data on it.

library(partykit)
mtcars
set.seed(12)
train = sample(nrow(mtcars), nrow(mtcars)/1.5)
sample_tree = ctree(mpg ~. , data = mtcars[train, ])
sample_tree %>% as.simpleparty

# Fitted party:
# [1] root
# |   [2] cyl <= 6: 23.755 (n = 11, err = 224.8)
# |   [3] cyl > 6: 15.380 (n = 10, err = # 42.1)

data.frame(node = predict(sample_tree, newdata = mtcars[-train, ], type = 'node'),
           prediction = mtcars[-train, ]$mpg) %>%
group_by(node) %>%
summarize(mpg = mean(prediction)) %>% as.list

 # $node
 # [1] 2 3
 # $mpg
 # [1] 24.31429 14.40000

In this case I'd update the nodes id as 2,3 in the tree to 24.31429 and 14.40000 respective.

Things I've tried: chat GPT 1000x, a lot of googling, jumping through hoops to figure out how to get terminal node values, etc.


edit2: this seems to work but I don't 100% understand why. Proceed with caution

Adapted from Achim Zeileis's answer

# library(evtree)
set.seed(123)
train = sample(nrow(diamonds), nrow(diamonds)/20)
diamonds_evtree =  evtree("price ~ .", data = (diamonds %>% select(any_of(c("carat", "depth", "table", "price"))))[train, ],
                          maxdepth = 3L, niterations = 101)

diamonds_ctree = ctree(price ~ ., data = (diamonds %>% select(any_of(c("depth", "table", "price", "x", "y", "y"))))[train, ])

refit_constparty(as.constparty(diamonds_evtree), diamonds[-train,]) #fails
refit_constparty(diamonds_ctree, diamonds[-train,]) #works

as.constparty(diamonds_evtree)


refit_simpleparty <- function(object, newdata) {
  stopifnot(inherits(object, "constparty") | inherits(object, "simpleparty"))
  if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
    stop("weights not implemented yet")
  }
  d <- model.frame(terms(object), data = newdata)
  ret <- party(object$node,
               data = d,
               fitted = data.frame(
                 "(fitted)" = fitted_node(object$node, d),
                 "(response)" = d[[1L]],
                 "(weights)" = 1L,
                 check.names = FALSE),
               terms = terms(object))
  as.simpleparty(ret)
}

# works with "arbitrary data"
refit_simpleparty(diamonds_ctree %>% as.simpleparty, newdata = diamonds)

Solution

  • This can be accomplished by setting up a new party() with the new data and fitted values and subsequently coercing to constparty. See vignette("constparty", package = "partykit") for more details and worked examples.

    I have written a short function that encapsulates the necessary steps:

    refit_constparty <- function(object, newdata) {
      stopifnot(inherits(object, "constparty"))
      if(any(abs(object$fitted[["(weights)"]] - 1) > 0)) {
        stop("weights not implemented yet")
      }
      d <- model.frame(terms(object), data = newdata)
      y <- names(d)[1L]
      d <- d[, names(object$data), drop = FALSE]
      ret <- party(object$node,
        data = d,
        fitted = data.frame(
          "(fitted)" = fitted_node(object$node, d),
          "(response)" = d[[y]],
          "(weights)" = 1L,
          check.names = FALSE),
        terms = terms(object))
      as.constparty(ret)
    }
    

    Note that calling the model.frame() is important for potentially re-ordering and transforming the variables (e.g., setting up factors or logs on the fly).

    For your data split I obtain the following:

    refit_constparty(sample_tree, mtcars[-train,])
    ## Model formula:
    ## mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb
    ## 
    ## Fitted party:
    ## [1] root
    ## |   [2] wt <= 2.32: NA (n = 0, err = NA)
    ## |   [3] wt > 2.32: 17.664 (n = 11, err = 135.8)
    ## 
    ## Number of inner nodes:    1
    ## Number of terminal nodes: 2
    

    In Node 2 the fitted value is NA because there are no observations. (Maybe I did something wrong but I could not replicate the fitted values you show above.)