Search code examples
mlr3

mlr3: encoding with GraphLearner


I am wondering, why the output at the bottom of this page implies that state of learner object is untrained.

# Load Libraries
packages <- c("mlr3", "mlr3proba", "mlr3learners", "mlr3extralearners", "glmnet", "mlr3pipelines")
lapply(packages, require, character.only = TRUE)
## Load the built-in pbc task
task = tsk("pbc")
## Define a pipeline with one-hot encoding preprocessing step using mlr3pipelines
po_onehot = po("encode", method = "one-hot")
## Define the learner and set parameters
learner0 = lrn("surv.cv_glmnet", alpha = 0.5)
# train the full pipeline
glrn0 = po_onehot %>>% po("learner", learner0)
learner = GraphLearner$new(glrn0)
learner$train(task)

Output: Verify the state of the learner

> print(learner$graph)
Graph with 2 PipeOps:
             ID         State       sccssors prdcssors
         <char>        <char>         <char>    <char>
         encode <<UNTRAINED>> surv.cv_glmnet          
 surv.cv_glmnet <<UNTRAINED>>                   encode

I anticipated <<TRAINED>> string in state column of the output.


Solution

  • Please read the documentation for the GraphLearner, the $graph field contains only the Graph object prototype. The $graph_model field has the trained model, so in your case learner$graph_model$is_trained = TRUE and print(learner$graph_model) will show you the produced objects in the State column.