I have a binary classification problem and used a random forest and a logistic regression.
From the results of conf_mat
, the collect_metrics()
and collect_predictions
I want to change my models to classify as TRUE only if the model is "sure" say 75% or a even higher probability. I just don't know where to specify this change. Would be amazing if someone can give me a hint. My intuition tells me that it should be somewhere in the model specification e.g. somewhere here, but maybe I am wrong.
canc_rf_model <- rand_forest(
mtry = tune(),
min_n = tune(),
trees = 500) %>%
set_engine("ranger") %>%
set_mode("classification")
canc_log_model <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
Thank you very much in advance! M.
The hard class predictions come from the underlying ranger::predictions()
function, not from a tidymodels function so there's not much to be done in the fitting itself.
However, you can pretty fluently change this if you like after fitting. Let's make an example classification model:
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data("ad_data")
alz <- ad_data
# data splitting
set.seed(100)
alz_split <- initial_split(alz, strata = Class, prop = .9)
alz_train <- training(alz_split)
alz_test <- testing(alz_split)
# data resampling
set.seed(100)
alz_folds <-
vfold_cv(alz_train, v = 10, strata = Class)
rf_mod <-
rand_forest(trees = 1e3) %>%
set_engine("ranger") %>%
set_mode("classification")
rf_wf <-
workflow() %>%
add_formula(Class ~ .) %>%
add_model(rf_mod)
set.seed(100)
rf_preds <- rf_wf %>%
fit_resamples(
resamples = alz_folds,
control = control_resamples(save_pred = TRUE)) %>%
collect_predictions()
Here is the default confusion matrix:
rf_preds %>%
conf_mat(Class, .pred_class)
#> Truth
#> Prediction Impaired Control
#> Impaired 37 5
#> Control 45 213
You can use the probably package to post-process your class probability estimates and just overwrite the default values:
library(probably)
#>
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#>
#> as.factor, as.ordered
rf_preds %>%
mutate(.pred_class = make_two_class_pred(.pred_Impaired,
levels(rf_preds$Class),
threshold = 0.75),
.pred_class = factor(.pred_class, levels = levels(rf_preds$Class))) %>%
conf_mat(Class, .pred_class)
#> Truth
#> Prediction Impaired Control
#> Impaired 0 0
#> Control 82 218
Created on 2021-03-23 by the reprex package (v1.0.0)