Search code examples
functiondplyrarguments

Test when the argument of function is a column name


I have the following code and function:

library(tidyverse)

set.seed(12)
data_table <- tibble(
  a = sample(100),
  b = rep(c("group_1", "group_2", "group_3", "group_4"), each = 25),
  c = rep(c("group_A", "group_B", "group_C", "group_D"), times = 25),
)

mean_fct <- function(data, grouped_var){
  data %>% 
    group_by({{grouped_var}}) %>% 
    mutate(mean_a = mean(a)) %>% 
    ungroup() %>% 
    distinct(mean_a, {{grouped_var}})
}


grouped_by_number <- mean_fct(data_table, b)
grouped_by_letter <- mean_fct(data_table, c)

Now, I want to test a specific part of the function, by giving it some arguments, i.e. I want to do:

data <- data_table
grouped_var <-  b

data %>% 
  group_by({{grouped_var}}) %>% 
  mutate(mean_a = mean(a))

This doesn't work and the error is "Error: object 'b' not found"

I understand why it doesn't work, as the argument is the column name of the tibble, but then, how can I test such a specific part of the function, i.e. how can I temporarily store the argument?

Thanks a lot for any help. I really like dplyr answers. The code above is just an example. The question is about testing the function, not about the usefulness of that code.


Solution

  • library(tidyverse)
    
    set.seed(12)
    data_table <- tibble(
      a = sample(100),
      b = rep(c("group_1", "group_2", "group_3", "group_4"), each = 25),
      c = rep(c("group_A", "group_B", "group_C", "group_D"), times = 25),
    )
    
    mean_fct <- function(data, grouped_var){
      data %>% 
        group_by({{grouped_var}}) %>% 
        mutate(mean_a = mean(a)) %>% 
        ungroup() %>% 
        distinct(mean_a, {{grouped_var}})
    }
    
    # Define test data and arguments
    data <- data_table
    grouped_var <- quo(b)  # Using quo() to capture the argument as an expression
    
    # Evaluate the expression within the test
    data %>% 
      group_by(!!grouped_var) %>%  # Unquote the argument using !! (bang-bang)
      mutate(mean_a = mean(a))