微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

当我尝试使用另存为“ .rds”的lightGBM模型进行预测时,为什么r会话中断?

如何解决当我尝试使用另存为“ .rds”的lightGBM模型进行预测时,为什么r会话中断?

在将lightGBM模型与tidymodels和treenip拟合后,我可以进行拟合的工作流程并毫无问题地对新数据进行预测。但是,在将调整后的模型保存为“ .rds”格式后,关闭会话并在新会话中加载“ .rds”模型,当我尝试生成预测时,R会话中断。

这仅在lightGBM模型中发生,对于任何其他类型的模型,这种麻烦不会发生。这是一个可重现的示例:

lightGBM模型的安装如下

PKG_URL <- "https://github.com/microsoft/LightGBM/releases/download/v3.0.0/lightgbm-3.0.0-r-cran.tar.gz"
remotes::install_url(PKG_URL)
library(dplyr)
library(parsnip)
library(rsample)
library(yardstick)
library(recipes)
library(workflows)
library(dials)
library(tune)
library(treesnip)

data = bind_rows(iris,iris,iris)

set.seed(2)
initial_split <- initial_split(data,p = 0.75)
train <- training(initial_split)
test  <- testing(initial_split)
initial_split
#> <Analysis/Assess/Total>
#> <788/262/1050>

recipe <- recipe(Sepal.Length ~ .,data = data) %>%
  step_dummy(all_nominal(),-all_outcomes())

model <- boost_tree(
  mtry = 3,trees = 1000,min_n = tune(),tree_depth = tune(),loss_reduction = tune(),learn_rate = tune(),sample_size = 0.75
) %>% 
  set_mode("regression") %>%
  set_engine("lightgbm")

wf <- workflow() %>% 
  add_model(model) %>% 
  add_recipe(recipe)

wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_dummy()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = 3
#>   trees = 1000
#>   min_n = tune()
#>   tree_depth = tune()
#>   learn_rate = tune()
#>   loss_reduction = tune()
#>   sample_size = 0.75
#> 
#> computational engine: lightgbm

# resamples
resamples <- vfold_cv(train,v = 3)

# grid
grid <- parameters(model) %>% 
  finalize(train) %>% 
  grid_random(size = 10)

head(grid)
#> # A tibble: 6 x 4
#>   min_n tree_depth   learn_rate loss_reduction
#>   <int>      <int>        <dbl>          <dbl>
#> 1     2          4 0.000282          0.0000402
#> 2    13         10 0.00333          13.0      
#> 3    32         11 0.000000585       0.000106 
#> 4    32          7 0.000258          0.163    
#> 5    31         13 0.0000000881      0.000479 
#> 6    19         14 0.000000167       0.00174


# grid search
tune_grid <- wf %>%
  tune_grid(
    resamples = resamples,grid = grid,control = control_grid(verbose = FALSE),metrics = metric_set(rmse)
  )


# select best hiperparameter found
best_params <- select_best(tune_grid,"rmse")
wf <- wf %>% finalize_workflow(best_params)

wf
#> ══ Workflow ════════════════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: boost_tree()
#> 
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 1 Recipe Step
#> 
#> ● step_dummy()
#> 
#> ── Model ───────────────────────────────────────────────────────────────────────
#> Boosted Tree Model Specification (regression)
#> 
#> Main Arguments:
#>   mtry = 3
#>   trees = 1000
#>   min_n = 13
#>   tree_depth = 10
#>   learn_rate = 0.00333377440294304
#>   loss_reduction = 13.0320661814971
#>   sample_size = 0.75
#> 
#> computational engine: lightgbm

# last fit
last_fit <- last_fit(wf,initial_split)

# metrics
collect_metrics(last_fit)
#> # A tibble: 2 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard       0.380
#> 2 rsq     standard       0.837

# fit to predict new data
model_fit <-  fit(wf,data)
#> [LightGBM] [Warning] Auto-choosing row-wise multi-threading,the overhead of testing was 0.000020 seconds.
#> You can set `force_row_wise=true` to remove the overhead.
#> And if memory is not enough,you can set `force_col_wise=true`.
#> [LightGBM] [Info] Total Bins 95
#> [LightGBM] [Info] Number of data points in the train set: 1050,number of used features: 5
#> [LightGBM] [Info] Start training from score 5.843333
#> [LightGBM] [Warning] No further splits with positive gain,best gain: -inf
#> [LightGBM] [Warning] No further splits with positive gain,best gain: -inf
.................................................................................
predicciones = predict(model_fit,iris)

head(predicciones)
#> # A tibble: 6 x 1
#>   .pred
#>   <dbl>
#> 1  5.13
#> 2  5.12
#> 3  5.12
#> 4  5.12
#> 5  5.13
#> 6  5.25

# save model
saveRDS(model_fit,"model_fit.rds")

保存模型后,我关闭会话并在新的会话中加载模型。

model <- readRDS("model_fit.rds")

predicciones = predict(model,iris)

当我尝试生成预测时,r会话中断。最有效的替代方法是拉动工作流,提取拟合并使用模型自己的方法保存,但是我丢失了所有存储在work_flow中的工作流。我将竭诚为您提供帮助或建议。

pull_lightgbm = pull_workflow_fit(model_fit)


library(lightgbm)

lgb.save(pull_lightgbm$fit,"lightgbm.model")

model = lgb.load("lightgbm.model")
sessionInfo()
#> R version 4.0.3 (2020-10-10)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  Grdevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] treesnip_0.1.0.9000 tune_0.1.1          dials_0.0.9        
#>  [4] scales_1.1.1        workflows_0.2.1     recipes_0.1.14     
#>  [7] yardstick_0.0.7     rsample_0.0.8       parsnip_0.1.4      
#> [10] dplyr_1.0.2        
#> 
#> loaded via a namespace (and not attached):
#>  [1] Rcpp_1.0.5         lubridate_1.7.9    lattice_0.20-41    tidyr_1.1.2       
#>  [5] listenv_0.8.0      class_7.3-17       assertthat_0.2.1   digest_0.6.27     
#>  [9] ipred_0.9-9        foreach_1.5.1      parallelly_1.21.0  R6_2.5.0          
#> [13] plyr_1.8.6         evaluate_0.14      ggplot2_3.3.2      highr_0.8         
#> [17] pillar_1.4.6       rlang_0.4.8        DiceDesign_1.8-1   furrr_0.2.1       
#> [21] rpart_4.1-15       Matrix_1.2-18      rmarkdown_2.5      splines_4.0.3     
#> [25] gower_0.2.2        stringr_1.4.0      munsell_0.5.0      compiler_4.0.3    
#> [29] xfun_0.19          pkgconfig_2.0.3    globals_0.13.1     htmltools_0.5.0   
#> [33] nnet_7.3-14        tidyselect_1.1.0   tibble_3.0.4       prodlim_2019.11.13
#> [37] codetools_0.2-16   GPfit_1.0-8        fansi_0.4.1        future_1.20.1     
#> [41] Crayon_1.3.4       withr_2.3.0        MASS_7.3-53        grid_4.0.3        
#> [45] gtable_0.3.0       lifecycle_0.2.0    magrittr_1.5       pROC_1.16.2       
#> [49] cli_2.1.0          stringi_1.5.3      timeDate_3043.102  ellipsis_0.3.1    
#> [53] lhs_1.1.1          generics_0.1.0     vctrs_0.3.4        lava_1.6.8.1      
#> [57] iterators_1.0.13   tools_4.0.3        glue_1.4.2         purrr_0.3.4       
#> [61] parallel_4.0.3     survival_3.2-7     yaml_2.2.1         colorspace_1.4-1  
#> [65] knitr_1.30

reprex package(v0.3.0)于2020-11-08创建

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。