Search code examples
rtidymodelsimbalanced-data

Tidymodels class cost


I am dealing with a prediction case where the data is suffering from a strong imbalance in the binary prediction target. Is there a way of penalizing wrong predictions of the minority class with a cost matrix in TidyModels? I know that caret had this implemented, but the information I find in TidyModels is quite confusing. All I find is the baguette::class_cost() function from the experimental baguette package, which only seems to apply to bagged trees models.


Solution

  • Yes, you want to set a classification_cost():

    library(yardstick)
    #> For binary classification, the first factor level is assumed to be the event.
    #> Use the argument `event_level = "second"` to alter this as needed.
    library(dplyr)
    #> 
    #> Attaching package: 'dplyr'
    #> The following objects are masked from 'package:stats':
    #> 
    #>     filter, lag
    #> The following objects are masked from 'package:base':
    #> 
    #>     intersect, setdiff, setequal, union
    
    # Two class example
    data(two_class_example)
    
    # Assuming `Class1` is our "event", this penalizes false positives heavily
    costs1 <- tribble(
      ~truth,   ~estimate, ~cost,
      "Class1", "Class2",  1,
      "Class2", "Class1",  2
    )
    
    # Assuming `Class1` is our "event", this penalizes false negatives heavily
    costs2 <- tribble(
      ~truth,   ~estimate, ~cost,
      "Class1", "Class2",  2,
      "Class2", "Class1",  1
    )
    
    classification_cost(two_class_example, truth, Class1, costs = costs1)
    #> # A tibble: 1 × 3
    #>   .metric             .estimator .estimate
    #>   <chr>               <chr>          <dbl>
    #> 1 classification_cost binary         0.288
    classification_cost(two_class_example, truth, Class1, costs = costs2)
    #> # A tibble: 1 × 3
    #>   .metric             .estimator .estimate
    #>   <chr>               <chr>          <dbl>
    #> 1 classification_cost binary         0.260
    

    Created on 2021-10-27 by the reprex package (v2.0.1)

    In tidymodels, you can use this metric either just to compute results after the fact or in tuning. Learn more here.