Search code examples
rmachine-learninggbmmlr

makeClassif with MLR - ID column excluded from task


I have data that has an ID column in it. I drop this column from my trainTask as it is not a feature. However, I would like to link the prediction probability up with the actual ID number in the data.

The column I want to match on is Init_Acct which is the ID number in the data.frame

My code is below:

# Make classif tasks
trainTask <- makeClassifTask(
  data = train.df %>% dplyr::select(-Init_Acct) # Init_Acct is the ID I want to match
  , id
  , target = "READMIT_FLAG"
  , positive = "Y"
)
testTask <- makeClassifTask(
  data = test.df %>% dplyr::select(-Init_Acct)
  , target = "READMIT_FLAG"
  , positive = "Y"
)

# Check trainTask and testTask
trainTask <- smote(trainTask, rate = 6)
testTask <- smote(testTask, rate = 6)

# GBM ####
getParamSet('classif.gbm')
gbm.learner <- makeLearner(
  'classif.gbm'
  , predict.type = 'prob'
)
plotLearnerPrediction(gbm.learner, trainTask)

# Tune model
gbm.tune.ctl <- makeTuneControlRandom(maxit = 50L)

# Cross validation
gbm.cv <- makeResampleDesc("CV", iters = 3L)

# Grid search - Hyper-parameter space
gbm.par <- makeParamSet(
  makeDiscreteParam('distribution', values = 'bernoulli')
  , makeIntegerParam('n.trees', lower = 10, upper = 1000)
  , makeIntegerParam('interaction.depth', lower = 2, upper = 10)
  , makeIntegerParam('n.minobsinnode', lower = 10, upper = 80)
  , makeNumericParam('shrinkage', lower = 0.01, upper = 1)
)

# Tune Hyper-parameters
parallelMap::parallelStartSocket(
  4
  , level = "mlr.tuneParams"
)
gbm.tune <- tuneParams(
  learner = gbm.learner
  , task = trainTask
  , resampling = gbm.cv
  , measures = acc
  , par.set = gbm.par
  , control = gbm.tune.ctl
)

parallelMap::parallelStop()

# Check CV acc
gbm.tune$y
gbm.tune$x

# Set hyper-parameters
gbm.ps <- setHyperPars(
  learner = gbm.learner
  , par.vals = gbm.tune$x
)

# Train gbm
gbm.train <- train(gbm.ps, testTask)
plotLearningCurve(
  generateLearningCurveData(
    gbm.learner
    , testTask
  )
)

# Predict
gbm.pred <- predict(gbm.train, testTask)
plotResiduals(gbm.pred)

# Create submission file
gbm.submit <- data.frame(
  gbm.pred$data
)
head(gbm.submit, 5)
table(gbm.submit$truth, gbm.submit$response)

# Confusion Matrix
calculateConfusionMatrix(gbm.pred)
calculateROCMeasures(gbm.pred)
conf_mat_f1_func(gbm.pred)

perf_plots_func(Model = gbm.pred)

Data would look something like this:

glimpse(train.df)
Observations: 33,031
Variables: 17
$ Init_Acct         <chr> "12345678", "87654321", "81734650", "11223344", "1422...
$ Init_LOS          <dbl> 2, 2, 5, 1, 12, 3, 16, 9, 3, 14, 1, 1, 4, 7, 4, 1, 3,...
$ Init_LACE         <dbl> 2, 7, 7, 9, 8, 8, 11, 10, 8, 10, 5, 4, 8, 8, 4, 5, 3,...
$ READMIT_FLAG      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, Y,...
$ Init_Hosp_Pvt     <fct> PRIVATE, HOSPITALIST, HOSPITALIST, HOSPITALIST, PRIVA...
$ Age_at_Init_Admit <dbl> 37, 26, 56, 67, 51, 53, 48, 57, 92, 67, 72, 22, 60, 6...
$ Age_Bucket        <fct> 3, 2, 5, 6, 5, 5, 4, 5, 9, 6, 7, 2, 6, 6, 7, 6, 9, 5,...
$ Gender            <fct> F, M, F, M, M, F, M, F, M, M, M, F, M, F, F, F, F, M,...
$ Init_ROM          <dbl> 1, 1, 3, 4, 2, 3, 1, 3, 4, 1, 1, 1, 1, 1, 2, 1, 2, 4,...
$ Init_SOI          <dbl> 1, 1, 3, 4, 3, 3, 3, 3, 3, 2, 1, 1, 2, 2, 2, 2, 2, 4,...
$ Has_Diabetes      <fct> N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,...
$ reduced_dispo     <fct> AHR, AHR, AHR, ATH, ATW, ATW, ATW, AHR, AHR, ATW, AHR...
$ reduced_hsvc      <fct> SUR, MED, MED, Other, MED, MED, MED, MED, MED, MED, M...
$ reduced_abucket   <fct> 3, 2, 5, 6, 5, 5, 4, 5, Other, 6, 7, 2, 6, 6, 7, 6, O...
$ reduced_spclty    <fct> Other, HOSIM, HOSIM, HOSIM, Other, HOSIM, HOSIM, HOSI...
$ reduced_lihn      <fct> Other, Medical, Pneumonia, Medical, Medical, Medical,...
$ discharge_month   <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,...

The output:

glimpse(gbm.submit)
Observations: 23,896
Variables: 5
$ id       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,...
$ truth    <fct> Y, N, N, N, N, N, N, Y, N, N, N, N, N, N, Y, N, Y, N, N, N, N,...
$ prob.N   <dbl> 0.9150623, 0.7914781, 0.9661108, 0.9198683, 0.8502536, 0.94376...
$ prob.Y   <dbl> 0.08493774, 0.20852192, 0.03388919, 0.08013167, 0.14974644, 0....
$ response <fct> N, N, N, N, N, N, N, Y, N, N, N, N, N, N, N, N, Y, N, N, N, N,...

Solution

  • MLR's predict() preserves row names and produces an additional id column in its output that indexes the original data. You can use either one to associate predictions with their original sample IDs.

    Setup

    library(tidyverse)
    library(mlr)
    
    ## Add a custom sample ID column
    iris2 <- iris %>% mutate(Init_Acct = paste0("Acct",1:n()))
    lrn <- makeLearner( "classif.gbm", predict.type="prob" )
    

    Option 1: Using id column to index the original data

    ## Drop the custom column as in your original post
    task <- makeClassifTask( data=select(iris2, -Init_Acct), target="Species" )
    mdl <- train( lrn, task )
    pred <- predict( mdl, task )
    
    ## Join against the original data by the "id" column
    iris2 %>% mutate(id=1:n()) %>% select(Init_Acct, id) %>% 
        inner_join( pred$data ) %>% select(-id)
    #   Init_Acct  truth prob.setosa prob.versicolor prob.virginica response
    # 1     Acct1 setosa   0.9998775    1.225043e-04   2.836942e-08   setosa
    # 2     Acct2 setosa   0.9999652    3.468690e-05   1.118015e-07   setosa
    # 3     Acct3 setosa   0.9999538    4.611200e-05   8.389636e-08   setosa
    

    Option 2: Use rownames

    ## Store the sample names into rownames
    task <- makeClassifTask( data=column_to_rownames(iris2, "Init_Acct"),
                             target="Species" )
    mdl <- train( lrn, task )
    pred <- predict( mdl, task )
    
    ## Pull the rownames back out into their own column
    pred$data %>% rownames_to_column( "Init_Acct" ) %>% select(-id)
    #     Init_Acct      truth  prob.setosa prob.versicolor prob.virginica   response
    # 1       Acct1     setosa 9.999266e-01    7.331226e-05   6.889259e-08     setosa
    # 2       Acct2     setosa 9.999751e-01    2.462816e-05   3.154618e-07     setosa
    # 3       Acct3     setosa 9.999656e-01    3.421543e-05   1.449155e-07     setosa