Search code examples
rrpart

Get id/name of rpart model nodes


How can I get ID (or name) of terminal node of rpart model for every row? predict.rpart can return only predicted class (number or factor) or class probability or some combination (using type="matrix") for classification tree.

I would like to do something like:

fit <- rpart(Kyphosis ~ Age + Number + Start, data = kyphosis)
plot(fit) # there are 5 terminal nodes
predict(fit, type = "node_id")   # should return IDs of terminal nodes (e.g. 1-5) (does not work)

Solution

  • For that model there were 4 splits, yielding 5 "terminal nodes" or in the terminology used in rpart: <leaf>s. I do not see why there should be 5 predictions for anything. The predictions are for particular cases and the leaves are the result of a variable number of the splits used to make those predictions. The numbers of rows in the original dataset that ended up in the leaves may be what you want, in which case these are ways of getting those numbers:

    # Row-wise predicted class
    fit$where
    
    # counts of cases in leaves of prediction rules
    table(fit$where)
     3  5  7  8  9 
    29 12 14  7 19 
    

    In order to assemble the labels(fit) that apply to a particular leaf, you would need to traverse the rule-tree and accumulate all the labels for all the splits that were applied to produce a particular leaf. You probably want to look at:

    ?print.rpart    
    ?rpart.object
    ?text.rpart
    ?labels.rpart