0

I am trying to plot a decision tree in R after using tidymodels workflow but I have trouble finding the good function to use and/or the good model. After a code like this, how do you code a plot?

xgboost_spec <- 
  boost_tree(trees = tune(), min_n = tune(), tree_depth = tune(), learn_rate = tune(), 
             loss_reduction = tune(), sample_size = tune()) %>% 
  set_mode("classification") %>% 
  set_engine("xgboost") 

xgboost_workflow <- 
  workflow() %>% 
  add_recipe(data_recipe) %>% 
  add_model(xgboost_spec) 

xgboost_tune <-
  tune_grid(xgboost_workflow, resamples = data_folds, grid = 10)

final_xgboost <- xgboost_workflow %>% 
  finalize_workflow(select_best(xgboost_tune, "roc_auc"))

xgboost_results <- final_xgboost %>% 
  fit_resamples(
    resamples = data_folds,
    metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
    control = control_resamples(save_pred = TRUE)
  )

Or after a decision tree code?

tree_spec <- decision_tree(
  cost_complexity = tune(),
  tree_depth = tune(),
  min_n = tune()) %>% 
  set_engine("rpart") %>% 
  set_mode("classification")

tree_workflow <- 
  workflow() %>% 
  add_recipe(data_recipe) %>% 
  add_model(tree_spec) 

tree_grid <- grid_regular(cost_complexity(),
                          tree_depth(),
                          min_n(), levels = 4)

tree_tune <- tree_workflow %>% 
  tune_grid(
  resamples = data_folds,
  grid = tree_grid,
  metrics = metric_set(roc_auc, accuracy, sensitivity, specificity)
)
final_tree <- tree_workflow %>% 
  finalize_workflow(select_best(tree_tune, "roc_auc"))

tree_results <- final_tree %>% 
  fit_resamples(
    resamples = data_folds,
    metrics = metric_set(roc_auc, accuracy, sensitivity, specificity),
    control = control_resamples(save_pred = TRUE)
  )

Is it possible? Or should I use the model after last_fit()?

Thank you!

RCchelsie
  • 111
  • 6

1 Answers1

1

I don't think it makes much sense to plot an xgboost model because it is boosted trees (lots and lots of trees) but you can plot a single decision tree.

The key is that most packages for visualization of tree results require you to repair the call object.

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

data(penguins)
penguins <- na.omit(penguins)

cart_spec <-
   decision_tree() %>%
   set_engine("rpart") %>%
   set_mode("classification")

cart_fit <- 
   cart_spec %>%
   fit(sex ~ species + bill_length_mm + body_mass_g, data = penguins)
cart_fit <- repair_call(cart_fit, data = penguins)

library(rattle)
#> Loading required package: bitops
#> Rattle: A free graphical interface for data science with R.
#> Version 5.4.0 Copyright (c) 2006-2020 Togaware Pty Ltd.
#> Type 'rattle()' to shake, rattle, and roll your data.
fancyRpartPlot(cart_fit$fit)

Created on 2021-08-07 by the reprex package (v2.0.0)

The package isn't the only thing out there; ggparty is another good option.

This does mean you must use a parsnip model plus a preprocessor, not a workflow. You can see a tutorial of how to tune a parsnip plus preprocessor here.

Julia Silge
  • 10,848
  • 2
  • 40
  • 48
  • Thanks @Julia! In that case, do you think it is not necessary to use `workflow()` at all? – RCchelsie Aug 10 '21 at 22:22
  • That's correct; don't use a `workflow()` in this situation but instead just a parsnip model plus a preprocessor. It will be a bit less convenient IMO but necessary to be able to use `repair_call()`. – Julia Silge Aug 11 '21 at 22:57
  • It works, thanks! I also found another solution here : [link](https://emilhvitfeldt.github.io/ISLR-tidymodels-labs/tree-based-methods.html) . So I tried `final_tree <- tree_workflow %>% finalize_workflow(select_best(tree_tune, "roc_auc"))` `tree_final_fit <- fit(final_tree, data = train_data)` `tree_final_fit %>% extract_fit_engine() %>% rpart.plot()` and that works too. I don't know what is the best... – RCchelsie Aug 13 '21 at 17:50
  • Yes, the `rpart.plot()` should work like that, but if you want the fancier plots from rattle or ggparty I think you'll need to `repair_call()`. – Julia Silge Aug 13 '21 at 23:40
  • Understood! Thank you so much! – RCchelsie Aug 16 '21 at 17:06