Search code examples
mlr3

How do I tune random forest with oob error?


Instead of doing a CV and train the Random Forest multiple times I would like to use the OOB Error as and unbiased estimation of the generalized error.

And for a few data points (in the low thousands), does it make sense to use the OOB error instead of CV, since it might be possible that only a few data points are oob?

So far I could only find something about it from this issue thread https://github.com/mlr-org/mlr/issues/338 from mlr. I think it is suggested to use a hould out split with almost only training data.

I found the insample resampling method https://mlr3.mlr-org.com/reference/mlr_resamplings_insample.html which uses the same data for training and testing.

This is my code:

learner = as_learner(
  po("select", selector=selector_name(selection)) %>>% po("learner", learner=lrn("regr.ranger"))
)


sp = ps(
  regr.ranger.mtry.ratio = p_dbl(0, 1),
  regr.ranger.replace = p_fct(c(TRUE, FALSE)),
  regr.ranger.sample.fraction = p_dbl(0.1, 1),
  regr.ranger.num.trees = p_int(1, 2000)
)

at = auto_tuner(
  resampling = rsmp("insample"),
  method = "random_search",
  learner = learner,
  measure = msr("oob_error"),
  term_evals = 5,
  search_space=sp
)

learners = c(at)
resamplings = rsmp("cv", folds = 5)

design = benchmark_grid(task, learners, resamplings)
bmr = benchmark(design)

But when running the code above, I get the error: Error in learner$oob_error() : attempt to apply non-function


Solution

  • The problem is that the resulting GraphLearner does not have the method oob_error() anymore. This is similar to the issues here:

    https://github.com/mlr-org/mlr3pipelines/issues/291

    Edit: Add workaround.

    This suggestion should be seen as a workaround.

    The idea is that it is possible to write custom measures as mentioned in the comments. A tutorial on that can be found in the mlr3 book

    This custom measure only works in this specific case, because it is tailored to the specific structure of the GraphLearner. For a different learner, the measure would have to be adjusted.

    library(mlr3verse)
    #> Loading required package: mlr3
    
    task = tsk("mtcars")
    
    selection = c("mpg", "cyl")
    
    learner = as_learner(
      po("select", selector = selector_name(selection)) %>>% po("learner", learner = lrn("regr.ranger"))
    )
    
    
    sp = ps(
      regr.ranger.mtry.ratio = p_dbl(0, 1),
      regr.ranger.replace = p_fct(c(TRUE, FALSE)),
      regr.ranger.sample.fraction = p_dbl(0.1, 1),
      regr.ranger.num.trees = p_int(1, 2000)
    )
    
    
    MyMeasure = R6::R6Class(
      "MyMeasure",
      inherit = MeasureRegr,
      public = list(
        initialize = function() {
          super$initialize(
            id = "MyMeasure",
            range = c(-Inf, Inf),
            minimize = TRUE,
            predict_type = "response",
            properties = "requires_learner"
          )
        }
      ),
      private = list(
        .score = function(prediction, learner, ...) {
          model = learner$state$model$regr.ranger
          if (is.null(model)) stop("Set store_models = TRUE.")
          model$model$prediction.error
        }
      )
    )
    
    
    
    at = auto_tuner(
      resampling = rsmp("insample"),
      method = "random_search",
      learner = learner,
      measure = MyMeasure$new(),
      term_evals = 1,
      search_space = sp,
      store_models = TRUE
    )
    
    learners = c(at)
    resamplings = rsmp("cv", folds = 5)
    
    design = benchmark_grid(task, learners, resamplings)
    
    lgr::get_logger("mlr3")$set_threshold(NULL)
    lgr::get_logger("mlr3tuning")$set_threshold(NULL)
    
    
    bmr = benchmark(design)
    #> INFO  [23:28:45.638] [mlr3] Running benchmark with 5 resampling iterations
    #> INFO  [23:28:45.740] [mlr3] Applying learner 'select.regr.ranger.tuned' on task 'mtcars' (iter 1/5)
    #> INFO  [23:28:47.112] [bbotk] Starting to optimize 4 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
    #> INFO  [23:28:47.158] [bbotk] Evaluating 1 configuration(s)
    #> INFO  [23:28:47.201] [mlr3] Running benchmark with 1 resampling iterations
    #> INFO  [23:28:47.209] [mlr3] Applying learner 'select.regr.ranger' on task 'mtcars' (iter 1/1)
    #> INFO  [23:28:47.346] [mlr3] Finished benchmark
    #> INFO  [23:28:47.419] [bbotk] Result of batch 1:
    #> INFO  [23:28:47.424] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:47.424] [bbotk]               0.5708216                TRUE                   0.4830289
    #> INFO  [23:28:47.424] [bbotk]  regr.ranger.num.trees MyMeasure warnings errors runtime_learners
    #> INFO  [23:28:47.424] [bbotk]                   1209  11.39842        0      0            0.124
    #> INFO  [23:28:47.424] [bbotk]                                 uhash
    #> INFO  [23:28:47.424] [bbotk]  abfcaa2f-8b01-4821-8e8b-1d209fbe2229
    #> INFO  [23:28:47.444] [bbotk] Finished optimizing after 1 evaluation(s)
    #> INFO  [23:28:47.445] [bbotk] Result:
    #> INFO  [23:28:47.447] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:47.447] [bbotk]               0.5708216                TRUE                   0.4830289
    #> INFO  [23:28:47.447] [bbotk]  regr.ranger.num.trees learner_param_vals  x_domain MyMeasure
    #> INFO  [23:28:47.447] [bbotk]                   1209          <list[6]> <list[4]>  11.39842
    #> INFO  [23:28:47.616] [mlr3] Applying learner 'select.regr.ranger.tuned' on task 'mtcars' (iter 2/5)
    #> INFO  [23:28:47.733] [bbotk] Starting to optimize 4 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
    #> INFO  [23:28:47.758] [bbotk] Evaluating 1 configuration(s)
    #> INFO  [23:28:47.799] [mlr3] Running benchmark with 1 resampling iterations
    #> INFO  [23:28:47.807] [mlr3] Applying learner 'select.regr.ranger' on task 'mtcars' (iter 1/1)
    #> INFO  [23:28:47.900] [mlr3] Finished benchmark
    #> INFO  [23:28:47.969] [bbotk] Result of batch 1:
    #> INFO  [23:28:47.971] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:47.971] [bbotk]               0.9683787               FALSE                   0.4303312
    #> INFO  [23:28:47.971] [bbotk]  regr.ranger.num.trees MyMeasure warnings errors runtime_learners
    #> INFO  [23:28:47.971] [bbotk]                    112  9.594568        0      0            0.084
    #> INFO  [23:28:47.971] [bbotk]                                 uhash
    #> INFO  [23:28:47.971] [bbotk]  4bb2742b-49e2-4b02-adc4-ffaa70aef8d4
    #> INFO  [23:28:47.984] [bbotk] Finished optimizing after 1 evaluation(s)
    #> INFO  [23:28:47.984] [bbotk] Result:
    #> INFO  [23:28:47.986] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:47.986] [bbotk]               0.9683787               FALSE                   0.4303312
    #> INFO  [23:28:47.986] [bbotk]  regr.ranger.num.trees learner_param_vals  x_domain MyMeasure
    #> INFO  [23:28:47.986] [bbotk]                    112          <list[6]> <list[4]>  9.594568
    #> INFO  [23:28:48.116] [mlr3] Applying learner 'select.regr.ranger.tuned' on task 'mtcars' (iter 3/5)
    #> INFO  [23:28:48.241] [bbotk] Starting to optimize 4 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
    #> INFO  [23:28:48.266] [bbotk] Evaluating 1 configuration(s)
    #> INFO  [23:28:48.308] [mlr3] Running benchmark with 1 resampling iterations
    #> INFO  [23:28:48.316] [mlr3] Applying learner 'select.regr.ranger' on task 'mtcars' (iter 1/1)
    #> INFO  [23:28:48.413] [mlr3] Finished benchmark
    #> INFO  [23:28:48.480] [bbotk] Result of batch 1:
    #> INFO  [23:28:48.483] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:48.483] [bbotk]               0.4089994                TRUE                   0.1780138
    #> INFO  [23:28:48.483] [bbotk]  regr.ranger.num.trees MyMeasure warnings errors runtime_learners
    #> INFO  [23:28:48.483] [bbotk]                    620  38.86261        0      0            0.089
    #> INFO  [23:28:48.483] [bbotk]                                 uhash
    #> INFO  [23:28:48.483] [bbotk]  9b47bdb0-15dc-421d-9091-db2e6c41cbee
    #> INFO  [23:28:48.495] [bbotk] Finished optimizing after 1 evaluation(s)
    #> INFO  [23:28:48.496] [bbotk] Result:
    #> INFO  [23:28:48.498] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:48.498] [bbotk]               0.4089994                TRUE                   0.1780138
    #> INFO  [23:28:48.498] [bbotk]  regr.ranger.num.trees learner_param_vals  x_domain MyMeasure
    #> INFO  [23:28:48.498] [bbotk]                    620          <list[6]> <list[4]>  38.86261
    #> INFO  [23:28:48.646] [mlr3] Applying learner 'select.regr.ranger.tuned' on task 'mtcars' (iter 4/5)
    #> INFO  [23:28:48.763] [bbotk] Starting to optimize 4 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
    #> INFO  [23:28:48.788] [bbotk] Evaluating 1 configuration(s)
    #> INFO  [23:28:48.829] [mlr3] Running benchmark with 1 resampling iterations
    #> INFO  [23:28:48.837] [mlr3] Applying learner 'select.regr.ranger' on task 'mtcars' (iter 1/1)
    #> INFO  [23:28:48.959] [mlr3] Finished benchmark
    #> INFO  [23:28:49.027] [bbotk] Result of batch 1:
    #> INFO  [23:28:49.030] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:49.030] [bbotk]               0.3449179               FALSE                    0.344375
    #> INFO  [23:28:49.030] [bbotk]  regr.ranger.num.trees MyMeasure warnings errors runtime_learners
    #> INFO  [23:28:49.030] [bbotk]                   1004  11.96155        0      0            0.112
    #> INFO  [23:28:49.030] [bbotk]                                 uhash
    #> INFO  [23:28:49.030] [bbotk]  d14754c3-ab73-4777-84bd-10daa10318f0
    #> INFO  [23:28:49.043] [bbotk] Finished optimizing after 1 evaluation(s)
    #> INFO  [23:28:49.044] [bbotk] Result:
    #> INFO  [23:28:49.046] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:49.046] [bbotk]               0.3449179               FALSE                    0.344375
    #> INFO  [23:28:49.046] [bbotk]  regr.ranger.num.trees learner_param_vals  x_domain MyMeasure
    #> INFO  [23:28:49.046] [bbotk]                   1004          <list[6]> <list[4]>  11.96155
    #> INFO  [23:28:49.203] [mlr3] Applying learner 'select.regr.ranger.tuned' on task 'mtcars' (iter 5/5)
    #> INFO  [23:28:49.327] [bbotk] Starting to optimize 4 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=1, k=0]'
    #> INFO  [23:28:49.352] [bbotk] Evaluating 1 configuration(s)
    #> INFO  [23:28:49.393] [mlr3] Running benchmark with 1 resampling iterations
    #> INFO  [23:28:49.401] [mlr3] Applying learner 'select.regr.ranger' on task 'mtcars' (iter 1/1)
    #> INFO  [23:28:49.537] [mlr3] Finished benchmark
    #> INFO  [23:28:49.614] [bbotk] Result of batch 1:
    #> INFO  [23:28:49.616] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:49.616] [bbotk]               0.4485645               FALSE                   0.4184389
    #> INFO  [23:28:49.616] [bbotk]  regr.ranger.num.trees MyMeasure warnings errors runtime_learners
    #> INFO  [23:28:49.616] [bbotk]                   1931  12.59067        0      0            0.127
    #> INFO  [23:28:49.616] [bbotk]                                 uhash
    #> INFO  [23:28:49.616] [bbotk]  295d1dc0-810d-4351-9bb4-7255fca38be3
    #> INFO  [23:28:49.629] [bbotk] Finished optimizing after 1 evaluation(s)
    #> INFO  [23:28:49.630] [bbotk] Result:
    #> INFO  [23:28:49.631] [bbotk]  regr.ranger.mtry.ratio regr.ranger.replace regr.ranger.sample.fraction
    #> INFO  [23:28:49.631] [bbotk]               0.4485645               FALSE                   0.4184389
    #> INFO  [23:28:49.631] [bbotk]  regr.ranger.num.trees learner_param_vals  x_domain MyMeasure
    #> INFO  [23:28:49.631] [bbotk]                   1931          <list[6]> <list[4]>  12.59067
    #> INFO  [23:28:49.806] [mlr3] Finished benchmark
    

    Created on 2023-01-30 with reprex v2.0.2