Search code examples
rtidymodels

Compare performance of linear regression models that differ by predictors used using cross validation


I would like to compare, using tidymodels and cross-validation, 3 linear regression models that can be specified as the following:

  • (model_A) y ~ a
  • (model_B) y ~ b
  • (model_AB) y ~ a + b

In the following y will denote the target variable, while a and b will denote independent variables.

Without using cross validation it is (I hope) quite clear to me what I have to do:

  1. Split my data into train and test set
set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  1. I can specify, fit, and evaluate my model in one go (for example for model_AB)
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

The output looks something like this:

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard       x.xxx

I can repeat step 2 for the other two models and compare the three models based on the RMSE metric (since this is the choice for this example).

For example I can create a dummy dataset and run the steps described above.

library(tidyverse)
library(tidymodels)

set.seed(1234)
n <- 1e4
data <- tibble(a = rnorm(n),
               b = rnorm(n),
               y = 1 + 3*a - 2*b + rnorm(n))

set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  • Model_A
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        2.23
  • Model_B
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        3.17
  • Model_AB
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        1.00

My question is: how can I evaluate the RMSE after performing cross validation on three models that differ by the list of possible features?

In this video Julia Silge does the job with three different models (logistic regression, knn, and decision trees) using the same set of predictors. However what I aim to do is to compare models that differ in the set of predictors.

Any suggestion and/or reference?


Solution

  • When you have a lot of different models you want to compare, one way to deal with that is to use the workflowsets package.

    This way you can specify any number of models and preprocessors and it will run all of them and give you back the results in a tidy format.

    Notice how we are using recipe() just denotes what variables are used in each model.

    Additionally you can pass a metric_set() to the metrics in workflow_map() if you want to use different metrics than the defaults.

    library(tidymodels)
    set.seed(1234)
    n <- 1e4
    data <- tibble(a = rnorm(n),
                   b = rnorm(n),
                   y = 1 + 3*a - 2*b + rnorm(n))
    
    set.seed(1234)
    split <- data %>% initial_split(strata = y)
    data_train <- training(split)
    data_test <- training(split)
    
    lm_spec <- linear_reg()
    
    rec_a <- recipe(y ~ a, data = data_train)
    rec_b <- recipe(y ~ b, data = data_train)
    rec_ab <- recipe(y ~ a + b, data = data_train)
    
    all_models_wfs <- workflow_set(
      preproc = list(a = rec_a, b = rec_b, c = rec_ab),
      models = list(lm = lm_spec),
      cross = TRUE
    )
    
    all_models_wfs
    #> # A workflow set/tibble: 3 × 4
    #>   wflow_id info             option    result    
    #>   <chr>    <list>           <list>    <list>    
    #> 1 a_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
    #> 2 b_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
    #> 3 c_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
    
    all_models_fit <- workflow_map(
      all_models_wfs, 
      resamples = vfold_cv(data_test),
      metrics = metric_set(rmse, rsq, mape)
    )
    
    all_models_fit %>%
      collect_metrics()
    #> # A tibble: 9 × 9
    #>   wflow_id .config           preproc model .metric .esti…¹    mean     n std_err
    #>   <chr>    <chr>             <chr>   <chr> <chr>   <chr>     <dbl> <int>   <dbl>
    #> 1 a_lm     Preprocessor1_Mo… recipe  line… mape    standa… 261.       10 3.99e+1
    #> 2 a_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   2.26     10 2.89e-2
    #> 3 a_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.627    10 7.72e-3
    #> 4 b_lm     Preprocessor1_Mo… recipe  line… mape    standa… 258.       10 2.07e+1
    #> 5 b_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   3.10     10 2.13e-2
    #> 6 b_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.298    10 7.61e-3
    #> 7 c_lm     Preprocessor1_Mo… recipe  line… mape    standa… 144.       10 3.66e+1
    #> 8 c_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   1.01     10 6.51e-3
    #> 9 c_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.926    10 2.06e-3
    #> # … with abbreviated variable name ¹​.estimator
    

    Created on 2022-09-19 by the reprex package (v2.0.1)