Search code examples
rmachine-learningtidyverser-caret

Combining Rolling Origin Forecast Resampling and Group V-Fold Cross-Validation in rsample


I would like to use the R package rsample to generate resamples of my data.

The package offers the function rolling_origin to produce resamples that keep the time series structure of the data. This means that training data (in the package called analysis) are always in the past of test data (assessment).

On the other hand I would like to perform block samples of the data. This means that groups of rows are kept together during sampling. This can be done using the function group_vfold_cv. As groups one could think of are months. Say, we want to do time series cross validation always keeping months together.

Is there a way to combine the two approaches in rsample?

I give examples for each procedure on its own:

## generate some data
library(tidyverse)
library(lubridate)
library(rsample)
my_dates = seq(as.Date("2018/1/1"), as.Date("2018/8/20"), "days")
some_data = data_frame(dates = my_dates) 
some_data$values = runif(length(my_dates))
some_data = some_data %>% mutate(month = as.factor(month(dates))) 

This gives data of the following form

 A tibble: 232 x 3
   dates      values month 
   <date>      <dbl> <fctr>
 1 2018-01-01 0.235  1     
 2 2018-01-02 0.363  1     
 3 2018-01-03 0.146  1     
 4 2018-01-04 0.668  1     
 5 2018-01-05 0.0995 1     
 6 2018-01-06 0.163  1     
 7 2018-01-07 0.0265 1     
 8 2018-01-08 0.273  1     
 9 2018-01-09 0.886  1     
10 2018-01-10 0.239  1  

Then we can e.g. produce samples that take 20 weeks of data and test on future 5 weeks (the parameter skip skips some rows extra):

rolling_origin_resamples <- rolling_origin(
  some_data,
  initial    = 7*20,
  assess     = 7*5,
  cumulative = TRUE,
  skip       = 7
)

We can check the data with the following code and see no overlap:

rolling_origin_resamples$splits[[1]] %>% analysis %>% tail
# A tibble: 6 x 3
  dates       values month 
  <date>       <dbl> <fctr>
1 2018-05-15 0.678   5     
2 2018-05-16 0.00112 5     
3 2018-05-17 0.339   5     
4 2018-05-18 0.0864  5     
5 2018-05-19 0.918   5     
6 2018-05-20 0.317   5 

### test data of first split:
rolling_origin_resamples$splits[[1]] %>% assessment
# A tibble: 6 x 3
  dates      values month 
  <date>      <dbl> <fctr>
1 2018-05-21  0.912 5     
2 2018-05-22  0.403 5     
3 2018-05-23  0.366 5     
4 2018-05-24  0.159 5     
5 2018-05-25  0.223 5     
6 2018-05-26  0.375 5   

Alternatively we can split by months:

## sampling by month:
gcv_resamples = group_vfold_cv(some_data, group = "month", v = 5)
gcv_resamples$splits[[1]]  %>% analysis %>% select(month) %>% summary
gcv_resamples$splits[[1]] %>% assessment %>% select(month) %>% summary

Solution

  • As discussed in the comments of the solution from @missuse, the way to achieve this is documented in the github issue: https://github.com/tidymodels/rsample/issues/42

    Essentially, the idea is to first nest over your "blocks" and then rolling_origin() will allow you to roll over them, keeping complete blocks intact.

    library(dplyr)
    library(lubridate)
    library(rsample)
    library(tidyr)
    library(tibble)
    
    # same data generation as before
    my_dates = seq(as.Date("2018/1/1"), as.Date("2018/8/20"), "days")
    some_data = data_frame(dates = my_dates)
    some_data$values = runif(length(my_dates))
    some_data = some_data %>% mutate(month = as.factor(month(dates)))
    
    # nest by month, then resample
    rset <- some_data %>%
      group_by(month) %>%
      nest() %>%
      rolling_origin(initial = 1)
    
    # doesn't show which month is which :(
    rset
    #> # Rolling origin forecast resampling 
    #> # A tibble: 7 x 2
    #>   splits       id    
    #>   <list>       <chr> 
    #> 1 <S3: rsplit> Slice1
    #> 2 <S3: rsplit> Slice2
    #> 3 <S3: rsplit> Slice3
    #> 4 <S3: rsplit> Slice4
    #> 5 <S3: rsplit> Slice5
    #> 6 <S3: rsplit> Slice6
    #> 7 <S3: rsplit> Slice7
    
    
    # only January (31 days)
    analysis(rset$splits[[1]])$data
    #> [[1]]
    #> # A tibble: 31 x 2
    #>    dates      values
    #>    <date>      <dbl>
    #>  1 2018-01-01 0.373 
    #>  2 2018-01-02 0.0389
    #>  3 2018-01-03 0.260 
    #>  4 2018-01-04 0.803 
    #>  5 2018-01-05 0.595 
    #>  6 2018-01-06 0.875 
    #>  7 2018-01-07 0.273 
    #>  8 2018-01-08 0.180 
    #>  9 2018-01-09 0.662 
    #> 10 2018-01-10 0.849 
    #> # ... with 21 more rows
    
    
    # only February (28 days)
    assessment(rset$splits[[1]])$data
    #> [[1]]
    #> # A tibble: 28 x 2
    #>    dates      values
    #>    <date>      <dbl>
    #>  1 2018-02-01 0.402 
    #>  2 2018-02-02 0.556 
    #>  3 2018-02-03 0.764 
    #>  4 2018-02-04 0.134 
    #>  5 2018-02-05 0.0333
    #>  6 2018-02-06 0.907 
    #>  7 2018-02-07 0.814 
    #>  8 2018-02-08 0.0973
    #>  9 2018-02-09 0.353 
    #> 10 2018-02-10 0.407 
    #> # ... with 18 more rows
    

    Created on 2018-08-28 by the reprex package (v0.2.0).