I want to write a generalized weighted_summarise()
function that will automatically parse and transform user-invoked function calls of the form:
data %>% weighted_summarise(weights, a = sum(b), c = mean(d))
into an actual call that delegates to dplyr::summarise
data %>% dplyr::summarise(a = sum(weights * b), c = mean(weights * d))
Here, a
and c
are new columns to be created inside the reduced data, and b
, d
and weights
are existing columns in data
.
Ideally, I want my to call my function exactly as I would a "native" dplyr::summarise
, but with an extra weights
argument that gets sprinkled into each aggregation function.
weighted_summarise <- function(data, weights, ...) {
data %>% dplyr::summarise(
# how to manipulate the ... and inject the weights in each name-value pair?
)
}
Question How can I manipulate the ellipsis so that the weights
will be injected into every name-value pair in the appropriate place? I want to somehow capture an AST and walk it and manipulate it systematically.
Here is one option to interpolate the 'weights' into expression passed in ...
by converting the multiple expressions into a single string and parse it to evaluate
weighted_summarise <- function(data, weights, ...) {
weights <- rlang::as_string(rlang::ensym(weights))
v1 <- purrr::map_chr(rlang::enexprs(...),
~ stringr::str_replace(rlang::as_label(.x), "\\(",
function(x) stringr::str_c("(", weights, "*")))
eval(rlang::parse_expr(stringr::str_c("data %>%
summarise(", stringr::str_c(names(v1), v1, sep = "=",
collapse = ", "), ")")))
}
-testing
> data %>%
weighted_summarise(weights, a = sum(b), c = mean(d))
# A tibble: 1 × 2
a c
<dbl> <dbl>
1 -2.95 1.13
# testing with the original summarise code outside the function
> data %>%
dplyr::summarise(a = sum(weights * b), c = mean(weights * d))
# A tibble: 1 × 2
a c
<dbl> <dbl>
1 -2.95 1.13
data <- structure(list(b = c(-0.545880758366027, 0.536585304107612, 0.419623148618683,
-0.583627199210279, 0.847460017311944, 0.266021979364892, 0.444585270360416,
-0.466495123565759, -0.848370043948898, 0.00231194241576697),
d = c(-1.31690812429962, 0.598269112694685, -0.7622143703459,
-1.42909030324076, 0.332244449013422, -0.469060687608488,
-0.334986793584065, 1.53625215550584, 0.609994533253692,
0.51633569843567), weights = 1:10), class = c("tbl_df", "tbl",
"data.frame"), row.names = c(NA, -10L))