Search code examples
mlr3

mlr3: Predicted values for surv.coxph learner with case weights


I am trying to get predicted values for surv.coxph learner trained on task with case weights. Please advise on how to fix execution error.

  • Task definition
library(mlr3)
library(mlr3proba)

rm(list=ls())
# Load the rats dataset and create a survival task
task0 = tsk("rats")

# Define task for surv.coxph learner with case weighte
set.seed(123)
selected_features= c("litter", "rx", "sex")
task  = task0$clone()
task$cbind(data.frame(weights = runif(task$nrow, 1, 2)))
task$select(selected_features)
task$col_roles$weight = "weights"
task
## <TaskSurv:rats> (300 x 5): Rats
## * Target: time, status
## * Properties: weights
## * Features (3):
##  - int (2): litter, rx
##  - fct (1): sex
## * Weights: weights
  • Train the model
# Define the learner - Cox Proportional Hazards Model (with case weights)
learner = lrn("surv.coxph")

# Perform a training/test split, stratified on `status` by default
part = partition(task)  # Train/test split balanced by status (default)

# Train the learner on the training split
learner$train(task, part$train)

  • Execution error
# Make predictions on the testing split
p = learner$predict(task, part$test)

## Error in model.frame.default(formula = Surv(time, status, type = "right") ~  : 
##   variable lengths differ (found for '(weights)')




Solution

  • Thanks so much for reporting this. I had to test thoroughly to find out what was the issue, also the weights test we have was failing for ages and never I had figured out why! The issue is now fixed in [email protected].

    The survival package author mentions this actually in the survfit.coxph() documentation:

    If the following pair of lines is used inside of another function then the model=TRUE argument must be added to the coxph call: fit <- coxph(...); survfit(fit). This is a consequence of the non-standard evaluation process used by the model.frame function when a formula is involved.

    This is exactly what mlr3proba does: uses the coxph(...) during $train() and survfit(fit) during $predict(). Using just the task features we could never encounter this error, but with the weights argument the model.matrix becomes a bit more complex, which caused the issue you mentioned.

    Code with some comments:

    task = tsk("rats")
    task$cbind(data.frame(weights = runif(task$nrow, 1, 2)))
    task$set_col_roles(cols = "weights", roles = "weight") # removes "weights" from features
    task$set_col_roles(cols = "status", add_to = "stratum") # keeps "status" as a target column
    # YOU NEED TO DO TEH ABOVE IF YOU WANT STRATIFICATION BY CENSORING STATUS IN RESAMPLING
    part = partition(task)
    
    cox = lrn("surv.coxph")
    p = cox$train(task, part$train)$predict(task, part$test)