Search code examples
rtidyversecross-validationpurrrtidymodels

rsample vfold_cv function not accepting .y parameter from purrr::map2


I'm trying to create nested cross-validations using the rsample package, and I use purrr::map2 to create them multiple times, with differing amount of folds as dictated by the v parameter. However, the vfold_cv function does not accept the v parameter, and instead I get this error: Error: v must be a single integer.

In the reprex below, I'm simulating the situation using the mtcars data, by creating a cross validation for each cylinder. Replacing .y with a number works, but I need the parameter to vary with each cylinder by using the n column.

library(purrr)
library(parsnip)
library(rsample)
library(tidyr)

data("mtcars")

nested <- mtcars %>% 
    select(cyl, disp:gear) %>% 
    group_by(cyl) %>% 
    nest(data = disp:gear) %>% 
    cbind(n = 2:4)

nested %>% 
    group_by(cyl) %>% 
    mutate(cv = map2(data, n,
                     ~nested_cv(.x,
                                inside = vfold_cv(v = 10, repeats = 3),
                                outside = vfold_cv(v = .y))))

Error: `v` must be a single integer.


Solution

  • It's vfold_cv function inside nested_cv, you can try it:

    createNested = function(x,y){
        nested_cv(x,inside = vfold_cv(v = 10, repeats = 3),outside = vfold_cv(v = y))
    }
    
    createNested(nested$data[[1]],3)
    Error in vfold_splits(data = data, v = v, strata = strata, breaks = breaks) : 
      object 'y' not found
    

    So it cannot see the y variable (like your .y) inside the function. So I wrote a function to explicitly pass the results of vfold_cv() for outside into nested_cv(), a few more lines of code but it's ok:

    createNested = function(x,y){
        outside_cv = vfold_cv(x,v = y)
        nested_cv(x,inside = vfold_cv(v = 10, repeats = 3),outside = outside_cv)
    }
    
    nested <- mtcars %>% 
    select(cyl, disp:gear) %>% 
    nest(data = disp:gear) %>%
    mutate(n=2:4)
    
    nested %>%  mutate(cv = map2(data,n,.f=createNested))
    
    # A tibble: 3 x 4
        cyl data                  n cv              
      <dbl> <list>            <int> <list>          
    1     6 <tibble [7 × 8]>      2 <tibble [2 × 3]>
    2     4 <tibble [11 × 8]>     3 <tibble [3 × 3]>
    3     8 <tibble [14 × 8]>     4 <tibble [4 × 3]>
    

    Note, once you have nested the data, you don't need group_by()