Search code examples
rresamplinglasso-regressionmlr

Nested resampling + LASSO (regr.cvglment) using mlr


I am trying to conduct nested resampling with 10 CVs for the inner and 10 CVs for the outer loop using regr.cvglment. Mlr provides the code using a wrapper function (https://mlr-org.github.io/mlr/articles/tutorial/devel/nested_resampling.html)

Now, I just exchanged two things from their code provided 1) "regr.cvglmnet" instead of support vector machine (ksvm) 2) the number of iterations for both inner and outer loop

After the lrn function I get the error specified below. Could someone explain this to me? I am completely new to coding and machine learning so I might have done something pretty stupid in the code....

ps = makeParamSet(
  makeDiscreteParam("C", values = 2^(-12:12)),
  makeDiscreteParam("sigma", values = 2^(-12:12))
)
ctrl = makeTuneControlGrid()
inner = makeResampleDesc("Subsample", iters = 10)
lrn = makeTuneWrapper("regr.cvglmnet", resampling = inner, par.set = ps, 
                      control = ctrl, show.info = FALSE)

# Error in checkTunerParset(learner, par.set, measures, control) : 
# Can only tune parameters for which learner parameters exist: C,sigma

### Outer resampling loop
outer = makeResampleDesc("CV", iters = 10) 
r = resample(lrn, iris.task, resampling = outer, extract = getTuneResult, 
             show.info = FALSE)

Solution

  • When using LASSO with glmnet, you only need to tune s. This is the important parameter that is used when the model predicts to new data. Parameter lambda has absolutely no influence due to the way the package is coded on the prediction. If you set s different to whatever lambda values have been chosen, the model will be refitted with s as the penalization term.

    By default, several models with various lambda values are fitted during the train call. However, for prediction a new model will be fitted using the best lambda value. So in fact the tuning is done in the prediction step.

    Good default ranges for s can be chosen by

    1. Training the model with the defaults from glmnet
    2. Check min and max values of lambda
    3. Use these as lower and upper bounds for s that is then tuned using mlr

    See also this discussion.

    library(mlr)
    #> Loading required package: ParamHelpers
    
    lrn_glmnet <- makeLearner("regr.glmnet",
                              alpha = 1,
                              intercept = FALSE)
    
    # check lambda
    glmnet_train = mlr::train(lrn_glmnet, bh.task)
    summary(glmnet_train$learner.model$lambda)
    #>    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
    #>   143.5   157.4   172.8   174.3   189.6   208.1
    
    # set limits
    ps_glmnet <- makeParamSet(makeNumericParam("s", lower = 140, upper = 208))
    
    # tune params in parallel using a grid search for simplicity
    tune.ctrl = makeTuneControlGrid()
    inner <- makeResampleDesc("CV", iters = 2)
    
    configureMlr(on.learner.error = "warn", on.error.dump = TRUE)
    library(parallelMap)
    parallelStart(mode = "multicore", level = "mlr.tuneParams", cpus = 4,
                  mc.set.seed = TRUE) # only parallelize the tuning
    #> Starting parallelization in mode=multicore with cpus=4.
    set.seed(12345)
    params_tuned_glmnet = tuneParams(lrn_glmnet, task = bh.task, resampling = inner,
                                     par.set = ps_glmnet, control = tune.ctrl, 
                                     measure = list(rmse))
    #> [Tune] Started tuning learner regr.glmnet for parameter set:
    #>      Type len Def     Constr Req Tunable Trafo
    #> s numeric   -   - 140 to 208   -    TRUE     -
    #> With control class: TuneControlGrid
    #> Imputation value: Inf
    #> Mapping in parallel: mode = multicore; cpus = 4; elements = 10.
    #> [Tune] Result: s=140 : rmse.test.rmse=17.9803086
    parallelStop()
    #> Stopped parallelization. All cleaned up.
    
    # train the model on the whole dataset using the `s` value from the tuning
    
    lrn_glmnet_tuned <- makeLearner("regr.glmnet",
                                    alpha = 1,
                                    s = 140,
                                    intercept = FALSE)
    #lambda = sort(seq(0, 5, length.out = 100), decreasing = T))
    glmnet_train_tuned = mlr::train(lrn_glmnet_tuned, bh.task)
    

    Created on 2018-07-03 by the reprex package (v0.2.0).

    devtools::session_info()
    #> Session info -------------------------------------------------------------
    #>  setting  value                       
    #>  version  R version 3.5.0 (2018-04-23)
    #>  system   x86_64, linux-gnu           
    #>  ui       X11                         
    #>  language (EN)                        
    #>  collate  en_US.UTF-8                 
    #>  tz       Europe/Berlin               
    #>  date     2018-07-03
    #> Packages -----------------------------------------------------------------
    #>  package      * version   date       source         
    #>  backports      1.1.2     2017-12-13 CRAN (R 3.5.0) 
    #>  base         * 3.5.0     2018-06-04 local          
    #>  BBmisc         1.11      2017-03-10 CRAN (R 3.5.0) 
    #>  bit            1.1-14    2018-05-29 cran (@1.1-14) 
    #>  bit64          0.9-7     2017-05-08 CRAN (R 3.5.0) 
    #>  blob           1.1.1     2018-03-25 CRAN (R 3.5.0) 
    #>  checkmate      1.8.5     2017-10-24 CRAN (R 3.5.0) 
    #>  codetools      0.2-15    2016-10-05 CRAN (R 3.5.0) 
    #>  colorspace     1.3-2     2016-12-14 CRAN (R 3.5.0) 
    #>  compiler       3.5.0     2018-06-04 local          
    #>  data.table     1.11.4    2018-05-27 CRAN (R 3.5.0) 
    #>  datasets     * 3.5.0     2018-06-04 local          
    #>  DBI            1.0.0     2018-05-02 cran (@1.0.0)  
    #>  devtools       1.13.6    2018-06-27 CRAN (R 3.5.0) 
    #>  digest         0.6.15    2018-01-28 CRAN (R 3.5.0) 
    #>  evaluate       0.10.1    2017-06-24 CRAN (R 3.5.0) 
    #>  fastmatch      1.1-0     2017-01-28 CRAN (R 3.5.0) 
    #>  foreach        1.4.4     2017-12-12 CRAN (R 3.5.0) 
    #>  ggplot2        2.2.1     2016-12-30 CRAN (R 3.5.0) 
    #>  git2r          0.21.0    2018-01-04 CRAN (R 3.5.0) 
    #>  glmnet         2.0-16    2018-04-02 CRAN (R 3.5.0) 
    #>  graphics     * 3.5.0     2018-06-04 local          
    #>  grDevices    * 3.5.0     2018-06-04 local          
    #>  grid           3.5.0     2018-06-04 local          
    #>  gtable         0.2.0     2016-02-26 CRAN (R 3.5.0) 
    #>  htmltools      0.3.6     2017-04-28 CRAN (R 3.5.0) 
    #>  iterators      1.0.9     2017-12-12 CRAN (R 3.5.0) 
    #>  knitr          1.20      2018-02-20 CRAN (R 3.5.0) 
    #>  lattice        0.20-35   2017-03-25 CRAN (R 3.5.0) 
    #>  lazyeval       0.2.1     2017-10-29 CRAN (R 3.5.0) 
    #>  magrittr       1.5       2014-11-22 CRAN (R 3.5.0) 
    #>  Matrix         1.2-14    2018-04-09 CRAN (R 3.5.0) 
    #>  memoise        1.1.0     2017-04-21 CRAN (R 3.5.0) 
    #>  memuse         4.0-0     2017-11-10 CRAN (R 3.5.0) 
    #>  methods      * 3.5.0     2018-06-04 local          
    #>  mlr          * 2.13      2018-07-01 local          
    #>  munsell        0.5.0     2018-06-12 CRAN (R 3.5.0) 
    #>  parallel       3.5.0     2018-06-04 local          
    #>  parallelMap  * 1.3       2015-06-10 CRAN (R 3.5.0) 
    #>  ParamHelpers * 1.11      2018-06-25 CRAN (R 3.5.0) 
    #>  pillar         1.2.3     2018-05-25 CRAN (R 3.5.0) 
    #>  plyr           1.8.4     2016-06-08 CRAN (R 3.5.0) 
    #>  Rcpp           0.12.17   2018-05-18 cran (@0.12.17)
    #>  rlang          0.2.1     2018-05-30 CRAN (R 3.5.0) 
    #>  rmarkdown      1.10      2018-06-11 CRAN (R 3.5.0) 
    #>  rprojroot      1.3-2     2018-01-03 CRAN (R 3.5.0) 
    #>  RSQLite        2.1.1     2018-05-06 cran (@2.1.1)  
    #>  scales         0.5.0     2017-08-24 CRAN (R 3.5.0) 
    #>  splines        3.5.0     2018-06-04 local          
    #>  stats        * 3.5.0     2018-06-04 local          
    #>  stringi        1.2.3     2018-06-12 CRAN (R 3.5.0) 
    #>  stringr        1.3.1     2018-05-10 CRAN (R 3.5.0) 
    #>  survival       2.42-3    2018-04-16 CRAN (R 3.5.0) 
    #>  tibble         1.4.2     2018-01-22 CRAN (R 3.5.0) 
    #>  tools          3.5.0     2018-06-04 local          
    #>  utils        * 3.5.0     2018-06-04 local          
    #>  withr          2.1.2     2018-03-15 CRAN (R 3.5.0) 
    #>  XML            3.98-1.11 2018-04-16 CRAN (R 3.5.0) 
    #>  yaml           2.1.19    2018-05-01 CRAN (R 3.5.0)