如何解决Sklearn 分组 k-fold - 测试和训练中的同一组
来自文档 (https://scikit-learn.org/stable/modules/cross_validation.html#group-k-fold):
GroupKFold 是 k-fold 的变体,可确保同一组不会出现在测试和训练集中
然后,稍微调整一下例子,我们有:
from sklearn.model_selection import GroupKFold
X = np.array([0.1,0.2,2.2,2.4,2.3,4.55,5.8,8.8,9,10])
y = np.array(["a","b","c","d","d"])
groups = [1,1,2,3,3]
gkf = GroupKFold(n_splits=3)
for train,test in gkf.split(X,y,groups=groups):
print("%s %s" % (train,test))
打印:
[0 1 2 3 4 5] [6 7 8 9]
[0 1 2 6 7 8 9] [3 4 5]
[3 4 5 6 7 8 9] [0 1 2]
在我看来,组 b
在这里的测试和训练集中,我们有 -
[3 4 5 6 7 8 9] [0 1 2]
对于最后一个输出,其中测试索引是 [0,2]
,这给了我们组 a
和组 b
中的两个值,这意味着组 {{ 中有一个值1}} 在测试集中以及训练(其中索引 b
)。
大概文档/模块是正确的,我错了,但我不明白如何。
要明确 - 我不希望在测试和培训中看到同一组的值,而且确实如此。
解决方法
您将课程误认为是小组。正如评论已经指出的那样,它们仅由 group
参数决定,与类无关。
按照您已经链接到的描述,您可以更好地理解示例:
例如,如果数据是从不同的受试者获得的,每个受试者有几个样本,并且如果模型足够灵活,可以从高度个性化的特征中学习,它可能无法推广到新的主题。
因此问题 GroupKFold
可能是这样一种情况:您从不同来源(在示例中为主题)获取数据,并希望控制您的模型是否具有足够好的泛化能力在其他来源的数据上表现良好。或者换句话说,您希望确保您的模型没有过度拟合来自一个或多个特定来源的数据。这就是 GroupKFold
的用途:
GroupKFold
可以检测这种过度拟合的情况。
因此这些来源(或主题)由 group
参数确定,并由 GroupKFold
分隔,因此在测试和训练中永远不会出现相同的来源折叠。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。