I have created a constparty decision tree (customized split rules) and print out the tree result. The result looks like this:
Fitted party:
[1] root
| [2] value.a < 1651: 0.067 (n = 1419, err = 88.6)
| [3] value.a >= 1651: 0.571 (n = 7, err = 1.7)
I am trying to extract terminal node info (the yval: 0.067 and 0.571; the n on each node: 1419 and 7; and err: 88.6 and 1.7) and put them into a list while having the corresponding node id (node ID 2 and 3) so that I can utilize those info later.
I have been looking into partykit functions for a while and could not find a function that could help me extracting those info I just listed.
Could someone help me please? Thank you!
As usual there are several approaches to obtain the information you are looking for. The technical way for extracting the info
stored in a particular node
is to use nodeapply(object, ids, info_node)
where info_node
returns a list of information stored in the respective node.
However, in the terminal nodes of constparty
objects there is nothing stored. Instead, the whole distribution of the response by fitted node is stored and can be extracted by fitted(object)
. This contains a data frame with the observed (response)
the (fitted)
node and the observation (weights)
(if any). And then you can easily use tapply()
or aggregate()
or something like that to compute node-wise means etc.
Alternatively, you can convert the constparty
object to a simpleparty
object which stores the printed information in the nodes and extract it.
A worked example for both strategies is a simple regression tree for the cars
data:
library("partykit")
data("cars", package = "datasets")
ct <- ctree(dist ~ speed, data = cars)
Then you can easily compute node-wise mean
s by
with(fitted(ct), tapply(`(response)`, `(fitted)`, mean))
## 3 4 5
## 18.20000 39.75000 65.26316
Of course, you can replace mean
by any other summary statistic you are interested in.
The nodeapply()
for the simpleparty
can be obtained by:
nodeapply(as.simpleparty(ct), ids = nodeids(ct, terminal = TRUE), info_node)
## $`3`
## $`3`$prediction
## [1] 18.2
##
## $`3`$n
## n
## 15
##
## $`3`$error
## [1] 1176.4
##
## $`3`$distribution
## NULL
##
## $`3`$p.value
## NULL
##
##
## $`4`
## $`4`$prediction
## [1] 39.75
## ...