I am tring to remove the text in the Y axis on some of the bar plots. I have updated the options in the 'scales' and the 'shared_axis_labels' from the 'geom_node_plot' fuction to no avail. Below some code to illustrate the issue and a plot of the labels that I want to remove.
library(ggparty)
library(tidyverse)
library(partykit)
ct <- ctree(Species ~ ., data = iris)
panel_prop <- function(count, panel) {
count / tapply(count, panel, sum)[as.character(panel)]
}
ggparty(ct) +
geom_edge() +
geom_edge_label(colour = "gray9", size = 3) +
geom_node_plot(scales = "fixed",
shared_axis_labels = FALSE,
gglist = list( aes(
y = Species,
x = after_stat(panel_prop(count, PANEL))
,fill = Species
,label = after_stat(scales::percent(panel_prop(count, PANEL), accuracy = 1))
),
geom_bar(),
geom_text(stat = "count", hjust = 0, size = 3),
coord_cartesian(clip = "off"),
scale_x_continuous(labels = scales::percent_format(accuracy = 1),
expand = expansion(mult = c(.05, .25))),
theme(axis.text.x = element_text(size = 5)),
xlab(""),
ylab(""))) +
geom_node_label(aes(),
line_list = list(aes(label = paste("Node", id)),
aes(label = splitvar),
aes(label = "")
), line_gpar = list(list(size = 8,
col = "black"#, fontface = "bold"
), list(size = 6), list(size = 8)
), ids = "inner")
I have a solution that works, but I think it's more than a bit overkill. I was sure there was a relatively easy way to make that happen...ya, no, I didn't find one...
I tried to make this solution dynamic, but there are an endless amount of things you can do with ggplot
, so I'm sure it won't work in many situations that differ from your specific question.
There is an assumption that there is only one call for geom_node_plot
(not sure if it ggparty
let's you have more than one...). Another assumption is that if there are labels, that they are percentages. While I called to look specifically for geom_text
, I didn't add in geom_label
(the whole, you can have one, but not the other....and ya, another rabbit hole I was lost in for a while... I digress.)
The function fixer
takes in the graph you made, picks it apart and remakes it. The plot has to be in the environment first, so chaining the fixer function will give you an error (e.g., can use ggplot... %>% fixer()
)
Unfortunately it's not a ggplot
object when it's finished, it's a gtree
... take a look and let me know if you have any questions.
Your originally coded plot is unchanged, except in that it is assigned to gg
.
library(ggparty) # install.packages("ggparty")
library(tidyverse)
library(partykit)
# fixer function for modifying the plot
fixer <- function(gg){
# assumes there's only 1 node_plot layer; assumes percentage labels (if labels called)
require(scales)
require(grid)
# split the tree from the plot
constructors <- map(gg$layers, \(k) { # identify which layers are which
gimme <- as.list(k$constructor)
tellMe <- str_detect(as.character(gimme[[1]]), 'plot')
if(isTRUE(tellMe)) {
res = 'p'
} else {
res = 'np'
}
res
}) %>% unlist() # this worked to get the layers
pL <- which(constructors == "p"); nL <- which(constructors == "np")
pplt <- nplt <- gg # create copies of the original plot, to create separated plots
nplt$layers[pL] <- NULL # node plot
pplt$layers[nL] <- NULL # graph plot
dta <- pplt$data %>% filter(level == max(level)) # filter for terminal node data
dta <- dta[, str_detect(names(dta), 'nodedata')] # select columns for node data
framer <- map(1:nrow(dta), \(k) { # extract terminal node data percentages
dtb <- dta[k, 1] %>% unlist()
pct <- as.data.frame(t(summary(dtb)))/length(dtb)
pct$grp <- paste0('grp', k)
pct
}) %>% list_rbind()
framer <- pivot_longer(framer, -grp) # format data for plotting
# extract pertinent details from ggplot layer
plotLayer <- gg$layers[[which(constructors == 'p')]]$constructor %>%
as.list()
pltList <- plotLayer[['gglist']] %>% as.list()
pltO <- ggplot(framer, aes(value, name, group = grp, fill = name))
# rebuild the geom_node_plot layer
map(2:length(pltList), \(k) { # 1 is [[1]] list; skip it
plst <- pltList[[k]] %>% as.list()
if (length(plst) == 1) {
if(as.character(plst) == 'geom_bar') pltO <<- pltO + geom_col()
else pltO <<- pltO + eval(pltList[[k]]) # other calls can be left as written
} else if (is.null(names(plst))) pltO <<- pltO + eval(pltList[[k]]) # eval as written
else {
thats <- as.character(plst[[1]])
if(thats == 'geom_text') {
iL <- list(data = framer[framer$value != 0, ], # initial text list
mapping = aes(value, name, group = grp,
label = label_percent(accuracy = 1)(value)))
plst <- plst[!names(plst) %in% c('stat', 'aes')] # not the geom, stat or aes
pltO <<- pltO + do.call(geom_text, append(iL, eval(plst[-1]))) # add remainder of geom
} else if (thats == 'theme') {
iL <- list(panel.spacing.x = unit(1, 'lines'), # make space for xaxis text
strip.background = element_blank(), # remove strip bg from facets
strip.text = element_blank(), # remove strip text from facets
legend.position = 'none', # no legend
plot.margin = margin(0, 15, 15, 15)) # make space everywhere except on top
pltO <<- pltO + do.call(theme, append(iL, eval(plst[-1]))) # add the theme
} else if(thats != 'aes') { # everything else except aes()
pltO <<- pltO + eval(pltList[[k]])
}
}
})
pltO <- pltO + facet_wrap(~grp, nrow = 1)
# using viewports to reassemble
print(nplt + theme(plot.margin = margin(b = 0, l = 60, r = 0)))
pushViewport(current.viewport())
pushViewport(viewport(layout = grid.layout(5, 1)))
pushViewport(viewport(layout.pos.row = 4:5, layout.pos.col = 1))
grid.draw(ggplotGrob(pltO))
popViewport(3)
grabby <- grid.grab(wrap.grobs = T)
grabby
}
#---------- with your code as in your question ---------
ct <- ctree(Species ~ ., data = iris)
panel_prop <- function(count, panel) {
count / tapply(count, panel, sum)[as.character(panel)]
}
# original plot
gg <- ggparty(ct) +
geom_edge() +
geom_edge_label(colour = "gray9", size = 3) +
geom_node_plot(
scales = "fixed", shared_axis_labels = FALSE,
gglist = list(aes(
y = Species, x = after_stat(panel_prop(count, PANEL)),
fill = Species,
label = after_stat(scales::percent(panel_prop(count, PANEL),
accuracy = 1))),
geom_bar(),
geom_text(stat = "count", hjust = 0, size = 3),
coord_cartesian(clip = "off"),
scale_x_continuous(labels = scales::percent_format(accuracy = 1),
expand = expansion(mult = c(.05, .25))),
theme(axis.text.x = element_text(size = 5)),
xlab(""), ylab(""))) +
geom_node_label(
aes(),
line_list = list(aes(label = paste("Node", id)),
aes(label = splitvar),
aes(label = "")),
line_gpar = list(list(size = 8, col = "black"),
list(size = 6), list(size = 8)),
ids = "inner")
#----------- apply modifications with fixer() -------------
gg2 <- fixer(gg) # call the plot (will show in the plot pane without calling gg2)
# to reprint to the plot pane
grid.newpage()
grid.draw(gg2)
# alternatively -- this combines the two previous calls
cowplot::ggdraw(gg2)