ValueError:至少需要一个数组才能与sklearn cross_val_predict方法连接

如何解决ValueError:至少需要一个数组才能与sklearn cross_val_predict方法连接

我正在尝试使用SVM分类器通过自定义交叉验证折叠对二进制分类问题进行建模,但是它给我带来了错误**需要至少一个数组以将cross_val_predict连接在一起**。该代码在cros_val_predict中的cv = 3上可以正常工作,但是当我使用custom_cv时,会出现此错误

下面是代码


from sklearn.model_selection import LeavePOut
import numpy as np
from sklearn.svm import SVC
from time import *
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import cross_val_predict,cross_val_score
clf = SVC(kernel='linear',C=25)
X = np.array([[1,2],[3,4],[5,6],[7,8],[9,10]])
y = np.array([0,1,0])
lpo = LeavePOut(2)
print(lpo.get_n_splits(X))
LeavePOut(p=2)
test_index_list=[]
train_index_list=[]
for train_index,test_index in lpo.split(X,y):
  
  if(y[test_index[0]]==y[test_index[1]]):
    pass
  else:
    print("TRAIN:",train_index,"TEST:",test_index)
    X_train,X_test = X[train_index],X[test_index]
    y_train,y_test = y[train_index],y[test_index]
    train_index_list.append(train_index)
    test_index_list.append(test_index)
custom_cv = zip(train_index_list,test_index_list)
scores = cross_val_score(clf,X,y,cv=custom_cv)

print(scores)
print('accuracy:',scores.mean())
predicted=cross_val_predict(clf,cv=custom_cv) # error with this line
print('Confusion matrix:',confusion_matrix(labels,predicted))

下面是完整的错误痕迹:

ValueError                                Traceback (most recent call last)
<ipython-input-11-d78feac932b2> in <module>()
     31 print(scores)
     32 print('accuracy:',scores.mean())
---> 33 predicted=cross_val_predict(clf,cv=custom_cv)
     34 
     35 print('Confusion matrix:',predicted))

/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_validation.py in cross_val_predict(estimator,groups,cv,n_jobs,verbose,fit_params,pre_dispatch,method)
    758     predictions = [pred_block_i for pred_block_i,_ in prediction_blocks]
    759     test_indices = np.concatenate([indices_i
--> 760                                    for _,indices_i in prediction_blocks])
    761 
    762     if not _check_is_permutation(test_indices,_num_samples(X)):

<__array_function__ internals> in concatenate(*args,**kwargs)

ValueError: need at least one array to concatenate

关于如何解决错误的任何建议?

解决方法

这里有2个错误:

  1. 如果要重用zip对象,请从中创建一个列表。使用一次后,该对象将耗尽。您可以像这样修复它:
custom_cv = [*zip(train_index_list,test_index_list)]
    cross_val_predict
  1. 交叉验证列表应该是实际数组(Each sample should only belong to exactly one test set)的分区。在您的情况下不是。如果考虑一下,从交叉验证列表中堆叠输出将得到长度为 6 的数组,而原始的 y 则为长度5。您可以像这样实现自定义交叉值预测:
def custom_cross_val_predict(clf,X,y,cv):
    y_pred,y_true = [],[]
    for tr_idx,vl_idx in cv:
        X_tr,y_tr = X[tr_idx],y[tr_idx]
        X_vl,y_vl = X[vl_idx],y[vl_idx]
        clf.fit(X_tr,y_tr)
        y_true.extend(y_vl)
        y_pred.extend(clf.predict(X_vl))
        
    return y_true,y_pred

labels,predicted = custom_cross_val_predict(clf,cv=custom_cv)
print('Confusion matrix:',confusion_matrix(labels,predicted))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?