Search code examples
rggplot2annotatesurvivalr-forestplot

How to add arrows to forest plot in survminer (ggforest)


I was wondering if there was a way to add two arrows to the x-axis of a forest plot, similar to what is shown in this example: How to add arrows to a forest plot? (this code does not work on ggforest)

Here is the sample code:

library(survival)
library(survminer)

model <- coxph(Surv(time, status) ~ sex + rx + adhere,
               data = colon )
ggforest(model)

colon <- within(colon, {
  sex <- factor(sex, labels = c("female", "male"))
  differ <- factor(differ, labels = c("well", "moderate", "poor"))
  extent <- factor(extent, labels = c("submuc.", "muscle", "serosa", "contig."))
})
bigmodel <-
  coxph(Surv(time, status) ~ sex + rx + adhere + differ + extent + node4,
        data = colon )
ggforest(bigmodel)

And I'd like two arrows underneath the x axis pointing in either direction with annotations of "increased risk" and "decreased risk".

Forest plot image example


Solution

  • The ggforest function from the survminer package uses ggplot() to create the plot, but converts that plot into a grob object thereafter. If you want to position the arrows appropriately with reference to the position of the vertical line at 1, you'll have to do so before the conversion.

    I modified the function to allow for this. Usage examples below:

    ggforest(bigmodel)
    ggforest2(bigmodel) # behaves like normal ggforest
    
    # basic usage: specify left & right labels
    ggforest2(bigmodel, arrow = TRUE, arrow.labels = c("increased risk", "decreased risk"))
    
    # change arrow colour & appearance
    ggforest2(bigmodel, arrow = TRUE, arrow.labels = c("increased risk", "decreased risk"),
              arrow.colour = "blue", 
              arrow.specification = arrow(angle = 20, length = unit(0.1, "inches")))
    
    # different arrow colours
    ggforest2(bigmodel, arrow = TRUE, arrow.labels = c("increased risk", "decreased risk"),
              arrow.colour = c("firebrick", "forestgreen"))
    

    example 1

    example 2

    example 3

    Code for ggforest2():

    ggforest2 <- function (model, data = NULL, main = "Hazard ratio", 
                           cpositions = c(0.02, 0.22, 0.4), 
                           fontsize = 0.7, refLabel = "reference", noDigits = 2,
    
                           # new parameters with some default values; function's behaviour
                           # does not differ from ggforest() unless arrow = TRUE
                           arrow = FALSE, arrow.labels = c("left", "right"), 
                           arrow.specification = arrow(), arrow.colour = "black") {
    
      # this part is unchanged
      conf.high <- conf.low <- estimate <- NULL
      stopifnot(class(model) == "coxph")
      data <- survminer:::.get_data(model, data = data)
      terms <- attr(model$terms, "dataClasses")[-1]
      terms <- terms[intersect(names(terms), 
                               gsub(rownames(anova(model))[-1], pattern = "`", replacement = ""))]
      allTerms <- lapply(seq_along(terms), function(i) {
        var <- names(terms)[i]
        if (terms[i] == "factor") {
          adf <- as.data.frame(table(data[, var]))
          cbind(var = var, adf, pos = 1:nrow(adf))
        }
        else {
          data.frame(var = var, Var1 = "", Freq = nrow(data), pos = 1)
        }
      })
      allTermsDF <- do.call(rbind, allTerms)
      colnames(allTermsDF) <- c("var", "level", "N", "pos")
      inds <- apply(allTermsDF[, 1:2], 1, paste0, collapse = "")
      coef <- as.data.frame(broom::tidy(model))
      gmodel <- broom::glance(model)
      rownames(coef) <- gsub(coef$term, pattern = "`", replacement = "")
      toShow <- cbind(allTermsDF, coef[inds, ])[, c("var", "level", "N", "p.value", "estimate", 
                                                    "conf.low", "conf.high", "pos")]
      toShowExp <- toShow[, 5:7]
      toShowExp[is.na(toShowExp)] <- 0
      toShowExp <- format(exp(toShowExp), digits = noDigits)
      toShowExpClean <- data.frame(toShow, pvalue = signif(toShow[, 4], noDigits + 1), toShowExp)
      toShowExpClean$stars <- paste0(round(toShowExpClean$p.value, noDigits + 1), " ", 
                                     ifelse(toShowExpClean$p.value < 0.05, "*", ""), 
                                     ifelse(toShowExpClean$p.value < 0.01, "*", ""), 
                                     ifelse(toShowExpClean$p.value < 0.001, "*", ""))
      toShowExpClean$ci <- paste0("(", toShowExpClean[, "conf.low.1"], 
                                  " - ", toShowExpClean[, "conf.high.1"], ")")
      toShowExpClean$estimate.1[is.na(toShowExpClean$estimate)] = refLabel
      toShowExpClean$stars[which(toShowExpClean$p.value < 0.001)] = "<0.001 ***"
      toShowExpClean$stars[is.na(toShowExpClean$estimate)] = ""
      toShowExpClean$ci[is.na(toShowExpClean$estimate)] = ""
      toShowExpClean$estimate[is.na(toShowExpClean$estimate)] = 0
      toShowExpClean$var = as.character(toShowExpClean$var)
      toShowExpClean$var[duplicated(toShowExpClean$var)] = ""
      toShowExpClean$N <- paste0("(N=", toShowExpClean$N, ")")
      toShowExpClean <- toShowExpClean[nrow(toShowExpClean):1, ]
      rangeb <- range(toShowExpClean$conf.low, toShowExpClean$conf.high, 
                      na.rm = TRUE)
      breaks <- axisTicks(rangeb/2, log = TRUE, nint = 7)
      rangeplot <- rangeb
      rangeplot[1] <- rangeplot[1] - diff(rangeb)
      rangeplot[2] <- rangeplot[2] + 0.15 * diff(rangeb)
      width <- diff(rangeplot)
      y_variable <- rangeplot[1] + cpositions[1] * width
      y_nlevel <- rangeplot[1] + cpositions[2] * width
      y_cistring <- rangeplot[1] + cpositions[3] * width
      y_stars <- rangeb[2]
      x_annotate <- seq_len(nrow(toShowExpClean))
      annot_size_mm <- fontsize * as.numeric(grid::convertX(unit(theme_get()$text$size, "pt"), "mm"))
    
      # modified code from here onwards
      p <- ggplot(toShowExpClean, aes(seq_along(var), exp(estimate))) + 
        geom_rect(aes(xmin = seq_along(var) - 0.5, 
                      xmax = seq_along(var) + 0.5,
                      ymin = exp(rangeplot[1]), 
                      ymax = exp(rangeplot[2]), 
                      fill = ordered(seq_along(var)%%2 + 1))) + 
        geom_point(pch = 15, size = 4) + 
        geom_errorbar(aes(ymin = exp(conf.low), ymax = exp(conf.high)), 
                      width = 0.15) + 
        geom_hline(yintercept = 1, linetype = 3) + 
    
        annotate(geom = "text", x = x_annotate, y = exp(y_variable), 
                 label = toShowExpClean$var, fontface = "bold", hjust = 0, 
                 size = annot_size_mm) + 
        annotate(geom = "text", x = x_annotate, y = exp(y_nlevel),
                 hjust = 0, label = toShowExpClean$level, 
                 vjust = -0.1, size = annot_size_mm) + 
        annotate(geom = "text", x = x_annotate, y = exp(y_nlevel), 
                 label = toShowExpClean$N, 
                 fontface = "italic", hjust = 0, 
                 vjust = ifelse(toShowExpClean$level == "", 0.5, 1.1), 
                 size = annot_size_mm) + 
        annotate(geom = "text", x = x_annotate, y = exp(y_cistring), 
                 label = toShowExpClean$estimate.1, 
                 size = annot_size_mm, 
                 vjust = ifelse(toShowExpClean$estimate.1 == "reference", 0.5, -0.1)) + 
        annotate(geom = "text", x = x_annotate, y = exp(y_cistring), 
                 label = toShowExpClean$ci, 
                 size = annot_size_mm, vjust = 1.1, fontface = "italic") + 
        annotate(geom = "text", x = x_annotate, y = exp(y_stars), 
                 label = toShowExpClean$stars, size = annot_size_mm, 
                 hjust = -0.2, fontface = "italic") + 
        annotate(geom = "text", x = 0.5, y = exp(y_variable), 
                 label = paste0("# Events: ", 
                                gmodel$nevent, "; Global p-value (Log-Rank): ", 
                                format.pval(gmodel$p.value.log, eps = ".001"), " \nAIC: ", 
                                round(gmodel$AIC, 2), "; Concordance Index: ", 
                                round(gmodel$concordance, 2)), 
                 size = annot_size_mm, hjust = 0, vjust = 1.2, 
                 fontface = "italic") +
    
        scale_y_log10(labels = sprintf("%g", breaks),
                      expand = c(0.02, 0.02), breaks = breaks) + 
        scale_fill_manual(values = c("#FFFFFF33", "#00000033"), guide = "none") + 
        labs(title = main, x = "", y = "") +
    
        coord_flip(ylim = exp(rangeplot), 
                   xlim = c(0.5, nrow(toShowExpClean) + 0.5),
                   clip = "off") + 
    
        theme_light() + 
        theme(panel.grid.minor = element_blank(), 
              panel.grid.major.y = element_blank(), 
              legend.position = "none", 
              panel.border = element_blank(), 
              axis.title.y = element_blank(), 
              axis.text.y = element_blank(), 
              axis.ticks.y = element_blank(), 
              plot.title = element_text(hjust = 0.5))
    
      if(arrow){
        # define arrow positions based on range of coefficient values, &
        # exact y-axis range after flipping coordinates, taking into account
        # any expansion due to annotated labels above
        range.arrow.outer <- exp(min(abs(rangeb)) * c(-1, 1))
        range.arrow.inner <- exp(min(abs(rangeb)) * c(-1, 1) / 2)
        arrow.y <- ggplot_build(p)$layout$panel_params[[1]]$y.range[1] - 
          0.05 * diff(ggplot_build(p)$layout$panel_params[[1]]$y.range)
    
        p <- p + 
          annotate("segment", 
                   x = arrow.y, xend = arrow.y, 
                   y = range.arrow.inner, 
                   yend = range.arrow.outer, 
                   arrow = arrow.specification, color = arrow.colour) +
          annotate("text",
                   x = arrow.y, y = range.arrow.inner,
                   label = arrow.labels, 
                   hjust = 0.5, vjust = -0.5, size = annot_size_mm,
                   color = arrow.colour) +
          theme(plot.margin = margin(5.5, 5.5, 20, 5.5, "pt"))
      }
    
      # this part is unchanged
      gt <- ggplot_gtable(ggplot_build(p))
      gt$layout$clip[gt$layout$name == "panel"] <- "off"
      ggpubr::as_ggplot(gt)
    }