Search code examples
rfeature-selectionmlr

MLR: How can I wrap the selection of specified features around the learner?


I would like to compare simple logistic regressions models where each model considers a specified set of features only. I would like to perform comparisons of these regression models on resamples of the data.

The R package mlr allows me to select columns at the task level using dropFeatures. The code would be something like:

full_task = makeClassifTask(id = "full task", data = my_data, target = "target")
reduced_task = dropFeatures(full_task, setdiff( getTaskFeatureNames(full_task), list_feat_keep))

Then I can do benchmark experiments where I have a list of tasks.

lrn = makeLearner("classif.logreg", predict.type = "prob") 
rdesc = makeResampleDesc(method = "Bootstrap", iters = 50, stratify = TRUE)
bmr = benchmark(lrn, list(full_task, reduced_task), rdesc, measures = auc, show.info = FALSE)

How can I generate a learner that only considers a specified set of features. As far as I know the filter or selection methods always apply some statistical procedure but do not allow to select the features directly. Thank you!


Solution

  • The first solution is lazy and also not optimal because the filter calculation is still carried out:

    library(mlr)
    task = sonar.task
    sel.feats = c("V1", "V10")
    lrn = makeLearner("classif.logreg", predict.type = "prob")
    lrn.reduced = makeFilterWrapper(learner = lrn, fw.method = "variance", fw.abs = 2, fw.mandatory.feat = sel.feats)
    bmr = benchmark(list(lrn, lrn.reduced), task, cv3, measures = auc, show.info = FALSE)
    

    The second one uses the preprocessing wrapper to filter the data and should be the fastest solution and is also more flexible:

    lrn.reduced.2 = makePreprocWrapper(
      learner = lrn, 
      train = function(data, target, args) list(data = data[, c(sel.feats, target)], control = list()),
      predict = function(data, target, args, control) data[, sel.feats]
    )
    bmr = benchmark(list(lrn, lrn.reduced.2), task, cv3, measures = auc, show.info = FALSE)