The dataset I use is the following (mushroom) : https://archive.ics.uci.edu/ml/datasets/mushroom
Tidymodels
:df_recipe_mixt <- df_train |> recipe(class ~ cap_diameter + cap_color + does_bruise_or_bleed + gill_color + stem_height + stem_width + stem_color + has_ring + habitat + season, data = df_train) |>
step_scale(all_numeric()) |>
step_center(all_numeric()) |>
step_dummy(all_nominal(), -all_outcomes()) |>
prep()
rf_mod <- rand_forest() |>
set_engine("ranger") |>
set_mode("classification") |>
set_args(mtry = tune(), trees = tune())
rf_wf <- workflow() |>
add_model(rf_mod) |>
add_recipe(df_recipe_mixt)
n_cores <- parallel::detectCores(logical = TRUE)
registerDoParallel(cores = n_cores - 1)
rf_params <- extract_parameter_set_dials(rf_wf) |>
update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))
rf_grid <- grid_regular(rf_params, levels = c(mtry = 5, trees = 3))
tic("random forest model tuning ")
tune_res_rf <- tune_grid(rf_wf,
resamples = df_folds,
grid = rf_grid,
metrics = metric_set(accuracy)
)
toc()
stopImplicitCluster()
autoplot(tune_res_rf) + dark_mode(theme_minimal())
rf_best <- tune_res_rf |> select_best(metric = "accuracy")
rf_best$trees;rf_best$mtry
rf_final_wf <- rf_wf |>
finalize_workflow(rf_best)
rf_res <- last_fit(rf_final_wf, split = df_split) |> collect_predictions()
The model works fine. I got back all the metrics, the confusion matrix along with the ROC Curve afterwards. However, I couldn't find a way the get a graph of variable importance (preferably using vip
)
To get variable importance from a ranger model you need to specify which importance
metric it should calculate. In tidymodels we do this by setting importance = "impurity"
inside set_engine()
so that this argument is being passed to the underlying {ranger} function.
Another place were this is shown is here: https://www.tidymodels.org/start/case-study/
I also updated the recipe to use all_numeric_predictors()
and all_nominal_predictors()
as you are more likely to want to use those.
(this reprex is doing to slightly differ from yours because you didn't put how you created df_split()
)
library(tidymodels)
mushroom_col_names <- c(
"cap_shape", "cap_surface", "cap_color", "bruises", "odor", "gill_attachment",
"gill_spacing", "gill_size", "gill_color", "stalk_shape", "stalk_root",
"stalk_surface_above_ring", "stalk_surface_below_ring",
"stalk_color_above_ring", "stalk_color_below_ring", "veil_type", "veil_color",
"ring_number", "ring_type", "spore_print_color", "population", "habitat"
)
mushrooms <- readr::read_csv(
"https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data",
col_names = mushroom_col_names,
show_col_types = FALSE
) |> select(-veil_color)
df_split <- initial_split(mushrooms)
df_train <- training(df_split)
df_folds <- vfold_cv(df_train, v = 2)
df_recipe_mixt <- recipe(cap_shape ~ .,
data = df_train) |>
step_scale(all_numeric_predictors()) |>
step_center(all_numeric_predictors()) |>
step_novel(all_nominal_predictors()) |>
step_unknown(all_nominal_predictors()) |>
step_dummy(all_nominal_predictors()) |>
prep()
rf_mod <- rand_forest() |>
set_engine("ranger", importance = "impurity") |>
set_mode("classification") |>
set_args(mtry = tune(), trees = tune())
rf_wf <- workflow() |>
add_model(rf_mod) |>
add_recipe(df_recipe_mixt)
rf_params <- extract_parameter_set_dials(rf_wf) |>
update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))
rf_grid <- grid_regular(rf_params, levels = c(mtry = 2, trees = 2))
tune_res_rf <- tune_grid(rf_wf,
resamples = df_folds,
grid = rf_grid,
metrics = metric_set(accuracy)
)
rf_best <- tune_res_rf |> select_best(metric = "accuracy")
rf_final_wf <- rf_wf |>
finalize_workflow(rf_best)
rf_res <- last_fit(rf_final_wf, split = df_split)
extract_fit_parsnip(rf_res$.workflow[[1]]) |>
vip::vi()
#> # A tibble: 135 × 2
#> Variable Importance
#> <chr> <dbl>
#> 1 gill_attachment_n 265.
#> 2 gill_color_n 180.
#> 3 gill_attachment_f 179.
#> 4 stalk_surface_below_ring_k 127.
#> 5 stalk_color_above_ring_k 106.
#> 6 odor 101.
#> 7 spore_print_color_p 100.
#> 8 population_h 94.4
#> 9 habitat_v 82.4
#> 10 stalk_surface_below_ring_s 79.2
#> # … with 125 more rows
Created on 2023-03-14 with reprex v2.0.2