微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

混淆矩阵错误“列表”对象没有属性“argmax”

如何解决混淆矩阵错误“列表”对象没有属性“argmax”

我正在编写 DCNN 模型的分类报告,但遇到错误。我的代码

from sklearn.metrics import confusion_matrix

test = ImageDataGenerator()
test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
test_data = test_generator.flow_from_directory(directory="/content/dataset/test",target_size=IMAGE_SHAPE,color_mode="rgb",class_mode='categorical',batch_size=1,shuffle = False )
test_data.reset()

predicted_class_indices=np.argmax(pred,axis=1)
cm = confusion_matrix(test_labels,predictions.argmax(axis=1))

错误

AttributeError: 'list' object has no attribute 'argmax'

解决方法

你的 predictions 显然是一个 Python 列表,列表没有 argmax 属性;您需要使用 Numpy 函数 argmax():

predictions = [[0.1,0.9],[0.8,0.2]] # dummy data
y_pred_binary = predictions.argmax(axis=1)
# AttributeError: 'list' object has no attribute 'argmax'

# Use Numpy:
import numpy as np
y_pred_binary = np.argmax(predictions,axis=1)
y_pred_binary
# array([1,0])

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。