I would like to compare, using tidymodels and cross-validation, 3 linear regression models that can be specified as the following:
y ~ a
y ~ b
y ~ a + b
In the following y
will denote the target variable, while a
and b
will denote independent variables.
Without using cross validation it is (I hope) quite clear to me what I have to do:
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
linear_reg() %>%
set_engine("lm") %>%
fit(y ~ a + b, data = data_train) %>%
augment(new_data = data_test) %>%
rmse(truth = y, estimate = .pred)
The output looks something like this:
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard x.xxx
I can repeat step 2 for the other two models and compare the three models based on the RMSE metric (since this is the choice for this example).
For example I can create a dummy dataset and run the steps described above.
n <- 1e4
data <- tibble(a = rnorm(n),
b = rnorm(n),
y = 1 + 3*a - 2*b + rnorm(n))
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
linear_reg() %>%
set_engine("lm") %>%
fit(y ~ a, data = data_train) %>%
augment(new_data = data_test) %>%
rmse(truth = y, estimate = .pred)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard 2.23
linear_reg() %>%
set_engine("lm") %>%
fit(y ~ b, data = data_train) %>%
augment(new_data = data_test) %>%
rmse(truth = y, estimate = .pred)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard 3.17
linear_reg() %>%
set_engine("lm") %>%
fit(y ~ a + b, data = data_train) %>%
augment(new_data = data_test) %>%
rmse(truth = y, estimate = .pred)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard 1.00
My question is: how can I evaluate the RMSE after performing cross validation on three models that differ by the list of possible features?
In this video Julia Silge does the job with three different models (logistic regression, knn, and decision trees) using the same set of predictors. However what I aim to do is to compare models that differ in the set of predictors.
Any suggestion and/or reference?
When you have a lot of different models you want to compare, one way to deal with that is to use the workflowsets package.
This way you can specify any number of models and preprocessors and it will run all of them and give you back the results in a tidy format.
Notice how we are using recipe()
just denotes what variables are used in each model.
Additionally you can pass a metric_set()
to the metrics
in workflow_map()
if you want to use different metrics than the defaults.
n <- 1e4
data <- tibble(a = rnorm(n),
b = rnorm(n),
y = 1 + 3*a - 2*b + rnorm(n))
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
lm_spec <- linear_reg()
rec_a <- recipe(y ~ a, data = data_train)
rec_b <- recipe(y ~ b, data = data_train)
rec_ab <- recipe(y ~ a + b, data = data_train)
all_models_wfs <- workflow_set(
preproc = list(a = rec_a, b = rec_b, c = rec_ab),
models = list(lm = lm_spec),
cross = TRUE
#> # A workflow set/tibble: 3 × 4
#> wflow_id info option result
#> <chr> <list> <list> <list>
#> 1 a_lm <tibble [1 × 4]> <opts[0]> <list [0]>
#> 2 b_lm <tibble [1 × 4]> <opts[0]> <list [0]>
#> 3 c_lm <tibble [1 × 4]> <opts[0]> <list [0]>
all_models_fit <- workflow_map(
resamples = vfold_cv(data_test),
metrics = metric_set(rmse, rsq, mape)
all_models_fit %>%
#> # A tibble: 9 × 9
#> wflow_id .config preproc model .metric .esti…¹ mean n std_err
#> <chr> <chr> <chr> <chr> <chr> <chr> <dbl> <int> <dbl>
#> 1 a_lm Preprocessor1_Mo… recipe line… mape standa… 261. 10 3.99e+1
#> 2 a_lm Preprocessor1_Mo… recipe line… rmse standa… 2.26 10 2.89e-2
#> 3 a_lm Preprocessor1_Mo… recipe line… rsq standa… 0.627 10 7.72e-3
#> 4 b_lm Preprocessor1_Mo… recipe line… mape standa… 258. 10 2.07e+1
#> 5 b_lm Preprocessor1_Mo… recipe line… rmse standa… 3.10 10 2.13e-2
#> 6 b_lm Preprocessor1_Mo… recipe line… rsq standa… 0.298 10 7.61e-3
#> 7 c_lm Preprocessor1_Mo… recipe line… mape standa… 144. 10 3.66e+1
#> 8 c_lm Preprocessor1_Mo… recipe line… rmse standa… 1.01 10 6.51e-3
#> 9 c_lm Preprocessor1_Mo… recipe line… rsq standa… 0.926 10 2.06e-3
#> # … with abbreviated variable name ¹.estimator
Created on 2022-09-19 by the reprex package (v2.0.1)