Search code examples
rdplyrmagrittrnse

Adjust function to work with dplyr/magrittr


I have:

df <- data_frame(
  a = 1:2, 
  b = list(1:10, 4:40)
)

and

foo <- function(x) mean(unlist(x))

The following works as expected:

df$b %>% foo

However, I was not able to figure out which modifications of foo are needed in order for df %>% foo(b) to work.


Solution

  • You can pass an ... parameter directly to the vars helper of summarise_at, e.g.

    foo <- function(.tbl, ...){
        summarise_at(.tbl, 
                     vars(...), 
                     funs(mean(unlist(.))))
    }
    

    It works for single variables, list column or not:

    df %>% foo(b)
    ## # A tibble: 1 × 1
    ##          b
    ##      <dbl>
    ## 1 18.48936
    

    or multiple:

    df %>% foo(a, b)
    ## # A tibble: 1 × 2
    ##       a        b
    ##   <dbl>    <dbl>
    ## 1   1.5 18.48936
    

    To go further with NSE, check out lazyeval, which is the package dplyr uses to implement its NSE.

    Also note that the SE/NSE system of dplyr has just been rebuilt in the development version (not on CRAN yet, and not yet documented).


    Bonus points: Do it all in base R!

    foo <- function(.tbl, ...){
        # collect dots as character vector
        cols <- as.character(substitute(list(...))[-1])
        cls <- class(.tbl)
    
        # handle grouped tibbles properly
        if('grouped_df' %in% cls){
            cls <- cls[which(cls != 'grouped_df')]    # drop grouping
            res <- aggregate(.tbl[cols], 
                             .tbl[attr(.tbl, 'vars')], 
                             FUN = function(x){mean(unlist(x))})
        } else {
            res <- as.data.frame(lapply(.tbl[cols], function(x){mean(unlist(x))}))
        }
    
        class(res) <- cls    # keep class (tibble, etc.)
        res
    }
    

    which works with list columns, groups, and multiple columns or groups, keeping class but dropping grouping:

    df %>% foo(a, b)
    ## # A tibble: 1 × 2
    ##       a        b
    ##   <dbl>    <dbl>
    ## 1   1.5 18.48936
    
    df %>% group_by(a) %>% foo(b)
    ## # A tibble: 2 × 2
    ##       a     b
    ##   <int> <dbl>
    ## 1     1   5.5
    ## 2     2  22.0
    
    mtcars %>% foo(mpg, hp)
    ##        mpg       hp
    ## 1 20.09062 146.6875
    
    mtcars %>% group_by(cyl, am) %>% foo(hp, mpg)
    ## # A tibble: 6 × 4
    ##     cyl    am        hp      mpg
    ##   <dbl> <dbl>     <dbl>    <dbl>
    ## 1     4     0  84.66667 22.90000
    ## 2     6     0 115.25000 19.12500
    ## 3     8     0 194.16667 15.05000
    ## 4     4     1  81.87500 28.07500
    ## 5     6     1 131.66667 20.56667
    ## 6     8     1 299.50000 15.40000