Search code examples
rmedianrpart

rpart package median or geometric mean instead of mean


Is it possible to change the average estimator in a region by something different from the mean, like median or geometric mean using the rpart library in R? (or another library)

I believe my tree partitioning is highly affected by extreme values and I would like to build trees showing other estimators.

Thanks!


Solution

  • One of the usual tricks for right-skewed responses would be to take logs. In many applications this makes the response distribution more symmetric and then you don't need to switch from the usual mean predictions.

    Another solution for changing the learning of the tree would be to use some more robust scores, e.g., ranks etc. The ctree() function from the partykit offers a nonparametric inference framework for this.

    Finally, the partykit package also allows to compute other predictions than the means from all the terminal nodes. You can easily transform rpart trees to party trees via as.party(). A very simple example would be to learn an rpart tree for the cars data

    library("rpart")
    data("cars", package = "datasets")
    rp <- rpart(dist ~ speed, data = cars)
    

    And then transform it to party:

    library("partykit")
    pr <- as.party(rp)
    

    The tree structure remains unchanged but you get enhanced plotting and predictions. The default plot methods yield:

    rpart and party tree

    Furthermore, the default predictions on both objects are the same.

    nd <- data.frame(speed = c(10, 15, 20))
    predict(rp, nd)
    ##        1        2        3 
    ## 18.20000 39.75000 65.26316 
    predict(pr, nd)
    ##        1        2        3 
    ## 18.20000 39.75000 65.26316 
    

    However, the latter allows you to specify a FUNction that should be used in each of the nodes. This must be of the form function(y, w) where y is the response and w are the case weights. As we haven't used any weights here, we can simply ignore that argument and do:

    predict(pr, nd, FUN = function(y, w) mean(y))
    ##        1        2        3 
    ## 18.20000 39.75000 65.26316 
    predict(pr, nd, FUN = function(y, w) median(y))
    ##  1  2  3 
    ## 18 35 64 
    predict(pr, nd, FUN = function(y, w) quantile(y, 0.9))
    ##    1    2    3 
    ## 28.0 57.0 92.2 
    

    And so on... See the package vignettes for more details.