Search code examples
rtidymodelsc5.0

How to plot a tree produced by C5.0 in tidymodels?


Why in the following short reprex I get an error for plotting a C5.0 tree when using tidymodels and I don't get same error when using C5.0 package directly ?

I used the same C50 parameters in both cases. I tried to find documentation about this but no success.

How to plot a tree produced by C5.0 in tidymodels ?

library(tidyverse)
library(tidymodels)

rcp <- recipe(Species ~ ., iris)
mdl <- decision_tree(mode ="classification", engine="C5.0")
wf <- workflow(rcp, mdl)
fttd <- fit(wf, iris)

summary(fttd$fit$fit$fit)
#> 
#> Call:
#> C5.0.default(x = x, y = y, trials = 1, control = C50::C5.0Control(minCases =
#>  2, sample = 0))
#> 
#> 
#> C5.0 [Release 2.07 GPL Edition]      Mon Sep  5 14:01:58 2022
#> -------------------------------
#> 
#> Class specified by attribute `outcome'
#> 
#> Read 150 cases (5 attributes) from undefined.data
#> 
#> Decision tree:
#> 
#> Petal.Length <= 1.9: setosa (50)
#> Petal.Length > 1.9:
#> :...Petal.Width > 1.7: virginica (46/1)
#>     Petal.Width <= 1.7:
#>     :...Petal.Length <= 4.9: versicolor (48/1)
#>         Petal.Length > 4.9: virginica (6/2)
#> 
#> 
#> Evaluation on training data (150 cases):
#> 
#>      Decision Tree   
#>    ----------------  
#>    Size      Errors  
#> 
#>       4    4( 2.7%)   <<
#> 
#> 
#>     (a)   (b)   (c)    <-classified as
#>    ----  ----  ----
#>      50                (a): class setosa
#>            47     3    (b): class versicolor
#>             1    49    (c): class virginica
#> 
#> 
#>  Attribute usage:
#> 
#>  100.00% Petal.Length
#>   66.67% Petal.Width
#> 
#> 
#> Time: 0.0 secs
plot(fttd$fit$fit$fit) # Error in eval(parse(text = paste(obj$call)[xspot])) : object 'x' not found
#> Error in eval(parse(text = paste(obj$call)[xspot])): object 'x' not found

library(C50)
c50tr <- C50::C5.0(iris[,1:4], iris$Species, trials = 1, control = C50::C5.0Control(minCases =  2, sample = 0) )

summary(c50tr)
#> 
#> Call:
#> C5.0.default(x = iris[, 1:4], y = iris$Species, trials = 1, control
#>  = C50::C5.0Control(minCases = 2, sample = 0))
#> 
#> 
#> C5.0 [Release 2.07 GPL Edition]      Mon Sep  5 14:01:58 2022
#> -------------------------------
#> 
#> Class specified by attribute `outcome'
#> 
#> Read 150 cases (5 attributes) from undefined.data
#> 
#> Decision tree:
#> 
#> Petal.Length <= 1.9: setosa (50)
#> Petal.Length > 1.9:
#> :...Petal.Width > 1.7: virginica (46/1)
#>     Petal.Width <= 1.7:
#>     :...Petal.Length <= 4.9: versicolor (48/1)
#>         Petal.Length > 4.9: virginica (6/2)
#> 
#> 
#> Evaluation on training data (150 cases):
#> 
#>      Decision Tree   
#>    ----------------  
#>    Size      Errors  
#> 
#>       4    4( 2.7%)   <<
#> 
#> 
#>     (a)   (b)   (c)    <-classified as
#>    ----  ----  ----
#>      50                (a): class setosa
#>            47     3    (b): class versicolor
#>             1    49    (c): class virginica
#> 
#> 
#>  Attribute usage:
#> 
#>  100.00% Petal.Length
#>   66.67% Petal.Width
#> 
#> 
#> Time: 0.0 secs
plot(c50tr) # plots nice tree

Created on 2022-09-05 with reprex v2.0.2


Solution

  • You are getting an error because the plotting method for {C5.0} uses the call to do some of the plotting. This doesn't work with {parsnip} objects so we have to do some modifications to the object.

    library(tidymodels)
    library(C50)
    
    rcp <- recipe(Species ~ ., iris)
    mdl <- decision_tree(mode = "classification", engine = "C5.0")
    wf <- workflow(rcp, mdl)
    fttd <- fit(wf, iris)
    

    We fit the model like normal, then we extract the model fit of the engine (this is safer then doing $fit$fit$fit.

    fttd_engine <- extract_fit_engine(fttd)
    

    The fit object has a x and y element we need to fill in. We can generate the data that will come out of the recipe in the workflow using the following code

    prepped_predictors <- fttd |> 
      extract_recipe() |> 
      prep(iris) |>
      bake(new_data = NULL, all_predictors())
    

    Then we use the call_modify() function from the {rlang} package to change the call object of the fitted model to be "correct"

    fttd_engine$call <- rlang::call_modify(
      fttd_engine$call,
      x = prepped_predictors,
      y = iris$Species
    )
    

    And now it is ready to be used with plot()

    fttd_engine |>
      plot()
    

    Created on 2022-09-06 by the reprex package (v2.0.1)