0
library(tidymodels)

Train %>% nest(-Groups) %>% 
        mutate(fit=map(data,~lm(X~Y+Z,x=.)),
               augmented = map(fit,augment),
               predict = map2(fit,Y,Z)) %>%
        unnest(augmented) %>% select(-data)

This works perfectly with the Train data. I can get fitted, model summary etc by using different broom functionalities like glance or augment. And each group has a model of its own the way I wnated.

The challenge is when I want to use this model on the test data.

Seems straight forward but somehow the solution eludes me :(

  • I suppose what you mean by "use this model on the test data" is you wanted to make predictions using the test data. You can try `predict(fit, new_data = test_df)`. – nyk Jan 25 '21 at 02:01

1 Answers1

0

When you fit to nested data like that, you end up with many models, not just one, so you will need to also to set yourself up to predict on many models.

library(tidyverse)
library(broom)

data(Orange)

Orange <- as_tibble(Orange)

orange_fit <- Orange %>% 
  nest(data = c(-Tree)) %>%    ## this sets up five separate models
  mutate(
    fit = map(data, ~ lm(age ~ circumference, data = .x))
  ) 

## the "test data" here is `circumference = c(50, 100, 150)`
orange_fit %>%
  select(Tree, fit) %>%
  crossing(circumference = c(50, 100, 150)) %>%
  mutate(new_data = map(circumference, ~tibble(circumference = .)),
         predicted_age = map2_dbl(fit, new_data, predict))
#> # A tibble: 15 x 5
#>    Tree  fit    circumference new_data         predicted_age
#>    <ord> <list>         <dbl> <list>                   <dbl>
#>  1 3     <lm>              50 <tibble [1 × 1]>          392.
#>  2 3     <lm>             100 <tibble [1 × 1]>          994.
#>  3 3     <lm>             150 <tibble [1 × 1]>         1596.
#>  4 1     <lm>              50 <tibble [1 × 1]>          331.
#>  5 1     <lm>             100 <tibble [1 × 1]>          927.
#>  6 1     <lm>             150 <tibble [1 × 1]>         1523.
#>  7 5     <lm>              50 <tibble [1 × 1]>          385.
#>  8 5     <lm>             100 <tibble [1 × 1]>          824.
#>  9 5     <lm>             150 <tibble [1 × 1]>         1264.
#> 10 2     <lm>              50 <tibble [1 × 1]>          257.
#> 11 2     <lm>             100 <tibble [1 × 1]>          647.
#> 12 2     <lm>             150 <tibble [1 × 1]>         1037.
#> 13 4     <lm>              50 <tibble [1 × 1]>          282.
#> 14 4     <lm>             100 <tibble [1 × 1]>          640.
#> 15 4     <lm>             150 <tibble [1 × 1]>          999.

Created on 2021-01-25 by the reprex package (v0.3.0)

Notice at the end we have a prediction for each point in the test set (3) for each model (5).

Julia Silge
  • 10,848
  • 2
  • 40
  • 48