Search code examples
rknngrid-searchtidymodelsmwe

How to tune a model using grid search and a single validation fold with tidymodels?


I have just learnt about the KNN algorithm and machine learning. It is a lot for me to take in and we are using tidymodels in R to practice.

Now, I know how to implement a grid search using k-fold cross-validation as follows:

hist_data_split <- initial_split(hist_data, strata = fraud)
hist_data_train <- training(hist_data_split)
hist_data_test <- testing(hist_data_split)
folds <- vfold_cv(hist_data_train, strata = fraud)
nearest_neighbor_grid <- grid_regular(neighbors(range = c(1, 500)), levels = 25)
knn_rec_1 <- recipe(fraud ~ ., data = hist_data_train)
knn_spec_1 <- nearest_neighbor(mode = "classification", engine = "kknn", neighbors = tune(), weight_func = "rectangular")
knn_wf_1 <- workflow(preprocessor = knn_rec_1, spec = knn_spec_1)
knn_fit_1 <- tune_grid(knn_wf_1, resamples = folds, metrics = metric_set(accuracy, sens, spec, roc_auc), control = control_resamples(save_pred = T), grid = nearest_neighbor_grid)

In the above case, I am essentially running a 10-fold cross-validated grid search to tune my model. However, the size of hist_data is 169173, which gives an optimal K of about 411 and with a 10-fold cross-validation, the tuning is going to take forever, so the hint given is to use a single validation fold instead of cross-validation.

Thus, I am wondering how I can tweak my code to implement this. When I add the argument v = 1 in vfold_cv, R throws me an error which says, "At least one row should be selected for the analysis set." Should I instead change resamples = folds in tune_grid to resamples = 1?

Any intuitive suggestions will be greatly appreciated :)

P.S. I did not include an MWE in the sense that the data is not provided because I feel like this is a really trivial question which can be answered as is!


Solution

  • If you are not able to do a cross validation split, for whatever reason, you can do a validation split which conceptually is very close to a v = 1 cross validation.

    library(tidymodels)
    
    hist_data_split <- initial_split(ames, strata = Street)
    hist_data_train <- training(hist_data_split)
    hist_data_test <- testing(hist_data_split)
    
    folds <- validation_split(hist_data_train, strata = Street)
    
    nearest_neighbor_grid <- grid_regular(
      neighbors(range = c(1, 500)), 
      levels = 25
    )
    
    knn_rec_1 <- recipe(Street ~ ., data = ames)
    knn_spec_1 <- nearest_neighbor(neighbors = tune()) %>%
      set_mode("classification") %>%
      set_engine("kknn") %>%
      set_args(weight_func = "rectangular")
    
    knn_wf_1 <- workflow(preprocessor = knn_rec_1, spec = knn_spec_1)
    
    knn_fit_1 <- tune_grid(
      knn_wf_1,
      resamples = folds,
      metrics = metric_set(accuracy, sens, spec, roc_auc),
      control = control_resamples(save_pred = T),
      grid = nearest_neighbor_grid
    )
    
    knn_fit_1
    #> # Tuning results
    #> # Validation Set Split (0.75/0.25)  using stratification 
    #> # A tibble: 1 × 5
    #>   splits             id         .metrics           .notes           .predictions
    #>   <list>             <chr>      <list>             <list>           <list>      
    #> 1 <split [1647/550]> validation <tibble [100 × 5]> <tibble [0 × 3]> <tibble>
    
    knn_fit_1 %>%
      collect_metrics()
    #> # A tibble: 100 × 7
    #>    neighbors .metric  .estimator  mean     n std_err .config              
    #>        <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
    #>  1         1 accuracy binary     0.996     1      NA Preprocessor1_Model01
    #>  2         1 roc_auc  binary     0.5       1      NA Preprocessor1_Model01
    #>  3         1 sens     binary     0         1      NA Preprocessor1_Model01
    #>  4         1 spec     binary     1         1      NA Preprocessor1_Model01
    #>  5        21 accuracy binary     0.996     1      NA Preprocessor1_Model02
    #>  6        21 roc_auc  binary     0.495     1      NA Preprocessor1_Model02
    #>  7        21 sens     binary     0         1      NA Preprocessor1_Model02
    #>  8        21 spec     binary     1         1      NA Preprocessor1_Model02
    #>  9        42 accuracy binary     0.996     1      NA Preprocessor1_Model03
    #> 10        42 roc_auc  binary     0.486     1      NA Preprocessor1_Model03
    #> # … with 90 more rows
    

    Created on 2022-09-06 by the reprex package (v2.0.1)