微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在 Pytorch 中定义“无关”类?

如何解决如何在 Pytorch 中定义“无关”类?

我有一个时间序列分类任务,我应该为每个时间戳 t 输出 3 个类别的分类

所有数据都按帧标记

数据集中有 3 个以上的类 [它们也是不平衡的]。

我的网络应该按顺序查看所有样本,因为它会将其用于历史信息。
因此,我不能只在预处理时消除所有不相关的类样本。

如果对标记与这 3 个类别不同的​​帧进行预测,我不关心结果。


如何在 Pytorch 中正确执行此操作?

解决方法

this discussion 之后,Google 无法搜索,有两个选项,都是 CrossEntropyLoss 的选项:

选项 1

如果只有一个类要忽略,则在实例化损失时使用 ignore_index=class_index

选项 2

如果有更多的类,使用weight=weights,带有weights.shape==n_classestorch.sum(weights[ignored_classes]) == 0

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。