Search code examples
rggplot2party

Remove text.y from box plot when generating plot with the ggparty package


I am tring to remove the text in the Y axis on some of the bar plots. I have updated the options in the 'scales' and the 'shared_axis_labels' from the 'geom_node_plot' fuction to no avail. Below some code to illustrate the issue and a plot of the labels that I want to remove.

library(ggparty)
library(tidyverse)
library(partykit)

ct <- ctree(Species ~ ., data = iris)

panel_prop <- function(count, panel) {
  count / tapply(count, panel, sum)[as.character(panel)]
}

ggparty(ct) +
  geom_edge() +
  geom_edge_label(colour = "gray9", size = 3) +
  geom_node_plot(scales = "fixed",
                 shared_axis_labels = FALSE,
                 gglist = list( aes(
                   y = Species,
                   x = after_stat(panel_prop(count, PANEL))
                   ,fill = Species
                   ,label = after_stat(scales::percent(panel_prop(count, PANEL), accuracy = 1))
                 ),
                 geom_bar(),
                 geom_text(stat = "count", hjust = 0, size = 3),
                 coord_cartesian(clip = "off"),
                 scale_x_continuous(labels = scales::percent_format(accuracy = 1),
                                    expand = expansion(mult = c(.05, .25))),
                 theme(axis.text.x = element_text(size = 5)),
                 xlab(""),
                 ylab(""))) +
  geom_node_label(aes(),
                  line_list = list(aes(label = paste("Node", id)),
                                   aes(label = splitvar),
                                   aes(label = "")
                  ),  line_gpar = list(list(size = 8,
                                            col = "black"#, fontface = "bold"
                  ),  list(size = 6), list(size = 8)
                  ),  ids = "inner")

enter image description here


Solution

  • I have a solution that works, but I think it's more than a bit overkill. I was sure there was a relatively easy way to make that happen...ya, no, I didn't find one...

    I tried to make this solution dynamic, but there are an endless amount of things you can do with ggplot, so I'm sure it won't work in many situations that differ from your specific question.

    There is an assumption that there is only one call for geom_node_plot (not sure if it ggparty let's you have more than one...). Another assumption is that if there are labels, that they are percentages. While I called to look specifically for geom_text, I didn't add in geom_label (the whole, you can have one, but not the other....and ya, another rabbit hole I was lost in for a while... I digress.)

    The function fixer takes in the graph you made, picks it apart and remakes it. The plot has to be in the environment first, so chaining the fixer function will give you an error (e.g., can use ggplot... %>% fixer())

    Unfortunately it's not a ggplot object when it's finished, it's a gtree... take a look and let me know if you have any questions.

    Your originally coded plot is unchanged, except in that it is assigned to gg.

    updated plot

    library(ggparty)  # install.packages("ggparty")
    library(tidyverse)
    library(partykit)
    
    # fixer function for modifying the plot
    fixer <- function(gg){ 
      # assumes there's only 1 node_plot layer; assumes percentage labels (if labels called)
      require(scales)
      require(grid)
      # split the tree from the plot
      constructors <- map(gg$layers, \(k) {  # identify which layers are which
        gimme <- as.list(k$constructor)
        tellMe <- str_detect(as.character(gimme[[1]]), 'plot')
        if(isTRUE(tellMe)) {
          res = 'p'
        } else {
          res = 'np'
        }
        res
      }) %>% unlist()   # this worked to get the layers
      
      pL <- which(constructors == "p"); nL <- which(constructors == "np") 
      pplt <- nplt <- gg       # create copies of the original plot, to create separated plots
      
      nplt$layers[pL] <- NULL  # node plot
      pplt$layers[nL] <- NULL  # graph plot
      
      dta <- pplt$data %>% filter(level == max(level))  # filter for terminal node data
      dta <- dta[, str_detect(names(dta), 'nodedata')]  # select columns for node data
      
      framer <- map(1:nrow(dta), \(k) {     # extract terminal node data percentages
        dtb <- dta[k, 1] %>% unlist()
        pct <- as.data.frame(t(summary(dtb)))/length(dtb)
        pct$grp <- paste0('grp', k)
        pct
      }) %>% list_rbind()
      
      framer <- pivot_longer(framer, -grp) # format data for plotting
      
      # extract pertinent details from ggplot layer
      plotLayer <- gg$layers[[which(constructors == 'p')]]$constructor %>% 
        as.list()
      
      pltList <- plotLayer[['gglist']] %>% as.list()
      
      pltO <- ggplot(framer, aes(value, name, group = grp, fill = name))
      
      # rebuild the geom_node_plot layer
      map(2:length(pltList), \(k) { # 1 is [[1]] list; skip it
        plst <- pltList[[k]] %>% as.list()
        if (length(plst) == 1) {
          if(as.character(plst) == 'geom_bar') pltO <<- pltO + geom_col()
          else pltO <<- pltO + eval(pltList[[k]])      # other calls can be left as written
        } else if (is.null(names(plst))) pltO <<- pltO + eval(pltList[[k]]) # eval as written
        else {
          thats <- as.character(plst[[1]])
          if(thats == 'geom_text') {
            iL <- list(data = framer[framer$value != 0, ], # initial text list
                       mapping = aes(value, name, group = grp,
                                     label = label_percent(accuracy = 1)(value)))
            plst <- plst[!names(plst) %in% c('stat', 'aes')]       # not the geom, stat or aes
            pltO <<- pltO + do.call(geom_text, append(iL, eval(plst[-1]))) # add remainder of geom
          } else if (thats == 'theme') {
            iL <- list(panel.spacing.x = unit(1, 'lines'),  # make space for xaxis text
                       strip.background = element_blank(),  # remove strip bg from facets
                       strip.text = element_blank(),        # remove strip text from facets
                       legend.position = 'none',            # no legend
                       plot.margin = margin(0, 15, 15, 15)) # make space everywhere except on top
            pltO <<- pltO + do.call(theme, append(iL, eval(plst[-1])))  # add the theme
          } else if(thats != 'aes') {                      # everything else except aes()
            pltO <<- pltO + eval(pltList[[k]])
          }
        }
      }) 
      pltO <- pltO + facet_wrap(~grp, nrow = 1)
    
      # using viewports to reassemble
      print(nplt + theme(plot.margin = margin(b = 0, l = 60, r = 0)))
      pushViewport(current.viewport())
      pushViewport(viewport(layout = grid.layout(5, 1)))
      pushViewport(viewport(layout.pos.row = 4:5, layout.pos.col = 1))
      grid.draw(ggplotGrob(pltO))
      popViewport(3)
      grabby <- grid.grab(wrap.grobs = T)
      grabby
    }
    
    #---------- with your code as in your question ---------
    ct <- ctree(Species ~ ., data = iris)
    
    panel_prop <- function(count, panel) {
      count / tapply(count, panel, sum)[as.character(panel)]
    }
    # original plot
    gg <- ggparty(ct) +
      geom_edge() +
      geom_edge_label(colour = "gray9", size = 3) +
      geom_node_plot(
        scales = "fixed", shared_axis_labels = FALSE,
        gglist = list(aes(
          y = Species, x = after_stat(panel_prop(count, PANEL)),
          fill = Species,
          label = after_stat(scales::percent(panel_prop(count, PANEL),
                                             accuracy = 1))),
          geom_bar(), 
          geom_text(stat = "count", hjust = 0, size = 3),
          coord_cartesian(clip = "off"),
          scale_x_continuous(labels = scales::percent_format(accuracy = 1),
                             expand = expansion(mult = c(.05, .25))),
          theme(axis.text.x = element_text(size = 5)),
          xlab(""), ylab(""))) +
      geom_node_label(
        aes(),
        line_list = list(aes(label = paste("Node", id)),
                         aes(label = splitvar),
                         aes(label = "")),  
        line_gpar = list(list(size = 8, col = "black"),  
                         list(size = 6), list(size = 8)),  
        ids = "inner") 
    
    #----------- apply modifications with fixer() -------------
    gg2 <- fixer(gg)  # call the plot (will show in the plot pane without calling gg2)
    
    # to reprint to the plot pane
    grid.newpage()
    grid.draw(gg2)
    
    # alternatively -- this combines the two previous calls
    cowplot::ggdraw(gg2)