如何解决用于线性模型的 R 中 K 折交叉验证的通用函数
我正在使用的数据集称为工资,它来自一个名为 library(ISLR) 的包 数据(工资)。
无论如何,我正在尝试开发一个函数,允许我对任何一般线性模型执行 k 折交叉验证。
我正在使用的函数的输入/参数是 function(numberOfFolds,y,x,InputData)
y 是因变量 x 是数据集中的所有其他变量 inputdata 是工资数据集 numberOfFolds 基本上是 k。
我已经开发了以下代码,但我得到了 NaN 值。不知道发生了什么问题!有人可以帮忙吗
my.k.fold.1<- function(numberOfFolds,inputData){
index<-sample(1:numberOfFolds,nrow(inputData),replace = T)
inputData$index<-index
mse<-vector('numeric',length = numberOfFolds)
for (n in 1:numberOfFolds) {
data.train<-inputData[index!=n,]
data.test<-inputData[index==n,]
my.equation<-paste(y,paste(x,collapse = '+'),sep='~')
formula.1<-formula(my.equation)
model.test<-lm(formula.1,data = data.train)
predictions<-predict(model.test,newdata=data.test)
mse[[n]]<-mean((data.test$y-predictions)^2)
}
return(mse)
}
my.k.fold.1(numberOfFolds = 5,y='earn',x=c('race','sex','ed','height','age'),inputData = wages)
我想保持参数不变,我可以在 y 和 xs 中写下列名
解决方法
这是因为[
{
"github_open_issues": {
"0": {
"git_url": "https://github.com/","git_assignees": "None","git_open_date": "2019-09-26","git_id": 253113,"repo": "repoA","git_user": "userA","state": "open"
},"1": {
"git_url": "https://github.com/","git_open_date": "2019-11-15","git_id": 294398,"repo": "repoB","git_user": "userB","2": {
"git_url": "https://github.com/","git_open_date": "2021-04-12","git_id": 661208,"state": "open"
}
},"unique_label_seen": {
"568": {
"label_name": "some label","times_seen": 12,"535": {
"label_name": "another label","times_seen": 1
}
}
}
}
]
变量是一个字符串,所以y
等价于data.test$y
。您应该将其替换为 data.test[["y"]]
,这相当于 data.test[[y]]
if data.test$earn
:
y="earn"
,
这是一个通用函数。参数名称是自描述的。我添加了一个参数 verbose
,默认为 FALSE
。
下面使用内置数据集 mtcars
进行测试。
my.k.fold.1 <- function(numberOfFolds,inputData,response,regressors,verbose = FALSE){
fmla <- paste(regressors,collapse = "+")
fmla <- paste(response,fmla,sep = "~")
fmla <- as.formula(fmla)
index <- sample(numberOfFolds,nrow(inputData),replace = TRUE)
mse.all <- numeric(numberOfFolds)
for (n in seq_len(numberOfFolds)) {
inx <- which(index != n)
data.training <- inputData[inx,]
data.test <- inputData[-inx,]
if(verbose){
msg <- paste("fold:",n,"nrow(training):",nrow(data.training),"nrow(test):",nrow(data.test))
message(msg)
}
model <- lm(fmla,data = data.training)
predicted <- predict(model,newdata = data.test)
mse <- mean((data.test[[response]] - predicted)^2)
mse.all[n] <- mse
}
return(mse.all)
}
X <- names(mtcars)[-c(1,3,5,7)]
y <- "mpg"
set.seed(2021)
mse.kcv <- my.k.fold.1(5,mtcars,response = y,regressors = X,verbose = TRUE)
mse.kcv
#[1] 14.255583 8.355831 2.765447 7.539299 10.151655
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。