Search code examples
mlr3

Setting `early_stopping_rounds` in xgboost learner using mlr3


I want to tune an xgboost learner and set the parameter early_stopping_rounds to 10% of the parameter nrounds (whichever is generated each time). Should be a simple thing to do in general (i.e. tuning a parameter relative to another) but I can't make it work, see example below:

library(mlr3verse)
#> Loading required package: mlr3

learner = lrn('surv.xgboost', nrounds = to_tune(50, 5000),
  early_stopping_rounds = to_tune(ps(
    a = p_int(10,5000), # had to put something in here, `early_stopping_rounds` also doesn't work
    .extra_trafo = function(x, param_set) {
      list(early_stopping_rounds = ceiling(0.1 * x$nrounds))
  }, .allow_dangling_dependencies = TRUE)))
#> Error in self$assert(xs): Assertion on 'xs' failed: early_stopping_rounds: tune token invalid: to_tune(ps(a = p_int(10, 5000), .extra_trafo = function(x, param_set) {     list(early_stopping_rounds = ceiling(0.1 * x$nrounds)) }, .allow_dangling_dependencies = TRUE)) generates points that are not compatible with param early_stopping_rounds.
#> Bad value:
#> numeric(0)
#> Parameter:
#>                       id    class lower upper levels default
#> 1: early_stopping_rounds ParamInt     1   Inf               .

# this works though:
pam = ps(z = p_int(-3,3), x = p_int(0,10),
  .extra_trafo = function(x, param_set) {
    x$z = 2*(x$x) # overwrite z as 2*x
    x
  })

dplyr::bind_rows(generate_design_random(pam, 5)$transpose())
#> # A tibble: 5 × 2
#>       z     x
#>   <dbl> <int>
#> 1     2     1
#> 2    14     7
#> 3     8     4
#> 4    12     6
#> 5    20    10

Created on 2022-08-29 by the reprex package (v2.0.1)


Solution

  • The reason why your solution is not working is that you are using x$nrounds from the paramset in which it does not exist.

    You can use this as a workaround.

    library(mlr3verse)
    #> Loading required package: mlr3
    
    search_space = ps(
      nrounds = p_int(lower = 50, upper = 5000),
      .extra_trafo = function(x, param_set) {
        x$early_stopping_rounds = as.integer(ceiling(0.1 * x$nrounds))
        x
      }
    )
    
    task = tsk("iris")
    learner = lrn("classif.xgboost")
    terminator = trm("evals", n_evals = 10)
    tuner = tnr("random_search")
    
    at = AutoTuner$new(
      learner = learner,
      resampling = rsmp("holdout"),
      measure = msr("classif.ce"),
      search_space = search_space,
      terminator = terminator,
      tuner = tuner
    )
    
    at$train(task)
    #> INFO  [13:12:50.316] [bbotk] Starting to optimize 1 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]' 
    #> INFO  [13:12:50.351] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:50.406] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:50.441] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:51.837] [mlr3] Finished benchmark 
    #> INFO  [13:12:51.865] [bbotk] Result of batch 1: 
    #> INFO  [13:12:51.867] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:51.867] [bbotk]     3497          0        0      0            1.387 
    #> INFO  [13:12:51.867] [bbotk]                                 uhash 
    #> INFO  [13:12:51.867] [bbotk]  8a8e7d03-3166-4c03-8e06-78fe9f4e8a35 
    #> INFO  [13:12:51.870] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:51.918] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:51.926] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:53.650] [mlr3] Finished benchmark 
    #> INFO  [13:12:53.680] [bbotk] Result of batch 2: 
    #> INFO  [13:12:53.681] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:53.681] [bbotk]     4197          0        0      0            1.718 
    #> INFO  [13:12:53.681] [bbotk]                                 uhash 
    #> INFO  [13:12:53.681] [bbotk]  85c94228-4419-4e7e-8f4b-6e289a2d2900 
    #> INFO  [13:12:53.684] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:53.725] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:53.730] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:54.648] [mlr3] Finished benchmark 
    #> INFO  [13:12:54.683] [bbotk] Result of batch 3: 
    #> INFO  [13:12:54.685] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:54.685] [bbotk]     2199          0        0      0            0.911 
    #> INFO  [13:12:54.685] [bbotk]                                 uhash 
    #> INFO  [13:12:54.685] [bbotk]  cd33357f-13bf-4851-8da3-f3c1b58755a6 
    #> INFO  [13:12:54.687] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:54.727] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:54.732] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:56.651] [mlr3] Finished benchmark 
    #> INFO  [13:12:56.679] [bbotk] Result of batch 4: 
    #> INFO  [13:12:56.681] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:56.681] [bbotk]     4679          0        0      0            1.909 
    #> INFO  [13:12:56.681] [bbotk]                                 uhash 
    #> INFO  [13:12:56.681] [bbotk]  4efe832d-9163-4447-9e4c-5a41190de74c 
    #> INFO  [13:12:56.684] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:56.722] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:56.727] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:57.850] [mlr3] Finished benchmark 
    #> INFO  [13:12:57.875] [bbotk] Result of batch 5: 
    #> INFO  [13:12:57.877] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:57.877] [bbotk]     2422          0        0      0            1.116 
    #> INFO  [13:12:57.877] [bbotk]                                 uhash 
    #> INFO  [13:12:57.877] [bbotk]  8db417a2-0b6e-4844-9c07-4c83e899964e 
    #> INFO  [13:12:57.880] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:57.915] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:57.920] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:12:59.769] [mlr3] Finished benchmark 
    #> INFO  [13:12:59.794] [bbotk] Result of batch 6: 
    #> INFO  [13:12:59.795] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:12:59.795] [bbotk]     4721          0        0      0            1.843 
    #> INFO  [13:12:59.795] [bbotk]                                 uhash 
    #> INFO  [13:12:59.795] [bbotk]  d37d1ec0-bd89-408b-9c29-ecf657a9bbb5 
    #> INFO  [13:12:59.798] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:12:59.833] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:12:59.838] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:13:00.336] [mlr3] Finished benchmark 
    #> INFO  [13:13:00.369] [bbotk] Result of batch 7: 
    #> INFO  [13:13:00.371] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:13:00.371] [bbotk]     1323          0        0      0            0.491 
    #> INFO  [13:13:00.371] [bbotk]                                 uhash 
    #> INFO  [13:13:00.371] [bbotk]  89f100b9-2f9e-4c47-8734-9165dc215277 
    #> INFO  [13:13:00.374] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:13:00.412] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:13:00.417] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:13:01.706] [mlr3] Finished benchmark 
    #> INFO  [13:13:01.736] [bbotk] Result of batch 8: 
    #> INFO  [13:13:01.737] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:13:01.737] [bbotk]     3424          0        0      0            1.282 
    #> INFO  [13:13:01.737] [bbotk]                                 uhash 
    #> INFO  [13:13:01.737] [bbotk]  9f754641-fa5f-420a-b09a-32fe7512bb9b 
    #> INFO  [13:13:01.740] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:13:01.784] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:13:01.789] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:13:03.160] [mlr3] Finished benchmark 
    #> INFO  [13:13:03.189] [bbotk] Result of batch 9: 
    #> INFO  [13:13:03.191] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:13:03.191] [bbotk]     3432          0        0      0            1.365 
    #> INFO  [13:13:03.191] [bbotk]                                 uhash 
    #> INFO  [13:13:03.191] [bbotk]  47cfe02f-fd4e-4382-9343-b4c4ac274d91 
    #> INFO  [13:13:03.194] [bbotk] Evaluating 1 configuration(s) 
    #> INFO  [13:13:03.232] [mlr3] Running benchmark with 1 resampling iterations 
    #> INFO  [13:13:03.237] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1) 
    #> INFO  [13:13:04.387] [mlr3] Finished benchmark 
    #> INFO  [13:13:04.413] [bbotk] Result of batch 10: 
    #> INFO  [13:13:04.415] [bbotk]  nrounds classif.ce warnings errors runtime_learners 
    #> INFO  [13:13:04.415] [bbotk]     2991          0        0      0            1.142 
    #> INFO  [13:13:04.415] [bbotk]                                 uhash 
    #> INFO  [13:13:04.415] [bbotk]  a1b9d503-0dae-4c5d-ba50-ffd27a754032 
    #> INFO  [13:13:04.421] [bbotk] Finished optimizing after 10 evaluation(s) 
    #> INFO  [13:13:04.422] [bbotk] Result: 
    #> INFO  [13:13:04.423] [bbotk]  nrounds learner_param_vals  x_domain classif.ce 
    #> INFO  [13:13:04.423] [bbotk]     3497          <list[4]> <list[2]>          0
    

    Created on 2022-08-29 by the reprex package (v2.0.1)