I want to fit a model using data from each group in a dataframe. Then I want to use this model to predict it to all data in the dataframe that was not in the group and compute a metric like the RMSE. I have some issues to wrap my head around how I could achieve something like this without doing many manual steps. I have the following toy-example using the mtcars
data.
I want to fit the model lm(mpg ~ wt, data=mtcars)
for each group of cyl
and use this model and the data from the remaining points to copmute a value like the RMSE.
I wrote the following code, but it is not working and also does not feel so good. I'd be happy to hear about any tipps and tricks:)
library(tidyverse)
# 'global' model
lm(mpg ~ wt, data=mtcars)
# 1. fit one model for each class of cyl,
# 2. use it to predict the remaining ones,
# 3. get the RMSE for each class of cyl
res = mtcars %>%
group_by(cyl) %>%
mutate(
models = list(lm(mpg ~ wt, data = cur_data())), # why does this needs to be a list?
ref_data = list(mtcars %>% filter(!cyl %in% cur_data()$cyl[[1]])), # get all the data minus the current group at put it in a list column
predict = map(models, ~predict(.x, newdata=mtcars %>% filter(!cyl %in% cur_data()$cyl[[1]]))), # predict it on all others -> will store one df per row...
rmse = map2_dbl(predict, ref_data, ~sqrt((sum(.y - .y)^2))/length(predict))
)
You could consider using tidymodels functions for this, because this situation is exactly what an rsample::rsplit
object is intended for.
For example here, the "analysis" set has cyl == 6
and the "assessment" set has cyl
equal to other values:
library(rsample)
ind <- list(analysis = which(mtcars$cyl == 6),
assessment = which(mtcars$cyl != 6))
make_splits(ind, mtcars)
#> <Analysis/Assess/Total>
#> <7/25/32>
Created on 2021-08-11 by the reprex package (v2.0.1)
To go through your modeling analysis, you would create a function to make an split
according to your parameter (cyl
here), then use purrr::map()
to map over the values of that parameter and:
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
manual_split_from_cyl <- function(cyl_value) {
ind <- list(analysis = which(mtcars$cyl == cyl_value),
assessment = which(mtcars$cyl != cyl_value))
make_splits(ind, mtcars)
}
tibble(cyl = unique(mtcars$cyl)) %>%
mutate(splits = map(cyl, manual_split_from_cyl),
model = map(splits, ~ lm(mpg ~ wt, data = analysis(.))),
preds = map2(model, splits, ~ predict(.x, newdata = assessment(.y))),
rmse = map2_dbl(splits, preds, ~ rmse_vec(assessment(.x)$mpg, .y)))
#> # A tibble: 3 × 5
#> cyl splits model preds rmse
#> <dbl> <list> <list> <list> <dbl>
#> 1 6 <split [7/25]> <lm> <dbl [25]> 4.38
#> 2 4 <split [11/21]> <lm> <dbl [21]> 3.35
#> 3 8 <split [14/18]> <lm> <dbl [18]> 6.94
Created on 2021-08-11 by the reprex package (v2.0.1)