Search code examples

How to inject weight into list of dplyr summarise name-value pairs?

I want to write a generalized weighted_summarise() function that will automatically parse and transform user-invoked function calls of the form:

data %>% weighted_summarise(weights, a = sum(b), c = mean(d))

into an actual call that delegates to dplyr::summarise

data %>% dplyr::summarise(a = sum(weights * b), c = mean(weights * d))

Here, a and c are new columns to be created inside the reduced data, and b, d and weights are existing columns in data.

Ideally, I want my to call my function exactly as I would a "native" dplyr::summarise, but with an extra weights argument that gets sprinkled into each aggregation function.

weighted_summarise <- function(data, weights, ...) {
   data %>% dplyr::summarise(
       # how to manipulate the ... and inject the weights in each name-value pair?

Question How can I manipulate the ellipsis so that the weights will be injected into every name-value pair in the appropriate place? I want to somehow capture an AST and walk it and manipulate it systematically.


  • Here is one option to interpolate the 'weights' into expression passed in ... by converting the multiple expressions into a single string and parse it to evaluate

    weighted_summarise <- function(data, weights, ...) {
          weights <- rlang::as_string(rlang::ensym(weights))
         v1 <- purrr::map_chr(rlang::enexprs(...), 
       ~ stringr::str_replace(rlang::as_label(.x), "\\(",
         function(x) stringr::str_c("(", weights, "*")))
       eval(rlang::parse_expr(stringr::str_c("data %>% 
          summarise(", stringr::str_c(names(v1), v1, sep = "=", 
              collapse = ", "), ")")))


    > data %>%
         weighted_summarise(weights, a = sum(b), c = mean(d))
    # A tibble: 1 × 2
          a     c
      <dbl> <dbl>
    1 -2.95  1.13
    # testing with the original summarise code outside the function
    > data %>% 
        dplyr::summarise(a = sum(weights * b), c = mean(weights * d))
    # A tibble: 1 × 2
          a     c
      <dbl> <dbl>
    1 -2.95  1.13


    data <- structure(list(b = c(-0.545880758366027, 0.536585304107612, 0.419623148618683, 
    -0.583627199210279, 0.847460017311944, 0.266021979364892, 0.444585270360416, 
    -0.466495123565759, -0.848370043948898, 0.00231194241576697), 
        d = c(-1.31690812429962, 0.598269112694685, -0.7622143703459, 
        -1.42909030324076, 0.332244449013422, -0.469060687608488, 
        -0.334986793584065, 1.53625215550584, 0.609994533253692, 
        0.51633569843567), weights = 1:10), class = c("tbl_df", "tbl", 
    "data.frame"), row.names = c(NA, -10L))