I want to tune an xgboost
learner and set the parameter early_stopping_rounds
to 10% of the parameter nrounds
(whichever is generated each time).
Should be a simple thing to do in general (i.e. tuning a parameter relative to another) but I can't make it work, see example below:
library(mlr3verse)
#> Loading required package: mlr3
learner = lrn('surv.xgboost', nrounds = to_tune(50, 5000),
early_stopping_rounds = to_tune(ps(
a = p_int(10,5000), # had to put something in here, `early_stopping_rounds` also doesn't work
.extra_trafo = function(x, param_set) {
list(early_stopping_rounds = ceiling(0.1 * x$nrounds))
}, .allow_dangling_dependencies = TRUE)))
#> Error in self$assert(xs): Assertion on 'xs' failed: early_stopping_rounds: tune token invalid: to_tune(ps(a = p_int(10, 5000), .extra_trafo = function(x, param_set) { list(early_stopping_rounds = ceiling(0.1 * x$nrounds)) }, .allow_dangling_dependencies = TRUE)) generates points that are not compatible with param early_stopping_rounds.
#> Bad value:
#> numeric(0)
#> Parameter:
#> id class lower upper levels default
#> 1: early_stopping_rounds ParamInt 1 Inf .
# this works though:
pam = ps(z = p_int(-3,3), x = p_int(0,10),
.extra_trafo = function(x, param_set) {
x$z = 2*(x$x) # overwrite z as 2*x
x
})
dplyr::bind_rows(generate_design_random(pam, 5)$transpose())
#> # A tibble: 5 × 2
#> z x
#> <dbl> <int>
#> 1 2 1
#> 2 14 7
#> 3 8 4
#> 4 12 6
#> 5 20 10
Created on 2022-08-29 by the reprex package (v2.0.1)
The reason why your solution is not working is that you are using x$nrounds
from the paramset in which it does not exist.
You can use this as a workaround.
library(mlr3verse)
#> Loading required package: mlr3
search_space = ps(
nrounds = p_int(lower = 50, upper = 5000),
.extra_trafo = function(x, param_set) {
x$early_stopping_rounds = as.integer(ceiling(0.1 * x$nrounds))
x
}
)
task = tsk("iris")
learner = lrn("classif.xgboost")
terminator = trm("evals", n_evals = 10)
tuner = tnr("random_search")
at = AutoTuner$new(
learner = learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
search_space = search_space,
terminator = terminator,
tuner = tuner
)
at$train(task)
#> INFO [13:12:50.316] [bbotk] Starting to optimize 1 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10, k=0]'
#> INFO [13:12:50.351] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:50.406] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:50.441] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:51.837] [mlr3] Finished benchmark
#> INFO [13:12:51.865] [bbotk] Result of batch 1:
#> INFO [13:12:51.867] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:51.867] [bbotk] 3497 0 0 0 1.387
#> INFO [13:12:51.867] [bbotk] uhash
#> INFO [13:12:51.867] [bbotk] 8a8e7d03-3166-4c03-8e06-78fe9f4e8a35
#> INFO [13:12:51.870] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:51.918] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:51.926] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:53.650] [mlr3] Finished benchmark
#> INFO [13:12:53.680] [bbotk] Result of batch 2:
#> INFO [13:12:53.681] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:53.681] [bbotk] 4197 0 0 0 1.718
#> INFO [13:12:53.681] [bbotk] uhash
#> INFO [13:12:53.681] [bbotk] 85c94228-4419-4e7e-8f4b-6e289a2d2900
#> INFO [13:12:53.684] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:53.725] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:53.730] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:54.648] [mlr3] Finished benchmark
#> INFO [13:12:54.683] [bbotk] Result of batch 3:
#> INFO [13:12:54.685] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:54.685] [bbotk] 2199 0 0 0 0.911
#> INFO [13:12:54.685] [bbotk] uhash
#> INFO [13:12:54.685] [bbotk] cd33357f-13bf-4851-8da3-f3c1b58755a6
#> INFO [13:12:54.687] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:54.727] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:54.732] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:56.651] [mlr3] Finished benchmark
#> INFO [13:12:56.679] [bbotk] Result of batch 4:
#> INFO [13:12:56.681] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:56.681] [bbotk] 4679 0 0 0 1.909
#> INFO [13:12:56.681] [bbotk] uhash
#> INFO [13:12:56.681] [bbotk] 4efe832d-9163-4447-9e4c-5a41190de74c
#> INFO [13:12:56.684] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:56.722] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:56.727] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:57.850] [mlr3] Finished benchmark
#> INFO [13:12:57.875] [bbotk] Result of batch 5:
#> INFO [13:12:57.877] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:57.877] [bbotk] 2422 0 0 0 1.116
#> INFO [13:12:57.877] [bbotk] uhash
#> INFO [13:12:57.877] [bbotk] 8db417a2-0b6e-4844-9c07-4c83e899964e
#> INFO [13:12:57.880] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:57.915] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:57.920] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:12:59.769] [mlr3] Finished benchmark
#> INFO [13:12:59.794] [bbotk] Result of batch 6:
#> INFO [13:12:59.795] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:12:59.795] [bbotk] 4721 0 0 0 1.843
#> INFO [13:12:59.795] [bbotk] uhash
#> INFO [13:12:59.795] [bbotk] d37d1ec0-bd89-408b-9c29-ecf657a9bbb5
#> INFO [13:12:59.798] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:12:59.833] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:12:59.838] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:13:00.336] [mlr3] Finished benchmark
#> INFO [13:13:00.369] [bbotk] Result of batch 7:
#> INFO [13:13:00.371] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:13:00.371] [bbotk] 1323 0 0 0 0.491
#> INFO [13:13:00.371] [bbotk] uhash
#> INFO [13:13:00.371] [bbotk] 89f100b9-2f9e-4c47-8734-9165dc215277
#> INFO [13:13:00.374] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:13:00.412] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:13:00.417] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:13:01.706] [mlr3] Finished benchmark
#> INFO [13:13:01.736] [bbotk] Result of batch 8:
#> INFO [13:13:01.737] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:13:01.737] [bbotk] 3424 0 0 0 1.282
#> INFO [13:13:01.737] [bbotk] uhash
#> INFO [13:13:01.737] [bbotk] 9f754641-fa5f-420a-b09a-32fe7512bb9b
#> INFO [13:13:01.740] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:13:01.784] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:13:01.789] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:13:03.160] [mlr3] Finished benchmark
#> INFO [13:13:03.189] [bbotk] Result of batch 9:
#> INFO [13:13:03.191] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:13:03.191] [bbotk] 3432 0 0 0 1.365
#> INFO [13:13:03.191] [bbotk] uhash
#> INFO [13:13:03.191] [bbotk] 47cfe02f-fd4e-4382-9343-b4c4ac274d91
#> INFO [13:13:03.194] [bbotk] Evaluating 1 configuration(s)
#> INFO [13:13:03.232] [mlr3] Running benchmark with 1 resampling iterations
#> INFO [13:13:03.237] [mlr3] Applying learner 'classif.xgboost' on task 'iris' (iter 1/1)
#> INFO [13:13:04.387] [mlr3] Finished benchmark
#> INFO [13:13:04.413] [bbotk] Result of batch 10:
#> INFO [13:13:04.415] [bbotk] nrounds classif.ce warnings errors runtime_learners
#> INFO [13:13:04.415] [bbotk] 2991 0 0 0 1.142
#> INFO [13:13:04.415] [bbotk] uhash
#> INFO [13:13:04.415] [bbotk] a1b9d503-0dae-4c5d-ba50-ffd27a754032
#> INFO [13:13:04.421] [bbotk] Finished optimizing after 10 evaluation(s)
#> INFO [13:13:04.422] [bbotk] Result:
#> INFO [13:13:04.423] [bbotk] nrounds learner_param_vals x_domain classif.ce
#> INFO [13:13:04.423] [bbotk] 3497 <list[4]> <list[2]> 0
Created on 2022-08-29 by the reprex package (v2.0.1)