如何解决Pytorch 闪电指标:ValueError:preds 和 target 必须具有相同数量的维度,或者为 preds 增加一个维度
在谷歌上搜索这个会让你无处可去,所以我决定通过将此作为可搜索问题发布来帮助未来的我和其他人。
def __init__():
...
self.val_acc = pl.metrics.Accuracy()
def validation_step(self,batch,batch_index):
...
self.val_acc.update(log_probs,label_batch)
给予
ValueError: preds and target must have same number of dimensions,or one additional dimension for preds
对于log_probs.shape == (16,4)
和对于label_batch.shape == (16,4)
有什么问题?
解决方法
pl.metrics.Accuracy()
需要一批 dtype=torch.long
标签,而不是单热编码标签。
因此,它应该被喂
self.val_acc.update(log_probs,torch.argmax(label_batch.squeeze(),dim=1))
这和torch.nn.CrossEntropyLoss
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。