The following explain_tidymodels is created, to to display partial dependence plots.
explainer <- explain_tidymodels(rf_vi_fit, data = Data_train, y = Data_train$Lead_week)
Now i'm creating plots by doing the following:
model_profile(explainer, variables = c( "AC", "Jaar, "Month", "Retentie")) %>% plot()
Now I'm getting the following image:
The problem is that first of all, the text of "Created for the workflow model" blocks my AC header. Secondly, I want to change the colour from blue to red. I tried %>% plot(color = "red") and %>% plot(col = "red"), but both do not seem to work.
Anyone knows how to fix one of these plotting issues? Thanks in advance!
You can access the data that creates these plots using the as_tibble()
function, and then you can create plots in whatever custom way you prefer:
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.2.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
data(ames)
ames_train <- ames %>%
transmute(Sale_Price = log10(Sale_Price),
Gr_Liv_Area = as.numeric(Gr_Liv_Area),
Year_Built, Bldg_Type)
rf_model <-
rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("regression")
rf_wflow <-
workflow() %>%
add_formula(
Sale_Price ~ Gr_Liv_Area + Year_Built + Bldg_Type) %>%
add_model(rf_model)
rf_fit <- rf_wflow %>% fit(data = ames_train)
explainer_rf <- explain_tidymodels(
rf_fit,
data = dplyr::select(ames_train, -Sale_Price),
y = ames_train$Sale_Price,
label = "random forest"
)
#> Preparation of a new explainer is initiated
#> -> model label : random forest
#> -> data : 2930 rows 3 cols
#> -> data : tibble converted into a data.frame
#> -> target variable : 2930 values
#> -> predict function : yhat.workflow will be used ( [33m default [39m )
#> -> predicted values : No value for predict function target column. ( [33m default [39m )
#> -> model_info : package tidymodels , ver. 0.1.3 , task regression ( [33m default [39m )
#> -> predicted values : numerical, min = 4.91122 , mean = 5.220561 , max = 5.520101
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.8113628 , mean = 7.953836e-05 , max = 0.3598514
#> [32m A new explainer has been created! [39m
pdp_rf <- model_profile(explainer_rf, N = NULL,
variables = "Gr_Liv_Area", groups = "Bldg_Type")
as_tibble(pdp_rf$agr_profiles) %>%
mutate(`_label_` = stringr::str_remove(`_label_`, "random forest_")) %>%
ggplot(aes(`_x_`, `_yhat_`, color = `_label_`)) +
geom_line(size = 1.2, alpha = 0.8) +
labs(x = "Gross living area",
y = "Sale Price (log)",
color = NULL,
title = "Partial dependence profile for Ames housing sales",
subtitle = "Predictions from a random forest model")
Created on 2021-05-27 by the reprex package (v2.0.0)