1

I'm looking for a faster way to extract predicted survival distributions with mlr3 and mlr3proba.
The prediction procedure is highly time-consuming, expecially using datasets with hundreds of observations and without ties in time variable.
Does it exist any option to estimate not the entire individual distribution at each time but only at pre-defined ones?
If it would not be possible, is there some option similar to ntimes in [randomForestSRC::rfsrc][1]?

Here an example using survivalmodels::akritas, in which the estimation at 1 time point lasts about 10 minutes:

pacman::p_load("survival","mltools","paradox","mlr3misc","mlr3tuning",
               "devtools","mlr3extralearners","mlr3proba","mlr3learners",
               "survivalmodels","mlr3pipelines", "tictoc", "casebase","distr6")

dat <- survival::rotterdam[,-c(1,2,12,13)]
length(unique(dat$dtime)) # 2215 unique times

set.seed(220311) 
sample.train <- sample(nrow(dat), nrow(dat)*.2)
dat_train <- dat[sample.train, ]
length(unique(dat_train$dtime)) # 558 unique times

sample.test <- c(1:nrow(dat))[which(!c(1:nrow(dat)) %in% sample.train)]
dat_test <- dat[sample.test, ]
length(unique(dat_test$dtime)) # 1875 unique times


task = mlr3proba::TaskSurv$new(id = "dat_train", backend = dat_train,
                               time = "dtime", event = "death")

search_space <- ps(
  lambda = p_dbl(lower = 0, upper = 0.25))

learner.dh <- lrn("surv.akritas", reverse=F)
learner.dh$encapsulate = c(train = "evaluate")

at <- AutoTuner$new(
  learner = learner.dh,
  search_space = search_space,
  resampling = rsmp("cv", folds = 5),
  measure = msr("surv.cindex"),
  terminator = trm("evals", n_evals = 10), #nevals very low, just for example 
  tuner = tnr("random_search")
)
tic()
at$train(task)
toc() #807.46  sec elapsed

tic()
pred.S_t2638 <- 1 - as.numeric(at$predict_newdata(dat_test)$distr$cdf(2638))
toc() #559.5 sec elapsed
desertnaut
  • 57,590
  • 26
  • 140
  • 166
  • 1
    Hey good question! Speeding this up is something I've been working on for a while but it's a general problem in constructing R6 objects. The issue is not that it's trying to predict all time points at once but construct an R6 distr6 distribution object over all predictions (hence taking a while to construct). I think your example might have just given me a breakthrough in how to fix the issue though... give me 1-2 days and I'll be back with an answer – RaphaelS Mar 16 '22 at 09:07
  • Thank you very very much for your answer! It seems a reasonable explanation, and it should explain why PC-Hazard Survival Neural Network doesn't have any kind of problem (probably for discrete times usage, imho). I'm glad this question have triggered some curiosity: can't wait to hear about it, thank you again! – GabrieleInfante Mar 16 '22 at 18:41
  • I've got your times down to 600s and 350s, give me 1 more day and I'll third it (maybe) – RaphaelS Mar 17 '22 at 20:36
  • 1
    Good news, bad news. I've got a 300x speed-up in predicting a single time-point and 20x speed-up in prediction. However that was changing learners as there is a separate bug in akritas so it's still much faster but still slow. Will be on CRAN in a few days. For now: `remotes::install_github("alan-turing-institute/distr6#271); remotes::install_github("mlr-org/mlr3proba#262")` – RaphaelS Mar 17 '22 at 23:25

1 Answers1

2

Hi sorry it's taken so long but fixes now on CRAN: install.packages(c("distr6", "mlr3proba")). Unfortunately Akritas is still slow, my cpp includes four loops which is not nice, will try and think of a better fix for this in the future. But for the actual prediction part this is now <1s. See below for a test on rfsrc, also included notes to self to fix bottlenecks across my packages.

library(paradox)
library(mlr3extralearners)
library(mlr3tuning)
library(tictoc)

dat <- survival::rotterdam[, -c(1, 2, 12, 13)]

set.seed(220311)
sample.train <- sample(nrow(dat), nrow(dat) * .2)
dat_train <- dat[sample.train, ]

sample.test <- c(1:nrow(dat))[which(!c(1:nrow(dat)) %in% sample.train)]
dat_test <- dat[sample.test, ]


task = mlr3proba::TaskSurv$new(
  id = "dat_train", backend = dat_train,
  time = "dtime", event = "death"
)

learner = lrn("surv.rfsrc", ntree = to_tune(50, 200))

at <- AutoTuner$new(
  learner = learner,
  resampling = rsmp("cv", folds = 5),
  measure = msr("surv.cindex"),
  terminator = trm("evals", n_evals = 10),
  tuner = tnr("random_search")
)
tic()
at$train(task)
toc()
#> 15.531 sec elapsed

tic()
pred <- at$predict_newdata(dat_test) # bottleneck is predict.rfsrc
toc()
#> 0.309 sec elapsed

tic()
distr <- pred$distr # bottleneck is support checks
toc()
#> 0.721 sec elapsed

tic()
pred.S_t2638 <- distr$survival(2638) # bottleneck is param6 transform
toc() # 559.5 sec elapsed
#> 0.274 sec elapsed

Created on 2022-03-25 by the reprex package (v2.0.1)

RaphaelS
  • 839
  • 4
  • 14