如何解决调整 LASSO 模型并使用 tidymodels 进行预测
我想对 LASSO 算法执行惩罚选择并使用 tidymodels
预测结果。我将使用波士顿住房数据集来说明问题。
library(tidymodels)
library(tidyverse)
library(mlbench)
data("BostonHousing")
dt <- BostonHousing
我首先将数据集拆分为训练/测试子集。
dt_split <- initial_split(dt)
dt_train <- training(dt_split)
dt_test <- testing(dt_split)
使用 recipe
包定义预处理。
rec <- recipe(medv ~ .,data = dt_train) %>%
step_center(all_predictors(),-all_nominal()) %>%
step_dummy(all_nominal()) %>%
prep()
模型和工作流程的初始化。我使用 glmnet
引擎。 mixture = 1
表示我选择 LASSO 惩罚,penalty = tune()
表示我稍后将使用交叉验证来选择最佳惩罚参数 lambda
。
lasso_mod <- linear_reg(mode = "regression",penalty = tune(),mixture = 1) %>%
set_engine("glmnet")
wf <- workflow() %>%
add_model(lasso_mod) %>%
add_recipe(rec)
准备分层 5 折交叉验证和惩罚网格:
folds <- rsample::vfold_cv(dt_train,v = 5,strata = medv,nbreaks = 5)
my_grid <- tibble(penalty = 10^seq(-2,-1,length.out = 10))
让我们运行交叉验证:
my_res <- wf %>%
tune_grid(resamples = folds,grid = my_grid,control = control_grid(verbose = FALSE,save_pred = TRUE),metrics = metric_set(rmse))
我现在可以从网格中获得最佳罚分并更新我的工作流程以获得最佳罚分:
best_mod <- my_res %>% select_best("rmse")
print(best_mod)
final_wf <- finalize_workflow(wf,best_mod)
print(final_wf)
== Workflow ===================================================================================================================
Preprocessor: Recipe
Model: linear_reg()
-- Preprocessor ---------------------------------------------------------------------------------------------------------------
2 Recipe Steps
* step_center()
* step_dummy()
-- Model ----------------------------------------------------------------------------------------------------------------------
Linear Regression Model Specification (regression)
Main Arguments:
penalty = 0.0278255940220712
mixture = 1
computational engine: glmnet
到目前为止一切顺利。现在我想将工作流应用于训练数据以获得我的最终模型:
final_mod <- fit(final_wf,data = dt_train) %>%
pull_workflow_fit()
问题来了。
final_mod$fit
是一个 elnet
和 glmnet
对象。它包含惩罚参数的 75 个值的网格上的完整正则化路径。因此,之前的惩罚调整步骤几乎没有用。所以预测步骤失败了:
predict(final_mod,new_data = dt)
返回错误:
Error in cbind2(1,newx) %*% nbeta :
invalid class 'NA' to dup_mMatrix_as_dgeMatrix
当然,我可以使用 glmnet::cv.glmnet
来获得最佳惩罚,然后使用方法 predict.cv.glmnet
,但我需要一个通用的工作流程,能够使用相同的界面处理多个机器学习模型。在 parsnip::linear_reg
的 documentation 中有关于 glmnet 引擎的注释:
对于 glmnet 模型,完整的正则化路径总是适合的 无论给予惩罚的价值如何。此外,还有一个选项 将多个值(或无值)传递给惩罚参数。什么时候 在这些情况下使用 predict() 方法,返回值取决于 惩罚的价值。使用 predict() 时,只有一个值 可以使用惩罚。在预测多个惩罚时, 可以使用 multi_predict() 函数。它返回一个带有列表的小标题 名为 .pred 的列包含一个带有所有惩罚的 tibble 结果。
但是,我不明白我应该如何使用 tidymodels
框架继续获得调整后的 LASSO 模型的预测。 multi_predict
函数抛出与 predict
相同的错误。
解决方法
您真的很接近让一切正常工作。
让我们读入数据,将其拆分为训练/测试并创建重采样折叠。
library(tidymodels)
library(tidyverse)
library(mlbench)
data("BostonHousing")
dt <- BostonHousing
dt_split <- initial_split(dt)
dt_train <- training(dt_split)
dt_test <- testing(dt_split)
folds <- vfold_cv(dt_train,v = 5,strata = medv,nbreaks = 5)
现在让我们创建一个预处理配方。 (请注意,如果您使用的是 prep()
,则不需要 workflow()
;如果您的数据很大,那可能会变慢,所以最好不要这样做,直到 workflow()
以后照顾它。)
rec <- recipe(medv ~ .,data = dt_train) %>%
step_center(all_predictors(),-all_nominal()) %>%
step_dummy(all_nominal())
现在让我们制作我们的模型,将它与我们的配方放在一个 workflow()
中,并使用网格调整工作流程。
lasso_mod <- linear_reg(mode = "regression",penalty = tune(),mixture = 1) %>%
set_engine("glmnet")
wf <- workflow() %>%
add_model(lasso_mod) %>%
add_recipe(rec)
my_grid <- tibble(penalty = 10^seq(-2,-1,length.out = 10))
my_res <- wf %>%
tune_grid(resamples = folds,grid = my_grid,control = control_grid(verbose = FALSE,save_pred = TRUE),metrics = metric_set(rmse))
这是我们得到的最好的惩罚:
best_mod <- my_res %>% select_best("rmse")
best_mod
#> # A tibble: 1 x 2
#> penalty .config
#> <dbl> <chr>
#> 1 0.0215 Preprocessor1_Model04
在这里,我们的做法与您的做法略有不同。我将用最佳惩罚最终确定我的工作流程,然后将最终确定的工作流程拟合到训练数据中。此处的输出是适合的工作流程。我不想从中拉出底层模型,因为模型需要预处理才能正常工作;它接受过训练,期望进行预处理。
相反,我可以predict()
直接使用经过训练的工作流程:
final_fitted <- finalize_workflow(wf,best_mod) %>%
fit(data = dt_train)
predict(final_fitted,dt_train)
#> # A tibble: 379 x 1
#> .pred
#> <dbl>
#> 1 18.5
#> 2 24.2
#> 3 23.3
#> 4 21.6
#> 5 37.6
#> 6 21.5
#> 7 16.7
#> 8 15.6
#> 9 21.3
#> 10 21.3
#> # … with 369 more rows
predict(final_fitted,dt_test)
#> # A tibble: 127 x 1
#> .pred
#> <dbl>
#> 1 30.2
#> 2 25.1
#> 3 19.6
#> 4 17.0
#> 5 13.9
#> 6 15.4
#> 7 13.7
#> 8 20.8
#> 9 31.1
#> 10 21.3
#> # … with 117 more rows
由 reprex package (v1.0.0) 于 2021 年 3 月 16 日创建
如果您调整工作流程,那么您通常希望最终确定、调整和预测工作流程。如果您在工作流程中使用非常简单的预处理器(例如可以传递给 fit()
; 的公式),则可能会有例外。我show an example that you could do that with here。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。