Search code examples
xgboostmlr3proba

xgboost prediction (of cox linear predictor) is different from mlr3 xgboost.cox


I would like to reproduce the fitting (training and subsequent prediction) of an XGBoost model in both mlr3 and xgboost package. See the following example using the Lung dataset, and predicting on the training dataset for simplicity. The linear predictors from xgboost (xgb_pred) and mlr3 (mlr3_xgb$lp) are not quite the same. Any advice on why this might be the case would be greatly appreciated (hopefully it is just a glitch in my coding or a lack of understanding).

library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(mlr3proba)
library(xgboost)

## mlr3 as an example ----
task_lung = tsk('lung')
lung = task_lung$data()
xgb_basic = as_learner(
  po("encode") %>>%  lrn("surv.xgboost.cox",  eta = 0.0103))

set.seed(123)
xgb_basic$train(task_lung)
mlr3_xgb = xgb_basic$predict(task_lung)


## use xgboost package -----
# labels to be attached to dataset
label <- ifelse(lung$status == 0, lung$time, -lung$time) # label
y_lower_bound = lung$time
y_upper_bound = ifelse(lung$status==0, +Inf, lung$time)

xgb_data=model.matrix(~.+0, data = lung[,-c(1,2),with=F]) # one hot coding 

# Data matrix
dmat = xgb.DMatrix(xgb_data, label=label) # for cox


params <- list(objective='survival:cox',  # train
               eval_metric='cox-nloglik',
               learning_rate=0.0103)  #aka eta

set.seed(123) 
bst <- xgb.train(params=params, 
                 data = dmat, 
                 nrounds=1, 
                 watchlist=list(train = dmat, eval=dmat))
#> [1]  train-cox-nloglik:3.896049  eval-cox-nloglik:3.896049

xgb_pred = predict(bst, newdata=dmat)

round(exp(mlr3_xgb$lp),3)
#>   [1] 0.500 0.510 0.496 0.510 0.500 0.503 0.500 0.498 0.498 0.499 0.498 0.510
#>  [13] 0.500 0.501 0.498 0.499 0.499 0.513 0.502 0.513 0.499 0.513 0.510 0.513
#>  [25] 0.499 0.496 0.513 0.499 0.500 0.505 0.500 0.500 0.513 0.505 0.498 0.499
#>  [37] 0.496 0.499 0.498 0.498 0.500 0.499 0.499 0.500 0.503 0.499 0.510 0.510
#>  [49] 0.495 0.499 0.505 0.499 0.499 0.513 0.504 0.499 0.498 0.499 0.505 0.498
#>  [61] 0.503 0.503 0.504 0.496 0.500 0.499 0.498 0.499 0.513 0.499 0.505 0.510
#>  [73] 0.513 0.499 0.495 0.499 0.505 0.499 0.503 0.513 0.505 0.503 0.500 0.513
#>  [85] 0.510 0.500 0.502 0.505 0.505 0.499 0.513 0.500 0.505 0.510 0.496 0.496
#>  [97] 0.499 0.503 0.505 0.496 0.499 0.503 0.513 0.505 0.513 0.499 0.498 0.499
#> [109] 0.503 0.505 0.500 0.510 0.513 0.502 0.502 0.499 0.495 0.504 0.500 0.499
#> [121] 0.498 0.495 0.503 0.499 0.499 0.499 0.501 0.499 0.500 0.503 0.499 0.513
#> [133] 0.499 0.495 0.498 0.499 0.501 0.495 0.505 0.498 0.510 0.503 0.505 0.498
#> [145] 0.499 0.500 0.499 0.495 0.502 0.499 0.513 0.495 0.495 0.495 0.510 0.495
#> [157] 0.505 0.502 0.498 0.513 0.500 0.495 0.496 0.499 0.499 0.500 0.499 0.499
round(xgb_pred,3)
#>   [1] 0.496 0.499 0.495 0.501 0.497 0.495 0.502 0.499 0.499 0.496 0.495 0.501
#>  [13] 0.496 0.496 0.496 0.498 0.496 0.495 0.499 0.495 0.496 0.495 0.501 0.495
#>  [25] 0.495 0.497 0.495 0.496 0.500 0.501 0.500 0.495 0.501 0.503 0.500 0.505
#>  [37] 0.498 0.505 0.505 0.496 0.498 0.496 0.496 0.501 0.499 0.505 0.505 0.495
#>  [49] 0.505 0.505 0.500 0.496 0.500 0.495 0.498 0.496 0.498 0.499 0.503 0.498
#>  [61] 0.495 0.506 0.510 0.506 0.496 0.505 0.505 0.503 0.495 0.496 0.505 0.503
#>  [73] 0.498 0.505 0.500 0.505 0.496 0.499 0.498 0.501 0.503 0.506 0.505 0.495
#>  [85] 0.495 0.502 0.495 0.500 0.497 0.497 0.495 0.496 0.496 0.499 0.505 0.507
#>  [97] 0.496 0.498 0.498 0.502 0.496 0.499 0.501 0.503 0.495 0.505 0.500 0.510
#> [109] 0.499 0.500 0.495 0.495 0.495 0.495 0.510 0.507 0.500 0.506 0.502 0.505
#> [121] 0.503 0.505 0.498 0.510 0.505 0.510 0.496 0.507 0.498 0.499 0.505 0.495
#> [133] 0.500 0.510 0.505 0.506 0.498 0.506 0.497 0.505 0.500 0.498 0.500 0.507
#> [145] 0.496 0.505 0.510 0.498 0.505 0.500 0.495 0.505 0.510 0.510 0.501 0.510
#> [157] 0.503 0.505 0.500 0.498 0.501 0.507 0.505 0.505 0.506 0.501 0.505 0.505

Created on 2024-08-30 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.1 (2021-08-10)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Australia/Adelaide          
#>  date     2024-08-30                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version    date       lib
#>  backports           1.5.0      2024-05-23 [1]
#>  checkmate           2.3.1      2023-12-04 [1]
#>  cli                 3.6.3      2024-06-21 [1]
#>  codetools           0.2-18     2020-11-04 [2]
#>  colorspace          2.1-0      2023-01-23 [1]
#>  crayon              1.4.1      2021-02-08 [2]
#>  data.table          1.15.4     2024-03-30 [1]
#>  dictionar6          0.1.3      2021-09-13 [1]
#>  digest              0.6.36     2024-06-23 [1]
#>  distr6              1.8.4      2024-06-13 [1]
#>  dplyr               1.1.3      2023-09-03 [1]
#>  evaluate            0.24.0     2024-06-10 [1]
#>  fansi               1.0.6      2023-12-08 [1]
#>  fastmap             1.1.0      2021-01-25 [2]
#>  fs                  1.5.0      2020-07-31 [2]
#>  future              1.33.2     2024-03-26 [1]
#>  generics            0.1.3      2022-07-05 [1]
#>  ggplot2             3.5.1      2024-04-23 [1]
#>  globals             0.16.3     2024-03-08 [1]
#>  glue                1.7.0      2024-01-09 [1]
#>  gtable              0.3.5      2024-04-22 [1]
#>  highr               0.9        2021-04-16 [2]
#>  htmltools           0.5.6      2023-08-10 [1]
#>  jsonlite            1.7.2      2020-12-09 [2]
#>  knitr               1.33       2021-04-24 [2]
#>  lattice             0.20-44    2021-05-02 [2]
#>  lgr                 0.4.4      2022-09-05 [1]
#>  lifecycle           1.0.4      2023-11-07 [1]
#>  listenv             0.9.1      2024-01-29 [1]
#>  magrittr            2.0.3      2022-03-30 [1]
#>  Matrix              1.3-4      2021-06-01 [2]
#>  mlr3              * 0.20.0     2024-06-28 [1]
#>  mlr3extralearners * 0.8.0-9000 2024-06-15 [1]
#>  mlr3misc            0.15.1     2024-06-24 [1]
#>  mlr3pipelines     * 0.6.0      2024-07-16 [1]
#>  mlr3proba         * 0.6.3      2024-06-13 [1]
#>  mlr3viz             0.9.0      2024-07-01 [1]
#>  munsell             0.5.1      2024-04-01 [1]
#>  ooplah              0.2.0      2022-01-21 [1]
#>  palmerpenguins      0.1.1      2022-08-15 [1]
#>  paradox             1.0.1      2024-07-09 [1]
#>  parallelly          1.37.1     2024-02-29 [1]
#>  param6              0.2.4      2023-11-22 [1]
#>  pillar              1.9.0      2023-03-22 [1]
#>  pkgconfig           2.0.3      2019-09-22 [2]
#>  R6                  2.5.1      2021-08-19 [1]
#>  Rcpp                1.0.12     2024-01-09 [1]
#>  reprex              2.0.1      2021-08-05 [1]
#>  RhpcBLASctl         0.23-42    2023-02-11 [1]
#>  rlang               1.1.4      2024-06-04 [1]
#>  rmarkdown           2.10       2021-08-06 [2]
#>  rstudioapi          0.15.0     2023-07-07 [1]
#>  scales              1.3.0      2023-11-28 [1]
#>  sessioninfo         1.1.1      2018-11-05 [2]
#>  set6                0.2.6      2023-11-22 [1]
#>  stringi             1.7.3      2021-07-16 [2]
#>  stringr             1.5.0      2022-12-02 [1]
#>  survival            3.7-0      2024-06-05 [1]
#>  tibble              3.2.1      2023-03-20 [1]
#>  tidyselect          1.2.0      2022-10-10 [1]
#>  utf8                1.2.4      2023-10-22 [1]
#>  uuid                1.2-0      2024-01-14 [1]
#>  vctrs               0.6.5      2023-12-01 [1]
#>  withr               3.0.0      2024-01-16 [1]
#>  xfun                0.25       2021-08-06 [2]
#>  xgboost           * 1.7.8.1    2024-07-24 [1]
#>  yaml                2.2.1      2020-02-01 [2]
#>  source                                    
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/distr6@95d7359)             
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3extralearners@6dc6965)
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3pipelines@c542a26)    
#>  Github (mlr-org/mlr3proba@5205752)        
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/param6@0fa3577)             
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  Github (xoopR/set6@a901255)               
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#> 
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

Thank you.


Solution

  • Difficult to say for sure, though I am quite positive that we are doing the correct data transformation. Some possible things to check out:

    • Are you doing the same encoding of the xgboost::xgb.DMatrix? For the cox objective, we just need to negate the label (observed times) for the censored observations: see here for our internal helper functions for xgboost. And a similar question that has some code that does exactly that conversion. I think that part is the same.
    • Is nrounds = 1 in both versions? (recently default changed to 1000, see NEWS)
    • I would try to test in a dataset with no factor encoding, eg tsk("gbcs") to simplify things when investigating such things.
    • You set the watchlist in the manual version, in the mlr3 it's empty (NULL). In the newest version we support internal validation and early stopping of xgboost btw, but for here I would make sure every parameter is the same and that the output raw xgboost model are the same (predictions would then surely follow).