微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

简单的 glmnet 模型,predict() 导致“lambda[1] 中的错误 - s:二元运算符的非数字参数”

如何解决简单的 glmnet 模型,predict() 导致“lambda[1] 中的错误 - s:二元运算符的非数字参数”

所以我一直在尝试将 predict() 与各种形式的数据帧格式一起使用,但它们似乎不起作用。我已经尝试过 1) 排除因变量,2) 包括带有切片数据的因变量,3) 包括其中包含 NA 值的因变量,以及许多其他事情。

R 4.1.0
R Studio 1.4.1717

下面的代码演示了 3).

library(tidyverse)
library(lubridate)
library(tidymodels)

df <- data.frame(y  = sample(5000000:120000000,100,replace = TRUE),yearr = sample(2015:2021,monthh = sample(1:12,dayy = sample(1:31,replace = TRUE))

rm(df_slice)
df_slice = df |>
  slice(1:50) |>
  select(yearr,monthh,dayy) |>
  mutate(y = NA)

m = linear_reg(mode = 'regression',penalty = varying(),mixture = 0.6) |>
  set_engine("glmnet") |>
  fit(y ~ .,data = df)

predict(m,df_slice)
predict.model_fit(m,df_slice)
predict_raw(m,df_slice)

最后三行代码抛出 Error in lambda[1] - s : non-numeric argument to binary operator 调试消息。我确保 dfdf_slice 中的所有变量都是数字,但仍然不确定发生了什么。如果我要进行训练测试拆分,我只想获得预测/拟合值以及“未来”值。为什么这不起作用?

解决方法

您正在使用 glmnet,您正在调整的 penalty 是 L2 规范,在 glmnet 中也称为 lambda,请参阅 the help page

如果你设置 penalty = varying() ,你是在一系列 L2 范数上运行 glmnet,当你调用 predict 时,你需要提供一个 lambda 值来预测。因此,对于现在的示例,您不应使用 penalty = varying(),而是提供 lambda 的值:

library(tidyverse)
library(lubridate)
library(tidymodels)

m = linear_reg(mode = 'regression',penalty = 1,mixture = 0.6) %>%
  set_engine("glmnet") %>%
  fit(y ~ .,data = df)

predict(m,df_slice)

否则,您需要调整并找到一个合适的lambda,然后通过它来重新拟合模型:

my_cv = vfold_cv(df)
rec = recipe(y ~.,data=df) %>% prep(training = df,retain=TRUE)
fit = linear_reg(mode = 'regression',penalty = tune(),mixture = 0.6) %>%
  set_engine("glmnet") 

wflow = workflow() %>%
add_recipe(rec) %>%
add_model(fit)

res = wflow %>% tune_grid(my_cv)

best_params = res %>% select_best(metric = "rmse")

m = wflow %>%
  finalize_workflow(best_params) %>%
  fit(data = df)

predict(m,df_slice)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?