Search code examples
rdplyrpurrrglmnet

dplyr accessing a subelement of all elements of a list


I'm trying to use safely() from the purrr package together with the lasso regression from the glmnet package. I'm stuck in the cross validation part since safely() returns a list with two elements, $results and $errors. I'm trying to get only the $results by using dplyr, but cannot get it to work.

I can get the cross validation to work for a single species, but not for all of them using dplyr.

library(dplyr)
library(glmnet)
library(purrr)
data(iris)

# Group to perform regressions for every species
grouped <- iris %>% 
  group_by(Species)

# Make model matrices
mm <- grouped %>%
  do(mm = safely(model.matrix)(Sepal.Length ~ Sepal.Width + Petal.Width, data = .)[1])

# Join the dependent values with the model matrices in a tibble
sepallengths <- grouped %>% 
  summarise(Sepal.Length = list(Sepal.Length))
temp <- inner_join(sepallengths, mm, by = "Species")

# Perform cross validation using the tibble above
cv_holder <- temp %>% 
  group_by(Species) %>% 
  # How to get this to work inside dplyr?
  do(cv = safely(cv.glmnet)(.$mm[1]$result, .$Sepal.Length, alpha = 1, nlambda = 100))

# Contains only errors when it should contain the cross validations
cv_holder$cv

# Works for individual lists this way
safely(cv.glmnet)(temp$mm[[1]][1]$result, temp$Sepal.Length[[1]], alpha = 1, nlambda = 100)$result

I expect the output to be a tibble (cv_holder) with a column (cv) that contains lists containing lists of the cross validations for each species. However, I can get dplyr to return only errors such as "simpleError in rep(1, N): invalid 'times' argument"

This is how it could be done by looping:

for(i in 1:length(temp$mm)){
    print(safely(cv.glmnet)(temp$mm[[i]][1]$result, temp$Sepal.Length[[i]], alpha = 1, nlambda = 100))
cv_holder$error <- NULL
}

Solution

  • I got it to work with purrr's pluck(), which selects the first item from each list:

    cv_holder <- temp %>% 
        group_by(Species) %>% 
        # Using pluck()
        do(cv = safely(cv.glmnet)(pluck(.$mm, 1)$result, pluck(.$Sepal.Length, 1), alpha = 1, nlambda = 100))
    
    # Now works as intended
    cv_holder$cv