Search code examples
rresamplingmlr3

How to extract mlr3 tuned graph step by step?


My codes in following

library(mlr3verse)
library(mlr3pipelines)
library(mlr3filters)
library(paradox)
filter_importance = mlr_pipeops$get(
  "filter",
  filter = FilterImportance$new(learner = lrn("classif.ranger", importance = "impurity")),
  param_vals = list(filter.frac = 0.7)
)

learner_classif = lrn(
  "classif.ranger",
  predict_type = "prob",
  importance = "impurity",
  num.trees = 500
)
polrn_classif = PipeOpLearner$new(learner_classif)

# create learner graph 
glrn_classif = filter_importance %>>%  polrn_classif
glrn_classif = GraphLearner$new(glrn_classif)
glrn_classif$predict_type = "prob"

# task 

task = tsk("german_credit")

# set search_space
ps_classif = ParamSet$new(list(
  ParamInt$new("classif.ranger.num.trees", lower = 300, upper = 500),
  ParamDbl$new("classif.ranger.sample.fraction", lower = 0.7, upper = 0.8)
))

# auto tunning
at = AutoTuner$new(
  learner = glrn_classif, 
  resampling = rsmp("cv", folds = 3),
  measure = msr("classif.auc"), 
  search_space = ps_classif, 
  terminator = trm("evals", n_evals = 3), 
  tuner = tnr("random_search")
)

# sampling
rr = resample(task, at, rsmp("cv", folds = 2))

After i have rr object from resampling and trained learner at. May i ask how to extract what these steps were doing?

Ex:

  • How i can rerun manually when i had results from at object?
  • Which sample was used for each step (train_index, test_index)?
  • Which variables is selected from filter_importance step? Which score of each variables in this step?

Many thanks !!!


Solution

  • To be able to fiddle with the models after resampling its best to call resample with store_models = TRUE

    Using your example

    library(mlr3verse)
    
    set.seed(1)
    rr <- resample(task,
                   at,
                   rsmp("cv", folds = 2),
                   store_models = TRUE)
    

    After you have finished resampling you can access the inner structure of the generated object like this:

    To get the row ids in each fold:

    rr$resampling$instance
    #output
          row_id fold
       1:      5    1
       2:      8    1
       3:      9    1
       4:     12    1
       5:     13    1
      ---            
     996:    989    2
     997:    993    2
     998:    994    2
     999:    995    2
    1000:    996    2
    

    with these and the tuned autotuners we can manually generate the predictions.

    Generate a list of test indexes

    rsample <- split(rr$resampling$instance$row_id,
                     rr$resampling$instance$fold)
    

    iterate over the folds and tuned autotuners and predict:

    lapply(1:2, function(i){
      x <- rsample[[i]] #get the test row ids
      task_test <- task$clone() #clone the task so we don't change the original task
      task_test$filter(x) #filter on the test row ids
      preds <- rr$learners[[i]]$predict(task_test) #use the trained autotuner and above filtered task
      preds
      }) -> preds_manual
    

    To check if these predictions match the ouput of resample

    all.equal(preds_manual,
              rr$predictions())
    #output
    TRUE
    

    To get information about the tuning

    zz <- rr$data$learners()$learner
    
    lapply(zz, function(x) x$tuning_result)
    #output
    [[1]]
       classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
    1:                      342                      0.7931022          <list[7]>
        x_domain classif.auc
    1: <list[2]>   0.7981283
    
    [[2]]
       classif.ranger.num.trees classif.ranger.sample.fraction learner_param_vals
    1:                      407                      0.7964164          <list[7]>
        x_domain classif.auc
    1: <list[2]>   0.7706533
    

    the slot

    zz[[1]]$learner$state$model$importance
    

    contains info about filter_importance step

    specifically

    lapply(zz, function(x) x$learner$state$model$importance$scores)
    #output
    [[1]]
                     amount                  status                     age 
                  27.491369               25.776145               22.021369 
                   duration                 purpose          credit_history 
                  18.732521               16.251643               14.884843 
        employment_duration                 savings                property 
                  11.225678               10.796583                9.078619 
        personal_status_sex       present_residence        installment_rate 
                   8.914802                7.875384                7.491573 
                        job          number_credits other_installment_plans 
                   6.293323                5.662485                5.345666 
                    housing               telephone           other_debtors 
                   4.869471                3.742213                3.548856 
              people_liable          foreign_worker 
                   2.632163                1.054919 
    
    [[2]]
                     amount                duration                     age 
                  26.764389               22.139400               20.749865 
                     status                 purpose     employment_duration 
                  20.524764               11.793789               10.962301 
             credit_history        installment_rate                 savings 
                  10.416572                9.597835                9.491894 
                   property       present_residence                     job 
                   9.403157                7.877391                6.760945 
        personal_status_sex                 housing other_installment_plans 
                   6.699065                5.811131                5.710761 
                  telephone           other_debtors          number_credits 
                   4.716322                4.318972                3.974793 
              people_liable          foreign_worker 
                   3.196563                0.846520 
    

    contains the ranking of the features. While

    lapply(zz, function(x) x$learner$state$model$importance$outtasklayout)
    #output
    [[1]]
                         id    type
     1:                 age integer
     2:              amount integer
     3:      credit_history  factor
     4:            duration integer
     5: employment_duration  factor
     6:    installment_rate ordered
     7:                 job  factor
     8:      number_credits ordered
     9: personal_status_sex  factor
    10:   present_residence ordered
    11:            property  factor
    12:             purpose  factor
    13:             savings  factor
    14:              status  factor
    
    [[2]]
                         id    type
     1:                 age integer
     2:              amount integer
     3:      credit_history  factor
     4:            duration integer
     5: employment_duration  factor
     6:             housing  factor
     7:    installment_rate ordered
     8:                 job  factor
     9: personal_status_sex  factor
    10:   present_residence ordered
    11:            property  factor
    12:             purpose  factor
    13:             savings  factor
    14:              status  factor
    

    contains the features that were kept after the filter step.