微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

这是 sklearn 分类报告对多标签分类报告的正确使用吗?

如何解决这是 sklearn 分类报告对多标签分类报告的正确使用吗?

我正在用 tf-keras 训练一个神经网络。这是一个标签分类,其中每个样本属于多个类 [1,1,0..etc] .. 最终模型线(只是为了清楚起见)是:

model.add(tf.keras.layers.Dense(9,activation='sigmoid'))#final layer

model.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=[tf.keras.metrics.BinaryAccuracy(),tfa.metrics.F1score(num_classes=9,average='macro',threshold=0.5)])

我需要为这些生成准确率、召回率和 F1 分数(我已经在训练期间获得了 F1 分数)。为此,我使用 sklearns 分类报告,但我需要确认我在多标签设置中正确使用它。

from sklearn.metrics import classification_report

pred = model.predict(x_test)
pred_one_hot = np.around(pred)#this generates a one hot representation of predictions

print(classification_report(one_hot_ground_truth,pred_one_hot))

这很好用,我得到了每个班级的完整报告,包括与来自 tensorflow 插件的 F1score 指标相匹配的 F1 分数(对于宏 F1)。抱歉,这篇文章很冗长,但我不确定的是:

在多标签设置的情况下,预测需要进行单热编码是否正确?如果我传入正常的预测分数(sigmoid 概率),则会抛出错误...

谢谢。

解决方法

对二元、多类和多标签分类使用 classification_report 是正确的。

在多类分类的情况下,标签不是单热编码的。它们只需要是 indiceslabels

您可以看到下面的两个代码产生相同的输出:

索引示例

from sklearn.metrics import classification_report
import numpy as np

labels = np.array(['A','B','C'])


y_true = np.array([1,2,1,0])
y_pred = np.array([1,0])
print(classification_report(y_true,y_pred,target_names=labels))

标签示例

from sklearn.metrics import classification_report
import numpy as np

labels = np.array(['A','C'])

y_true = labels[np.array([1,0])]
y_pred = labels[np.array([1,0])]
print(classification_report(y_true,y_pred))

两者都返回

              precision    recall  f1-score   support

           A       1.00      0.50      0.67         2
           B       0.50      1.00      0.67         2
           C       1.00      0.50      0.67         2

    accuracy                           0.67         6
   macro avg       0.83      0.67      0.67         6
weighted avg       0.83      0.67      0.67         6

在多标签分类的上下文中,classification_report 可以像下面的例子一样使用:

from sklearn.metrics import classification_report
import numpy as np

labels =['A','C']

y_true = np.array([[1,1],[0,0],[1,1]])
y_pred = np.array([[1,1]])

print(classification_report(y_true,target_names=labels))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。