Search code examples
rggplot2decision-tree

Number of decimal places on edges of a decision-tree plot with ggparty


I want to plot a decision tree (as estimated by the partykit package) using the powerful ggparty package. Everything is fine except for the number of decimal places of numeric split variables. How can I format the breaks_label in geom_edge_label(), for example, to change > 75.33333 into > 75.3 in the plot below? round() does not work. I might use a workaround via the general options(digits = 3), but I am wondering whether there is a more direct way.

library("ggparty") 
data("WeatherPlay", package = "partykit")

sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75 + 1/3)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
    partynode(2L, split = sp_h, kids = list(
        partynode(3L, info = "yes"),
        partynode(4L, info = "no"))),
    partynode(5L, info = "yes"),
    partynode(6L, split = sp_w, kids = list(
        partynode(7L, info = "yes"),
        partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

ggparty(py) +
    geom_edge() +
    # geom_edge_label() +
    geom_edge_label(mapping = aes(label = paste(breaks_label))) +
    geom_node_splitvar() +
    geom_node_info()

Created on 2020-03-05 by the reprex package (v0.3.0)


Solution

  • Thanks for using ggparty!

    So I think, this is something for which there's really no straight-forward solution with the current version. But I'll make sure to implement it in the future!

    Generally, through using the geoms only on subsets of the nodes, one can usually work around quite a lot of stuff. As you have already noticed, the breaks_label are not stored as numeric but as character with some parsable text for the inequality signs infront of them. Therefore you'll have to use something like substr().

    ggparty(py) +
      geom_edge() +
      geom_edge_label(id = -c(3, 4)) +
        geom_edge_label(mapping = aes(label = paste(substr(breaks_label, start = 1, stop = 15))),
                        id = c(3, 4)) +
      geom_node_splitvar() +
      geom_node_info() 
    

    I also modified one of the internal functions to include the rounding feature, so you can get it from github and use it. But I haven't really tested it, so use at your own risk ;)

    library(devtools)
    source_url("https://raw.githubusercontent.com/martin-borkovec/ggparty/martin/R/add_splitvar_breaks_index_new.R")
    
    rounded_labels <- add_splitvar_breaks_index_new(party_object = py,
                                                    plot_data = ggparty:::get_plot_data(py), 
                                                    round_digits = 2)
    
    ggparty(py) +
      geom_edge() +
      geom_edge_label(mapping = aes(label = unlist(rounded_labels)),
                      data = rounded_labels) +
      geom_node_splitvar() +
      geom_node_info()