Search code examples
rggplot2survivalsurvminer

Shade area between multiple stratified Kaplan-Meier curves in R


I'm creating Kalpan-Meier curves stratified by treatment and stage. I'd like to shade the area between the curves belonging to the same stage.

For example, if I import data from the survival package:

data(pbc, package="survival")

And I generate Kaplan-Meier curves using survfit and ggsurvplot:

survminer::ggsurvplot(survival::survfit(survival::Surv(time=time/365, event=status==2) ~ trt + stage, 
                                        data=pbc),
                      palette=c("red", "blue", "green", "orange",
                                "red", "blue", "green", "orange"))

The above Figure looks like this:

enter image description here

I'd like to shade the area between treatment 1 and 2 for stage 1 (the red lines), treatment 1 and 2 for stage 2 (blue lines), etc.

Any tips on how to do this?

Thank you!


Solution

  • One option would be to use a geom_polygon which first requires duplicating the observations where the step occurs and second arranging the data in the correct order, i.e. arrange one treatment group by time and the second in the reverse order.

    Note: The "zero" step is of course to get the data for the survival curves which I retrieve from the plot element of the ggsurvplot object via layer_data. A different approach would be to use the data provided in the data.survplot element. This saves some steps to prepare the data but would require to add the colors using an additional fill scale.

    library(tidyverse)
    
    dat <- p$plot |>
      layer_data() |>
      arrange(colour, group) |>
      mutate(trt = consecutive_id(group), .by = colour) |>
      select(stage = colour, trt, time = x, surv = y)
    
    dat_step <- dat |>
      arrange(stage, trt, time) |>
      mutate(surv_lag = lag(surv, default = 1)) |>
      filter(surv_lag > surv, .by = c(stage, trt)) |>
      mutate(surv = surv_lag, .keep = "unused")
    
    dat <- bind_rows(dat, dat_step) |>
      split(~stage) |>
      lapply(\(x) {
        bind_rows(
          x |> filter(trt == 1) |> arrange(time, desc(surv)),
          x |> filter(trt == 2) |> arrange(desc(time), surv)
        )
      }) |>
      bind_rows()
    
    p$plot +
      geom_polygon(
        data = dat,
        aes(time, surv, fill = I(stage)),
        alpha = .2
      )