如何解决如何在忽略类中使用 pytorch 闪电精度?
我有一些训练管道,它使用 CrossEntropyLoss
和一个忽略类。
模型输出形状为 log_probs
的 (150,3)
- 表示 3 个可能的类,每批 150 个。
label_batch
的形状为 150
,并且 torch.max(label_batch)
== tensor(3,device='cuda:0')
,这意味着有一个额外的类标记为 3
,即忽略类.
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',ignore_index=3
)
但准确度指标认为类 3
是有效的,并给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
由于 self.train_acc.update(log_probs,label_batch)
标签而导致的 3
错误结果应该被忽略。
如何在忽略类中正确使用 pl.metrics.Accuracy()
?
解决方法
复制来自 github 论坛讨论主题的回复 https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890
准确度指标目前不支持它,但我们有一个公开 PR 来实现该准确功能 PyTorchLightning/metrics#155
目前你可以做的是计算混淆矩阵,然后根据它忽略一些类(记住,真正的正/正确分类是在混淆矩阵的对角线上找到的):
ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds,target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。