Search code examples
rrandom-foresttidymodels

How to get a variable importance graph from a random forest using Tidymodels and vip


The dataset I use is the following (mushroom) : https://archive.ics.uci.edu/ml/datasets/mushroom

  • I specified the following recipe, model and workflow of a Random Forest using Tidymodels :
df_recipe_mixt <- df_train |>  recipe(class ~ cap_diameter + cap_color + does_bruise_or_bleed + gill_color + stem_height + stem_width + stem_color + has_ring + habitat + season, data = df_train) |>
  step_scale(all_numeric()) |> 
  step_center(all_numeric()) |> 
  step_dummy(all_nominal(), -all_outcomes()) |> 
  prep()

rf_mod <- rand_forest() |> 
  set_engine("ranger") |> 
  set_mode("classification") |> 
  set_args(mtry = tune(), trees = tune())

rf_wf <- workflow() |>  
  add_model(rf_mod) |> 
  add_recipe(df_recipe_mixt) 

  • Then I tuned the model and applied it to test data
n_cores <- parallel::detectCores(logical = TRUE)
registerDoParallel(cores = n_cores - 1)

rf_params <- extract_parameter_set_dials(rf_wf) |>  
  update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))

rf_grid <- grid_regular(rf_params, levels = c(mtry = 5, trees = 3))

tic("random forest model tuning ")

tune_res_rf <- tune_grid(rf_wf,
  resamples = df_folds,
  grid = rf_grid,
  metrics = metric_set(accuracy)
)

toc()

stopImplicitCluster()

autoplot(tune_res_rf) + dark_mode(theme_minimal())

rf_best <- tune_res_rf |> select_best(metric = "accuracy")

rf_best$trees;rf_best$mtry

rf_final_wf <- rf_wf |>
  finalize_workflow(rf_best)

rf_res <- last_fit(rf_final_wf, split = df_split) |> collect_predictions()

The model works fine. I got back all the metrics, the confusion matrix along with the ROC Curve afterwards. However, I couldn't find a way the get a graph of variable importance (preferably using vip)


Solution

  • To get variable importance from a ranger model you need to specify which importance metric it should calculate. In tidymodels we do this by setting importance = "impurity" inside set_engine() so that this argument is being passed to the underlying {ranger} function.

    Another place were this is shown is here: https://www.tidymodels.org/start/case-study/

    I also updated the recipe to use all_numeric_predictors() and all_nominal_predictors() as you are more likely to want to use those.

    (this reprex is doing to slightly differ from yours because you didn't put how you created df_split())

    library(tidymodels)
    
    mushroom_col_names <- c(
      "cap_shape", "cap_surface", "cap_color", "bruises", "odor", "gill_attachment", 
      "gill_spacing", "gill_size", "gill_color", "stalk_shape", "stalk_root", 
      "stalk_surface_above_ring", "stalk_surface_below_ring", 
      "stalk_color_above_ring", "stalk_color_below_ring", "veil_type", "veil_color", 
      "ring_number", "ring_type", "spore_print_color", "population", "habitat"
    )
    
    mushrooms <- readr::read_csv(
      "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data",
      col_names = mushroom_col_names,
      show_col_types = FALSE
    ) |> select(-veil_color)
    
    df_split <- initial_split(mushrooms)
    
    df_train <- training(df_split)
    df_folds <- vfold_cv(df_train, v = 2)
    
    df_recipe_mixt <- recipe(cap_shape ~ .,
                             data = df_train) |>
      step_scale(all_numeric_predictors()) |>
      step_center(all_numeric_predictors()) |>
      step_novel(all_nominal_predictors()) |>
      step_unknown(all_nominal_predictors()) |>
      step_dummy(all_nominal_predictors()) |>
      prep()
    
    rf_mod <- rand_forest() |> 
      set_engine("ranger", importance = "impurity") |> 
      set_mode("classification") |> 
      set_args(mtry = tune(), trees = tune())
    
    rf_wf <- workflow() |>  
      add_model(rf_mod) |> 
      add_recipe(df_recipe_mixt) 
    
    rf_params <- extract_parameter_set_dials(rf_wf) |>  
      update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))
    
    rf_grid <- grid_regular(rf_params, levels = c(mtry = 2, trees = 2))
    
    tune_res_rf <- tune_grid(rf_wf,
                             resamples = df_folds,
                             grid = rf_grid,
                             metrics = metric_set(accuracy)
    )
    
    
    rf_best <- tune_res_rf |> select_best(metric = "accuracy")
    
    rf_final_wf <- rf_wf |>
      finalize_workflow(rf_best)
    
    rf_res <- last_fit(rf_final_wf, split = df_split)
    
    extract_fit_parsnip(rf_res$.workflow[[1]]) |>
      vip::vi()
    #> # A tibble: 135 × 2
    #>    Variable                   Importance
    #>    <chr>                           <dbl>
    #>  1 gill_attachment_n               265. 
    #>  2 gill_color_n                    180. 
    #>  3 gill_attachment_f               179. 
    #>  4 stalk_surface_below_ring_k      127. 
    #>  5 stalk_color_above_ring_k        106. 
    #>  6 odor                            101. 
    #>  7 spore_print_color_p             100. 
    #>  8 population_h                     94.4
    #>  9 habitat_v                        82.4
    #> 10 stalk_surface_below_ring_s       79.2
    #> # … with 125 more rows
    

    Created on 2023-03-14 with reprex v2.0.2