Search code examples
rforecastingfable-rfabletools

Unnest <dist> column from `fabletools::forecast()` output


I am asking this question again because those answers don't work for my environment (see "Session Info" below), and thus I assume, don't work for the current versions of the relevant packages.

Question

How can I take the elements of a <dist> column output by fabletools::forecast() and place each element in a new column of a fable (convertible to a data.frame)? E.g., all the mu elements in mu, sigma elements in sigma, etc.

Old Answers Fail

library(magrittr)
data(aus_production, package = "tsibbledata")
aus_production %>% 
    fabletools::model(ets_log = fable::ETS(log(Beer) ~ error("M") + trend("Ad") + season("A")),
                      ets = fable::ETS(Beer ~ error("M") + trend("Ad") + season("A"))) %>% 
    fabletools::forecast(h = "6 months")
#> # A fable: 4 x 4 [1Q]
#> # Key:     .model [2]
#>   .model  Quarter              Beer .mean
#>   <chr>     <qtr>            <dist> <dbl>
#> 1 ets_log 2010 Q3   t(N(6, 0.0013))  407.
#> 2 ets_log 2010 Q4 t(N(6.2, 0.0014))  483.
#> 3 ets     2010 Q3       N(408, 237)  408.
#> 4 ets     2010 Q4       N(483, 340)  483.

beer_fc <- aus_production %>% 
    fabletools::model(ets_log = fable::ETS(log(Beer) ~ error("M") + trend("Ad") + season("A")),
                      ets = fable::ETS(Beer ~ error("M") + trend("Ad") + season("A"))) %>% 
    fabletools::forecast(h = "6 months")

str(beer_fc$Beer[1])
#> dist [1:1] 
#> $ :List of 3
#>  ..$ dist     :List of 2
#>  .. ..$ mu   : num 6.01
#>  .. ..$ sigma: num 0.036
#>  .. ..- attr(*, "class")= chr [1:2] "dist_normal" "dist_default"
#>  ..$ transform:function (Beer)  
#>  .. ..- attr(*, "class")= chr "transformation"
#>  .. ..- attr(*, "inverse")=function (Beer)  
#>  ..$ inverse  :function (Beer)  
#>  ..- attr(*, "class")= chr [1:2] "dist_transformed" "dist_default"
#> @ vars: chr "Beer"
class(beer_fc$Beer)
#> [1] "distribution" "vctrs_vctr"   "list"


##### Answer 1 fails (https://stackoverflow.com/a/64960991/15723919)
beer_fc$Beer2 <- purrr::map(beer_fc$Beer, ~ .x[[1]]$x)
beer_fc %>% 
    tidyr::unnest(c(Beer2))
#> # A tibble: 0 × 5
#> # … with 5 variables: .model <chr>, Quarter <qtr>, Beer <dist>, .mean <dbl>,
#> #   Beer2 <???>


##### Answer 2 fails (https://stackoverflow.com/a/64963665/15723919)
beer_fc %>% 
    dplyr::mutate(value = purrr::map(Beer, purrr::pluck, 'dist', 'x')) %>% 
    tidyr::unnest(value)
#> # A tibble: 0 × 6
#> # … with 6 variables: .model <chr>, Quarter <qtr>, Beer <dist>, .mean <dbl>,
#> #   Beer2 <list>, value <???>


##### Change 'x' to something else for Answers 1 and 2
beer_fc$Beer2 <- purrr::map(beer_fc$Beer, ~ .x[[1]][["mu"]])
#> Error in `vec_slice()`:
#> ! Can't use character names to index an unnamed vector.

#> Backtrace:
#>     ▆
#>  1. ├─purrr::map(beer_fc$Beer, ~.x[[1]][["mu"]])
#>  2. │ └─global .f(.x[[i]], ...)
#>  3. │   └─.x[[1]][["mu"]]
#>  4. │     ├─base (local) `[[.distribution`(.x[[1]], "mu")
#>  5. │     └─vctrs:::`[.vctrs_vctr`(.x[[1]], "mu")
#>  6. │       └─vctrs:::vec_index(x, i, ...)
#>  7. │         └─vctrs::vec_slice(x, i)
#>  8. └─rlang::abort(message = message)
beer_fc %>% 
    tidyr::unnest(c(Beer2))
#> # A tibble: 0 × 5
#> # … with 5 variables: .model <chr>, Quarter <qtr>, Beer <dist>, .mean <dbl>,
#> #   Beer2 <???>


beer_fc %>% 
    dplyr::mutate(value = purrr::map(Beer, purrr::pluck, 'dist', 'mu')) %>% 
    tidyr::unnest(value)
#> # A tibble: 0 × 6
#> # … with 6 variables: .model <chr>, Quarter <qtr>, Beer <dist>, .mean <dbl>,
#> #   Beer2 <list>, value <???>

Created on 2022-11-14 with reprex v2.0.2

Structure of Ideal Output

My ideal output looks like unnesting the elements of the dist column, however it's critical to extract the mu and sigma elements (or whatever the distribution parameters might be) in a rowwise fashion:

#> # A fable: 12 x 8 [1Q]
#> # Key:     .model [1]
#>    .model   Quarter              Beer .mean    mu sigma transform inverse
#>    <chr>      <qtr>            <dist> <dbl> <dbl> <dbl> <chr>     <chr>
#>  1 ets_log  2010 Q3   t(N(6, 0.0013))  407.  6.01 0.036 log       exp
#>  2 ets_log  2010 Q4 t(N(6.2, 0.0014))  483.   ...   ... ...       ...
#>  3 ets      2010 Q3       N(408, 237)  408.   408   237  NA        NA
#>  4 ets      2010 Q4       N(483, 340)  483.   ...   ... ...       ...
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       Ubuntu 20.04.5 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Etc/UTC
#>  date     2022-11-14
#>  pandoc   2.19.2 @ /usr/lib/rstudio-server/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package        * version    date (UTC) lib source
#>  anytime          0.3.9      2020-08-27 [1] RSPM (R 4.2.0)
#>  assertthat       0.2.1      2019-03-21 [1] RSPM (R 4.2.0)
#>  cli              3.4.0      2022-09-08 [1] RSPM (R 4.2.0)
#>  colorspace       2.0-3      2022-02-21 [1] RSPM (R 4.2.0)
#>  DBI              1.1.3      2022-06-18 [1] RSPM (R 4.2.0)
#>  digest           0.6.29     2021-12-01 [1] RSPM (R 4.2.0)
#>  distributional   0.3.1      2022-09-02 [1] RSPM (R 4.2.0)
#>  dplyr            1.0.10     2022-09-01 [1] RSPM (R 4.2.0)
#>  ellipsis         0.3.2      2021-04-29 [1] RSPM (R 4.2.0)
#>  evaluate         0.16       2022-08-09 [1] RSPM (R 4.2.0)
#>  fable            0.3.2      2022-09-01 [1] RSPM (R 4.2.0)
#>  fabletools       0.3.2      2021-11-29 [1] RSPM (R 4.2.0)
#>  fansi            1.0.3      2022-03-24 [1] RSPM (R 4.2.0)
#>  farver           2.1.1      2022-07-06 [1] RSPM (R 4.2.0)
#>  fastmap          1.1.0      2021-01-25 [1] RSPM (R 4.2.0)
#>  fs               1.5.2      2021-12-08 [1] RSPM (R 4.2.0)
#>  generics         0.1.3      2022-07-05 [1] RSPM (R 4.2.0)
#>  ggplot2          3.3.6      2022-05-03 [1] RSPM (R 4.2.0)
#>  glue             1.6.2      2022-02-24 [1] RSPM (R 4.2.0)
#>  gtable           0.3.1      2022-09-01 [1] RSPM (R 4.2.0)
#>  highr            0.9        2021-04-16 [1] RSPM (R 4.2.0)
#>  htmltools        0.5.3      2022-07-18 [1] RSPM (R 4.2.0)
#>  knitr            1.40       2022-08-24 [1] RSPM (R 4.2.0)
#>  lifecycle        1.0.2      2022-09-09 [1] RSPM (R 4.2.0)
#>  lubridate        1.8.0      2021-10-07 [1] RSPM (R 4.2.0)
#>  magrittr       * 2.0.3      2022-03-30 [1] RSPM (R 4.2.0)
#>  munsell          0.5.0      2018-06-12 [1] RSPM (R 4.2.0)
#>  numDeriv         2016.8-1.1 2019-06-06 [1] RSPM (R 4.2.0)
#>  pillar           1.8.1      2022-08-19 [1] RSPM (R 4.2.0)
#>  pkgconfig        2.0.3      2019-09-22 [1] RSPM (R 4.2.0)
#>  progressr        0.11.0     2022-09-02 [1] RSPM (R 4.2.0)
#>  purrr            0.3.4      2020-04-17 [1] RSPM (R 4.2.0)
#>  R6               2.5.1      2021-08-19 [1] RSPM (R 4.2.0)
#>  Rcpp             1.0.9      2022-07-08 [1] RSPM (R 4.2.0)
#>  reprex           2.0.2      2022-08-17 [1] RSPM (R 4.2.0)
#>  rlang            1.0.5      2022-08-31 [1] RSPM (R 4.2.0)
#>  rmarkdown        2.16       2022-08-24 [1] RSPM (R 4.2.0)
#>  rstudioapi       0.14       2022-08-22 [1] RSPM (R 4.2.0)
#>  scales           1.2.1      2022-08-20 [1] RSPM (R 4.2.0)
#>  sessioninfo      1.2.2      2021-12-06 [1] RSPM (R 4.2.0)
#>  stringi          1.7.8      2022-07-11 [1] RSPM (R 4.2.0)
#>  stringr          1.4.1      2022-08-20 [1] RSPM (R 4.2.0)
#>  tibble           3.1.8      2022-07-22 [1] RSPM (R 4.2.0)
#>  tidyr            1.2.1      2022-09-08 [1] RSPM (R 4.2.0)
#>  tidyselect       1.1.2      2022-02-21 [1] RSPM (R 4.2.0)
#>  tsibble          1.1.3      2022-10-09 [1] RSPM (R 4.2.0)
#>  utf8             1.2.2      2021-07-24 [1] RSPM (R 4.2.0)
#>  vctrs            0.4.1      2022-04-13 [1] RSPM (R 4.2.0)
#>  withr            2.5.0      2022-03-03 [1] RSPM (R 4.2.0)
#>  xfun             0.33       2022-09-12 [1] RSPM (R 4.2.0)
#>  yaml             2.3.5      2022-02-21 [1] RSPM (R 4.2.0)
#> 
#>  [1] /usr/local/lib/R/site-library
#>  [2] /usr/local/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Solution

  • You can obtain the parameters from a distribution using the parameters() function. You can obtain the shape of the distribution using the family() function.

    For example, you might have a Normally distributed forecast:

    library(distributional)
    fc <- dist_normal(10, 3)
    

    Obtain the parameters from the distribution

    parameters(fc)
    #>   mu sigma
    #> 1 10     3
    

    Obtain the shape of the distribution

    family(fc)
    #> [1] "normal"
    

    For your example…

    library(fable)
    #> Loading required package: fabletools
    fc <- tsibbledata::aus_production %>% 
      model(ets = ETS(log(Beer) ~ error("M") + trend("Ad") + season("A"))) %>% 
      forecast(h = "3 years")
    

    You could extract the parameters and family

    library(dplyr)
    #> 
    #> Attaching package: 'dplyr'
    #> The following objects are masked from 'package:stats':
    #> 
    #>     filter, lag
    #> The following objects are masked from 'package:base':
    #> 
    #>     intersect, setdiff, setequal, union
    fc %>% 
      mutate(parameters(Beer), family(Beer))
    #> # A fable: 12 x 8 [1Q]
    #> # Key:     .model [1]
    #>    .model Quarter              Beer .mean           dist trans…¹ inverse famil…²
    #>    <chr>    <qtr>            <dist> <dbl>         <dist> <list>  <list>  <chr>  
    #>  1 ets    2010 Q3   t(N(6, 0.0013))  407.   N(6, 0.0013) <fn>    <fn>    transf…
    #>  2 ets    2010 Q4 t(N(6.2, 0.0014))  483. N(6.2, 0.0014) <fn>    <fn>    transf…
    #>  3 ets    2011 Q1   t(N(6, 0.0014))  419.   N(6, 0.0014) <fn>    <fn>    transf…
    #>  4 ets    2011 Q2   t(N(6, 0.0015))  384.   N(6, 0.0015) <fn>    <fn>    transf…
    #>  5 ets    2011 Q3   t(N(6, 0.0019))  405.   N(6, 0.0019) <fn>    <fn>    transf…
    #>  6 ets    2011 Q4 t(N(6.2, 0.0022))  481. N(6.2, 0.0022) <fn>    <fn>    transf…
    #>  7 ets    2012 Q1   t(N(6, 0.0023))  417.   N(6, 0.0023) <fn>    <fn>    transf…
    #>  8 ets    2012 Q2 t(N(5.9, 0.0025))  383. N(5.9, 0.0025) <fn>    <fn>    transf…
    #>  9 ets    2012 Q3   t(N(6, 0.0032))  403.   N(6, 0.0032) <fn>    <fn>    transf…
    #> 10 ets    2012 Q4 t(N(6.2, 0.0036))  479. N(6.2, 0.0036) <fn>    <fn>    transf…
    #> 11 ets    2013 Q1   t(N(6, 0.0039))  416.   N(6, 0.0039) <fn>    <fn>    transf…
    #> 12 ets    2013 Q2 t(N(5.9, 0.0043))  382. N(5.9, 0.0043) <fn>    <fn>    transf…
    #> # … with abbreviated variable names ¹​transform, ²​`family(Beer)`
    

    Which produces the columns ‘dist’, ‘transform’ and ‘inverse’ - the parameters of the transformed distribution However you’re probably more interested in the parameters of the Normal, which you can obtain from the parameters() of dist.

    fc %>% 
      mutate(parameters(Beer), family(Beer)) %>% 
      mutate(parameters(dist), family(dist))
    #> # A fable: 12 x 11 [1Q]
    #> # Key:     .model [1]
    #>    .model Quarter              Beer .mean           dist trans…¹ inverse famil…²
    #>    <chr>    <qtr>            <dist> <dbl>         <dist> <list>  <list>  <chr>  
    #>  1 ets    2010 Q3   t(N(6, 0.0013))  407.   N(6, 0.0013) <fn>    <fn>    transf…
    #>  2 ets    2010 Q4 t(N(6.2, 0.0014))  483. N(6.2, 0.0014) <fn>    <fn>    transf…
    #>  3 ets    2011 Q1   t(N(6, 0.0014))  419.   N(6, 0.0014) <fn>    <fn>    transf…
    #>  4 ets    2011 Q2   t(N(6, 0.0015))  384.   N(6, 0.0015) <fn>    <fn>    transf…
    #>  5 ets    2011 Q3   t(N(6, 0.0019))  405.   N(6, 0.0019) <fn>    <fn>    transf…
    #>  6 ets    2011 Q4 t(N(6.2, 0.0022))  481. N(6.2, 0.0022) <fn>    <fn>    transf…
    #>  7 ets    2012 Q1   t(N(6, 0.0023))  417.   N(6, 0.0023) <fn>    <fn>    transf…
    #>  8 ets    2012 Q2 t(N(5.9, 0.0025))  383. N(5.9, 0.0025) <fn>    <fn>    transf…
    #>  9 ets    2012 Q3   t(N(6, 0.0032))  403.   N(6, 0.0032) <fn>    <fn>    transf…
    #> 10 ets    2012 Q4 t(N(6.2, 0.0036))  479. N(6.2, 0.0036) <fn>    <fn>    transf…
    #> 11 ets    2013 Q1   t(N(6, 0.0039))  416.   N(6, 0.0039) <fn>    <fn>    transf…
    #> 12 ets    2013 Q2 t(N(5.9, 0.0043))  382. N(5.9, 0.0043) <fn>    <fn>    transf…
    #> # … with 3 more variables: mu <dbl>, sigma <dbl>, `family(dist)` <chr>, and
    #> #   abbreviated variable names ¹​transform, ²​`family(Beer)`
    

    Or directly to the underlying (untransformed) distribution

    fc %>% 
      mutate(parameters(parameters(Beer)$dist))
    #> # A fable: 12 x 6 [1Q]
    #> # Key:     .model [1]
    #>    .model Quarter              Beer .mean    mu  sigma
    #>    <chr>    <qtr>            <dist> <dbl> <dbl>  <dbl>
    #>  1 ets    2010 Q3   t(N(6, 0.0013))  407.  6.01 0.0360
    #>  2 ets    2010 Q4 t(N(6.2, 0.0014))  483.  6.18 0.0377
    #>  3 ets    2011 Q1   t(N(6, 0.0014))  419.  6.04 0.0380
    #>  4 ets    2011 Q2   t(N(6, 0.0015))  384.  5.95 0.0390
    #>  5 ets    2011 Q3   t(N(6, 0.0019))  405.  6.00 0.0437
    #>  6 ets    2011 Q4 t(N(6.2, 0.0022))  481.  6.17 0.0467
    #>  7 ets    2012 Q1   t(N(6, 0.0023))  417.  6.03 0.0482
    #>  8 ets    2012 Q2 t(N(5.9, 0.0025))  383.  5.95 0.0504
    #>  9 ets    2012 Q3   t(N(6, 0.0032))  403.  6.00 0.0562
    #> 10 ets    2012 Q4 t(N(6.2, 0.0036))  479.  6.17 0.0600
    #> 11 ets    2013 Q1   t(N(6, 0.0039))  416.  6.03 0.0624
    #> 12 ets    2013 Q2 t(N(5.9, 0.0043))  382.  5.94 0.0653
    

    Created on 2022-11-15 by the reprex package (v2.0.1)