Search code examples

Tuning with classification_cost and custom cost matrix in Tidymodels

I am using tidymodels for building a model where false negatives are more costly than false positives. Hence I'd like to use the yardstick::classification_cost metric for hyperparameter tuning, but with a custom classification cost matrix that reflects this fact.

Doing this after fitting a model is simple enough:


# load simulated prediction output

# cost matrix penalizing false negatives
cost_matrix <- tribble(
  ~truth, ~estimate, ~cost,
  "Class1", "Class2",  2,
  "Class2", "Class1",  1

# use function on simulated prediction output
  data = two_class_example,
  truth = truth,
  # target class probability
  # supply the function with the cost matrix
  costs = cost_matrix)
#> # A tibble: 1 × 3
#>   .metric             .estimator .estimate
#>   <chr>               <chr>          <dbl>
#> 1 classification_cost binary         0.260

Created on 2021-11-01 by the reprex package (v2.0.1)

But using this function during hyperparameter tuning is where I run into problems. The documentation states that for setting options the metric should be wrapped in a custom function. Here's my attempt and the resulting error. Note how this wrapper works fine for evaluating a fitted model, but throws an error when trying to use for tuning:


# load data

# create custom metric penalizing false negatives 
classification_cost_penalized <- function(
  na_rm = TRUE
) {
  # cost matrix penalizing false negatives
  cost_matrix <- tribble(
    ~truth, ~estimate, ~cost,
    "Class1", "Class2",  2,
    "Class2", "Class1",  1
    data = data,
    truth = !! rlang::enquo(truth),
    # supply the function with the class probabilities
    !! rlang::enquo(class_proba), 
    # supply the function with the cost matrix
    costs = cost_matrix,
    na_rm = na_rm

# Use `new_numeric_metric()` to formalize this new metric function
classification_cost_penalized <- new_prob_metric(classification_cost_penalized, "minimize")

# test if this works on the simulated estimates
two_class_example %>% 
  classification_cost_penalized(truth = truth, class_prob = Class1)
#> # A tibble: 1 × 3
#>   .metric             .estimator .estimate
#>   <chr>               <chr>          <dbl>
#> 1 classification_cost binary         0.260

# test if this works with hyperparameter tuning

# specify a RF model
my_model <- 
  rand_forest(mtry = tune(), 
              min_n = tune(),
              trees = 500) %>% 
  set_engine("ranger") %>% 

# specify recipe
my_recipe <- recipe(Class ~ A + B, data = two_class_dat)

# bundle to workflow
my_wf <- workflow() %>% 
  add_model(my_model) %>% 

# start tuning
tuned_rf <- my_wf %>% 
  # set up tuning grid
    resamples = vfold_cv(two_class_dat, 
                         v = 5),
    grid = 5,
    metrics = metric_set(classification_cost_penalized))
#> i Creating pre-processing data to finalize unknown parameter: mtry
#> x Fold1: internal: Error: In metric: `classification_cost_penalized`
#> unused argum...
#> x Fold2: internal: Error: In metric: `classification_cost_penalized`
#> unused argum...
#> x Fold3: internal: Error: In metric: `classification_cost_penalized`
#> unused argum...
#> x Fold4: internal: Error: In metric: `classification_cost_penalized`
#> unused argum...
#> x Fold5: internal: Error: In metric: `classification_cost_penalized`
#> unused argum...
#> Warning: All models failed. See the `.notes` column.

Created on 2021-11-01 by the reprex package (v2.0.1)

Unnesting the notes shows that there are unused arguments: "internal: Error: In metric: classification_cost_penalized\nunused arguments (estimator = ~prob_estimator, event_level = ~event_level)" But apparently the yardstick_event_level()function, which is how event_level should be set according to this documentation, does not exist? No function under that name shows up when searching for it.

I don't know how to proceed here. Thank you for your time.


  • When you are tweaking an existing yardstick metric, it is much easier to use the metric_tweak() function, which allows you to hard code certain optional arguments (like cost), while keeping everything else the same. It is sort of like purrr::partial(), but for yardstick metrics.

    # load data
    cost_matrix <- tribble(
      ~truth, ~estimate, ~cost,
      "Class1", "Class2",  2,
      "Class2", "Class1",  1
    classification_cost_penalized <- metric_tweak(
      .name = "classification_cost_penalized",
      .fn = classification_cost,
      costs = cost_matrix
    # test if this works on the simulated estimates
    two_class_example %>% 
      classification_cost_penalized(truth = truth, class_prob = Class1)
    #> # A tibble: 1 × 3
    #>   .metric                       .estimator .estimate
    #>   <chr>                         <chr>          <dbl>
    #> 1 classification_cost_penalized binary         0.260
    # specify a RF model
    my_model <- 
        mtry = tune(), 
        min_n = tune(),
        trees = 500
      ) %>% 
      set_engine("ranger") %>% 
    # specify recipe
    my_recipe <- recipe(Class ~ A + B, data = two_class_dat)
    # bundle to workflow
    my_wf <- workflow() %>% 
      add_model(my_model) %>% 
    # start tuning
    tuned_rf <- my_wf %>% 
        resamples = vfold_cv(two_class_dat, v = 5),
        grid = 5,
        metrics = metric_set(classification_cost_penalized)
    #> i Creating pre-processing data to finalize unknown parameter: mtry
    #> # A tibble: 5 × 8
    #>    mtry min_n .metric              .estimator  mean     n std_err .config       
    #>   <int> <int> <chr>                <chr>      <dbl> <int>   <dbl> <chr>         
    #> 1     1    35 classification_cost… binary     0.407     5  0.0162 Preprocessor1…
    #> 2     1    23 classification_cost… binary     0.403     5  0.0146 Preprocessor1…
    #> 3     1    10 classification_cost… binary     0.403     5  0.0137 Preprocessor1…
    #> 4     2    27 classification_cost… binary     0.396     5  0.0166 Preprocessor1…
    #> 5     2     6 classification_cost… binary     0.401     5  0.0161 Preprocessor1…

    Created on 2021-11-03 by the reprex package (v2.0.1)