Search code examples
rmachine-learningsurvival-analysisgbmboosting

Making sense of gbm survival prediction model


I am a newbie in using and making sense of ML methods and currently doing survival analysis using gbm package in R.

I have difficulty understanding some of the output of the survival prediction model. I have checked this tutorial and this post but still, find trouble in making sense of the outputted survival prediction model.

Here is my code for analysis based on example data:

rm(list=ls(all=TRUE))
library(randomForestSRC)
library(gbm)
library(survival)
library(Hmisc)

data(pbc, package="randomForestSRC")
data <- na.omit(pbc)

set.seed(9512)
train <- sample(1:nrow(data), round(nrow(data)*0.7))
data.train <- data[train, ]
data.test <- data[-train, ]

set.seed(9741)
model <- gbm(Surv(days, status)~.,
           data.train,
           interaction.depth=2,
           shrinkage=0.01,
           n.trees=500,
           distribution="coxph",
           cv.folds = 5)

summary(model)

best.iter <- gbm.perf(model, plot.it = TRUE, method = 'cv',
                      overlay = TRUE) #to get the optimal number of Boosting iterations
best.iter

#Us the best number of tree to produce predicted values for each observation in newdata 
# return a vector of prediction on n.trees indicting log hazard scale.f(x)
# By default the predictions are on log hazard scale for coxph
# proportional hazard model assumes h(t|x)=lambda(t)*exp(f(x)).
# estimate the f(x) component of the hazard function
pred.train <- predict(object=model, newdata=data.train, n.trees = best.iter)
pred.test <- predict(object=model, newdata=data.test, n.trees = best.iter)


#trainig set
Hmisc::rcorr.cens(-pred.train, Surv(data.train$days, data.train$status))
#val set
Hmisc::rcorr.cens(-pred.test, Surv(data.test$days, data.test$status))

# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.train$days,       #The survival times.
                           delta=data.train$status, #The censoring indicator
                           f.x=pred.train,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = data.train$days,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.

basehaz.cum

#Estimation of survival rate of all:
surv.rate <- exp(-exp(pred.train)*basehaz.cum)
surv.rate

res_train <- data.train
# predicted outcome for train set
res_train$pred <- pred.train
res_train$survival_rate <- surv.rate
res_train


# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.test$days,       #The survival times.
                           delta=data.test$status, #The censoring indicator
                           f.x=pred.test,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = data.test$days,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.

basehaz.cum
#Estimation of survival rate of all at specified time is:
surv.rate <- exp(-exp(pred.test)*basehaz.cum)
surv.rate

res_test <- data.test
# predicted outcome for test set
res_test$pred <- pred.test
res_test$survival_rate <- surv.rate
res_test

#--------------------------------------------------
#Estimate survival rate at time of interest

# Specify time of interest
time.interest <- sort(unique(data.train$days[data.train$status==1]))

# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.train$days,       #The survival times.
                           delta=data.train$status, #The censoring indicator
                           f.x=pred.train,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = time.interest,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.


#For individual $i$ in test set, estimation of survival function is:
surf.i <- exp(-exp(pred.test[1])*basehaz.cum) #survival rate

#Estimation of survival rate of all at specified time is:
specif.time <- time.interest[10]
surv.rate <- exp(-exp(pred.test)*basehaz.cum[10])
cat("Survival Rate of all at time", specif.time, "\n")
print(surv.rate)

The output returned from the predict function represents the f(x) component of the hazard function ( h(t|x)=lambda(t)*exp(f(x)) ).

My questions:

• A bit confused about whether hazard ratios can be calculated here?

• Wondering how can I divide the population into low-risk and high-risk groups? Can I rely on the estimated f(x) component of the hazard function to do the scoring system for the training set? I aim from this to have a scoring system where I show KM plots for low and high-risk groups for training and test sets.

• How can I construct calibration curve plots where I can plot observed survival vs. predicted survival for the training set and test set?


Solution

  • Amer. Thx for your reading of my tutorial!

    As you mentioned that "The output returned from the predict function represents the f(x) component of the hazard function ( h(t|x)=lambda(t)*exp(f(x)) )", maybe we need to understand the hazard function, i.e. h(t|x).

    Before this, please sure that you have the basic knowledge of survival analysis. if not, it's recommended to read the great post. I think the post would help you solve the questions.

    Back to your questions:

    • Exactly, we can get the hazard ratios of log scale by invoking the predict function. Therefore, the hazard ratio can be calculated by exp() .
    • Sure! Relying on the values of hazard ratio, we can divide the population into low-risk and high-risk groups. Alternatively, you can use the median of hazard ratios as the cutoff value. I think the cutoff value should be derived from the training set, and then test in the test set. If your model is effective, KM plots for low and high-risk groups would have a significant difference (measured by log-rank test statistically).
    • Calibration curve plots are often used to evaluated the performance of model that outputs probabilities or likelihoods ranged from [0.0, 1.0]. We can calculate the survival function, and then specify a time point of interest, e.g. 5-Year. At last, we compare the survival probabilities with the actual survival state at the specified time, which is just the same as we do evaluating a binary classification model. More details of obtaining survival function can refer to my tutorial, and the principles can be found in that post aforementioned.