Search code examples
rfunctiondplyrcallstackrlang

R How to check that a custom function is called within a specific function from a certain package


I want to create a function myfun that can only be used inside another function, in my case dplyrs mutate or summarise. I further do not want to rely on dplyrs internals (for example mask$...).

I came up with a quick and dirty workaround: A function search_calling_fn that checks all function names in the call stack and looks for a specific pattern in the calling functions.

search_calling_fn <- function(pattern) {
  
  call_st <- lapply(sys.calls(), `[[`, 1)
  
  res <- any(unlist(lapply(call_st, function(x) grepl(pattern, x, perl = TRUE))))
  
  if (!res) {
    stop("`myfun()` must only be used inside dplyr::mutate or dplyr::summarise")
  } else {
    return()
  }
}

This works as expected as the two examples below show (dplyr = 1.0.0)

library(dplyr)

myfun <- function() {
  search_calling_fn("^mutate|^summarise")
  NULL
}

# throws as expected no error
mtcars %>% 
  mutate(myfun())


myfun2 <- function() {
  search_calling_fn("^select")
  NULL
}

# throws as expected an error
mtcars %>% 
  mutate(myfun2())

This approach has one loophole: myfun could be called from a function with a similar name which is not a dplyr function. I wonder how I can check from which namespace a function on my call stack is coming. rlang has a function call_ns but this will only work, if the function is explicitly called with package::.... Further, when using mutate there is mutate_cols an internal function and mutate.data.frame an S3 method on the call stack - both seem to make getting the namespace even more complicated.

On a second thought I wonder whether there is a better, more official approach to achieve the same outcome: only allow myfun to be called within dplyrs mutate or summarise.

The approach should work no matter how the function is called:

  1. mutate
  2. dplyr::mutate

Additional note

After discussing @r2evans answer, I realize that a solution should pass the following test:

library(dplyr)

myfun <- function() {
  search_calling_fn("^mutate|^summarise")
  NULL
}

# an example for a function masking dplyr's mutate
mutate <- function(df, x) {
  NULL
}

# should throw an error but doesn't
mtcars %>% 
  mutate(myfun())

So the checking function should not only look at the callstack, but also try to see which package a function on the callstack is coming from. Interestingly, RStudios debugger shows the namespace for each function on the callstack, even for internal functions. I wonder how it does this, since environment(fun)) is only working on exported functions.


Solution

  • Update: I'm going to "borrow" from rlang::trace_back, since it seems to have an elegant (and working) method for determining a full package::function for most of the call tree (some like %>% are not always fully-resolved).

    (If you're trying to reduce package bloat ... while it's unlikely you'd have dplyr and not purrr available, if you would prefer to do as much in base as possible, I've provided #==# equivalent base-R calls. It's certainly feasible to try to remove some of the rlang calls, but again ... if you're assuming dplyr, then you definitely have rlang around, in which case this should not be a problem.)

    EDIT (2022-02-25): the function below uses ::: functions in rlang, which (not surprisingly) no longer exist as of today, as a clear example of why using :::-funcs is inherently risky. This function no longer works. I'm not going to attempt to fix now (no immediate need/motivation). Cheers.

    search_calling_pkg <- function(pkgs, funcs) {
      # <borrowed from="rlang::trace_back">
      frames <- sys.frames()
      idx <- rlang:::trace_find_bottom(NULL, frames)
      frames <- frames[idx]
      parents <- sys.parents()[idx]
      calls <- as.list(sys.calls()[idx])
      calls <- purrr::map(calls, rlang:::call_fix_car)
      #==# calls <- lapply(calls, rlang:::call_fix_car)
      calls <- rlang:::add_pipe_pointer(calls, frames)
      calls <- purrr::map2(calls, seq_along(calls), rlang:::maybe_add_namespace)
      #==# calls <- Map(rlang:::maybe_add_namespace, calls, seq_along(calls))
      # </borrowed>
      calls_chr <- vapply(calls, function(cl) as.character(cl)[1], character(1))
      ptn <- paste0("^(", paste(pkgs, collapse = "|"), ")::")
      pkgres <- any(grepl(ptn, calls_chr))
      funcres <- !missing(funcs) && any(mapply(grepl, paste0("^", funcs, "$"), list(calls_chr)))
      if (!pkgres || !funcres) {
        stop("not correct")
      } else return()
    }
    

    The intention is that you can look for particular packages and/or particular functions. The funcs= argument can be fixed strings (taken as verbatim), but since I thought you might want to match against any of the mutate* functions (etc), you can also make it a regex. All functions need to be full package::funcname, not just funcname (though you could certainly make it a regex :-).

    myfun1 <- function() {
      search_calling_pkg(pkgs = "dplyr")
      NULL
    }
    myfun2 <- function() {
      search_calling_pkg(funcs = c("dplyr::mutate.*", "dplyr::summarize.*"))
      NULL
    }
    mutate <- function(df, x) { force(x); NULL; }
    
    mtcars[1:2,] %>% mutate(myfun1())
    # Error: not correct
    
    mtcars[1:2,] %>% dplyr::mutate(myfun1())
    #   mpg cyl disp  hp drat    wt  qsec vs am gear carb
    # 1  21   6  160 110  3.9 2.620 16.46  0  1    4    4
    # 2  21   6  160 110  3.9 2.875 17.02  0  1    4    4
    
    mtcars[1:2,] %>% mutate(myfun2())
    # Error: not correct
    
    mtcars[1:2,] %>% dplyr::mutate(myfun2())
    #   mpg cyl disp  hp drat    wt  qsec vs am gear carb
    # 1  21   6  160 110  3.9 2.620 16.46  0  1    4    4
    # 2  21   6  160 110  3.9 2.875 17.02  0  1    4    4
    

    And performance seems to be significantly better than the first answer, though still not a "zero hit" on performance:

    microbenchmark::microbenchmark(
      a = mtcars %>%
      dplyr::mutate(),
      b = mtcars %>%
      dplyr::mutate(myfun1())
    )
    # Unit: milliseconds
    #  expr    min     lq     mean  median      uq     max neval
    #     a 1.5965 1.7444 1.883837 1.82955 1.91655  3.0574   100
    #     b 3.4748 3.7335 4.187005 3.92580 4.18140 19.4343   100
    

    (This portion kept for prosperity, though note that getAnywhere will find dplyr::mutate even if the above non-dplyr mutate is defined and called.)

    Seeded by Rui's links, I suggest that looking for specific functions might very well miss new functions and/or otherwise-valid but differently-named functions. (I don't have a clear example.) From here, consider looking for particular packages instead of particular functions.

    search_calling_pkg <- function(pkgs) {
      call_st <- lapply(sys.calls(), `[[`, 1)
      res <- any(vapply(call_st, function(ca) any(pkgs %in% tryCatch(getAnywhere(as.character(ca)[1])$where, error=function(e) "")), logical(1)))
      if (!res) {
        stop("not called from packages")
      } else return()
    }
    myfun <- function() {
      search_calling_pkg("package:dplyr")
      NULL
    }
    

    Realize that this is not an inexpensive operation. I believe the majority of time spent in this is dealing with the calling tree, perhaps not something we can easily remedy.

    microbenchmark::microbenchmark(
      a = mtcars %>% mutate(),
      b = mtcars %>% mutate(myfun())
    )
    # Unit: milliseconds
    #  expr        min         lq       mean     median        uq        max neval
    #     a   1.872101   2.165801   2.531046   2.312051   2.72835   4.861202   100
    #     b 546.916301 571.909551 603.528225 589.995251 612.20240 798.707300   100
    

    If you believe it will be called infrequently and your function takes "a little time", then perhaps the half-second delay won't be that noticeable, but with this toy example the difference is palpable.