I have the following codes for creating a tidymodels
workflow with lightgbm
model. However, there is some problem when I tried to save into a .rds
object and predict
library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()
### Model ###
# data
data <- make_ames() %>%
janitor::clean_names()
data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))
data$id <- c(1:nrow(data))
data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())
# model specification
lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")
# recipe and workflow
lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7) %>%
prep()
lgbm_workflow <- workflow() %>%
add_recipe(lgbm_recipe) %>%
add_model(lgbm_model)
# fit workflow
fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)
# predict
data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
### CASE 1: Save the workflow with SaveRDS()
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
# Predict - error: Attempting to use a Booster which no longer exists
predict(new_lgbm_workflow, new_data = data_predict)
### CASE 2: Save the workflow and the fitted model separately
fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model
# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’
predict(new_lgbm_workflow, new_data = data_predict)
Only workflows with lightgbm
model seem to have this problem. For other types of models (random forest, xgboost, glm, etc), I can save the fitted workflow with saveRDS()
, read with readRDS()
, and predict using new data just fine
For Case 2, apparently the underlying predict function will be changed to predict.lgb.Booster()
, which take a matrix
as input. But my id variable has character
format whereas all columns in a matrix
must have the same format
Is there a way to save the entire workflow
for future use?