Search code examples
rdataframetidyversetidymodels

Fit a model on each group and evaluate it using data from all rows not in this group, R


I want to fit a model using data from each group in a dataframe. Then I want to use this model to predict it to all data in the dataframe that was not in the group and compute a metric like the RMSE. I have some issues to wrap my head around how I could achieve something like this without doing many manual steps. I have the following toy-example using the mtcarsdata.

I want to fit the model lm(mpg ~ wt, data=mtcars) for each group of cyl and use this model and the data from the remaining points to copmute a value like the RMSE.

I wrote the following code, but it is not working and also does not feel so good. I'd be happy to hear about any tipps and tricks:)


library(tidyverse)

# 'global' model
lm(mpg ~ wt, data=mtcars)

# 1. fit one model for each  class of cyl, 
# 2. use it to predict the remaining ones,
# 3. get the RMSE for each class of cyl

res = mtcars %>% 
  group_by(cyl) %>% 
  mutate(
    models = list(lm(mpg ~ wt, data = cur_data())), # why does this needs to be a list?
    ref_data = list(mtcars %>% filter(!cyl %in% cur_data()$cyl[[1]])), # get all the data minus the current group at put it in a list column
    predict = map(models, ~predict(.x, newdata=mtcars %>% filter(!cyl %in% cur_data()$cyl[[1]]))), # predict it on all others -> will store one df per row...
    rmse = map2_dbl(predict, ref_data, ~sqrt((sum(.y - .y)^2))/length(predict)) 
  )


Solution

  • You could consider using functions for this, because this situation is exactly what an rsample::rsplit object is intended for.

    For example here, the "analysis" set has cyl == 6 and the "assessment" set has cyl equal to other values:

    library(rsample)
    
    ind <- list(analysis = which(mtcars$cyl == 6), 
                assessment = which(mtcars$cyl != 6))
    make_splits(ind, mtcars)
    #> <Analysis/Assess/Total>
    #> <7/25/32>
    

    Created on 2021-08-11 by the reprex package (v2.0.1)

    To go through your modeling analysis, you would create a function to make an split according to your parameter (cyl here), then use purrr::map() to map over the values of that parameter and:

    • make the splits
    • fit the models to each split
    • predict with each model on each split (notice you predict on the assessment set)
    • compute RMSE
    library(tidyverse)
    library(tidymodels)
    #> Registered S3 method overwritten by 'tune':
    #>   method                   from   
    #>   required_pkgs.model_spec parsnip
    
    manual_split_from_cyl <- function(cyl_value) {
        ind <- list(analysis = which(mtcars$cyl == cyl_value), 
                    assessment = which(mtcars$cyl != cyl_value))
        make_splits(ind, mtcars)
    }
    
    tibble(cyl = unique(mtcars$cyl)) %>%
        mutate(splits = map(cyl, manual_split_from_cyl),
               model = map(splits, ~ lm(mpg ~ wt, data = analysis(.))),
               preds = map2(model, splits, ~ predict(.x, newdata = assessment(.y))),
               rmse = map2_dbl(splits, preds, ~ rmse_vec(assessment(.x)$mpg, .y)))
    #> # A tibble: 3 × 5
    #>     cyl splits          model  preds       rmse
    #>   <dbl> <list>          <list> <list>     <dbl>
    #> 1     6 <split [7/25]>  <lm>   <dbl [25]>  4.38
    #> 2     4 <split [11/21]> <lm>   <dbl [21]>  3.35
    #> 3     8 <split [14/18]> <lm>   <dbl [18]>  6.94
    

    Created on 2021-08-11 by the reprex package (v2.0.1)