Search code examples
rmlr3

Predict on holdout set using mlr3


I am using mlr3 package. I set roles of some rows to "holdout" and than trained the model:

library(mlr3)

# train on iris
task = tsk("iris")
task$nrow
task$set_row_roles(130:150, "holdout")
learner = lrn("classif.rpart")
learner$train(task)

How can I know use holdout set to make prediction on them?

# predict on holdout
task$row_roles$holdout
## HOW TO PREDICT ON HOLDOUT SET ? 
# learner$predict()

Solution

  • You can set the row ids that you want to get the prediction for in the $predict() method.

    library(mlr3)
    
    task = tsk("iris")
    task$nrow
    #> [1] 150
    task$set_row_roles(130:150, "holdout")
    learner = lrn("classif.rpart")
    learner$train(task)
    
    learner$predict(task, row_ids = task$row_roles$holdout)
    #> <PredictionClassif> for 21 observations:
    #>     row_ids     truth   response
    #>         130 virginica versicolor
    #>         131 virginica  virginica
    #>         132 virginica  virginica
    #> ---                             
    #>         148 virginica  virginica
    #>         149 virginica  virginica
    #>         150 virginica  virginica
    

    Created on 2023-02-10 with reprex v2.0.2