Search code examples
rtidyversepurrr

iteratively apply a function to a dataset via purrr where the function arguments are stored in a tibble


I want to iteratively apply a function (myFun) to a dataset (df) via {purrr} where the function arguments are stored in a tibble (var_question). This is a simple toy example of the use case.

library(tidyverse)

# tibble with variables and question labels to iterate through
  var_question <- tribble(
    ~var, ~question,
    "var1", "Q1",
    "var2", "Q2"
  )

# data
  set.seed(2)
  
  df <- tibble(
    demo1 = sample(c("a", "b", "c"), 100, replace=TRUE),
    demo2 = sample(c("d", "e", "f"), 100, replace=TRUE),
    var1 = sample(c("Yes", "No"), 100, replace=TRUE),
    var2 = runif(100),
    var3 = runif(100)
  ) %>%
    mutate(demo1 = factor(demo1),
           demo2 = factor(demo2),
           var1 = factor(var1))

# function to plot selected vars by demo
# different approach for factor vs numeric variables
  
  myFun <- function(.data, var, question) {
    
    var_ <- rlang::enquo(var)
    
    cl <- .data %>% 
      dplyr::summarise_all(class) %>% 
      tidyr::gather(variable, class) %>%
      filter(variable == rlang::sym(rlang::quo_name(var_))) %>%
      distinct(class) %>%
      pull(class)
    
    if (cl == "factor") {
      
      demo1 <- .data %>%
        group_by(demo1, !!var_, .drop = FALSE) %>% 
        count() %>% 
        group_by(demo1, .drop = FALSE) %>% 
        mutate(p = n/sum(n)) %>% 
        ungroup() %>% 
        ggplot(aes(x = !!var_,
                   y = p,
                   ymin = 0,
                   ymax = p, 
                   label=paste0(round(p*100, 0), "%"),
                   color=demo1)) +
        geom_linerange(position = position_dodge(width = .5),
                       size=.8) +
        geom_point(size=3, position = position_dodge(width = .5)) +
        geom_text(#nudge_y = 0.05, 
          position = position_dodge(width = .5),
          hjust = -0.5,
          show.legend = FALSE
        ) +
        scale_y_continuous(labels = scales::label_percent(),
                           limits = c(0, 1)) + 
        coord_flip() + 
        theme_bw() + 
        labs(x = NULL, 
             y = NULL,
             subtitle = paste0("N = ", nrow(.data))) + 
        theme(legend.position = "bottom",
              panel.grid.major.y = element_blank())
      
      demo2 <- .data %>%
        group_by(demo2, !!var_, .drop = FALSE) %>% 
        count() %>% 
        group_by(demo2, .drop = FALSE) %>% 
        mutate(p = n/sum(n)) %>% 
        ungroup() %>% 
        ggplot(aes(x = !!var_,
                   y = p,
                   ymin = 0,
                   ymax = p, 
                   label=paste0(round(p*100, 0), "%"),
                   color=demo2)) +
        geom_linerange(position = position_dodge(width = .5),
                       size=.8) +
        geom_point(size=3, position = position_dodge(width = .5)) +
        geom_text(#nudge_y = 0.05, 
          position = position_dodge(width = .5),
          hjust = -0.5,
          show.legend = FALSE
        ) +
        scale_y_continuous(labels = scales::label_percent(),
                           limits = c(0, 1)) + 
        coord_flip() + 
        theme_bw() + 
        labs(x = NULL, 
             y = NULL,
             subtitle = paste0("N = ", nrow(.data))) + 
        theme(legend.position = "bottom",
              panel.grid.major.y = element_blank())
      
      overall <- .data %>%
        group_by(!!var_, .drop = FALSE) %>% 
        count() %>% 
        ungroup() %>% 
        mutate(p = n/sum(n)) %>% 
        ggplot(aes(x = !!var_,
                   y = p,
                   ymin = 0,
                   ymax = p,
                   label=paste0(round(p*100, 0), "%"))) +
        geom_linerange(position = position_dodge(width = 0),
                       size=.8, color="#009936") +
        geom_point(size=3, position = position_dodge(width = 0),
                   color="#009936") + 
        geom_text(#nudge_y = 0.05, 
          position = position_dodge(width = 0),
          hjust = -0.5,
          show.legend = FALSE
        ) +
        scale_y_continuous(labels = scales::label_percent(),
                           limits = c(0, 1)) + 
        coord_flip() + 
        theme_bw() + 
        labs(x = NULL, 
             y = NULL,
             subtitle = paste0("N = ", nrow(.data))) +
        theme(legend.position = "bottom",
              panel.grid.major.y = element_blank())
      
      overall + demo1 + demo2 + plot_layout(ncol = 1) + plot_annotation(
        title = stringr::str_wrap(question, 80),
        theme = theme(plot.title = element_text(size = 16,
                                                face="bold"))
      )
      
    } else {
      
      demo1 <- .data %>%
        group_by(demo1, .drop = FALSE) %>% 
        summarize(mean = mean(!!var_, na.rm=TRUE)) %>%
        rename(group = demo1)
      
      demo2 <- .data %>%
        group_by(demo2, .drop = FALSE) %>% 
        summarize(mean = mean(!!var_, na.rm=TRUE)) %>%
        rename(group = demo2)
      
      overall <- .data %>%
        summarize(mean = mean(!!var_, na.rm=TRUE)) %>%
        mutate(group = "overall")
      
      overall %>%
        bind_rows(demo1) %>%
        bind_rows(demo2) %>%
        mutate(group = factor(group,
                              levels = c("overall",
                                         "a", 
                                         "b",
                                         "c",
                                         "d",
                                         "e",
                                         "f"))) %>%
        ggplot(aes(x = fct_rev(group),
                   y = mean,
                   ymin = 0,
                   ymax = mean,
                   label=round(mean, 1))) +
        geom_linerange(position = position_dodge(width = 0),
                       size=.8, color="#009936") +
        geom_point(size=3, position = position_dodge(width = 0),
                   color="#009936") + 
        geom_text(#nudge_y = 0.05, 
          position = position_dodge(width = 0),
          hjust = -0.5,
          show.legend = FALSE
        ) +
        coord_flip() + 
        theme_bw() + 
        labs(title = stringr::str_wrap(question, 80),
             x = NULL, 
             y = NULL
        ) +
        theme(legend.position = "bottom",
              panel.grid.major.y = element_blank(),
              plot.title = element_text(size = 16,
                                        face="bold"))
    }
  }
  

Without iteration I would run the following:

# example without iteration
  myFun(.data = df, var = var1, question = "Q1")
  myFun(.data = df, var = var2, question = "Q2")

Solution

  • pmap() works as @margusl suggested, but I had to change my approach to referring to variable names. In my original function a user would specify a variable directly, for example

    myFun(.data = df, var = var2, question = "Q2")

    var2 would then work inside the function, for example:

    rlang::enquo(!!var)

    But when I tried passing the variable name via pmap() as a character string in the var column of var_question ("var2"), it didn't work.

    Here I re-write as myFun2() and refer to the variable name inside the function as !!rlang::sym(var).

      myFun2 <- function(.data, var, question) {
        
        cl <- .data %>% 
          dplyr::summarise_all(class) %>% 
          tidyr::gather(variable, class) %>%
          filter(variable == var) %>%
          distinct(class) %>%
          pull(class)
        
        if (cl == "factor") {
          
          demo1 <- .data %>%
            group_by(demo1, !!rlang::sym(var), .drop = FALSE) %>% 
            count() %>% 
            group_by(demo1, .drop = FALSE) %>% 
            mutate(p = n/sum(n)) %>% 
            ungroup() %>% 
            ggplot(aes(x = !!rlang::sym(var),
                       y = p,
                       ymin = 0,
                       ymax = p, 
                       label=paste0(round(p*100, 0), "%"),
                       color=demo1)) +
            geom_linerange(position = position_dodge(width = .5),
                           size=.8) +
            geom_point(size=3, position = position_dodge(width = .5)) +
            geom_text(#nudge_y = 0.05, 
              position = position_dodge(width = .5),
              hjust = -0.5,
              show.legend = FALSE
            ) +
            scale_y_continuous(labels = scales::label_percent(),
                               limits = c(0, 1)) + 
            coord_flip() + 
            theme_bw() + 
            labs(x = NULL, 
                 y = NULL,
                 subtitle = paste0("N = ", nrow(.data))) + 
            theme(legend.position = "bottom",
                  panel.grid.major.y = element_blank())
          
          demo2 <- .data %>%
            group_by(demo2, !!rlang::sym(var), .drop = FALSE) %>% 
            count() %>% 
            group_by(demo2, .drop = FALSE) %>% 
            mutate(p = n/sum(n)) %>% 
            ungroup() %>% 
            ggplot(aes(x = !!rlang::sym(var),
                       y = p,
                       ymin = 0,
                       ymax = p, 
                       label=paste0(round(p*100, 0), "%"),
                       color=demo2)) +
            geom_linerange(position = position_dodge(width = .5),
                           size=.8) +
            geom_point(size=3, position = position_dodge(width = .5)) +
            geom_text(#nudge_y = 0.05, 
              position = position_dodge(width = .5),
              hjust = -0.5,
              show.legend = FALSE
            ) +
            scale_y_continuous(labels = scales::label_percent(),
                               limits = c(0, 1)) + 
            coord_flip() + 
            theme_bw() + 
            labs(x = NULL, 
                 y = NULL,
                 subtitle = paste0("N = ", nrow(.data))) + 
            theme(legend.position = "bottom",
                  panel.grid.major.y = element_blank())
          
          overall <- .data %>%
            group_by(!!rlang::sym(var), .drop = FALSE) %>% 
            count() %>% 
            ungroup() %>% 
            mutate(p = n/sum(n)) %>% 
            ggplot(aes(x = !!rlang::sym(var),
                       y = p,
                       ymin = 0,
                       ymax = p,
                       label=paste0(round(p*100, 0), "%"))) +
            geom_linerange(position = position_dodge(width = 0),
                           size=.8, color="#009936") +
            geom_point(size=3, position = position_dodge(width = 0),
                       color="#009936") + 
            geom_text(#nudge_y = 0.05, 
              position = position_dodge(width = 0),
              hjust = -0.5,
              show.legend = FALSE
            ) +
            scale_y_continuous(labels = scales::label_percent(),
                               limits = c(0, 1)) + 
            coord_flip() + 
            theme_bw() + 
            labs(x = NULL, 
                 y = NULL,
                 subtitle = paste0("N = ", nrow(.data))) +
            theme(legend.position = "bottom",
                  panel.grid.major.y = element_blank())
          
          overall + demo1 + demo2 + plot_layout(ncol = 1) + plot_annotation(
            title = stringr::str_wrap(question, 80),
            theme = theme(plot.title = element_text(size = 16,
                                                    face="bold"))
          )
          
        } else {
          
          demo1 <- .data %>%
            group_by(demo1, .drop = FALSE) %>% 
            summarize(mean = mean(!!rlang::sym(var), na.rm=TRUE)) %>%
            rename(group = demo1)
          
          demo2 <- .data %>%
            group_by(demo2, .drop = FALSE) %>% 
            summarize(mean = mean(!!rlang::sym(var), na.rm=TRUE)) %>%
            rename(group = demo2)
          
          overall <- .data %>%
            summarize(mean = mean(!!rlang::sym(var), na.rm=TRUE)) %>%
            mutate(group = "overall")
          
          overall %>%
            bind_rows(demo1) %>%
            bind_rows(demo2) %>%
            mutate(group = factor(group,
                                  levels = c("overall",
                                             "a", 
                                             "b",
                                             "c",
                                             "d",
                                             "e",
                                             "f"))) %>%
            ggplot(aes(x = fct_rev(group),
                       y = mean,
                       ymin = 0,
                       ymax = mean,
                       label=round(mean, 1))) +
            geom_linerange(position = position_dodge(width = 0),
                           size=.8, color="#009936") +
            geom_point(size=3, position = position_dodge(width = 0),
                       color="#009936") + 
            geom_text(#nudge_y = 0.05, 
              position = position_dodge(width = 0),
              hjust = -0.5,
              show.legend = FALSE
            ) +
            coord_flip() + 
            theme_bw() + 
            labs(title = stringr::str_wrap(question, 80),
                 x = NULL, 
                 y = NULL
            ) +
            theme(legend.position = "bottom",
                  panel.grid.major.y = element_blank(),
                  plot.title = element_text(size = 16,
                                            face="bold"))
        }
      }
    

    Finally, here's the working solution:

    pmap(var_question, myFun2, .data = df)