kfold 交叉验证后如何绘制适合每个折叠的数据和模型?

如何解决kfold 交叉验证后如何绘制适合每个折叠的数据和模型?

我试图根据一个特征预测一个标签变量。两者似乎是高度线性相关的。我选择了一个线性回归模型来描述数据。我的代码输出显示了训练和测试数据的 R2 分数。我的模型表现良好,期望测试样本的一倍,其中 R2 为负。 我想绘制每个折叠的数据和模型的拟合,以了解出了什么问题。但是,我无法从 python 编码的角度弄清楚如何做到这一点。

有人可以帮忙吗?


Test_scores = list()
Train_scores =list()
n_splits = 5
kfold = KFold(n_splits=n_splits,shuffle=False)
for train_ix,test_ix in kfold.split(Feature_X):
    Train_Feature_X,Test_Feature_X=Feature_X[train_ix],Feature_X[test_ix]
    Train_label_X,Test_label_X= label_X[train_ix],label_X[test_ix]
    model = LinearRegression()
    model.fit(Train_Feature_X,Train_label_X)
    pred_label_train = model.predict(Train_Feature_X)
    acc_train = r2_score(Train_label_X,pred_label_train)
    pred_label_test = model.predict(Test_Feature_X)
    acc_test = r2_score(Test_label_X,pred_label_test)
    Test_scores.append(acc_test)
    Train_scores.append(acc_train)
    print('> ','Train:'+ str(acc_train),"Test:"+ str(acc_test))
Test_mean,Test_std = np.mean(Test_scores),np.std(Test_scores)
Train_mean,Train_std = np.mean(Train_scores),np.std(Train_scores)

print('Mean of test: %.3f,Standard Deviation: %.3f' % (Test_mean,Test_std))
print('Mean of train: %.3f,Standard Deviation: %.3f' % (Train_mean,Train_std))



代码输出

>  Train:0.9948113361306588 Test:0.9715872368615199
>  Train:0.9905854864161807 Test:0.9917503220348162
>  Train:0.9888929852977923 Test:-4.996610921978263
>  Train:0.990942242777374 Test:0.9960355777732937
>  Train:0.9925744355834707 Test:0.9458246438971184
Mean of test: -0.218,Standard Deviation: 2.389
Mean of train: 0.992,Standard Deviation: 0.002

enter image description here

解决方法

您可以将绘图添加到循环循环中。

每次迭代您都可以访问训练测试折叠和预测,因此在打印值之前,即 print('> ','Train:'+ str(acc_train),"Test:"+ str(acc_test)) 您可以执行以下操作:

fig,ax = plt.subplots(nrows=1,ncols=5)
curr_split = 1
for ...

    plt.subplot(1,5,curr_split)
    plt.plot(x,y)
    curr_split += 1
plt.show()

这将绘制 5 个子图,每个子图代表折叠。

请注意,这是您应该做的大纲,请参阅以下链接中的文档 plt.subplots()

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?