Search code examples
rstatisticsdata-analysis

How to Change Node's Color Based on Node's Level in CART Plot (rpart.plot) [R]


I want to change node's color based on node's level in CART Plot / rpart.plot on R. The required plot is like this.

enter image description here

I have done until this step which I haven't yet : 1. Move the values of the target variable (Setosa, Versicolor, and Virginica) to the left-side of chart. 2. Change the node's color same as required.

enter image description here


Solution

  • By "node's level", I assume you mean the class predicted at the node. If so, do it like this (see the rpart.plot package vignette Figure 1 bottom plot):

    library(rpart.plot)
    png("aswin.png")
    data(iris)
    tree <- rpart(Species~., data=iris)
    # may have to play with value of legend.x and legend.y for your plot
    rpart.plot(tree, type=1, extra=4, legend.x=-.25, legend.y=1.2)
    dev.off()
    

    which gives the following plot

    plot

    If instead by "node's level", you mean the depth of the node in the tree, then your first example figure is confusing because in that figure the depth of the leaf node on the left (1.00 .00 .00) is 2, but its color is the same as the other leaf nodes at depth of 3. Nevertheless, the following code will color a node by its depth in the tree:

    library(rpart.plot)
    data(iris)
    tree <- rpart(Species~., data=iris)
    node.depth <- function(node.number)
    {
        node.depth <- 1
        while(node.number > 1) {
            node.number <- node.number %/% 2
            node.depth <- node.depth + 1
        }
        node.depth
    }
    # node numbers in order they appear in tree$frame
    node.numbers <- as.numeric(row.names(tree$frame))
    # depth of each node in node.numbers
    node.depths <- integer(length(node.numbers))
    for(i in 1:length(node.depths))
        node.depths[i] <- node.depth(node.numbers[i])
    colors <- topo.colors(n=max(node.depths)) # change these colors to taste
    rpart.plot(tree, type=1, extra=4, 
               fallen.leaves=FALSE, nn=TRUE, # optional
               box.col=colors[node.depths])
    

    which gives the following plot

    plot