Search code examples
rtidyevalnon-standard-evaluation

Use NSE to construct a formula


I am trying to construct a formula using NSE so that I can easily pipe in columns. The following is my desired use case:

df %>% make_formula(col1, col2, col3)

[1] "col1 ~ col2 + col3"

I have made first this function:

varstring <- function(...) {
 as.character(match.call()[-1])
}

This works great with either single objects or multiple objects:

varstring(col)

[1] "col"

varstring(col1, col2, col3)

[1] "col1" "col2" "col3"

I create my function to create the formula next:

formula <- function(df, col, ...) {
 group <- varstring(col)
 vars <- varstring(...)

 paste(group,"~", paste(vars, collapse = " + "), sep = " ")
}

However, the function call formula(df, col, col1, col2, col3) produces [1] "group ~ ..1 + ..2 + ..3".

I understand that the formula is literally evaluating varstring(group) and varstring(...) and not actually substituting in the user supplied objects for evaluation like I would like it too. But I can not figure out how to make this work as intended.


Solution

  • You can join an arbitrary number of arguments with a binary function by using reduce()

    make_formula <- function(lhs, ..., op = "+") {
      lhs <- ensym(lhs)
      args <- ensyms(...)
    
      n <- length(args)
    
      if (n == 0) {
        rhs <- 1
      } else if (n == 1) {
        rhs <- args[[1]]
      } else {
        rhs <- purrr::reduce(args, function(out, new) call(op, out, new))
      }
    
      # Don't forget to forward the caller environment
      new_formula(lhs, rhs, env = caller_env())
    }
    
    make_formula(disp)
    #> disp ~ 1
    
    make_formula(disp, cyl)
    #> disp ~ cyl
    
    make_formula(disp, cyl, am, drat)
    #> disp ~ cyl + am + drat
    
    make_formula(disp, cyl, am, drat, op = "*")
    #> disp ~ cyl * am * drat
    

    One big advantage of working with expressions is that it's robust to little bobby tables (https://xkcd.com/327/):

    # User inputs are always interpreted as symbols (variable name)
    make_formula(disp, `I(file.remove('~'))`)
    #> disp ~ `I(file.remove('~'))`
    
    # With `paste()` + `parse()` user inputs are interpreted as arbitrary code
    reformulate(c("foo", "I(file.remove('~'))"))
    #> ~foo + I(file.remove("~"))