如何解决当我尝试使用另存为“ .rds”的lightGBM模型进行预测时,为什么r会话中断?
在将lightGBM模型与tidymodels和treenip拟合后,我可以进行拟合的工作流程并毫无问题地对新数据进行预测。但是,在将调整后的模型保存为“ .rds”格式后,关闭会话并在新会话中加载“ .rds”模型,当我尝试生成预测时,R会话中断。
这仅在lightGBM模型中发生,对于任何其他类型的模型,这种麻烦不会发生。这是一个可重现的示例:
lightGBM模型的安装如下
PKG_URL <- "https://github.com/microsoft/LightGBM/releases/download/v3.0.0/lightgbm-3.0.0-r-cran.tar.gz"
remotes::install_url(PKG_URL)
library(dplyr)
library(parsnip)
library(rsample)
library(yardstick)
library(recipes)
library(workflows)
library(dials)
library(tune)
library(treesnip)
data = bind_rows(iris,iris,iris)
set.seed(2)
initial_split <- initial_split(data,p = 0.75)
train <- training(initial_split)
test <- testing(initial_split)
initial_split
#> <Analysis/Assess/Total>
#> <788/262/1050>
recipe <- recipe(Sepal.Length ~ .,data = data) %>%
step_dummy(all_nominal(),-all_outcomes())
model <- boost_tree(
mtry = 3,trees = 1000,min_n = tune(),tree_depth = tune(),loss_reduction = tune(),learn_rate = tune(),sample_size = 0.75
) %>%
set_mode("regression") %>%
set_engine("lightgbm")
wf <- workflow() %>%
add_model(model) %>%
add_recipe(recipe)
wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#>
#> ● step_dummy()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#>
#> Main Arguments:
#> mtry = 3
#> trees = 1000
#> min_n = tune()
#> tree_depth = tune()
#> learn_rate = tune()
#> loss_reduction = tune()
#> sample_size = 0.75
#>
#> computational engine: lightgbm
# resamples
resamples <- vfold_cv(train,v = 3)
# grid
grid <- parameters(model) %>%
finalize(train) %>%
grid_random(size = 10)
head(grid)
#> # A tibble: 6 x 4
#> min_n tree_depth learn_rate loss_reduction
#> <int> <int> <dbl> <dbl>
#> 1 2 4 0.000282 0.0000402
#> 2 13 10 0.00333 13.0
#> 3 32 11 0.000000585 0.000106
#> 4 32 7 0.000258 0.163
#> 5 31 13 0.0000000881 0.000479
#> 6 19 14 0.000000167 0.00174
# grid search
tune_grid <- wf %>%
tune_grid(
resamples = resamples,grid = grid,control = control_grid(verbose = FALSE),metrics = metric_set(rmse)
)
# select best hiperparameter found
best_params <- select_best(tune_grid,"rmse")
wf <- wf %>% finalize_workflow(best_params)
wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#>
#> ● step_dummy()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#>
#> Main Arguments:
#> mtry = 3
#> trees = 1000
#> min_n = 13
#> tree_depth = 10
#> learn_rate = 0.00333377440294304
#> loss_reduction = 13.0320661814971
#> sample_size = 0.75
#>
#> computational engine: lightgbm
# last fit
last_fit <- last_fit(wf,initial_split)
# metrics
collect_metrics(last_fit)
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 rmse standard 0.380
#> 2 rsq standard 0.837
# fit to predict new data
model_fit <- fit(wf,data)
#> [LightGBM] [Warning] Auto-choosing row-wise multi-threading,the overhead of testing was 0.000020 seconds.
#> You can set `force_row_wise=true` to remove the overhead.
#> And if memory is not enough,you can set `force_col_wise=true`.
#> [LightGBM] [Info] Total Bins 95
#> [LightGBM] [Info] Number of data points in the train set: 1050,number of used features: 5
#> [LightGBM] [Info] Start training from score 5.843333
#> [LightGBM] [Warning] No further splits with positive gain,best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain,best gain: -inf
.................................................................................
predicciones = predict(model_fit,iris)
head(predicciones)
#> # A tibble: 6 x 1
#> .pred
#> <dbl>
#> 1 5.13
#> 2 5.12
#> 3 5.12
#> 4 5.12
#> 5 5.13
#> 6 5.25
# save model
saveRDS(model_fit,"model_fit.rds")
model <- readRDS("model_fit.rds")
predicciones = predict(model,iris)
当我尝试生成预测时,r会话中断。最有效的替代方法是拉动工作流,提取拟合并使用模型自己的方法保存,但是我丢失了所有存储在work_flow中的工作流。我将竭诚为您提供帮助或建议。
pull_lightgbm = pull_workflow_fit(model_fit)
library(lightgbm)
lgb.save(pull_lightgbm$fit,"lightgbm.model")
model = lgb.load("lightgbm.model")
sessionInfo()
#> R version 4.0.3 (2020-10-10)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics Grdevices utils datasets methods base
#>
#> other attached packages:
#> [1] treesnip_0.1.0.9000 tune_0.1.1 dials_0.0.9
#> [4] scales_1.1.1 workflows_0.2.1 recipes_0.1.14
#> [7] yardstick_0.0.7 rsample_0.0.8 parsnip_0.1.4
#> [10] dplyr_1.0.2
#>
#> loaded via a namespace (and not attached):
#> [1] Rcpp_1.0.5 lubridate_1.7.9 lattice_0.20-41 tidyr_1.1.2
#> [5] listenv_0.8.0 class_7.3-17 assertthat_0.2.1 digest_0.6.27
#> [9] ipred_0.9-9 foreach_1.5.1 parallelly_1.21.0 R6_2.5.0
#> [13] plyr_1.8.6 evaluate_0.14 ggplot2_3.3.2 highr_0.8
#> [17] pillar_1.4.6 rlang_0.4.8 DiceDesign_1.8-1 furrr_0.2.1
#> [21] rpart_4.1-15 Matrix_1.2-18 rmarkdown_2.5 splines_4.0.3
#> [25] gower_0.2.2 stringr_1.4.0 munsell_0.5.0 compiler_4.0.3
#> [29] xfun_0.19 pkgconfig_2.0.3 globals_0.13.1 htmltools_0.5.0
#> [33] nnet_7.3-14 tidyselect_1.1.0 tibble_3.0.4 prodlim_2019.11.13
#> [37] codetools_0.2-16 GPfit_1.0-8 fansi_0.4.1 future_1.20.1
#> [41] Crayon_1.3.4 withr_2.3.0 MASS_7.3-53 grid_4.0.3
#> [45] gtable_0.3.0 lifecycle_0.2.0 magrittr_1.5 pROC_1.16.2
#> [49] cli_2.1.0 stringi_1.5.3 timeDate_3043.102 ellipsis_0.3.1
#> [53] lhs_1.1.1 generics_0.1.0 vctrs_0.3.4 lava_1.6.8.1
#> [57] iterators_1.0.13 tools_4.0.3 glue_1.4.2 purrr_0.3.4
#> [61] parallel_4.0.3 survival_3.2-7 yaml_2.2.1 colorspace_1.4-1
#> [65] knitr_1.30
由reprex package(v0.3.0)于2020-11-08创建
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。