Search code examples
rdataframedplyrgroup-by

Add predicted values for multiple different models to a novel dataset by groups


Given a data.frame of grouped data:

library(tidyverse)

# fake up some grouped data: 
set.seed(123)
dat <- data.frame(x = rnorm(100), 
                  y = rnorm(100), 
                  group = rep(x=letters[1:10],each=10))
head(dat)
> head(dat)
            x           y group
1 -0.56047565 -0.71040656     a
2 -0.23017749  0.25688371     a
3  1.55870831 -0.24669188     a
4  0.07050839 -0.34754260     a
5  0.12928774 -0.95161857     a
6  1.71506499 -0.04502772     a

I want to build a set of independent models by one (or more) grouping columns:

# store models by group in a list
models <- list()
for(i in letters[1:10]) {
  models[[paste0("mdl_",i)]] = lm(y ~ x, dat %>% filter(group == i))
}

names(models)
 [1] "mdl_a" "mdl_b" "mdl_c" "mdl_d" "mdl_e" "mdl_f" "mdl_g" "mdl_h" "mdl_i" "mdl_j"

I can add the model predictions (fitted values) to the original data frame a number of ways, this way is convenient:

# add model predictions (fitted values) column to original data frame
dat <- dat %>%
  group_by(group) %>%
  mutate(fits = lm(y ~ x)$fitted.values)

# verify prediction from stored models and fitted values column match 
# to within a 10-decimal tolerance: 
for(i in letters[1:10]) {
  tmp <- dat %>%
    filter(group == i) %>%
    select(group, x, y, fits)
  tmp$stored_fit = predict(models[[paste0("mdl_",i)]], tmp)
  print(paste("mdl", i, "results match:", all(round(tmp$stored_fit,10) == round(tmp$fits,10))))
}
[1] "mdl a results match: TRUE"
[1] "mdl b results match: TRUE"
[1] "mdl c results match: TRUE"
[1] "mdl d results match: TRUE"
[1] "mdl e results match: TRUE"
[1] "mdl f results match: TRUE"
[1] "mdl g results match: TRUE"
[1] "mdl h results match: TRUE"
[1] "mdl i results match: TRUE"
[1] "mdl j results match: TRUE"

All of these steps have been discused in other questions like this one.

Now I want to generate the predictions from these models on a new data.frame and add those predictions as a column to that data.frame.

Here's a couple things I tried:

# fake up some new grouped data: 
set.seed(456)
dat2 <- data.frame(x = rnorm(100), 
                   y = rnorm(100), 
                   group = rep(x=letters[1:10],each=10))

Method 1 (apply):

tmp <- dat2 %>%
  group_by(group) %>%
  nest() # %>%
  # mutate(fits = map())

fits = as.data.frame(apply(X = tmp, MARGIN=1, FUN = function(X) predict(models[[paste0("mdl_",X$group)]], X$data)))
names(fits) = tmp$group
fits <- fits %>% 
  pivot_longer(cols = everything(), names_to = "group.fits") %>% 
  arrange(group.fits)

tmp <- tmp %>%
  unnest(cols = c(data)) %>%
  bind_cols(fits)

... which just feels error-prone and inelegant.

Method 2 (for loop, base r):

tmp$fits = NA
for(g in unique(tmp$group)) {
  tmp[tmp$group==g,]$fits = predict(models[[paste0("mdl_",g)]], tmp[tmp$group==g,])
}
tmp

Nothing particularly wrong with this other than for loops being notoriously slow on larger datasets.

Method 3 (nest/map):

I thought something like the following would work but I have something wrong in the syntax...

dat2 %>%
  group_by(group) %>%
  nest() %>%
  mutate(fits = map(.f = predict(models[[paste0("mdl_",group)]]), data))

or

  mutate(fits = map(.x = data, 
                    .f = predict(models[[paste0("mdl_",group)]],
                                 .x)))

I'm looking for an answer somewhere along Method 3's route - ideally all within one set of dplyr commands.


Solution

  • Option 1: purrr::map2

    To walk along your method 3, you should use map2() to predict across each model and data.

    dat2 %>%
      nest(.by = group) %>% # .by: {tidyr} >= v1.3.0
      mutate(fits = map2(group, data, ~ predict(models[[paste0("mdl_", .x)]], .y))) %>%
      unnest(c(data, fits))
    

    Option 2: rowwise

    You can also substitute rowwise() for map2() and surround the predicted values with list().

    dat2 %>%
      nest(.by = group) %>% # .by: {tidyr} >= v1.3.0
      rowwise() %>%
      mutate(fits = list(predict(models[[paste0("mdl_", group)]], data))) %>%
      unnest(c(data, fits))
    

    Option 3: group_modify

    You even do not need nest/unnest from {tidyr}. Just take advantage of dplyr::group_modify():

    dat2 %>%
      group_by(group) %>%
      group_modify(~ {
        .x %>% mutate(fits = predict(models[[paste0("mdl_", .y$group)]], .x))
      }) %>%
      ungroup()
    

    All approaches return the same output:

    # # A tibble: 100 × 4
    #    group      x       y   fits
    #    <chr>  <dbl>   <dbl>  <dbl>
    #  1 a     -1.34   0.118  -0.677
    #  2 a      0.622  0.870  -0.287
    #  3 a      0.801 -0.0919 -0.252
    #  4 a     -1.39   0.0689 -0.686
    #  5 a     -0.714 -1.68   -0.552
    #  6 a     -0.324  1.12   -0.475
    #  7 a      0.691 -1.35   -0.274
    #  8 a      0.251 -0.537  -0.361
    #  9 a      1.01  -0.370  -0.211
    # 10 a      0.573  0.354  -0.297
    # # ℹ 90 more rows
    # # ℹ Use `print(n = ...)` to see more rows
    

    Benchmark

    bench::mark(
      `purrr::map2` = {
        dat2 %>%
          nest(.by = group) %>%
          mutate(fits = map2(group, data, ~ predict(models[[paste0("mdl_", .x)]], .y))) %>%
          unnest(c(data, fits))
      }, `dplyr::rowwise` = {
        dat2 %>%
          nest(.by = group) %>%
          rowwise() %>%
          mutate(fits = list(predict(models[[paste0("mdl_", group)]], data))) %>%
          unnest(c(data, fits))
      }, `dplyr::group_modify` = {
        dat2 %>%
          group_by(group) %>%
          group_modify(~ {
            .x %>% mutate(fits = predict(models[[paste0("mdl_", .y$group)]], .x))
          }) %>%
          ungroup()
      },
      iterations = 100, min_time = Inf
    )
    
    # # A tibble: 3 × 13
    #   expression            min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time
    #   <bch:expr>         <bch:> <bch:>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm>
    # 1 purrr::map2        16.7ms 17.3ms      57.0    36.4KB     3.00    95     5      1.67s
    # 2 dplyr::rowwise     19.9ms 20.2ms      48.7    40.7KB     3.11    94     6      1.93s
    # 3 dplyr::group_modi… 33.1ms 34.3ms      28.8     186KB     3.20    90    10      3.13s