4

I have a binary classification problem and used a random forest and a logistic regression. From the results of conf_mat, the collect_metrics() and collect_predictions I want to change my models to classify as TRUE only if the model is "sure" say 75% or a even higher probability. I just don't know where to specify this change. Would be amazing if someone can give me a hint. My intuition tells me that it should be somewhere in the model specification e.g. somewhere here, but maybe I am wrong.

canc_rf_model <- rand_forest(
    mtry = tune(),
    min_n = tune(),
    trees = 500) %>%
  set_engine("ranger") %>%
  set_mode("classification")

canc_log_model <- logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification")

Thank you very much in advance! M.

Mischa
  • 137
  • 8

1 Answers1

6

The hard class predictions come from the underlying ranger::predictions() function, not from a function so there's not much to be done in the fitting itself.

However, you can pretty fluently change this if you like after fitting. Let's make an example classification model:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

data("ad_data")
alz <- ad_data

# data splitting
set.seed(100)
alz_split  <- initial_split(alz, strata = Class, prop = .9)
alz_train  <- training(alz_split)
alz_test   <- testing(alz_split)

# data resampling
set.seed(100)
alz_folds <- 
    vfold_cv(alz_train, v = 10, strata = Class)

rf_mod <-
    rand_forest(trees = 1e3) %>% 
    set_engine("ranger") %>% 
    set_mode("classification")

rf_wf <-
    workflow() %>% 
    add_formula(Class ~ .) %>% 
    add_model(rf_mod)

set.seed(100)
rf_preds <- rf_wf %>% 
    fit_resamples(
        resamples = alz_folds, 
        control = control_resamples(save_pred = TRUE)) %>% 
    collect_predictions()

Here is the default confusion matrix:

rf_preds %>%
    conf_mat(Class, .pred_class)
#>           Truth
#> Prediction Impaired Control
#>   Impaired       37       5
#>   Control        45     213

You can use the probably package to post-process your class probability estimates and just overwrite the default values:

library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered

rf_preds %>%
    mutate(.pred_class = make_two_class_pred(.pred_Impaired, 
                                             levels(rf_preds$Class),
                                             threshold = 0.75),
           .pred_class = factor(.pred_class, levels = levels(rf_preds$Class))) %>%
    conf_mat(Class, .pred_class)
#>           Truth
#> Prediction Impaired Control
#>   Impaired        0       0
#>   Control        82     218

Created on 2021-03-23 by the reprex package (v1.0.0)

Julia Silge
  • 10,848
  • 2
  • 40
  • 48