Search code examples
parallel-processingmlr

Error in custom measure parsing external function while using parallelization


I defined a custom measure that allows to transform the prediction$data with an external function before evaluating standard measures such as rmse. If I try to tune params without parallelization everything goes smooth but if I start a parallelized session it seems not to find the external function anymore, although it's declared in the global environment.

library(compiler)
library(mlr)
library(parallelMap)
library(parallel)

# define function
inverse_fun = function(x){x^2}
inverse_fun = Vectorize(inverse_fun)
inverse_fun = cmpfun(inverse_fun, options=list(suppressUndefined=T))
assign('inverse_fun', inverse_fun, envir = .GlobalEnv)

tuning_criterion = 'rmse'

# define a new measure that applies inverse_fun to prediction and evaluates rmse
original_measure = eval(parse(text=tuning_criterion))
transf_measure_fun = function(task, model, pred, feats, extra.args){
  # transform back to original value
  pred$data$truth = inverse_fun(pred$data$truth)
  pred$data$response = inverse_fun(pred$data$response)
  return(original_measure$fun(task, model, pred, feats, extra.args))
}
transf_measure = makeMeasure(
  id = 'ii', name = 'ccc',
  properties = original_measure$properties,
  minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
  fun = transf_measure_fun)

transf_measure = setAggregation(transf_measure, original_measure$aggr)
aggregated_measure = list(transf_measure, setAggregation(transf_measure, test.sd), setAggregation(transf_measure, train.mean), setAggregation(transf_measure, train.sd))

# train and predict
lrn.lm = makeLearner("regr.ksvm")
mod.lm = train(lrn.lm, bh.task)
task.pred.lm = predict(mod.lm, task = bh.task)

# inverse function on prediction
inv_pred = task.pred.lm
inv_pred$data$truth = inverse_fun(inv_pred$data$truth)
inv_pred$data$response = inverse_fun(inv_pred$data$response)

# check for performance match
performance(task.pred.lm, transf_measure)
performance(inv_pred, rmse)

# tuning
discrete_ps = makeParamSet(
  makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
  makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
)
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("CV", iters = 3L)

# this works
res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
                 par.set = discrete_ps, control = ctrl, measures = transf_measure)

# try with parallelization - doesn't work
current_os = Sys.info()[['sysname']]  # detect OS
if (current_os == "Windows"){
  set.seed(1, "L'Ecuyer-CMRG")
  parallelStart(mode = "socket", cpus = detectCores(), show.info = F)
  parallel::clusterSetRNGStream(iseed = 1)
} else if (current_os == "Linux"){
  set.seed(1, "L'Ecuyer-CMRG")
  parallelStart(mode = "multicore", cpus = detectCores(), show.info = F)
} else {
  cat('\n\n#### OS not recognized, check parallelization init\n\n')
} 
res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
                 par.set = discrete_ps, control = ctrl, measures = transf_measure)
parallelStop()

getting the following error:

Error in stopWithJobErrorMessages(inds, vcapply(result.list[inds], as.character)) : 
  Errors occurred in 16 slave jobs, displaying at most 10 of them:

00001: Error in inverse_fun(pred$data$truth) : 
  cannot find "inverse_fun"

I tried to pass the function with extra.args but I get an error

original_measure = eval(parse(text=tuning_criterion))
transf_measure_fun = function(task, model, pred, feats, extra.args){
  # transform back to original value
  pred$data$truth = extra.args$inv_fun(pred$data$truth)
  pred$data$response = extra.args$inv_fun(pred$data$response)
  return(original_measure$fun(task, model, pred, feats, extra.args))
}
transf_measure = makeMeasure(
  id = 'ii', name = 'ccc',
  properties = original_measure$properties,
  minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
  fun = transf_measure_fun(extra.args = list(inv_fun = inverse_fun))
)

and I get Error in FUN(X[[i]], ...) : argument "pred" is missing, with no default

Thanks in advance


Solution

  • You need to export your custom objects using parallelMap::parallelExport().

    library(mlr)
    #> Loading required package: ParamHelpers
    library(parallelMap)
    library(compiler)
    
    # define function
    inverse_fun = function(x){x^2}
    inverse_fun = Vectorize(inverse_fun)
    inverse_fun = cmpfun(inverse_fun, options=list(suppressUndefined=T))
    assign('inverse_fun', inverse_fun, envir = .GlobalEnv)
    
    tuning_criterion = 'rmse'
    
    # define a new measure that applies inverse_fun to prediction and evaluates rmse
    original_measure = eval(parse(text=tuning_criterion))
    transf_measure_fun = function(task, model, pred, feats, extra.args){
      # transform back to original value
      pred$data$truth = inverse_fun(pred$data$truth)
      pred$data$response = inverse_fun(pred$data$response)
      return(original_measure$fun(task, model, pred, feats, extra.args))
    }
    transf_measure = makeMeasure(
      id = 'ii', name = 'ccc',
      properties = original_measure$properties,
      minimize = original_measure$minimize, best = original_measure$best, worst = original_measure$worst,
      fun = transf_measure_fun)
    
    transf_measure = setAggregation(transf_measure, original_measure$aggr)
    
    # tuning
    discrete_ps = makeParamSet(
      makeDiscreteParam("C", values = c(0.5, 1.0, 1.5, 2.0)),
      makeDiscreteParam("sigma", values = c(0.5, 1.0, 1.5, 2.0))
    )
    ctrl = makeTuneControlGrid()
    rdesc = makeResampleDesc("CV", iters = 3L)
    lrn.lm = makeLearner("regr.ksvm")
    
    set.seed(1, "L'Ecuyer-CMRG")
    parallelStart(mode = "socket", cpus = 2, show.info = F)
    parallelExport("inverse_fun", "original_measure")
    
    res = tuneParams(lrn.lm, task = bh.task, resampling = rdesc,
                     par.set = discrete_ps, control = ctrl, measures = transf_measure)
    #> [Tune] Started tuning learner regr.ksvm for parameter set:
    #>           Type len Def      Constr Req Tunable Trafo
    #> C     discrete   -   - 0.5,1,1.5,2   -    TRUE     -
    #> sigma discrete   -   - 0.5,1,1.5,2   -    TRUE     -
    #> With control class: TuneControlGrid
    #> Imputation value: Inf
    #> [Tune] Result: C=2; sigma=0.5 : ii.test.rmse=270.8008465
    parallelStop()
    

    Created on 2019-10-08 by the reprex package (v0.3.0)

    Session info

    devtools::session_info()
    #> ─ Session info ──────────────────────────────────────────────────────────
    #>  setting  value                       
    #>  version  R version 3.6.1 (2019-07-05)
    #>  os       Arch Linux                  
    #>  system   x86_64, linux-gnu           
    #>  ui       X11                         
    #>  language (EN)                        
    #>  collate  en_US.UTF-8                 
    #>  ctype    en_US.UTF-8                 
    #>  tz       Europe/Berlin               
    #>  date     2019-10-08                  
    #> 
    #> ─ Packages ──────────────────────────────────────────────────────────────
    #>  ! package      * version     date       lib
    #>    assertthat     0.2.1       2019-03-21 [1]
    #>    backports      1.1.5       2019-10-02 [1]
    #>    BBmisc         1.11        2017-03-10 [1]
    #>    callr          3.3.2       2019-09-22 [1]
    #>    checkmate      1.9.4       2019-07-04 [1]
    #>    cli            1.1.0       2019-03-19 [1]
    #>    colorspace     1.4-1       2019-03-18 [1]
    #>    crayon         1.3.4       2017-09-16 [1]
    #>    data.table     1.12.4      2019-10-03 [1]
    #>    desc           1.2.0       2018-05-01 [1]
    #>    devtools       2.2.1       2019-09-24 [1]
    #>    digest         0.6.21      2019-09-20 [1]
    #>    dplyr          0.8.3       2019-07-04 [1]
    #>    ellipsis       0.3.0       2019-09-20 [1]
    #>    evaluate       0.14        2019-05-28 [1]
    #>    fastmatch      1.1-0       2017-01-28 [1]
    #>    fs             1.3.1       2019-05-06 [1]
    #>    ggplot2        3.2.1       2019-08-10 [1]
    #>    glue           1.3.1       2019-03-12 [1]
    #>    gtable         0.3.0       2019-03-25 [1]
    #>    highr          0.8         2019-03-20 [1]
    #>    htmltools      0.4.0       2019-10-04 [1]
    #>    kernlab        0.9-27      2018-08-10 [1]
    #>    knitr          1.25        2019-09-18 [1]
    #>    lattice        0.20-38     2018-11-04 [1]
    #>    lazyeval       0.2.2       2019-03-15 [1]
    #>    magrittr       1.5         2014-11-22 [1]
    #>    Matrix         1.2-17      2019-03-22 [1]
    #>    memoise        1.1.0       2017-04-21 [1]
    #>    mlr          * 2.15.0.9000 2019-10-08 [1]
    #>    munsell        0.5.0       2018-06-12 [1]
    #>    parallelMap  * 1.4         2019-05-17 [1]
    #>    ParamHelpers * 1.12        2019-01-18 [1]
    #>    pillar         1.4.2       2019-06-29 [1]
    #>    pkgbuild       1.0.5       2019-08-26 [1]
    #>    pkgconfig      2.0.3       2019-09-22 [1]
    #>    pkgload        1.0.2       2018-10-29 [1]
    #>    prettyunits    1.0.2       2015-07-13 [1]
    #>    processx       3.4.1       2019-07-18 [1]
    #>    ps             1.3.0       2018-12-21 [1]
    #>    purrr          0.3.2       2019-03-15 [1]
    #>    R6             2.4.0       2019-02-14 [1]
    #>    Rcpp           1.0.2       2019-07-25 [1]
    #>    remotes        2.1.0       2019-06-24 [1]
    #>    rlang          0.4.0       2019-06-25 [1]
    #>    rmarkdown      1.16        2019-10-01 [1]
    #>    rprojroot      1.3-2       2018-01-03 [1]
    #>    scales         1.0.0       2018-08-09 [1]
    #>    sessioninfo    1.1.1       2018-11-05 [1]
    #>    stringi        1.4.3       2019-03-12 [1]
    #>    stringr        1.4.0       2019-02-10 [1]
    #>  R survival       2.44-1.1    <NA>       [2]
    #>    testthat       2.2.1       2019-07-25 [1]
    #>    tibble         2.1.3       2019-06-06 [1]
    #>    tidyselect     0.2.5       2018-10-11 [1]
    #>    usethis        1.5.1.9000  2019-10-07 [1]
    #>    withr          2.1.2       2018-03-15 [1]
    #>    xfun           0.10        2019-10-01 [1]
    #>    XML            3.98-1.20   2019-06-06 [1]
    #>    yaml           2.2.0       2018-07-25 [1]
    #>  source                        
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  local                         
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  <NA>                          
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #>  Github (r-lib/usethis@3015465)
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.1)                
    #>  CRAN (R 3.6.0)                
    #>  CRAN (R 3.6.0)                
    #> 
    #> [1] /home/pjs/R/x86_64-pc-linux-gnu-library/3.6
    #> [2] /usr/lib/R/library
    #> 
    #>  R ── Package was removed from disk.