Search code examples
rggplot2survival

Creating new Stat with survival::survfit object failing (NA removed from data in compute_group)


I want to create a new stat which calculates interval-censored survival with survival::survfit.formula. But I seem to get a wrong data frame in the compute_group function, and I struggle to find the reason for it.

Creating a data frame with exactly the same code "outside" and using geom_path (which I want to use for the stat), results in a fine result (see expected result). - it seems as if survfit.formula() is creating NAs within compute_group(), but I don't understand why.

setting /adding na.rm = TRUE/FALSE does not change anything.

Also using Inf instead of NA for time2 does not help.

library(ggplot2)
library(survival)
set.seed(42)
testdf <- data.frame(time = sample(30, replace = TRUE), time2 = c(20, 10, 10, 30, rep(NA, 26)))

fit_icens <-
  survival::survfit.formula(
    survival::Surv(time = time, time2 = time2, type = "interval2") ~ 1,
    data = testdf
  )

Expected result

path <- data.frame(time = fit_icens$time, time2= fit_icens$surv)

ggplot(path, aes(x = time, y = time2)) +
  geom_path() +
  coord_cartesian(ylim = c(0, 1))

Failing


StatIcen <- ggplot2::ggproto("StatIcen", Stat,
  required_aes = c("time", "time2"),
  compute_group = function(data, scales) {
    fit_icens <-
      survival::survfit.formula(
        survival::Surv(time = data$time, time2 = data$time2, type = "interval2") ~ 1,
        data = data
      )
    path <- data.frame(x = fit_icens$time, y = fit_icens$surv)
    path
  }
)

stat_icen <- function(mapping = NULL, data = NULL, geom = "path",
                      position = "identity", show.legend = NA,
                      inherit.aes = TRUE, ...) {
  layer(
    stat = StatIcen, data = data, mapping = mapping, geom = geom,
    position = position, show.legend = show.legend, inherit.aes = inherit.aes,
    params = list(...)
  )
}

ggplot(testdf, aes(time = time, time2 = time2)) +
  stat_icen()
#> Warning: Removed 26 rows containing non-finite values (stat_icen).

Created on 2020-05-04 by the reprex package (v0.3.0)


Solution

  • Great question Tjebo, thanks for posting.

    As you have already figured out, the problem is that the NA values are being stripped out of your data before it is passed to compute_group. The Extending ggplot vignette doesn't mention this, but your data is first passed through the compute_layer member function of your ggproto object. Since you haven't defined a compute_layer method, your StatIcen class inherits the method from the class ggplot2::Stat.

    If you look at the source code for this method in ggplot2::Stat$compute_layer, you will see this is where your NA values are stripped out, using the remove_missing function, which removes rows in your data frame with missing values in any of the named columns. Presumably, you still want NA values removed if they appear in your time column, but not if they appear in time2.

    So all I have done here is to copy the source code from Stat$compute_layer and adjust the remove_missing call slightly, then make it a member of StatIcen:

    StatIcen <- ggplot2::ggproto("StatIcen", Stat,
      required_aes = c("time", "time2"),
      compute_group = function(data, scales){
        fit_icens <- survival::survfit.formula(
          survival::Surv(time = data$time,  time2 = data$time2, 
                         type = "interval2") ~ 1, data = data)
        data.frame(x = fit_icens$time, y = fit_icens$surv)
      },
      compute_layer = function (self, data, params, layout) 
      {
        ggplot2:::check_required_aesthetics(self$required_aes, c(names(data), 
            names(params)), snake_class(self))
        data <- remove_missing(data, params$na.rm, "time", 
                               ggplot2:::snake_class(self), finite = TRUE)
        params <- params[intersect(names(params), self$parameters())]
        args <- c(list(data = quote(data), scales = quote(scales)), params)
        ggplot2:::dapply(data, "PANEL", function(data) {
            scales <- layout$get_scales(data$PANEL[1])
            tryCatch(do.call(self$compute_panel, args), 
                     error = function(e) {
                warning("Computation failed in `", 
                        ggplot2:::snake_class(self), 
                        "()`:\n", e$message, call. = FALSE)
                ggplot2:::new_data_frame()
            })
        })
      }
    )
    

    So now we get:

    ggplot(testdf, aes(time = time, time2 = time2)) + stat_icen()
    

    enter image description here