Search code examples
rknntidymodels

How do I see the selected k from parsnip::nearest_neighbor()


If I fit a fit a k nearest neighbors model using parsnip::nearest_neighbor(), what k is selected if I don't specify how to tune?

I am trying to figure out what k is selected here:

the_model <- nearest_neighbor() %>%
  set_engine("kknn") %>% 
  set_mode("classification") 

the_workflow <- 
  workflow() %>% 
  add_recipe(the_recipe) %>% 
  add_model(the_model) 

the_results <-
  the_workflow %>%
  fit_resamples(resamples = cv_folds, 
                metrics = metric_set(roc_auc),
                control = control_resamples(save_pred = TRUE)) 

I know that if I use nearest_neighbor(neighbors = tune()) I can get the k back using select_best("roc_auc") but without specifying how to tune I get results but select_best() does not return a k. What k value is it using (and how did you figure out the answer)?


Solution

  • If you don't specify parameters for a model specification in parsnip, the value will be determined by the defaults in the underlying engine implementation unless otherwise specified in the documentation.

    Look at the documentation for nearest_neighbors() and go down to arguments it says under neighbors

    For kknn, a value of 5 is used if neighbors is not specified.

    You can also use the translate() function from {parsnip} to see the code that the model specification creates

    library(parsnip)
    
    the_model <- nearest_neighbor() %>%
      set_engine("kknn") %>% 
      set_mode("classification") 
    
    the_model %>%
      translate()
    #> K-Nearest Neighbor Model Specification (classification)
    #> 
    #> Computational engine: kknn 
    #> 
    #> Model fit template:
    #> kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    #>     ks = min_rows(5, data, 5))
    

    Where we see that ks was set to min_rows(5, data, 5), and if we specify neighbors in nearest_neighbors() that value will change

    nearest_neighbor(neighbors = 25) %>%
      set_engine("kknn") %>% 
      set_mode("classification") %>%
      translate()
    #> K-Nearest Neighbor Model Specification (classification)
    #> 
    #> Main Arguments:
    #>   neighbors = 25
    #> 
    #> Computational engine: kknn 
    #> 
    #> Model fit template:
    #> kknn::train.kknn(formula = missing_arg(), data = missing_arg(), 
    #>     ks = min_rows(25, data, 5))