Search code examples
rtidymodels

Set tuning parameter range a priori


I know that in tidymodels you can set a custom tunable parameter space by interacting directly with the workflow object as follows:

library(tidymodels)

model <- linear_reg(
  mode = "regression", 
  engine = "glmnet", 
  penalty = tune()
  )

rec_cars <- recipe(mpg ~ ., data = mtcars)
 
wkf <- workflow() %>% 
  add_recipe(rec_cars) %>% 
  add_model(model) 

wkf_new_param_space <- wkf %>%
  parameters() %>%
  update(penalty = penalty(range = c(0.9, 1)))

but sometimes it makes more sense to do this right at the moment I specify a recipe or a model.

Someone knows a way to achieve this?


Solution

  • The parameter ranges are inherently separated from the model specification and recipe specification in tidymodels. When you set tune() you are giving a signal to the tune function that this parameter will take multiple values and should be tuned over.

    So as a short answer, you can not specify ranges of parameters when you specify a recipe or a model, but you can create the parameters object right after as you did.

    In the end, you need the parameter set to construct the grid values that you are using for hyperparameter tuning, and you can create those gid values in at least 4 ways.

    The first way is to do it the way you are doing it, by pulling the needed parameters out of the workflow and modifying them when needed.

    The second way is to create a parameters object that will match the parameters that you will need to use. This option and the remaining require you to make sure that you create values for all the parameters you are tuning.

    The Third way is to skip the parameters object altogether and create the grid with your grid_*() function and dials functions.

    The fourth way is to skip dials functions altogether and create the data frame yourself. I find tidyr::crossing() an useful replacement for grid_regular(). This way is a lot easier when you are working with integer parameters and parameters that don't benefit from transformations.

    library(tidymodels)
    
    model <- linear_reg(
      mode = "regression", 
      engine = "glmnet", 
      penalty = tune()
      )
    
    rec_cars <- recipe(mpg ~ ., data = mtcars)
     
    wkf <- workflow() %>% 
      add_recipe(rec_cars) %>% 
      add_model(model) 
    
    # Option 1: using parameters() on workflow
    wkf_new_param_space <- wkf %>%
      parameters() %>%
      update(penalty = penalty(range = c(-5, 5)))
    
    wkf_new_param_space %>%
      grid_regular(levels = 10)
    #> # A tibble: 10 × 1
    #>          penalty
    #>            <dbl>
    #>  1      0.00001 
    #>  2      0.000129
    #>  3      0.00167 
    #>  4      0.0215  
    #>  5      0.278   
    #>  6      3.59    
    #>  7     46.4     
    #>  8    599.      
    #>  9   7743.      
    #> 10 100000
    
    # Option 2: Using parameters() on list
    my_params <- parameters(
      list(
        penalty(range = c(-5, 5))
      )
    )
    
    my_params %>%
      grid_regular(levels = 10)
    #> # A tibble: 10 × 1
    #>          penalty
    #>            <dbl>
    #>  1      0.00001 
    #>  2      0.000129
    #>  3      0.00167 
    #>  4      0.0215  
    #>  5      0.278   
    #>  6      3.59    
    #>  7     46.4     
    #>  8    599.      
    #>  9   7743.      
    #> 10 100000
    
    # Option 3: Use grid_*() with dials objects directly
    grid_regular(
      penalty(range = c(-5, 5)),
      levels = 10
    )
    #> # A tibble: 10 × 1
    #>          penalty
    #>            <dbl>
    #>  1      0.00001 
    #>  2      0.000129
    #>  3      0.00167 
    #>  4      0.0215  
    #>  5      0.278   
    #>  6      3.59    
    #>  7     46.4     
    #>  8    599.      
    #>  9   7743.      
    #> 10 100000
    
    
    
    # Option 4: Create grid values manually
    tidyr::crossing(
      penalty = 10 ^ seq(-5, 5, length.out = 10)
    )
    #> # A tibble: 10 × 1
    #>          penalty
    #>            <dbl>
    #>  1      0.00001 
    #>  2      0.000129
    #>  3      0.00167 
    #>  4      0.0215  
    #>  5      0.278   
    #>  6      3.59    
    #>  7     46.4     
    #>  8    599.      
    #>  9   7743.      
    #> 10 100000
    

    Created on 2021-08-17 by the reprex package (v2.0.1)