Search code examples
rrparttidymodelsrattle

Node link diagram in R using Rpart.plot and rattle


I am trying to create a node-link diagram (decision tree) by using parsnip and tidymodels. What I am performing is building a decision tree model for the StackOverflow dataset using the tidymodels package and rpart as model engine. The model should predict whether a developer will work remotely (variable remote) based on the number of years of programming experience (years_coded_job), degree of career satisfaction (career_satisfaction), job title "Data Scientist" yes/no (data_scientist), and size of the employing company (company_size_number).

My pipeline

library(tidyverse)
library(tidymodels)
library(rpart.plot)
library(rpart)
library(rattle)

so <- read_rds(here::here("stackoverflow.rds"))

fit <- rpart(remote ~ years_coded_job + career_satisfaction + data_scientist + company_size_number,
             data = so,
             control = rpart.control(minsplit = 20, minbucket = 2))

fancyRpartPlot(fit, sub = "")

The plot I obtain

plot

I want to know whether is this the correct approach for determining the predictors. Since I am not building a model is this the right way?


Solution

  • If you are going and parsnip to fit your model, it's better to use that actual fitted model for any visualizations like this. You can get the underlying engine object from a parsnip model using $fit.

    library(tidyverse)
    library(tidymodels)
    library(rattle)
    #> Loading required package: bitops
    #> Rattle: A free graphical interface for data science with R.
    #> Version 5.4.0 Copyright (c) 2006-2020 Togaware Pty Ltd.
    #> Type 'rattle()' to shake, rattle, and roll your data.
    data(kyphosis, package = "rpart")
    
    tree_fit <- decision_tree(min_n = 20) %>%
      set_engine("rpart") %>%
      set_mode("classification") %>%
      fit(Kyphosis ~ Age + Number + Start,
          data = kyphosis)
    
    fancyRpartPlot(tree_fit$fit, sub = "")
    

    Created on 2021-05-25 by the reprex package (v2.0.0)

    For some kinds of visualizations, you will need to use repair_call().