4

The code below works correctly and has no errors that I know of, but I want to add more to it.

The two things I want to add are:

1 - Predictions of the model on the training data to the final plot. I want to run collect_predictions() on the model fitted to training data.

2 - Code to view the metrics of the model on the training data. I want to run collect_metrics() on the model fitted to training data.

How do I get this information?

# Setup
library(tidyverse)
library(tidymodels)

parks <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-06-22/parks.csv')

modeling_df <- parks %>% 
  select(pct_near_park_data, spend_per_resident_data, med_park_size_data) %>% 
  rename(nearness = "pct_near_park_data",
         spending = "spend_per_resident_data",
         acres = "med_park_size_data") %>% 
  mutate(nearness = (parse_number(nearness)/100)) %>% 
  mutate(spending = parse_number(spending))

# Start building models
set.seed(123)
park_split <- initial_split(modeling_df)
park_train <- training(park_split)
park_test <- testing(park_split)

tree_rec <- recipe(nearness ~., data = park_train)
tree_prep <- prep(tree_rec)
juiced <- juice(tree_prep)

tune_spec <- rand_forest(
  mtry = tune(),
  trees = 1000,
  min_n = tune()
) %>% 
  set_mode("regression") %>% 
  set_engine("ranger")

tune_wf <- workflow() %>% 
  add_recipe(tree_rec) %>% 
  add_model(tune_spec)

set.seed(234)
park_folds <- vfold_cv(park_train)

# Make a grid of various different models
doParallel::registerDoParallel()

set.seed(345)
tune_res <- tune_grid(
  tune_wf,
  resamples = park_folds,
  grid = 20,
 control = control_grid(verbose = TRUE)
)

best_rmse <- select_best(tune_res, "rmse")

# Finalize a model with the best grid
final_rf <- finalize_model(
  tune_spec,
  best_rmse
)

final_wf <- workflow() %>% 
  add_recipe(tree_rec) %>% 
  add_model(final_rf)

final_res <- final_wf %>% 
  last_fit(park_split)

# Visualize the performance
# My issue here is that this is only the testing data
# How can I also get this model's performance on the training data?
# I want to plot both with a facetwrap or color indication as well as numerically see the difference with collect_metrics

final_res %>% 
  collect_predictions() %>% 
  ggplot(aes(nearness, .pred)) +
    geom_point() +
    geom_abline()
Indescribled
  • 320
  • 1
  • 10
  • https://tune.tidymodels.org/reference/control_grid.html Add `save_pred=TRUE` to control_grid to save predictions. And to collect metrics on training model, do collect_metrics(tune_res) – Desmond Jun 25 '21 at 04:19
  • @Desmond no, this is for test predictions. The question is for train – Sergey Skripko Feb 10 '23 at 21:19

1 Answers1

8

What you can do is pull out the trained workflow object from final_res and use that to create predictions on the training data set.

final_model <- final_res$.workflow[[1]]

Now you can use augment() on the test and training data set to visualize the performance.

final_model %>% 
  augment(new_data = park_test) %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline()

final_model %>% 
  augment(new_data = park_train) %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline()

You can also combine the results with bind_rows() so you can compare more easily.

all_predictions <- bind_rows(
  augment(final_model, new_data = park_train) %>% 
    mutate(type = "train"),
  augment(final_model, new_data = park_test) %>% 
    mutate(type = "test")
)

all_predictions %>%
  ggplot(aes(nearness, .pred)) +
  geom_point() +
  geom_abline() +
  facet_wrap(~type)

all the yardstick metric functions work on grouped data.frames as well.


all_predictions %>%
  group_by(type) %>%
  metrics(nearness, .pred)
#> # A tibble: 6 x 4
#>   type  .metric .estimator .estimate
#>   <chr> <chr>   <chr>          <dbl>
#> 1 test  rmse    standard      0.0985
#> 2 train rmse    standard      0.0473
#> 3 test  rsq     standard      0.725 
#> 4 train rsq     standard      0.943 
#> 5 test  mae     standard      0.0706
#> 6 train mae     standard      0.0350

Created on 2021-06-24 by the reprex package (v2.0.0)

EmilHvitfeldt
  • 2,555
  • 1
  • 9
  • 12