如何解决使用 sklearn 对减少的交叉验证数据集执行网格搜索的有效方法
我正在使用网格搜索来查找 2 个模型的最佳参数。我必须使用整个数据集构建一个模型,并使用减少的数据集构建另一个模型(需要保持两个模型的折叠相同)。因此,在第二个模型的情况下,将从已用于第一个模型(具有整个数据集的模型)的同一折叠中省略/删除数据点列表。以下是我的代码:
rkf = RepeatedKFold(n_splits=2,n_repeats=5,random_state=24)
rkf_new_indices = []
for train_idx,test_idx in rkf.split(x):
Model1x_train,Model1x_test = x[train_idx],x[test_idx]
Model1y_train,Model1y_test = y[train_idx],y[test_idx]
temp_list1 = train_idx.copy()
temp_list2 = test_idx.copy()
Model2trn_idx = remove_datapoints(temp_list1,out_list)
Model2tst_idx = remove_datapoints(temp_list2,out_list)
Model2train_idx = list(Model2trn_idx)
Model2test_idx = list(Model2tst_idx)
rkf_new_indices = np.append(Model2train_idx,Model2test_idx)
param_grid = [{'C': [1,10,100,1000],'kernel': ['linear']},{'C': [1,'gamma': [0.001,0.0001],'kernel': ['rbf']},]
svr_model = SVR()
# define search for model with entire dataset
BASE_SVR = gridsearchcv(svr_model,param_grid,scoring='neg_mean_absolute_error',n_jobs=-1,cv=rkf,return_train_score=True)
BASE_SVR_grid_results = BASE_SVR.fit(x,y)
# define search for model with reduced dataset
New_SVR = gridsearchcv(svr_model,cv=rkf_new_indices,return_train_score=True)
# ^^^^^^^^^^^^ raises TypeError
New_SVR_grid_results = New_SVR.fit(x,y)
对于第二个 GridSearch(第 19 行),出现错误:
for train,test in self.cv:
> TypeError: cannot unpack non-iterable numpy.int32 object
我在使用 cv=rkf_new_indices
时做错了什么,我该如何解决?
解决方法
当你在段下运行时,分割的输出是
rkf_new_indices = []
for train_idx,test_idx in rkf.split([8,8,8]):
print(train_idx,test_idx)
rkf_new_indices = np.append(train_idx,test_idx)
[0 1 2 3] [4 5 6 7 8]
[4 5 6 7 8] [0 1 2 3]
[2 3 4 7] [0 1 5 6 8]
[0 1 5 6 8] [2 3 4 7]
[1 3 7 8] [0 2 4 5 6]
[0 2 4 5 6] [1 3 7 8]
[1 4 7 8] [0 2 3 5 6]
[0 2 3 5 6] [1 4 7 8]
[1 2 6 7] [0 3 4 5 8]
[0 3 4 5 8] [1 2 6 7]
然而,rkf_new_indices = np.append(train_idx,test_idx)
只获取最后一个实例:
array([0,3,4,5,1,2,6,7])
您可以尝试使用 rkf_new_indices.append((train_idx,test_idx))
将它们全部配对:
[(array([0,3]),array([4,7,8])),(array([4,8]),array([0,3])),(array([2,7]),(array([0,array([2,7])),(array([1,6])),6]),array([1,7]))]
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。