Search code examples
rpurrrfable-rtsibble

problem using a dynamic trend or seasonal parameter with ETS with fable and purrr


I have a tsibble as shown below.

test.data <- structure(list(RSLITM = c("004", "004", "004", "004", "004", 
"004", "004", "004", "004", "004", "004", "004", "004", "004", 
"004", "004", "004", "004", "004", "004", "004", "004", "004", 
"004", "004", "004", "004", "004", "004", "004", "004", "004", 
"004", "004", "004", "004", "004", "005", "005", "005", "005", 
"005", "005", "005", "005", "005", "005", "005", "005", "005", 
"005", "005", "005", "005", "005", "005", "005", "005", "005", 
"005", "005", "005", "005", "005", "005", "005", "005", "005", 
"005", "005", "005", "005", "005", "005"), RSFMTH = structure(c(17713, 
17744, 17775, 17805, 17836, 17866, 17897, 17928, 17956, 17987, 
18017, 18048, 18078, 18109, 18140, 18170, 18201, 18231, 18262, 
18293, 18322, 18353, 18383, 18414, 18444, 18475, 18506, 18536, 
18567, 18597, 18628, 18659, 18687, 18718, 18748, 18779, 18809, 
17713, 17744, 17775, 17805, 17836, 17866, 17897, 17928, 17956, 
17987, 18017, 18048, 18078, 18109, 18140, 18170, 18201, 18231, 
18262, 18293, 18322, 18353, 18383, 18414, 18444, 18475, 18506, 
18536, 18567, 18597, 18628, 18659, 18687, 18718, 18748, 18779, 
18809), class = c("yearmonth", "vctrs_vctr")), RSFQTY = c(285600, 
352200, 273600, 282700, 175800, 138700, 177700, 245900, 165000, 
180100, 298000, 173800, 257300, 282800, 164500, 155100, 232300, 
175500, 226000, 287100, 221400, 270800, 286200, 394400, 336600, 
331000, 224600, 216800, 351600, 374700, 173500, 423700, 357700, 
245200, 454700, 361700, 381200, 79000, 58100, 66300, 52700, 68600, 
33000, 76600, 85600, 84100, 49000, 98000, 113500, 83800, 64000, 
116800, 72000, 65200, 49800, 33300, 79800, 48000, 81600, 125000, 
53500, 97600, 80000, 81900, 80000, 53800, 39000, 73800, 76600, 
33700, 60200, 84000, 66600, 32400), RSSEAS = c("A", "A", "A", 
"A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", 
"A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", 
"A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", 
"A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", 
"A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", "A", 
"A", "A", "A", "A", "A", "A"), RSTREND = c("N", "N", "N", "N", 
"N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", 
"N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", 
"N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", 
"N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", 
"N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", "N", 
"N", "N", "N", "N", "N"), RSMODE = c("EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP", 
"EXP", "EXP", "EXP", "EXP", "EXP", "EXP", "EXP")), row.names = c(NA, 
-74L), key = structure(list(RSLITM = c("004", "005"), RSSEAS = c("A", 
"A"), RSTREND = c("N", "N"), RSMODE = c("EXP", "EXP"), .rows = structure(list(
    1:37, 38:74), ptype = integer(0), class = c("vctrs_list_of", 
"vctrs_vctr", "list"))), row.names = c(NA, -2L), class = c("tbl_df", 
"tbl", "data.frame"), .drop = TRUE), index = structure("RSFMTH", ordered = TRUE), index2 = "RSFMTH", interval = structure(list(
    year = 0, quarter = 0, month = 1, week = 0, day = 0, hour = 0, 
    minute = 0, second = 0, millisecond = 0, microsecond = 0, 
    nanosecond = 0, unit = 0), .regular = TRUE, class = c("interval", 
"vctrs_rcrd", "vctrs_vctr")), class = c("tbl_ts", "tbl_df", "tbl", 
"data.frame"))

I would like to apply a modified ETS function using the saved parameters from the tsibble. For instance, whatever is in the RSSEAS and RSTREND columns will be used to estimate the ETS model.

The following works:

test.data %>% model(EXP = ETS(RSFQTY ~ trend("N") + season("A")))

However, when I try to use a function below to extract the parameters for each SKU (since presumably they could be different for each SKU), I get an error message.

ets.function <- function(tsib){
  season.param <- as.character(tsib[1, "RSSEAS"])
  trend.param <- as.character(tsib[1, "RSTREND"])
  tsib %>% model(EXP = ETS(RSFQTY ~ trend(trend.param) + season(season.param))) %>% forecast(h = "3 years")
}

If I call ets.function(test.data) R returns a fable but it is NULL/NA since the model is not being estimated with the specified parameters.

Calling map_dfr(test.data, ets.function) gives me the following error:

Error in tsib[1, "RSSEAS"] : incorrect number of dimensions

This doesn't make sense to me since if I run the code for season.param or trend.param in my console, I get "A" or "N" as appropriate, which is exactly the value the trend and season specials take inside the ETS function.

Basically I am trying to figure out a way to map ETS over my tsibble using prespecified parameters for each unique key combination. I am open to other ideas about how to achieve this (pmap_dfr for vectors of parameters, etc).


Solution

  • We could create the formula with glue or paste

    library(fable)
    ets.function <- function(tsib){
      season.param <- tsib[["RSSEAS"]][1]
      trend.param <- tsib[["RSTREND"]][1]
      fmla <- as.formula(glue::glue("RSFQTY ~ trend('{trend.param}') +",
                " season('{season.param}')"))
      print(fmla)
      tsib %>% 
         model(EXP = ETS(fmla)) %>% 
      forecast(h = "3 years")
    }
    

    -testing

    > ets.function(test.data)
    RSFQTY ~ trend("N") + season("A")
    <environment: 0x7fface3d9778>
    # A fable: 72 x 8 [1M]
    # Key:     RSLITM, RSSEAS, RSTREND, RSMODE, .model [2]
       RSLITM RSSEAS RSTREND RSMODE .model   RSFMTH             RSFQTY   .mean
       <chr>  <chr>  <chr>   <chr>  <chr>     <mth>             <dist>   <dbl>
     1 004    A      N       EXP    EXP    2021 Aug  N(4e+05, 1.4e+10) 395706.
     2 004    A      N       EXP    EXP    2021 Sep N(279181, 8.6e+09) 279181.
     3 004    A      N       EXP    EXP    2021 Oct N(266837, 8.8e+09) 266837.
     4 004    A      N       EXP    EXP    2021 Nov N(349230, 1.4e+10) 349230.
     5 004    A      N       EXP    EXP    2021 Dec N(327811, 1.4e+10) 327811.
     6 004    A      N       EXP    EXP    2022 Jan N(265657, 1.2e+10) 265657.
     7 004    A      N       EXP    EXP    2022 Feb N(375557, 1.9e+10) 375557.
     8 004    A      N       EXP    EXP    2022 Mar  N(3e+05, 1.6e+10) 300908.
     9 004    A      N       EXP    EXP    2022 Apr N(318455, 1.8e+10) 318455.
    10 004    A      N       EXP    EXP    2022 May  N(4e+05, 2.4e+10) 400250.
    # … with 62 more rows
    
    

    or may use sprintf as well

    ets.function <- function(tsib){
      season.param <- tsib[["RSSEAS"]][1]
      trend.param <- tsib[["RSTREND"]][1]
      fmla <- as.formula(sprintf("RSFQTY ~ trend('%s') + season('%s')",  
               trend.param, season.param))
      
      print(fmla)
      tsib %>% 
         model(EXP = ETS(fmla)) %>% 
      forecast(h = "3 years")
    }
    ets.function(test.data)
    

    If we need to apply based on group, we could split with group_split and apply the function in a loop using map

    library(purrr)
    test.data %>%
        group_split(RSLITM) %>% 
        map(~ ets.function(.x))
    

    -output

    RSFQTY ~ trend("N") + season("A")
    <environment: 0x7ffac94171b8>
    RSFQTY ~ trend("N") + season("A")
    <environment: 0x7ffacf3f1950>
    [[1]]
    # A fable: 36 x 8 [1M]
    # Key:     RSLITM, RSSEAS, RSTREND, RSMODE, .model [1]
       RSLITM RSSEAS RSTREND RSMODE .model   RSFMTH             RSFQTY   .mean
       <chr>  <chr>  <chr>   <chr>  <chr>     <mth>             <dist>   <dbl>
     1 004    A      N       EXP    EXP    2021 Aug  N(4e+05, 1.4e+10) 395706.
     2 004    A      N       EXP    EXP    2021 Sep N(279181, 8.6e+09) 279181.
     3 004    A      N       EXP    EXP    2021 Oct N(266837, 8.8e+09) 266837.
     4 004    A      N       EXP    EXP    2021 Nov N(349230, 1.4e+10) 349230.
     5 004    A      N       EXP    EXP    2021 Dec N(327811, 1.4e+10) 327811.
     6 004    A      N       EXP    EXP    2022 Jan N(265657, 1.2e+10) 265657.
     7 004    A      N       EXP    EXP    2022 Feb N(375557, 1.9e+10) 375557.
     8 004    A      N       EXP    EXP    2022 Mar  N(3e+05, 1.6e+10) 300908.
     9 004    A      N       EXP    EXP    2022 Apr N(318455, 1.8e+10) 318455.
    10 004    A      N       EXP    EXP    2022 May  N(4e+05, 2.4e+10) 400250.
    # … with 26 more rows
    
    [[2]]
    # A fable: 36 x 8 [1M]
    # Key:     RSLITM, RSSEAS, RSTREND, RSMODE, .model [1]
       RSLITM RSSEAS RSTREND RSMODE .model   RSFMTH            RSFQTY   .mean
       <chr>  <chr>  <chr>   <chr>  <chr>     <mth>            <dist>   <dbl>
     1 005    A      N       EXP    EXP    2021 Aug N(67121, 4.1e+08)  67121.
     2 005    A      N       EXP    EXP    2021 Sep N(95706, 8.3e+08)  95706.
     3 005    A      N       EXP    EXP    2021 Oct N(73173, 4.9e+08)  73173.
     4 005    A      N       EXP    EXP    2021 Nov N(57981, 3.1e+08)  57981.
     5 005    A      N       EXP    EXP    2021 Dec N(42901, 1.7e+08)  42901.
     6 005    A      N       EXP    EXP    2022 Jan N(62766, 3.6e+08)  62766.
     7 005    A      N       EXP    EXP    2022 Feb N(79394, 5.7e+08)  79394.
     8 005    A      N       EXP    EXP    2022 Mar N(61960, 3.5e+08)  61960.
     9 005    A      N       EXP    EXP    2022 Apr N(60882, 3.4e+08)  60882.
    10 005    A      N       EXP    EXP    2022 May  N(106257, 1e+09) 106257.
    # … with 26 more rows