How can I implement a custom metric that relies on additional features in training data? Below is an example for within R Squared with additional code reproduce the core issue. The implementation uses an additional argument called group
for the name of the grouping variable. This implementation works when I call it directly and pass the entire training data to the data
argument with an additional column for estimate
. More precisely, the tibble
passed to the data
argument needs three columns for truth
, estimate
, and group
.
However, this approach does not work with any of the tuning functions such as tune_grid
or workflow_map
because these functions only pass a data.frame
with three variables to any metric function: .pred
, .row
and the observed values truth
.
How can I implement a custom metric that requires additional columns from the training data?
# Custom metric --------------------------------------------------------------
rsq_within_vec <- function(truth, estimate, group, na_rm = TRUE, ...) {
rsq_within_impl <- function(truth, estimate, group) {
d <- tibble(truth, estimate, group) %>%
group_by(group) %>%
mutate(truth = truth - mean(truth), estimate = estimate - mean(estimate))
if(sd(d$estimate) == 0) return(0)
yardstick:::yardstick_cor(d$truth, d$estimate)^2
}
metric_vec_template(
metric_impl = rsq_within_impl,
truth = truth,
estimate = estimate,
na_rm = na_rm,
cls = "numeric",
group = group,
...
)
}
rsq_within <- function(data, ...) {
UseMethod("rsq_within")
}
rsq_within <- new_numeric_metric(rsq_within, direction = "maximize")
rsq_within.data.frame <- function(data, truth, estimate, group, na_rm = TRUE, ...) {
numeric_metric_summarizer(
name = "rsq_within",
fn = rsq_within_vec,
data = data,
truth = !! enquo(truth),
estimate = !! enquo(estimate),
fn_options = list(group = select(data, !! enquo(group))[[1]]),
na_rm = na_rm,
...
)
}
# Wrapper --------------------------------------------------------------
rsq_within_gear <- function(data, truth, estimate, na_rm = TRUE, ...) {
rsq_within(
data = data,
truth = !!rlang::enquo(truth),
estimate = !!rlang::enquo(estimate),
group = gear,
na_rm = na_rm,
...
)
}
rsq_within_gear <- new_numeric_metric(rsq_within_gear, direction = "maximize")
# Illustrate problem with example -----------------------------------------------------
set.seed(6735)
folds <- vfold_cv(mtcars, v = 5)
recipe <- recipes::recipe(mpg ~ cyl, data = head(mtcars))
model <- linear_reg() %>% set_engine("lm")
wf <- workflow() %>% add_recipe(recipe) %>% add_model(model)
# Does not work
tune_grid(
object = wf,
resamples = folds,
grid = 1,
metrics = metric_set(rmse, rsq, rsq_within_gear)
) %>%
collect_metrics()