微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

即使在四舍五入预测后也无法获得分类报告

如何解决即使在四舍五入预测后也无法获得分类报告

我正在尝试使用 keras 对我的数据集进行分类,但出现 ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets 错误y_pred 中的值如下

array([[2.95522604e-02,9.70325887e-01,3.20542094e-05,...,1.74383260e-07,1.98587145e-07,9.88743452e-08],[3.25102806e-01,6.68996394e-01,1.65001326e-03,5.84201662e-05,5.91963508e-05,4.68929684e-05],[8.87618303e-01,1.12024814e-01,1.22764613e-04,1.44616331e-06,1.33618846e-06,1.68983024e-06],[3.09438616e-01,6.83520675e-01,1.94711238e-03,7.57295784e-05,7.51852640e-05,5.94857411e-05],[6.73729360e-01,3.21534157e-01,1.41171378e-03,4.93246625e-05,4.61974196e-05,4.73670734e-05],[1.33120596e-01,8.64127636e-01,7.41749362e-04,1.87505502e-05,1.95825924e-05,1.34223355e-05]],dtype=float32)

我正在按照 this 问题中提到的方式将它们四舍五入,因为 y_test 值是

array([1,1,1]) 

y_pred 四舍五入 y_pred = y_pred.round().astype(int) 后,我有

array([[0,0],[1,[0,0]])

即使在此之后,当我尝试使用 print(metrics.classification_report(y_test,y_pred)) 获取分类报告时,我也会遇到与上述相同的错误。有人可以帮助我了解我在这里做错了什么吗?谢谢

解决方法

scikit-learn docs 声明 y_pred 输入必须是 1d 类数组。你需要 argmax 你的 logits。

import numpy as np
import tensorflow as tf
from sklearn.metrics import classification_report


y_pred = tf.math.abs(tf.random.normal([32,2])).numpy()
y_test = tf.random.uniform([32,1],minval=0,maxval=2,dtype=tf.int32).numpy()

# this will explode
print(classification_report(y_test,y_pred))

# ValueError: Classification metrics can't handle a mix of binary and 
# continuous-multioutput targets

# get predicted indices
y_pred = np.argmax(y_pred,1)

# try again
print(classification_report(y_test,y_pred))

#                precision    recall  f1-score   support
#
#             0       0.41      0.50      0.45        14
#             1       0.53      0.44      0.48        18
# 
#      accuracy                           0.47        32
#     macro avg       0.47      0.47      0.47        32
#  weighted avg       0.48      0.47      0.47        32

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。