使用 cross_validate 生成混淆矩阵

如何解决使用 cross_validate 生成混淆矩阵

我想弄清楚如何使用 cross_validate 生成混淆矩阵。我可以用目前的代码打印出分数。

# Instantiating model
model = DecisionTreeClassifier()

#scores
scoring = {'accuracy' : make_scorer(accuracy_score),'precision' : make_scorer(precision_score),'recall' : make_scorer(recall_score),'f1_score' : make_scorer(f1_score)}

# 10-fold cross validation
scores = cross_validate(model,X,y,cv=10,scoring=scoring)

print("Accuracy (Testing):  %0.2f (+/- %0.2f)" % (scores['test_accuracy'].mean(),scores['test_accuracy'].std() * 2))
print("Precision (Testing):  %0.2f (+/- %0.2f)" % (scores['test_precision'].mean(),scores['test_precision'].std() * 2))
print("Recall (Testing):  %0.2f (+/- %0.2f)" % (scores['test_recall'].mean(),scores['test_recall'].std() * 2))
print("F1-score (Testing):  %0.2f (+/- %0.2f)" % (scores['test_f1_score'].mean(),scores['test_f1_score'].std() * 2))

但我正在尝试将这些数据放入混淆矩阵中。我可以通过使用 cross_val_predict 来制作混淆矩阵 -

y_train_pred = cross_val_predict(model,cv=10)
confusion_matrix(y,y_train_pred)

这很好,但由于它进行了自己的交叉验证,结果将不匹配。我只是在寻找一种方法来产生匹配的结果。

任何帮助或指示都会很棒。谢谢!

解决方法

答案是你不能。

混淆矩阵的思想是使用一个训练好的模型评估一个数据。结果是矩阵,而不是像准确率这样的分数。所以你不能计算平均值或类似的东西。 cross_val_score 顾名思义,仅适用于分数。混淆矩阵不是一个分数,它是对评估过程中发生的事情的一种总结。

cross_val_predict 与您要查找的内容非常相似。此函数将数据拆分为 K 部分。每个部分都将使用您使用其他部分数据获得的模型进行测试。所有测试过的样本将被合并。但是要小心这个函数: “将这些预测传递到评估指标中可能不是衡量泛化性能的有效方法。结果可能与 cross_validate 和 cross_val_score 不同,除非所有测试集具有相同的大小并且指标在样本上分解。 “

,

我认为最好的方法是将混淆矩阵定义为一个记分器,而不是您定义的其他矩阵,或者除此之外。幸运的是,这是用户指南中的一个示例;查看第三个项目符号 here

def confusion_matrix_scorer(clf,X,y):
    y_pred = clf.predict(X)
    cm = confusion_matrix(y,y_pred)
    return {'tn': cm[0,0],'fp': cm[0,1],'fn': cm[1,'tp': cm[1,1]}
cv_results = cross_validate(svm,y,cv=5,scoring=confusion_matrix_scorer)

然后cv_results['test_tp'](等)是每个折叠的真阳性数量的列表。现在您可以聚合混淆矩阵,但最适合您。


首先想到了另一种方法,我会在这里添加它,以防它有助于理解 sklearn 如何处理事物。但我绝对认为第一种方法更好。

您可以在 return_estimator 中设置 cross_validate,在这种情况下,返回的字典有一个键 estimator,其值为拟合模型列表。不过,您仍然需要能够找到相应的测试折叠。为此,您可以手动定义您的 cv 对象(例如 cv = StratifiedKFold(10)cross_validate(...,cv=cv);然后 cv 仍将包含用于进行拆分的相关数据。因此您可以使用拟合估计器对适当的测试折叠进行评分,生成混淆矩阵。或者您可以使用 cross_val_predict(...,cv=cv),但此时您会重复拟合,因此您可能应该跳过 cross_validate 并自己进行循环。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?