Search code examples
rglmnettidymodels

LASSO regression - Force variables in glmnet with tidymodels


I am doing feature selection using LASSO regression with tidymodels and glmnet.

It is possible to force variables in glmnet by using the penalty.factors argument (see here and here, for example).

Is it possible to do the same using tidymodels ?

library(tidymodels)
library(vip)
library(forcats)
library(dplyr)
library(ggplot2)
library(data.table)

# Define data split
datasplit = rsample::initial_split(mtcars, prop=0.8)
data_training = rsample::training(datasplit)
data_testing = rsample::testing(datasplit)

# Model specifications - should penalty.factors go here?
model_spec = parsnip::linear_reg(penalty = tune::tune(),
                                 mixture = 1) %>%
  parsnip::set_engine("glmnet")
# Model recipe
rec = recipe(mpg ~ ., mtcars)
# Model workflow
wf = workflows::workflow() %>%
  workflows::add_recipe(rec) %>%
  workflows::add_model(model_spec)
# Resampling
data_resample = rsample::vfold_cv(data_training,
                                  repeats = 3,
                                  v = 2)
hyperparam_grid = dials::grid_regular(dials::penalty(),
                                      levels = 100)
# Define metrics
metrics = yardstick::metric_set(yardstick::rsq,
                                yardstick::mape,
                                yardstick::mpe)
# Tune the model
tune_grid_results = tune::tune_grid(
  wf,
  resamples = data_resample,
  grid = hyperparam_grid,
  metrics = metrics
)
# Collect and finalise best model
selected_model = tune_grid_results %>%
  tune::select_best("mape")

final_model = tune::finalize_workflow(wf, selected_model)

final_model_fit = final_model %>%
  parsnip::fit(data_training) %>%
  workflows::extract_fit_parsnip()
# Plot variables importance
t_importance = final_model_fit %>%
  vip::vi(lambda = selected_model$penalty) %>%
  dplyr::mutate(
    Importance = Importance,
    Variable = forcats::fct_reorder(Variable, Importance)
  ) %>%
  data.table() %>%
  setorder( - Importance)

t_importance %>%
  ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
  geom_col() +
  scale_x_continuous(expand = c(0, 0)) +
  labs(y = NULL) +
  theme_minimal()

Created on 2022-03-14 by the reprex package (v2.0.1)


Solution

  • As mentioned in the comment above, you can pass engine-specific arguments like penalty.factor in set_engine():

    library(tidyverse)
    library(tidymodels)
    library(vip)
    #> 
    #> Attaching package: 'vip'
    #> The following object is masked from 'package:utils':
    #> 
    #>     vi
    
    datasplit <- initial_split(mtcars, prop = 0.8)
    car_train <- training(datasplit)
    car_test <- testing(datasplit)
    car_folds <- vfold_cv(car_train, repeats = 3, v = 2)
    

    You can pass penalty.factor here to the model specification as an engine-specific argument:

    glmnet_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
       set_engine("glmnet", penalty.factor = c(0, rep(1, 7), 0, 0))
    
    car_wf <- workflow(mpg ~ ., glmnet_spec)
    glmnet_res <- tune_grid(car_wf, resamples = car_folds, grid = 5)
    glmnet_res
    #> # Tuning results
    #> # 2-fold cross-validation repeated 3 times 
    #> # A tibble: 6 × 5
    #>   splits          id      id2   .metrics          .notes          
    #>   <list>          <chr>   <chr> <list>            <list>          
    #> 1 <split [12/13]> Repeat1 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
    #> 2 <split [13/12]> Repeat1 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
    #> 3 <split [12/13]> Repeat2 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
    #> 4 <split [13/12]> Repeat2 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
    #> 5 <split [12/13]> Repeat3 Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
    #> 6 <split [13/12]> Repeat3 Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
    
    best_penalty <- select_best(glmnet_res, "rmse")
    
    final_fit <- car_wf %>%
       finalize_workflow(best_penalty) %>%
       fit(data = car_train) %>%
       extract_fit_parsnip()
    
    
    final_fit %>%
       vi(lambda = best_penalty$penalty) %>%
       mutate(Variable = fct_reorder(Variable, Importance)) %>%
       ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
       geom_col() +
       scale_x_continuous(expand = c(0, 0)) +
       labs(y = NULL) +
       theme_minimal()
    

    Created on 2022-03-14 by the reprex package (v2.0.1)

    This does require that you know the number of predictors when you create the model specification, which can become challenging for a complex recipe including many feature engineering steps.