I am trying to get predicted values for surv.coxph
learner trained on task with case weights. Please advise on how to fix execution error.
library(mlr3)
library(mlr3proba)
rm(list=ls())
# Load the rats dataset and create a survival task
task0 = tsk("rats")
# Define task for surv.coxph learner with case weighte
set.seed(123)
selected_features= c("litter", "rx", "sex")
task = task0$clone()
task$cbind(data.frame(weights = runif(task$nrow, 1, 2)))
task$select(selected_features)
task$col_roles$weight = "weights"
task
## <TaskSurv:rats> (300 x 5): Rats
## * Target: time, status
## * Properties: weights
## * Features (3):
## - int (2): litter, rx
## - fct (1): sex
## * Weights: weights
# Define the learner - Cox Proportional Hazards Model (with case weights)
learner = lrn("surv.coxph")
# Perform a training/test split, stratified on `status` by default
part = partition(task) # Train/test split balanced by status (default)
# Train the learner on the training split
learner$train(task, part$train)
# Make predictions on the testing split
p = learner$predict(task, part$test)
## Error in model.frame.default(formula = Surv(time, status, type = "right") ~ :
## variable lengths differ (found for '(weights)')
Thanks so much for reporting this. I had to test thoroughly to find out what was the issue, also the weights test we have was failing for ages and never I had figured out why! The issue is now fixed in [email protected]
.
The survival
package author mentions this actually in the survfit.coxph()
documentation:
If the following pair of lines is used inside of another function then the model=TRUE argument must be added to the coxph call: fit <- coxph(...); survfit(fit). This is a consequence of the non-standard evaluation process used by the model.frame function when a formula is involved.
This is exactly what mlr3proba
does: uses the coxph(...)
during $train()
and survfit(fit)
during $predict()
. Using just the task features we could never encounter this error, but with the weights
argument the model.matrix
becomes a bit more complex, which caused the issue you mentioned.
Code with some comments:
task = tsk("rats")
task$cbind(data.frame(weights = runif(task$nrow, 1, 2)))
task$set_col_roles(cols = "weights", roles = "weight") # removes "weights" from features
task$set_col_roles(cols = "status", add_to = "stratum") # keeps "status" as a target column
# YOU NEED TO DO TEH ABOVE IF YOU WANT STRATIFICATION BY CENSORING STATUS IN RESAMPLING
part = partition(task)
cox = lrn("surv.coxph")
p = cox$train(task, part$train)$predict(task, part$test)