2

I would like to compare, using tidymodels and cross-validation, 3 linear regression models that can be specified as the following:

  • (model_A) y ~ a
  • (model_B) y ~ b
  • (model_AB) y ~ a + b

In the following y will denote the target variable, while a and b will denote independent variables.

Without using cross validation it is (I hope) quite clear to me what I have to do:

  1. Split my data into train and test set
set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  1. I can specify, fit, and evaluate my model in one go (for example for model_AB)
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

The output looks something like this:

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard       x.xxx

I can repeat step 2 for the other two models and compare the three models based on the RMSE metric (since this is the choice for this example).

For example I can create a dummy dataset and run the steps described above.

library(tidyverse)
library(tidymodels)

set.seed(1234)
n <- 1e4
data <- tibble(a = rnorm(n),
               b = rnorm(n),
               y = 1 + 3*a - 2*b + rnorm(n))

set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)
  • Model_A
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        2.23
  • Model_B
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        3.17
  • Model_AB
linear_reg() %>%
    set_engine("lm") %>%
    fit(y ~ a + b, data = data_train) %>%
    augment(new_data = data_test) %>%
    rmse(truth = y, estimate = .pred)

result

# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        1.00

My question is: how can I evaluate the RMSE after performing cross validation on three models that differ by the list of possible features?

In this video Julia Silge does the job with three different models (logistic regression, knn, and decision trees) using the same set of predictors. However what I aim to do is to compare models that differ in the set of predictors.

Any suggestion and/or reference?

filebb
  • 23
  • 5
  • Greetings! Usually it is helpful to provide a minimally reproducible dataset for questions here. One way of doing this is by using the `dput` function. You can find out how to use it here: https://youtu.be/3EID3P1oisg – Shawn Hemelstrand Sep 11 '22 at 10:36
  • @ShawnHemelstrand Thanks for your suggestion, dataset included. – filebb Sep 11 '22 at 12:26

1 Answers1

2

When you have a lot of different models you want to compare, one way to deal with that is to use the workflowsets package.

This way you can specify any number of models and preprocessors and it will run all of them and give you back the results in a tidy format.

Notice how we are using recipe() just denotes what variables are used in each model.

Additionally you can pass a metric_set() to the metrics in workflow_map() if you want to use different metrics than the defaults.

library(tidymodels)
set.seed(1234)
n <- 1e4
data <- tibble(a = rnorm(n),
               b = rnorm(n),
               y = 1 + 3*a - 2*b + rnorm(n))

set.seed(1234)
split <- data %>% initial_split(strata = y)
data_train <- training(split)
data_test <- training(split)

lm_spec <- linear_reg()

rec_a <- recipe(y ~ a, data = data_train)
rec_b <- recipe(y ~ b, data = data_train)
rec_ab <- recipe(y ~ a + b, data = data_train)

all_models_wfs <- workflow_set(
  preproc = list(a = rec_a, b = rec_b, c = rec_ab),
  models = list(lm = lm_spec),
  cross = TRUE
)

all_models_wfs
#> # A workflow set/tibble: 3 × 4
#>   wflow_id info             option    result    
#>   <chr>    <list>           <list>    <list>    
#> 1 a_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
#> 2 b_lm     <tibble [1 × 4]> <opts[0]> <list [0]>
#> 3 c_lm     <tibble [1 × 4]> <opts[0]> <list [0]>

all_models_fit <- workflow_map(
  all_models_wfs, 
  resamples = vfold_cv(data_test),
  metrics = metric_set(rmse, rsq, mape)
)

all_models_fit %>%
  collect_metrics()
#> # A tibble: 9 × 9
#>   wflow_id .config           preproc model .metric .esti…¹    mean     n std_err
#>   <chr>    <chr>             <chr>   <chr> <chr>   <chr>     <dbl> <int>   <dbl>
#> 1 a_lm     Preprocessor1_Mo… recipe  line… mape    standa… 261.       10 3.99e+1
#> 2 a_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   2.26     10 2.89e-2
#> 3 a_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.627    10 7.72e-3
#> 4 b_lm     Preprocessor1_Mo… recipe  line… mape    standa… 258.       10 2.07e+1
#> 5 b_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   3.10     10 2.13e-2
#> 6 b_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.298    10 7.61e-3
#> 7 c_lm     Preprocessor1_Mo… recipe  line… mape    standa… 144.       10 3.66e+1
#> 8 c_lm     Preprocessor1_Mo… recipe  line… rmse    standa…   1.01     10 6.51e-3
#> 9 c_lm     Preprocessor1_Mo… recipe  line… rsq     standa…   0.926    10 2.06e-3
#> # … with abbreviated variable name ¹​.estimator

Created on 2022-09-19 by the reprex package (v2.0.1)

EmilHvitfeldt
  • 2,555
  • 1
  • 9
  • 12
  • Thanks! One question about your code: shouldn't you have specified the engine with the `set_engine` function? – filebb Sep 12 '22 at 23:18
  • That would have been better style-wise. Since we were using the `lm` engine I used that `linear_reg()` defaults to using the `lm` engine, and didn't specify it directly – EmilHvitfeldt Sep 12 '22 at 23:51
  • All works perfectly. I have just a final side question to your approach that I will ask here. How can I change metric used? – filebb Sep 19 '22 at 10:52
  • I have updated the example to show how to change metrics – EmilHvitfeldt Sep 19 '22 at 16:13