Search code examples
survival-analysissurvivalmlr3

Model Interpretability for Survival task in MLR3


I have performed previously model interpretation (pdp) on my survival models in MLR. However, I am unable to do it in MLR3 due to the "Predictor" object that is not accept models from type survival.

I have attached below a sample code from MLR, is there anyways to do model interpretation in MLR3?

mod = train(lrn,surv.task)

task.pred = predict(mod,newdata = traindata)

getLearnerModel(mod)

pd = generatePartialDependenceData(mod,surv.task,c("X","Y","Z"))

plotPartialDependence(pd)

as for c("X","Y","Z")) #These are three random features (I want to be able to do it for all features and for selected ones like in MLR)


Solution

  • to my knowledge mlr3 cannot directly apply interpretable machine learning methods to survival models. However, there is the survex package, which can perform a plethora of explainability analyses and is compatible with mlr3proba (survex package GitHub link)

    Here is a minimal code example how you can generate PDPs from a fitted mlr3proba model, you can check the package vignettes for more information:

    library(mlr3proba)
    library(mlr3extralearners)
    library(mlr3pipelines)
    library(survex)
    library(survival)
    
    veteran_task <- as_task_surv(veteran,
                         time = "time",
                         event = "status",
                         type = "right")
    ranger_learner <- lrn("surv.ranger")    
    ranger_learner$train(veteran_task)
    ranger_learner_explainer <- explain(ranger_learner, 
                         data = veteran[, -c(3,4)],
                         y = Surv(veteran$time, veteran$status),
                         label = "Ranger model")
    
    model_profile <- model_profile(explainer)
    
    plot <- plot(model_profile, 
                 variables = "celltype")
    plot