3

Where does mlr3 save the final model, after training a learner --- learner$train(data)? By "final model", I mean something like a list produced by the following code:

model <- xgboost::xgb.train(data = data_train, 
                                   max.depth = 8, nthread = 2, nrounds = 15,
                                   verbose = 0)  

Is there a way to extract this list/object?


task <- TaskRegr$new("data", data, "y")
learner <- lrn("regr.xgboost")
preprocess <- po("scale", param_vals = list(center = TRUE, scale = TRUE))
pp <- preprocess %>>% learner
gg<- GraphLearner$new(pp)
gg$train(task)
Nip
  • 387
  • 4
  • 11
  • 1
    I've been learning about this myself recently (MLR3 uses R6 classes, instead of S3/S4 classes, see e.g. https://mlr3book.mlr-org.com/r6.html), so I'm probably not able to answer your question, but if you post your actual code it will make it much easier for someone else to provide a useful answer. – jared_mamrot Oct 08 '20 at 02:23
  • Does `gg$param_set` return anything? – jared_mamrot Oct 08 '20 at 02:44
  • 1
    Yes. A list with parameters, but is not the object that is used to make predictions. – Nip Oct 08 '20 at 02:48
  • 2
    Yep - the model will be stored in gg$model after training has been conducted. From the docs: "The field $model stores the model that is produced in the training step. Before the $train() method is called on a learner object, this field is NULL". Re https://mlr3book.mlr-org.com/train-predict.html . So, does `print(gg$model)` return anything? – jared_mamrot Oct 08 '20 at 02:50
  • 1
    I think `gg$model$regr.xgboost$model` is the answer. I need to test it. – Nip Oct 08 '20 at 02:53
  • 1
    @jared_mamrot, thank you. `gg$model$regr.xgboost$model` is exactly the same list. – Nip Oct 08 '20 at 03:07

1 Answers1

4

In xgboost the 'model' is stored as:

model <- xgboost::xgb.train(data = data_train, 
                                   max.depth = 8, nthread = 2, nrounds = 15,
                                   verbose = 0)

In MLR3, when trained using:

task <- TaskRegr$new("data", data, "y")
learner <- lrn("regr.xgboost")
preprocess <- po("scale", param_vals = list(center = TRUE, scale = TRUE))
pp <- preprocess %>>% learner
gg<- GraphLearner$new(pp)
gg$train(task)

The equivalent to 'model' is stored as

gg$model$regr.xgboost$model
jared_mamrot
  • 22,354
  • 4
  • 21
  • 46