Search code examples
rr-carettidymodels

How can I use the same crossvalidation sets in R caret and rsamples


I am trying to learn the tidymodels ecosystem by converting caret::train() code into tidymodels workflows. I am getting differences that I think are a biproduct of the resampling algorithms in caret vs. rsample. A colleague wrote a gist showing the differences in datasets with the same seed: https://gist.github.com/bradleyboehmke/7794b79a07afb443da11d930ff84bed7

You can see small differences here in simple models (that I think I coded to be the same):

library(caret)
library(tidyverse)
library(tidymodels)
data(ames)

set.seed(123)
(cv_model1 <- train(
  form = Sale_Price ~ Gr_Liv_Area, 
  data = ames,
  method = "lm",
  trControl = trainControl(method="cv", number = 10)
))

vs.

set.seed(123)
folds <- vfold_cv(ames, v = 10)

the_lm_model <- 
  linear_reg() %>% 
  set_engine("lm")

the_rec <- 
  recipe(Sale_Price ~ Gr_Liv_Area, data = ames)

the_workflow <- 
  workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(the_lm_model) 


the_results <- 
  fit_resamples(the_workflow, folds)

collect_metrics(the_results)

Is there a straight forward way to use caret resamples (from caret::createFolds()) in a tidymodel workflow (that would normally be created with rsample::vfold_cv()? I am hoping if I can figure out this detail I can replicate complex old code with the new ecosystem (for teaching).


Solution

  • Edit. Thanks to Julia Silge comment.

    The functions rsample2caret() and caret2rsample()

    can be used to convert resampling objects between formats.

    The answer below can be useful to convert from arbitrary formats to rsample.

    Old Answer

    Here is an approach to convert the output of caret::createFolds to rsample

    library(caret)
    library(tidyverse)
    library(tidymodels)
    
    data(ames)
    
    #create train folds
    set.seed(123)
    folds_train <- caret::createFolds(ames$Sale_Price, returnTrain = TRUE, k = 10)
    
    #get test indexes
    folds_test <- lapply(folds_train, function(x) setdiff(seq_along(ames$Sale_Price), x))
    

    combine the train and test indexes to create a list of analysis and assessment lists as described in manual_rset

    rsplit <- map2(folds_train,
                   folds_test,
                   function(x,y) list(analysis = x, assessment = y))
    
    splits <- lapply(rsplit, make_splits, data = ames)
    splits <- manual_rset(splits, names(splits))
    > splits
    # Manual resampling 
    # A tibble: 10 x 2
       splits             id    
       <named list>       <chr> 
     1 <split [2637/293]> Fold01
     2 <split [2638/292]> Fold02
     3 <split [2637/293]> Fold03
     4 <split [2637/293]> Fold04
     5 <split [2638/292]> Fold05
     6 <split [2637/293]> Fold06
     7 <split [2637/293]> Fold07
     8 <split [2636/294]> Fold08
     9 <split [2636/294]> Fold09
    10 <split [2637/293]> Fold10
    

    check to see if same result:

    set.seed(123)
    cv_model1 <- train(
      form = Sale_Price ~ Gr_Liv_Area, 
      data = ames,
      method = "lm",
      trControl = trainControl(index= folds_train))
    > cv_model1
    Linear Regression 
    
    2930 samples
       1 predictor
    
    No pre-processing
    Resampling: Bootstrapped (10 reps) 
    Summary of sample sizes: 2637, 2638, 2637, 2637, 2638, 2637, ... 
    Resampling results:
    
      RMSE      Rsquared   MAE     
      56364.67  0.5066935  38575.21
    
    Tuning parameter 'intercept' was held constant at a value of TRUE
    
    the_lm_model <- 
      linear_reg() %>% 
      set_engine("lm")
    
    the_rec <- 
      recipe(Sale_Price ~ Gr_Liv_Area, data = ames)
    
    the_workflow <- 
      workflow() %>% 
      add_recipe(the_rec) %>% 
      add_model(the_lm_model) 
    
    set.seed(123)
    the_results <- 
      fit_resamples(the_workflow, splits)
    
    collect_metrics(the_results)
    > collect_metrics(the_results)
    # A tibble: 2 x 6
      .metric .estimator      mean     n   std_err .config             
      <chr>   <chr>          <dbl> <int>     <dbl> <chr>               
    1 rmse    standard   56365.       10 1782.     Preprocessor1_Model1
    2 rsq     standard       0.507    10    0.0220 Preprocessor1_Model1
    
    all.equal(
    cv_model1$results$RMSE,
    collect_metrics(the_results)$mean[1])
    TRUE
    

    perhaps there is a more straightforward way but I don't use tidymodels to know for sure.

    If you did not create folds prior to calling caret::train:

    set.seed(123)
    cv_model1 <- train(
      form = Sale_Price ~ Gr_Liv_Area, 
      data = ames,
      method = "lm",
      trControl = trainControl(number = 10, method = "cv"))
    

    you can use

    cv_model1$control$index
    
    cv_model1$control$indexOut
    

    to create a rsample object

    rsplit <- map2(cv_model1$control$index,
                   cv_model1$control$indexOut,
                   function(x,y) list(analysis = x, assessment = y))
    

    and proceed as described above.

    splits <- lapply(rsplit, make_splits, data = ames)
    splits <- manual_rset(splits, names(splits))