Search code examples
rrandom-foresttidymodels

Tidymodels: Classify as TRUE only if the probability is 75% or higher


I have a binary classification problem and used a random forest and a logistic regression. From the results of conf_mat, the collect_metrics() and collect_predictions I want to change my models to classify as TRUE only if the model is "sure" say 75% or a even higher probability. I just don't know where to specify this change. Would be amazing if someone can give me a hint. My intuition tells me that it should be somewhere in the model specification e.g. somewhere here, but maybe I am wrong.

canc_rf_model <- rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500) %>%
  set_engine("ranger") %>%
  set_mode("classification")

canc_log_model <- logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification")

Thank you very much in advance! M.


Solution

  • The hard class predictions come from the underlying ranger::predictions() function, not from a function so there's not much to be done in the fitting itself.

    However, you can pretty fluently change this if you like after fitting. Let's make an example classification model:

    library(tidymodels)
    #> Registered S3 method overwritten by 'tune':
    #>   method                   from   
    #>   required_pkgs.model_spec parsnip
    
    data("ad_data")
    alz <- ad_data
    
    # data splitting
    set.seed(100)
    alz_split  <- initial_split(alz, strata = Class, prop = .9)
    alz_train  <- training(alz_split)
    alz_test   <- testing(alz_split)
    
    # data resampling
    set.seed(100)
    alz_folds <- 
        vfold_cv(alz_train, v = 10, strata = Class)
    
    rf_mod <-
        rand_forest(trees = 1e3) %>% 
        set_engine("ranger") %>% 
        set_mode("classification")
    
    rf_wf <-
        workflow() %>% 
        add_formula(Class ~ .) %>% 
        add_model(rf_mod)
    
    set.seed(100)
    rf_preds <- rf_wf %>% 
        fit_resamples(
            resamples = alz_folds, 
            control = control_resamples(save_pred = TRUE)) %>% 
        collect_predictions()
    

    Here is the default confusion matrix:

    rf_preds %>%
        conf_mat(Class, .pred_class)
    #>           Truth
    #> Prediction Impaired Control
    #>   Impaired       37       5
    #>   Control        45     213
    

    You can use the probably package to post-process your class probability estimates and just overwrite the default values:

    library(probably)
    #> 
    #> Attaching package: 'probably'
    #> The following objects are masked from 'package:base':
    #> 
    #>     as.factor, as.ordered
    
    rf_preds %>%
        mutate(.pred_class = make_two_class_pred(.pred_Impaired, 
                                                 levels(rf_preds$Class),
                                                 threshold = 0.75),
               .pred_class = factor(.pred_class, levels = levels(rf_preds$Class))) %>%
        conf_mat(Class, .pred_class)
    #>           Truth
    #> Prediction Impaired Control
    #>   Impaired        0       0
    #>   Control        82     218
    

    Created on 2021-03-23 by the reprex package (v1.0.0)