1

I'm having trouble with the trafo function for SMOTE {smotefamily}'s K parameter. In particular, when the number of nearest neighbours K is greater than or equal to the sample size, an error is returned (warning("k should be less than sample size!")) and the tuning process is terminated.

The user cannot control K to be smaller than the sample size during the internal resampling process. This would have to be controlled internally so that if, for instance, trafo_K = 2 ^ K >= sample_size for some value of K, then, say, trafo_K = sample_size - 1.

I was wondering if there's a solution to this or if one is already on its way?

library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    x[[index]] <- round(3 ^ x[[index]]) #  Intentionally define a trafo that won't work
  }
  x
}

# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 3), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

And here's what happens

INFO  [11:00:14.904] Benchmark with 2 resampling iterations 
INFO  [11:00:14.919] Applying learner 'smote.rf' on task 'optdigits' (iter 2/2) 
Error in get.knnx(data, query, k, algorithm) : ANN: ERROR------->
In addition: Warning message:
In get.knnx(data, query, k, algorithm) : k should be less than sample size!

Session info

R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
 [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
 [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
[10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
[13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
[16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
[19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
[22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
[25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
[28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
[31] tidyverse_1.3.0         

loaded via a namespace (and not attached):
  [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
  [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
  [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
 [10] bbotk_0.1               DT_0.12                 future_1.17.0          
 [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
 [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
 [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
 [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
 [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
 [28] xfun_0.12               R6_2.4.1                markdown_1.1           
 [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
 [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
 [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
 [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
 [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
 [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
 [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
 [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
 [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
 [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
 [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
 [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
 [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
 [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
 [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
 [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
 [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
 [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
 [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
 [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
 [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
 [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
 [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
[100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
[103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
[106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
[109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
[112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
[115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
[118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
[121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
[124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
[127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
[130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
[133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
[136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
[139] memoise_1.1.0           future.apply_1.5.0     

Many thanks.

  • What kind of solution do you propose? Many hyper parameter combinations are not possible (result in error) due to data peculiarities, as in the example you show. – missuse May 13 '20 at 18:08
  • I propose adding an `ifelse` process internally, so that when `K` exceeds the sample size, then `K` is set to an appropriate max value (could that be `sample_size - 1`?). A warning should be returned, so that the process actually runs instead of being terminated. The process, as it currently stands, really depends on how lucky you are. For example, try running the code in my Q several times. Sometimes you'll get the reported error and the script will stop, sometimes you won't and it'll run. I've been quite unlucky with my own data, so I can't really use the said `trafo` with my data. – andreassot10 May 14 '20 at 10:43
  • then you propose that in all mlr3 supported algorithms hyper parameter combinations are checked first if they are possible with the data and tweaked so they fall in permissive range with a warning? Sounds like a lot of work. I think its much easier to just skip over bad hyper parameter combos during tuning. See encapsulation: https://mlr3book.mlr-org.com/error-handling.html – missuse May 14 '20 at 11:50
  • I was actually proposing to do that for SMOTE, not all supported algorithms. Anyway, thanks for the resource- I'll see what I can do. – andreassot10 May 15 '20 at 07:39

1 Answers1

0

I've found a workaround.

As pointed out earlier, the problem is that SMOTE {smotefamily}'s K cannot be greater than or equal to the sample size.

I dag into the process and disovered that SMOTE {smotefamily} uses knearest {smotefamily}, which uses knnx.index {FNN}, which in turn uses get.knn {FNN}, which is what returns the error warning("k should be less than sample size!") that terminates the tuning process in mlr3.

Now, within SMOTE {smotefamily}, the three arguments for knearest {smotefamily} are P_set, P_set and K. From an mlr3 resampling perspective, data frame P_set is a subset of the cross-validation fold of the training data, filtered to only contain the records of the minority class. The 'sample size' that the error is referring to is the number of rows of P_set.

Thus, it becomes more likely that K >= nrow(P_set) as K increases via a trafo such as some_integer ^ K (e.g. 2 ^ K).

We need to ensure that K will never be greater than or equal to P_set.

Here's my proposed solution:

  1. Define a variable cv_folds before defining the CV resampling strategy with rsmp().
  2. Define the CV resampling strategy where folds = cv_folds in rsmp(), before defining the trafo.
  3. Instantiate the CV. Now, the dataset is split into training and test/valitation data in each fold.
  4. Find the minimum sample size of the minority class among all training data folds and set that as the threshold for K:
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique
  1. Now define the trafo as follows:
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(2 ^ x[[index]])
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

In other words, when the trafoed K remains smaller than the sample size, keep it. Otherwise, set its value to be any number between 1 and smote_k_thresh - 1.

Implementation

Original code slightly modified to accommodate proposed tweaks:

library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Define and instantiate resampling strategy to be applied within pipeline
# Do that BEFORE defining the trafo
cv_folds <- 2
cv <- rsmp("cv", folds = cv_folds)
cv$instantiate(task)

# Calculate max possible value for k-nearest neighbours
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K while ensuring it never equals or exceeds the sample size
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(5 ^ x[[index]]) # Try a large value here for the sake of the example
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 10), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

# Here are the original K values
instance$archive$data

# And here are their transformations
instance$archive$data$opt_x