Search code examples
rcross-validationglmnet

Extract the best parameters from cva.glmnet object


I'm sure there is an elegant way to extract the best alpha and lambda after running cva.glmnet but somehow I cannot find it.

Here is the code I am using in the meantime.

Thank you

library(data.table);library(glmnetUtils);library(useful)

# make some dummy data

data(iris)

x <- useful::build.x(data = iris,formula = Sepal.Length ~ .)
y <- iris$Sepal.Length

# run cv for alpha in c(0,0.5,1)

output.of.cva.glmnet <- cva.glmnet(x=x,y=y,alpha = c(0,0.5,1))

# extract the best parameters

number.of.alphas.tested <- length(output.of.cva.glmnet$alpha)

cv.glmnet.dt <- data.table()

for (i in 1:number.of.alphas.tested){
  glmnet.model <- output.of.cva.glmnet$modlist[[i]]
  min.mse <-  min(glmnet.model$cvm)
  min.lambda <- glmnet.model$lambda.min
  alpha.value <- output.of.cva.glmnet$alpha[i]
  new.cv.glmnet.dt <- data.table(alpha=alpha.value,min_mse=min.mse,min_lambda=min.lambda)
  cv.glmnet.dt <- rbind(cv.glmnet.dt,new.cv.glmnet.dt)
}

best.params <- cv.glmnet.dt[which.min(cv.glmnet.dt$min_mse)]

enter image description here


Solution

  • Based on a thread I read on GitHub the author wants people to use plot(fit) instead of just outputting the best parameters. However, that isn't always possible, especially when cross validation is involved. These helper functions can be a good workaround.

    # Train model.
    fit <- cva.glmnet(X, y)
    
    # Get alpha.
    get_alpha <- function(fit) {
      alpha <- fit$alpha
      error <- sapply(fit$modlist, function(mod) {min(mod$cvm)})
      alpha[which.min(error)]
    }
    
    # Get all parameters.
    get_model_params <- function(fit) {
      alpha <- fit$alpha
      lambdaMin <- sapply(fit$modlist, `[[`, "lambda.min")
      lambdaSE <- sapply(fit$modlist, `[[`, "lambda.1se")
      error <- sapply(fit$modlist, function(mod) {min(mod$cvm)})
      best <- which.min(error)
      data.frame(alpha = alpha[best], lambdaMin = lambdaMin[best],
                 lambdaSE = lambdaSE[best], eror = error[best])
    }