Search code examples
rtidymodels

target encoding in the tidymodels framework using embedd


I would like to do target encoding for a categorical variable with too many levels.

I have seen this vignette , which proposes the following approach to target encode a variable:

step_lencode_glm()
step_lencode_bayes() 
step_lencode_mixed()

The three approaches use all the records to create the estimates, which tends to overfit to that column.

Using tidymodels, is there an easy way to split my training set 5 folds and get the target encoding from the other 4 folds?

Thanks


Solution

  • That is exactly what will happen if you use a function like fit_resamples(); you will get an estimate for performance from fitting to n - 1 folds and evaluating on the last fold.

    If you want to explore this in more detail, you can follow along with this vignette.

    library(tidymodels)
    library(embed)
    
    data(grants, package = "modeldata")
    
    set.seed(1)
    folds <- vfold_cv(grants_other, v = 3)
    folds
    #> #  3-fold cross-validation 
    #> # A tibble: 3 × 2
    #>   splits              id   
    #>   <list>              <chr>
    #> 1 <split [5460/2730]> Fold1
    #> 2 <split [5460/2730]> Fold2
    #> 3 <split [5460/2730]> Fold3
    
    rec <- 
      recipe(class ~ sponsor_code, data = grants_other) %>%
      step_lencode_glm(sponsor_code, outcome = vars(class))
    
    res <-
      folds %>%
      mutate(recipe = map(splits, prepper, recipe = rec),
             processed = map(recipe, tidy, number = 1))
    
    res %>%
      select(fold_id = id, processed) %>%
      unnest(processed)
    #> # A tibble: 757 × 5
    #>    fold_id level   value terms        id               
    #>    <chr>   <chr>   <dbl> <chr>        <chr>            
    #>  1 Fold1   100D    0.288 sponsor_code lencode_glm_gfHLA
    #>  2 Fold1   101A   -1.50  sponsor_code lencode_glm_gfHLA
    #>  3 Fold1   103C   -1.95  sponsor_code lencode_glm_gfHLA
    #>  4 Fold1   105A   -1.39  sponsor_code lencode_glm_gfHLA
    #>  5 Fold1   107C   16.6   sponsor_code lencode_glm_gfHLA
    #>  6 Fold1   10B    16.6   sponsor_code lencode_glm_gfHLA
    #>  7 Fold1   111C  -16.6   sponsor_code lencode_glm_gfHLA
    #>  8 Fold1   112D    0.560 sponsor_code lencode_glm_gfHLA
    #>  9 Fold1   113A    0.223 sponsor_code lencode_glm_gfHLA
    #> 10 Fold1   118B    0     sponsor_code lencode_glm_gfHLA
    #> # … with 747 more rows
    

    Created on 2022-02-22 by the reprex package (v2.0.1)

    We would recommend resampling like this to estimate the performance of an embedding strategy, and then the whole training set to fit the final embedding.