Search code examples
rtidymodelsc5.0

Extract Rules from Trained C5.0 Model in Tidymodels


I could and should have made a simpler reprex, but this is really straight out of my work. After training a C5.0 model in the Tidymodels framwork, how do I "see" the rules that the model generated? I tried to replicate what is illustrated here

https://www.tidyverse.org/blog/2020/05/rules-0-0-1/

but I did not go very far (but I am sure the solution must be a one-liner).

Many thanks!

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────── tidymodels 0.1.2 ──
#> ✔ broom     0.7.2          ✔ recipes   0.1.15    
#> ✔ dials     0.0.9          ✔ rsample   0.0.8     
#> ✔ dplyr     1.0.2          ✔ tibble    3.0.4     
#> ✔ ggplot2   3.3.2          ✔ tidyr     1.1.2     
#> ✔ infer     0.5.3          ✔ tune      0.1.2.9000
#> ✔ modeldata 0.1.0          ✔ workflows 0.2.1     
#> ✔ parsnip   0.1.4.9000     ✔ yardstick 0.0.7     
#> ✔ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ purrr::discard() masks scales::discard()
#> ✖ dplyr::filter()  masks stats::filter()
#> ✖ dplyr::lag()     masks stats::lag()
#> ✖ recipes::step()  masks stats::step()
library(rules)
#> 
#> Attaching package: 'rules'
#> The following object is masked from 'package:dials':
#> 
#>     max_rules


df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009, 
2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019), 
    berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861, 
    5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42, 
    7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96), 
    gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144, 
    1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586, 
    2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
    ), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39, 
    2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 
    2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9), employment_c = c(2562.53, 
    2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 
    2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 
    2622.5, 2656.89), employment_j = c(400.75, 387.53, 384.64, 
    389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59, 
    438.96, 440.33, 460.84, 473.4, 494.4, 513.62), employment_k = c(502.42, 
    504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 
    534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
    ), employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88, 
    1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 
    1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225), employment_oq = c(3241.36, 
    3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
    3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
    4238.87, 4284.27), employment_total = c(15113.52, 15307.28, 
    15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87, 
    16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32, 
    17650.21, 17951.61, 18156.52), value_be = c(47967.1, 50737.6, 
    52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443, 
    63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3, 
    77284), value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8, 62196, 65063.5, 66063.6), value_j = c(7737.1, 
    7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 
    9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871, 
    13540.3), value_k = c(10225.2, 10541.9, 11005.3, 11912.3, 
    13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 
    12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1), value_mn = c(15074, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 
    33781.9, 35152.9), value_oq = c(35065.6, 37329.6, 38288.8, 
    40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
    50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
    ), value_total = c(202353.5, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
    ), gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978, 
    293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 
    333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4), 
    gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4, 
    208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 
    243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
    ), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2), 
    gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5, 59584.7, 64333.5, 68409.7), turnover_manu_dom = c(80, 
    87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 
    107.1, 104.7, 102.9, 107.9, 107.9, 107.9), turnover_manu_non_dom = c(70.9, 
    81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
    112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2), turnover_manu_tot = c(74.7, 
    84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9, 
    111.7, 112.6, 112.9, 120.3, 120.3, 120.3), price_index = c(1.7, 
    2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 
    1, 2.2, 2.1, 1.5), capital_n1132g = c(3638.4, 3633.3, 3616.2, 
    3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 
    3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721), capital_n117g = c(8369.6, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 
    19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6, 
    20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 
    24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
    ), lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74, 
    1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143, 
    2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779, 
    2913.369), lagged_employment_be = c(2775.22, 2775.22, 2731.08, 
    2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 
    2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
    ), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98, 
    2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41, 
    2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5), lagged_employment_j = c(400.75, 
    400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 
    410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
    ), lagged_employment_k = c(502.42, 502.42, 504.63, 515.39, 
    523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13, 
    518.89, 511.57, 505.32, 496.41, 495.4), lagged_employment_mn = c(1248.01, 
    1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 
    1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51, 
    2109.71, 2189.27), lagged_employment_oq = c(3241.36, 3241.36, 
    3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
    3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
    4238.87), lagged_employment_total = c(15113.52, 15113.52, 
    15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 
    16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 
    17365.32, 17650.21, 17951.61), lagged_value_be = c(47967.1, 
    47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 
    58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 
    72698.8, 75792.3), lagged_value_c = c(40192.9, 40192.9, 42014.6, 
    44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7, 
    53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
    ), lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8, 
    8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4, 
    10695.4, 11455.3, 11720.6, 12871), lagged_value_k = c(10225.2, 
    10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9, 
    12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4, 
    13744.1, 14152.6), lagged_value_mn = c(15074, 15074, 16569.1, 
    18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4, 
    25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
    ), lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8, 
    40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
    50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9), lagged_value_total = c(202353.5, 
    202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7, 
    256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7, 
    318952.7, 329396.1, 344338.6), lagged_gdp_b1gq = c(226735.3, 
    226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1, 
    295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3, 
    357608, 369341.3, 385361.9), lagged_gdp_p3 = c(164107.8, 
    164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2, 274583.7), lagged_gdp_p61 = c(74691.6, 
    74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8, 150278.2), lagged_gdp_p62 = c(28063.4, 
    28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7, 64333.5), lagged_turnover_manu_dom = c(80, 80, 87, 
    93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1, 
    104.7, 102.9, 107.9, 107.9), lagged_turnover_manu_non_dom = c(70.9, 
    70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
    112.8, 114.9, 118.2, 120.1, 129.2, 129.2), lagged_turnover_manu_tot = c(74.7, 
    74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 
    111.9, 111.7, 112.6, 112.9, 120.3, 120.3), lagged_price_index = c(1.7, 
    1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 
    0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4, 
    3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 
    4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6), 
    lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9, 
    9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5, 
    15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4), lagged_capital_n11mg = c(18749.6, 
    18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT", 
    "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", 
    "AT", "AT")), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))




set.seed(1234)

nn <- nrow(df_ini)

time_back <- 1

indices <-
  list(analysis   =  1:(nn-time_back) , 
       assessment = (nn-time_back+1):nn
       )

df_split <- make_splits(indices, df_ini)



df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)

cubist_recipe <- 
  recipe(formula = berd ~ ., data = df_train) %>% 
    ## step_string2factor(one_of("country")) %>%
   update_role(year, new_role = "ID") %>%
   step_zv(all_predictors()) 

cubist_spec <- 
  cubist_rules(committees = tune(), neighbors = tune()) %>% 
  set_engine("Cubist") 

cubist_workflow <- 
  workflow() %>% 
  add_recipe(cubist_recipe) %>% 
  add_model(cubist_spec) 

cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
                                   neighbors = c(0, 3, 6, 9)) 

cubist_tune <- 
  tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid) 
#> 
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#> 
#>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
#>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
#>     splice
#> 
#> Attaching package: 'vctrs'
#> The following object is masked from 'package:tibble':
#> 
#>     data_frame
#> The following object is masked from 'package:dplyr':
#> 
#>     data_frame
#> Loading required package: lattice


best_cub <- select_best(cubist_tune, "rmse")


final_cub <- finalize_workflow(
  cubist_workflow,
  best_cub
)


final_cub
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_zv()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Cubist Model Specification (regression)
#> 
#> Main Arguments:
#>   committees = 1
#>   neighbors = 3
#> 
#> Computational engine: Cubist
   
fit_model <- final_cub %>%
    fit(df_train)

fit_model
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: cubist_rules()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_zv()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> 
#> Call:
#> cubist.default(x = x, y = y, committees = 1)
#> 
#> Number of samples: 16 
#> Number of predictors: 52 
#> 
#> Number of committees: 1 
#> Number of rules: 1

 ### at this point how to see the rules in the model trained on the data ???

Created on 2020-12-10 by the reprex package (v0.3.0)


Solution

  • The current solution that tidymodels offers to get the rules out is not quite ideal, admittedly. I believe currently the best way to get out the rules in the model is to pull out the underlying fit object, which is several layers deep inside of the workflow and then call summary() on it. You want to do: summary(fit_model$fit$fit$fit).

    library(tidymodels)
    library(rules)
    #> 
    #> Attaching package: 'rules'
    #> The following object is masked from 'package:dials':
    #> 
    #>     max_rules
    
    df_ini <- structure(list(year = c(2002, 2004, 2005, 2006, 2007, 2008, 2009, 
                                      2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019), 
                             berd = c(3130.884, 3556.479, 4207.669, 4448.676, 4845.861, 
                                      5232.63, 5092.902, 5520.422, 5692.841, 6540.457, 6778.42, 
                                      7324.679, 7498.488, 7824.51, 7888.444, 8461.72, 8865.96), 
                             gbaord = c(1537.89, 1537.89, 1619.74, 1697.55, 1770.144, 
                                        1986.775, 2149.787, 2269.986, 2428.143, 2452.955, 2587.586, 
                                        2647.489, 2744.844, 2875.706, 2889.779, 2913.369, 3081.087
                             ), employment_be = c(2775.22, 2731.08, 2709.59, 2708.39, 
                                                  2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 
                                                  2725.66, 2735.69, 2750.52, 2782.9, 2852.33, 2890.9), 
                             employment_c = c(2562.53, 
                                              2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 
                                              2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 
                                              2622.5, 2656.89), 
                             employment_j = c(400.75, 387.53, 384.64, 
                                              389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 419.75, 427.59, 
                                              438.96, 440.33, 460.84, 473.4, 494.4, 513.62), 
                             employment_k = c(502.42, 
                                              504.63, 515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 
                                              534.98, 524.13, 518.89, 511.57, 505.32, 496.41, 495.4, 495.98
                             ), 
                             employment_mn = c(1248.01, 1365.29, 1425.81, 1537.88, 
                                               1622.95, 1727.76, 1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 
                                               1950.02, 1968.83, 2021.51, 2109.71, 2189.27, 2225), 
                             employment_oq = c(3241.36, 
                                               3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
                                               3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
                                               4238.87, 4284.27), 
                             employment_total = c(15113.52, 15307.28, 
                                                  15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 16392.87, 
                                                  16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 17365.32, 
                                                  17650.21, 17951.61, 18156.52), 
                             value_be = c(47967.1, 50737.6, 
                                          52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 61443, 
                                          63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 72698.8, 75792.3, 
                                          77284), 
                             value_c = c(40192.9, 42014.6, 44229, 47735.5, 51552.4, 
                                         51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
                                         57458.7, 60962.8, 62196, 65063.5, 66063.6), 
                             value_j = c(7737.1, 
                                         7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 
                                         9405.1, 9802.1, 10361.4, 10695.4, 11455.3, 11720.6, 12871, 
                                         13540.3), 
                             value_k = c(10225.2, 10541.9, 11005.3, 11912.3, 
                                         13102.7, 13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 
                                         12962.4, 13482.9, 13236.4, 13744.1, 14152.6, 14739.1), 
                             value_mn = c(15074, 
                                          16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
                                          24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 
                                          33781.9, 35152.9), 
                             value_oq = c(35065.6, 37329.6, 38288.8, 
                                          40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
                                          50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9, 61680.1
                             ), 
                             value_total = c(202353.5, 216098.3, 225888.1, 239076, 
                                             253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
                                             297230.1, 307037.7, 318952.7, 329396.1, 344338.6, 355359.1
                             ), 
                             gdp_b1gq = c(226735.3, 242348.3, 254075, 267824.4, 283978, 
                                          293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 
                                          333146.1, 344269.3, 357608, 369341.3, 385361.9, 397575.4), 
                             gdp_p3 = c(164107.8, 176316.4, 185871.1, 194102, 200944.4, 
                                        208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 
                                        243860.6, 249404.3, 257166.5, 265900.2, 274583.7, 282863.3
                             ), gdp_p61 = c(74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 
                                            113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
                                            126109.3, 129183.6, 131524, 140057.8, 150278.2, 152545.2), 
                             gdp_p62 = c(28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 
                                         38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
                                         55885.5, 59584.7, 64333.5, 68409.7), 
                             turnover_manu_dom = c(80, 
                                                   87, 93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 
                                                   107.1, 104.7, 102.9, 107.9, 107.9, 107.9), 
                             turnover_manu_non_dom = c(70.9, 
                                                       81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
                                                       112.8, 114.9, 118.2, 120.1, 129.2, 129.2, 129.2), 
                             turnover_manu_tot = c(74.7, 
                                                   84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 111.9, 
                                                   111.7, 112.6, 112.9, 120.3, 120.3, 120.3), 
                             price_index = c(1.7, 
                                             2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 
                                             1, 2.2, 2.1, 1.5), 
                             capital_n1132g = c(3638.4, 3633.3, 3616.2, 
                                                3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 
                                                3467.9, 4214.2, 4237.4, 4450.2, 4598.6, 4721), 
                             capital_n117g = c(8369.6, 
                                               8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
                                               13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 
                                               19642.4, 20713.1), capital_n11mg = c(18749.6, 19433.5, 20051.6, 
                                                                                    20569.8, 22646.1, 23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 
                                                                                    24057.7, 23832.8, 25019.2, 27608.2, 29790.1, 30998, 32856.8
                                               ), 
                             lagged_gbaord = c(1537.89, 1537.89, 1537.89, 1619.74, 
                                               1697.55, 1770.144, 1986.775, 2149.787, 2269.986, 2428.143, 
                                               2452.955, 2587.586, 2647.489, 2744.844, 2875.706, 2889.779, 
                                               2913.369), 
                             lagged_employment_be = c(2775.22, 2775.22, 2731.08, 
                                                      2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 2705.1, 
                                                      2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9, 2852.33
                             ), lagged_employment_c = c(2562.53, 2562.53, 2518.57, 2496.98, 
                                                        2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 2507.41, 
                                                        2500.94, 2499.6, 2511.75, 2523.97, 2555.48, 2622.5), 
                             lagged_employment_j = c(400.75, 
                                                     400.75, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 
                                                     410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4, 494.4
                             ), 
                             lagged_employment_k = c(502.42, 502.42, 504.63, 515.39, 
                                                     523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 524.13, 
                                                     518.89, 511.57, 505.32, 496.41, 495.4), 
                             lagged_employment_mn = c(1248.01, 
                                                      1248.01, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 
                                                      1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51, 
                                                      2109.71, 2189.27), 
                             lagged_employment_oq = c(3241.36, 3241.36, 
                                                      3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 3759.23, 
                                                      3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 4171.72, 
                                                      4238.87), 
                             lagged_employment_total = c(15113.52, 15113.52, 
                                                         15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 16269.97, 
                                                         16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 17142.13, 
                                                         17365.32, 17650.21, 17951.61), 
                             lagged_value_be = c(47967.1, 
                                                 47967.1, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 
                                                 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6, 
                                                 72698.8, 75792.3), 
                             lagged_value_c = c(40192.9, 40192.9, 42014.6, 
                                                44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 51467.7, 
                                                53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196, 65063.5
                             ), 
                             lagged_value_j = c(7737.1, 7737.1, 7756.1, 8134.2, 8378.8, 
                                                8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 10361.4, 
                                                10695.4, 11455.3, 11720.6, 12871), 
                             lagged_value_k = c(10225.2, 
                                                10225.2, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 12123.9, 
                                                12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 13236.4, 
                                                13744.1, 14152.6), 
                             lagged_value_mn = c(15074, 15074, 16569.1, 
                                                 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 24895.4, 
                                                 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6, 33781.9
                             ), 
                             lagged_value_oq = c(35065.6, 35065.6, 37329.6, 38288.8, 
                                                 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 49381.5, 
                                                 50261.7, 51624.3, 53715, 55926.4, 57637.1, 59648.9), 
                             lagged_value_total = c(202353.5, 
                                                    202353.5, 216098.3, 225888.1, 239076, 253604.6, 262414.7, 
                                                    256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 307037.7, 
                                                    318952.7, 329396.1, 344338.6), 
                             lagged_gdp_b1gq = c(226735.3, 
                                                 226735.3, 242348.3, 254075, 267824.4, 283978, 293761.9, 288044.1, 
                                                 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 344269.3, 
                                                 357608, 369341.3, 385361.9), 
                             lagged_gdp_p3 = c(164107.8, 
                                               164107.8, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
                                               213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
                                               249404.3, 257166.5, 265900.2, 274583.7), 
                             lagged_gdp_p61 = c(74691.6, 
                                                74691.6, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
                                                91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
                                                129183.6, 131524, 140057.8, 150278.2), 
                             lagged_gdp_p62 = c(28063.4, 
                                                28063.4, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
                                                39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
                                                59584.7, 64333.5), 
                             lagged_turnover_manu_dom = c(80, 80, 87, 
                                                          93.2, 99.9, 104.9, 113.6, 97, 100, 110, 112.1, 110.7, 107.1, 
                                                          104.7, 102.9, 107.9, 107.9), 
                             lagged_turnover_manu_non_dom = c(70.9, 
                                                              70.9, 81.9, 86.2, 95.3, 102.8, 106.5, 86.8, 100, 112.8, 111.7, 
                                                              112.8, 114.9, 118.2, 120.1, 129.2, 129.2), 
                             lagged_turnover_manu_tot = c(74.7, 
                                                          74.7, 84, 89.1, 97.2, 103.7, 109.4, 91.1, 100, 111.6, 111.9, 
                                                          111.9, 111.7, 112.6, 112.9, 120.3, 120.3), 
                             lagged_price_index = c(1.7, 
                                                    1.7, 2, 2.1, 1.7, 2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 
                                                    0.8, 1, 2.2, 2.1), lagged_capital_n1132g = c(3638.4, 3638.4, 
                                                                                                 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 
                                                                                                 4005.6, 3718.6, 3467.9, 4214.2, 4237.4, 4450.2, 4598.6), 
                             lagged_capital_n117g = c(8369.6, 8369.6, 8679.9, 8938.9, 
                                                      9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 13465.4, 13927.5, 
                                                      15510.2, 15754.4, 16584.7, 17647.1, 18273.8, 19642.4), 
                             lagged_capital_n11mg = c(18749.6, 
                                                      18749.6, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
                                                      20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
                                                      29790.1, 30998), country = c("AT", "AT", "AT", "AT", "AT", 
                                                                                   "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", "AT", 
                                                                                   "AT", "AT")), 
                        row.names = c(NA, -17L), class = c("tbl_df", 
                                                           "tbl", "data.frame"))
    
    
    
    
    set.seed(1234)
    
    nn <- nrow(df_ini)
    
    time_back <- 1
    
    indices <-
      list(analysis   =  1:(nn-time_back) , 
           assessment = (nn-time_back+1):nn
      )
    
    df_split <- make_splits(indices, df_ini)
    
    
    
    df_train <- training(df_split)
    df_test <- testing(df_split)
    
    folded_data <- vfold_cv(df_train,3)
    
    cubist_recipe <- 
      recipe(formula = berd ~ ., data = df_train) %>% 
      ## step_string2factor(one_of("country")) %>%
      update_role(year, new_role = "ID") %>%
      step_zv(all_predictors()) 
    
    cubist_spec <- 
      cubist_rules(committees = tune(), neighbors = tune()) %>% 
      set_engine("Cubist") 
    
    cubist_workflow <- 
      workflow() %>% 
      add_recipe(cubist_recipe) %>% 
      add_model(cubist_spec) 
    
    cubist_grid <- tidyr::crossing(committees = c(1:9, (1:5) * 10),
                                   neighbors = c(0, 3, 6, 9)) 
    
    cubist_tune <- 
      tune_grid(cubist_workflow, resamples = folded_data, grid = cubist_grid) 
    #> 
    #> Attaching package: 'rlang'
    #> The following objects are masked from 'package:purrr':
    #> 
    #>     %@%, as_function, flatten, flatten_chr, flatten_dbl, flatten_int,
    #>     flatten_lgl, flatten_raw, invoke, list_along, modify, prepend,
    #>     splice
    #> 
    #> Attaching package: 'vctrs'
    #> The following object is masked from 'package:tibble':
    #> 
    #>     data_frame
    #> The following object is masked from 'package:dplyr':
    #> 
    #>     data_frame
    #> Loading required package: lattice
    
    best_cub <- select_best(cubist_tune, "rmse")
    
    
    final_cub <- finalize_workflow(
      cubist_workflow,
      best_cub
    )
    
    fit_model <- final_cub %>%
      fit(df_train)
    
    summary(fit_model$fit$fit$fit)
    #> 
    #> Call:
    #> cubist.default(x = x, y = y, committees = 1)
    #> 
    #> 
    #> Cubist [Release 2.07 GPL Edition]  Thu Dec 10 16:52:59 2020
    #> ---------------------------------
    #> 
    #>     Target attribute `outcome'
    #> 
    #> Read 16 cases (53 attributes) from undefined.data
    #> 
    #> Model:
    #> 
    #>   Rule 1: [16 cases, mean 5877.817, range 3130.884 to 8461.72, est err 251.023]
    #> 
    #>  outcome = -5043.087 + 0.0357 gdp_b1gq
    #> 
    #> 
    #> Evaluation on training data (16 cases):
    #> 
    #>     Average  |error|            196.045
    #>     Relative |error|               0.14
    #>     Correlation coefficient        0.99
    #> 
    #> 
    #>  Attribute usage:
    #>    Conds  Model
    #> 
    #>           100%    gdp_b1gq
    #> 
    #> 
    #> Time: 0.0 secs
    

    Created on 2020-12-10 by the reprex package (v0.3.0.9001)

    If you want to get the coefficients out to handle them, check out what results you get from as_tibble(fit_model$fit$fit$fit$coefficients).