I am trying to plot a decision tree in R after using tidymodels workflow but I have trouble finding the good function to use and/or the good model. After a code like this, how do you code a plot?
xgboost_spec <-
boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(),
loss_reduction = tune(), sample_size = tune()) %>%
set_mode("classification") %>%
set_engine("xgboost")
xgboost_workflow <-
workflow() %>%
add_recipe(data_recipe) %>%
add_model(xgboost_spec)
xgboost_tune <-
tune_grid(xgboost_workflow, resamples = data_folds, grid = 10)
final_xgboost <- xgboost_workflow %>%
finalize_workflow(select_best(xgboost_tune, "roc_auc"))
xgboost_results <- final_xgboost %>%
fit_resamples(
resamples = data_folds,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
control = control_resamples(save_pred = TRUE)
)
Or after a decision tree code?
tree_spec <- decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()) %>%
set_engine("rpart") %>%
set_mode("classification")
tree_workflow <-
workflow() %>%
add_recipe(data_recipe) %>%
add_model(tree_spec)
tree_grid <- grid_regular(cost_complexity(),
tree_depth(),
min_n(), levels = 4)
tree_tune <- tree_workflow %>%
tune_grid(
resamples = data_folds,
grid = tree_grid,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity)
)
final_tree <- tree_workflow %>%
finalize_workflow(select_best(tree_tune, "roc_auc"))
tree_results <- final_tree %>%
fit_resamples(
resamples = data_folds,
metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
control = control_resamples(save_pred = TRUE)
)
Is it possible? Or should I use the model after last_fit()
?
Thank you!
I don't think it makes much sense to plot an xgboost model because it is boosted trees (lots and lots of trees) but you can plot a single decision tree.
The key is that most packages for visualization of tree results require you to repair the call object.
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(penguins)
penguins <- na.omit(penguins)
cart_spec <-
decision_tree() %>%
set_engine("rpart") %>%
set_mode("classification")
cart_fit <-
cart_spec %>%
fit(sex ~ species + bill_length_mm + body_mass_g, data = penguins)
cart_fit <- repair_call(cart_fit, data = penguins)
library(rattle)
#> Loading required package: bitops
#> Rattle: A free graphical interface for data science with R.
#> Version 5.4.0 Copyright (c) 2006-2020 Togaware Pty Ltd.
#> Type 'rattle()' to shake, rattle, and roll your data.
fancyRpartPlot(cart_fit$fit)
Created on 2021-08-07 by the reprex package (v2.0.0)
The rattle package isn't the only thing out there; ggparty is another good option.
This does mean you must use a parsnip model plus a preprocessor, not a workflow. You can see a tutorial of how to tune a parsnip plus preprocessor here.