关于随着回归数据集的增加,交叉验证模型的 MSE 增加

如何解决关于随着回归数据集的增加,交叉验证模型的 MSE 增加

我有以下用于回归问题的实验设置。

使用以下例程,将大约 1800 个条目的数据集分为三组:验证、测试和训练。

X_train,X_test,y_train,y_test = train_test_split(inputs,targets,test_size=0.2,random_state=42,shuffle=True)

X_train,X_val,y_val = train_test_split(X_train,test_size=0.25,shuffle=True)

所以本质上,训练大小 ~ 1100,验证和测试大小 ~ 350,然后每个子集都有唯一的数据点集,这是在其他子集中看不到的。

通过这些子集,我可以使用 scikit-learn 提供的任意数量的回归模型进行拟合,使用以下例程:

model = LinearRegression()
clf = make_pipeline(StandardScaler(),model)
clf.fit(X_train,y_train)
predictions = clf.predict(X_test)

这样做之后,我会计算预测的 RMSE,在线性回归量的情况下,大约为 ~ 0.948。

现在,我可以改为使用交叉验证,而不用担心拆分数据,使用以下例程:

model = LinearRegression()
clf = make_pipeline(StandardScaler(),model)
predictions2 = cross_val_predict(clf,X,y,cv=KFold(n_splits=10,shuffle=True,random_state=42))

然而,当我计算这些预测的 RMSE 时,大约是 ~2.4!为了比较,我尝试使用类似的例程,但将 X 切换为 X_train,将 y 切换为 y_train,即

model = LinearRegression()
clf = make_pipeline(StandardScaler(),model)
predictions3 = cross_val_predict(clf,X_train,random_state=42))

并获得约 0.956 的 RMSE。

我真的不明白为什么在使用整个数据集时,交叉验证的 RMSE 要高得多,而且与减少数据集的预测相比,预测很糟糕。

附加说明

此外,我尝试运行上述例程,这次使用简化的子集 X_val,y_val 作为交叉验证的输入,并且仍然收到较小的 RMSE。此外,当我简单地在简化的子集 X_val,y_val 上拟合模型,然后在 X_train,y_train 上进行预测时,RMSE 仍然比交叉验证的 RMSE 更好(更低)!

这不仅适用于 LinearRegressor,也适用于 RandomForrestRegressor 等。我还尝试更改拆分中的随机状态,并在将数据交给 train_test_split 之前完全混洗数据,但仍然发生相同的结果。

编辑 1.)

我在来自 scikit 的 make_regression 数据集上对此进行了测试,但没有得到相同的结果,但所有 RMSE 都很小且相似。我的猜测是这与 my 数据集有关。 如果有人能帮助我理解这一点,我将不胜感激。

编辑 2.)

嗨,谢谢 (@desertnaut) 的建议,解决方案实际上很简单,事实上,在我处理数据的例程中,我使用了 (targets,inputs) = (X,y),这确实是错误的。我将其与 (targets,inputs) = (y,X) 交换,现在 RMSE 与其他配置文件大致相同。我制作了数据的直方图配置文件并发现了这个问题。谢谢!我将问题保存约 1 小时,然后将其删除

解决方法

你过拟合了。

假设您有 10 个数据点和 10 个参数,那么 RMSE 将为零,因为模型可以完美地拟合数据,现在将数据点增加到 100,RMSE 将增加(假设您的数据存在一些差异)当然添加)因为您的模型不再完全适合数据。

RMSE 低(或 R 平方高)通常并不意味着千斤顶,您需要考虑参数估计的标准误差。 . .如果你只是增加参数的数量(或者相反,在你的情况下,减少观察的数量)你只是在咀嚼你的自由度。

我敢打赌,您对 X 模型参数估计的标准误差估计小于您在 X_train 模型中的标准误差估计,即使 X_train 模型中的 RMSE“较低”。

编辑:我要补充一点,您的数据集表现出高度的多重共线性。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?