Search code examples
rcross-validationglmnetcoxelasticnet

How can I do repeated cross validation for an elastic-net penalized cox model in R?


I would like to know what would be the best method to do 10x repeated 10-fold cross validation for an elastic-net penalized cox model in R. I am currently using the package glmnet, tuning the alpha first and second the lambda. I would like to use the C-index as a performance metric to choose the best lambda, but I have not been able to find where it is stored in my glmnet object. I was thinking to make a loop and set a seed inside the loop to repeat the CV 10 times, but without the C-index for the different lambda values, I am not sure how to choose the best lambda across models. I get visualize the C-index for the different lambda values using plot(glmnetfit), but I would have to visually inspect the plots for each repeats, which does not sound very accurate to me and becomes difficult if there are more than 10 repeats.

Here is the code I have been using:

library(glmnet)
library(survival)

data("CoxExample")
x <- CoxExample$x
y <- CoxExample$y

alphas <- seq(0, 1, by = 0.1) 

# Fit models with different alpha values
models <- lapply(alphas, function(alpha) {
  glmnet(x = x, y = y, family = "cox", alpha = alpha)
})

# Optimize alpha
set.seed(123)
cv_results <- lapply(seq_along(models), function(i) {
  alpha <- alphas[[i]]
  print(alpha)
  lambda <- models[[i]]$lambda
  if (is.null(alpha)) {
    alpha <- 1  # Default to lasso if alpha is missing
  }
  cv.glmnet(x = x, y = y, family = "cox", 
            alpha = alpha, lambda = lambda)
})

optimal_alpha_index <- sapply(cv_results, function(cv_result) which.min(cv_result$cvm))
optimal_cvm <- vector("list", length = length(cv_results))
for (i in seq_along(cv_results)) {
  optimal_index <- which.min(cv_results[[i]]$cvm)  # Get the index of the optimal model for alpha i
  optimal_cvm[[i]] <- cv_results[[i]]$cvm[optimal_index]  # Extract the corresponding cross-validation error
}


optimal_alpha = alphas[which.min(optimal_cvm)]
optimal_alpha #1

#Optimize lambda
models <- list()

for (i in 1:10) {
  set.seed(i*123)
  models[[i]] <- cv.glmnet(x=x, 
                           y = y, 
                           family = "cox", 
                           alpha = optimal_alpha, 
                           type.measure = "C"
  )
  print(models[[i]]$lambda.min)
}

optimal_lambda_index <- sapply(models, function(models) which.min(models$cvm))

I end up with the best lambda for each of the 10 repeats. I would like to extract the C-index for the best lambda in each model, or if not possible hear what would be the best approach to CV repeats. Any help would be appreciated.

Thanks!


Solution

  • You are setting type.measure to "C" for C-index in the cross-validation procedure, but the extraction and usage can be refined.

    library(glmnet)
    library(survival)
    library(foreach)
    library(doParallel)
    
    data("CoxExample")
    x <- CoxExample$x
    y <- CoxExample$y
    
    alphas <- seq(0, 1, by = 0.1) 
    n_repeats <- 10
    
    results <- vector("list", length = n_repeats)
    
    num_cores <- detectCores()-1  
    registerDoParallel(cores = num_cores)
    
    # parallel execution using foreach package
    results <- foreach(r = 1:n_repeats, .packages = 'glmnet') %dopar% {
      set.seed(r * 123)
      
      # Using different alpha values to fit the model
      models <- lapply(alphas, function(alpha) {
        cv.glmnet(x = x, y = y, family = "cox", alpha = alpha, type.measure = "C")
      })
      
      # Identifies the best lambda for each alpha based on minimum C-index
      best_lambda_per_alpha <- sapply(models, function(model) {
        lambda_min <- model$lambda.min
        c_index <- max(model$cvm)  
        return(c(lambda = lambda_min, c_index = c_index))
      })
      
      # Finding the alpha and lambda combination that gives the best C-index
      optimal_alpha_index <- which.max(best_lambda_per_alpha["c_index",])
      optimal_alpha <- alphas[optimal_alpha_index]
      optimal_lambda <- best_lambda_per_alpha["lambda", optimal_alpha_index]
      
      list(alpha = optimal_alpha, lambda = optimal_lambda, c_index = best_lambda_per_alpha["c_index", optimal_alpha_index])
    }
    
    do.call(rbind, results)