Search code examples
rparallel-processingdplyrmultidplyr

multidplyr: trial custom function


I'm trying to learn to run a custom function through multidplyr::do() on a cluster. Consider this simple self contained example. For example's sake, I'm trying to apply my custom function myWxTest to each common_dest (destinations with more than 50 flights) in the flight dataset:

library(dplyr)
library(multidplyr)
library(nycflights13)
library(quantreg)

myWxTest <- function(x){
    stopifnot(!is.null(x$dep_time))
    stopifnot(!is.null(x$dep_delay))
    stopifnot(!is.null(x$sched_dep_time))
    stopifnot(!is.null(x$sched_arr_time))
    stopifnot(!is.null(x$arr_time))

    out_mat <- c('(Intercept)' = NA, dep_time = NA, dep_delay = NA, sched_dep_time = NA, sched_arr_time = NA)
    if(length(x$arr_time)>5){
        model_1 <- quantreg::rq(arr_time ~ dep_time + dep_delay + sched_dep_time + sched_arr_time, data = x, tau = .5)
        out_mat[names(coef(model_1))] <- coef(model_1)
    }
    return(out_mat)
}

common_dest <- flights %>%
  count(dest) %>%
  filter(n >= 365) %>%
  semi_join(flights, .) %>% 
  mutate(yday = lubridate::yday(ISOdate(year, month, day)))


cluster <- create_cluster(2)
set_default_cluster(cluster)
by_dest <- common_dest %>% 
           partition(dest, cluster = cluster)
cluster_library(by_dest, "quantreg")

So far so good (but I'm just reproducing the examples from the vignette). Now, I have to send my custom function to each node:

cluster %>% cluster_call(myWxTest)

But I get:

Error in checkForRemoteErrors(lapply(cl, recvResult)) : 
  2 nodes produced errors; first error: argument "x" is missing, with no default

eventually, I want to apply myWxTest to each subgroup:

models <- by_dest %>% 
          do(myWxTest(.))

Solution

  • I got it running with a couple tweaks:

    library(dplyr)
    library(multidplyr)
    library(nycflights13)
    library(quantreg)
    
    myWxTest <- function(x){
        stopifnot(!is.null(x$dep_time))
        stopifnot(!is.null(x$dep_delay))
        stopifnot(!is.null(x$sched_dep_time))
        stopifnot(!is.null(x$sched_arr_time))
        stopifnot(!is.null(x$arr_time))
    
        out_mat <- c('(Intercept)' = NA, dep_time = NA, dep_delay = NA, sched_dep_time = NA, sched_arr_time = NA)
        if(length(x$arr_time)>5){
            model_1 <- quantreg::rq(arr_time ~ dep_time + dep_delay + sched_dep_time + sched_arr_time, data = x, tau = .5)
            out_mat[names(coef(model_1))] <- coef(model_1)
        }
        return(as.data.frame(out_mat, stringsAsFactors = FALSE))    # change result to data.frame, not matrix
    }
    
    common_dest <- flights %>%
        count(dest) %>%
        filter(n >= 365) %>%
        semi_join(flights, .) %>% 
        mutate(yday = lubridate::yday(ISOdate(year, month, day)))
    
    by_dest <- common_dest %>% partition(dest)
    
    cluster_library(by_dest, "quantreg")
    cluster_copy(by_dest, myWxTest)    # copy function to each node
    
    models <- by_dest %>% do(myWxTest(.)) %>% collect()    # collect data from clusters
    

    ...which returns a local data.frame:

    models
    #> Source: local data frame [390 x 2]
    #> Groups: dest [78]
    #> 
    #>     dest     out_mat
    #>    <chr>       <dbl>
    #> 1    CAK 156.5248953
    #> 2    CAK   0.9904261
    #> 3    CAK  -0.0767928
    #> 4    CAK  -0.3523211
    #> 5    CAK   0.3220386
    #> 6    DCA  74.5959035
    #> 7    DCA   0.2751917
    #> 8    DCA   1.0712483
    #> 9    DCA   0.2874165
    #> 10   DCA   0.4344960
    #> # ... with 380 more rows