如何解决二元交叉熵计算中的pos_weight
当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常会使用pos_weight
参数。
pos_weight
的期望是当 positive sample
得到错误标签时,模型会比 negative sample
得到更高的损失。
当我使用 binary_cross_entropy_with_logits
函数时,我发现:
import torch
import torch.nn.functional as F
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5,1.5])
label_pos = torch.FloatTensor([1,0])
loss_pos_wrong = F.binary_cross_entropy_with_logits(preds_pos_wrong,label_pos,pos_weight=pos_weight)
'''
loss_pos_wrong = tensor(2.0359)
'''
preds_neg_wrong = torch.FloatTensor([1.5,0.5])
label_neg = torch.FloatTensor([0,1])
loss_neg_wrong = F.binary_cross_entropy_with_logits(preds_neg_wrong,label_neg,pos_weight=pos_weight)
'''
loss_neg_wrong = tensor(2.0359)
'''
错误的正样本和负样本得到的损失是一样的,那么pos_weight
在不平衡数据损失计算中是如何工作的?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。