如何解决我的“焦点损失”函数中是否有任何逻辑错误?
“焦点损失”的诞生是为了解决困难的样本。我用二进制交叉包裹了我的焦点损失函数 pytorch 中的熵:
class FocalLoss(nn.Module):
def __init__(self,gamma=2):
super(FocalLoss,self).__init__()
self.gamma = gamma
def forward(self,pred,label):
# label is not the one-hot
true = torch.zeros_like(pred,dtype=torch.float)
for i,j in enumerate(label):
true[i,j] = 1.0
loss = nn.BCEWithLogitsLoss(pred,true)
pred_prob = torch.sigmoid(pred) # sigmoid
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= modulating_factor
return loss.mean()
并且我在 cifar10 数据集中训练我的 resnet18 以使用它和 nn.CrossEntropyLoss 进行分类任务 准确率对比。
但结果与我的预期相差甚远:CrossEntropyLoss 的准确率比我的焦点损失高约 5%!任何人都可以在我上面的焦点损失代码中找到逻辑错误吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。