Search code examples
rtidymodelsr-recipes

R - Partial dependence plots from workflow


I created the following recipe to predict my random forest in R:

set.seed(123456)
cv_folds <- Data_train %>% vfold_cv(v = 4, strata = Lead_week)
# Create a recipe
rf_mod_recipe <- recipe(Lead_week ~ Jaar + Aantal + Verzekering + Leeftijd + Retentie + 
                          Aantal_proeven + Geslacht + FLG_ADVERTISING + FLG_MAIL +
                          FLG_PHONE + FLG_EMAIL + Proef1 + Proef2 + Regio + 
                          Month + AC,
                        data = Data_train) %>%
                        step_normalize(Leeftijd) 

# Specify the recipe
rf_mod <- rand_forest(mtry = tune(), min_n = tune(), trees = 200) %>%
  set_mode("regression") %>%
  set_engine("ranger", importance = "permutation")
# Create a workflow
rf_mod_workflow <-  workflow() %>%
  add_model(rf_mod) %>%
  add_recipe(rf_mod_recipe) 
rf_mod_workflow
# State our error metrics
class_metrics <- metric_set(rmse, mae)
rf_grid <- grid_regular(
  mtry(range = c(5, 15)),
  min_n(range = c(10, 200)),
  levels = 5
)

rf_grid
# Train the model
set.seed(654321)

rf_tune_res <- tune_grid(
  rf_mod_workflow,
  resamples = cv_folds,
  grid = rf_grid,
  metrics = class_metrics
)
# Collect the optimal hyperparameters
rf_tune_res %>%
  collect_metrics()
# Select the best number of mtry
best_rmse <- select_best(rf_tune_res, "rmse")
rf_final_wf <- finalize_workflow(rf_mod_workflow, best_rmse)
rf_final_wf
# Create a workflow
rf_mod_workflow <-  workflow() %>%
  add_model(rf_mod) %>%
  add_recipe(rf_mod_recipe) 
rf_mod_workflow
predict(rf_final_wf, grid) %>%
  bind_cols(rf_mod_recipe %>% select(AC)) %>%
  ggplot(aes(y = .pred, x = AC)) +
  geom_path()

After retrieving the in-sample performance, I use the workflow to predict on holdout data.

# Finalise the workflow
set.seed(56789)
rf_final_fit <- rf_final_wf %>%
  last_fit(splits, metrics = class_metrics)
# Collect predictions
summary_rf <- rf_final_fit %>% 
  collect_predictions()

summary(summary_rf$.pred)
# Collect metrics
rf_final_fit %>% 
  collect_metrics()

So I used cross-validation to finetune and eventually test on holdout data. However, how do I get partial dependence plots to 'open the black box'?


Solution

  • We recommend using DALEX for these kinds of model explainability tasks, because there is great support for tidymodels.

    After you have a final fitted model (such as your random forest), you need to:

    • create a DALEX explainer
    • compute the PDP
    library(tidymodels)
    #> Registered S3 method overwritten by 'tune':
    #>   method                   from   
    #>   required_pkgs.model_spec parsnip
    library(DALEXtra)
    #> Loading required package: DALEX
    #> Welcome to DALEX (version: 2.2.0).
    #> Find examples and detailed introduction at: http://ema.drwhy.ai/
    #> Additional features will be available after installation of: ggpubr.
    #> Use 'install_dependencies()' to get all suggested dependencies
    #> 
    #> Attaching package: 'DALEX'
    #> The following object is masked from 'package:dplyr':
    #> 
    #>     explain
    
    data(ames)
    ames_train <- ames %>%
        transmute(Sale_Price = log10(Sale_Price),
                  Gr_Liv_Area = as.numeric(Gr_Liv_Area), 
                  Year_Built, Bldg_Type)
    
    rf_model <- 
        rand_forest(trees = 1000) %>% 
        set_engine("ranger") %>% 
        set_mode("regression")
    
    rf_wflow <- 
        workflow() %>% 
        add_formula(
            Sale_Price ~ Gr_Liv_Area + Year_Built + Bldg_Type) %>% 
        add_model(rf_model) 
    
    rf_fit <- rf_wflow %>% fit(data = ames_train)
    explainer_rf <- explain_tidymodels(
        rf_fit, 
        data = dplyr::select(ames_train, -Sale_Price), 
        y = ames_train$Sale_Price,
        label = "random forest"
    )
    #> Preparation of a new explainer is initiated
    #>   -> model label       :  random forest 
    #>   -> data              :  2930  rows  3  cols 
    #>   -> data              :  tibble converted into a data.frame 
    #>   -> target variable   :  2930  values 
    #>   -> predict function  :  yhat.workflow  will be used ( [33m default [39m )
    #>   -> predicted values  :  No value for predict function target column. ( [33m default [39m )
    #>   -> model_info        :  package tidymodels , ver. 0.1.3 , task regression ( [33m default [39m ) 
    #>   -> predicted values  :  numerical, min =  4.896018 , mean =  5.220595 , max =  5.518857  
    #>   -> residual function :  difference between y and yhat ( [33m default [39m )
    #>   -> residuals         :  numerical, min =  -0.8083636 , mean =  4.509735e-05 , max =  0.3590898  
    #>  [32m A new explainer has been created! [39m
    
    pdp_rf <- model_profile(explainer_rf, N = NULL, 
                            variables = "Gr_Liv_Area", groups = "Bldg_Type")
    
    as_tibble(pdp_rf$agr_profiles) %>%
        mutate(`_label_` = stringr::str_remove(`_label_`, "random forest_")) %>%
        ggplot(aes(`_x_`, `_yhat_`, color = `_label_`)) +
        geom_line(size = 1.2, alpha = 0.8) +
        labs(x = "Gross living area", 
             y = "Sale Price (log)", 
             color = NULL,
             title = "Partial dependence profile for Ames housing sales",
             subtitle = "Predictions from a random forest model")
    

    Created on 2021-05-27 by the reprex package (v2.0.0)

    Looks like I should maybe put the x-axis on a log scale.

    You can call plot(pdp_rf) to use the default plot methods from DALEX but I showed here how to make a more customized plot using the underlying computed PDPs.