如何解决关于焦点损失函数实现的问题
在介绍焦点损失的 paper 中,他们声明损失函数的公式如下:
哪里
我在另一个作者的 Github 页面上找到了它的实现,他在他们的 paper 中使用了它。我在我拥有的分割问题数据集上尝试了该功能,它似乎工作得很好,但对我来说实现起来似乎很奇怪。
以下是实现:
def binary_focal_loss(pred,truth,gamma=2.,alpha=.25):
eps = 1e-8
pred = nn.softmax(1)(pred)
truth = F.one_hot(truth,num_classes = pred.shape[1]).permute(0,3,1,2).contiguous()
pt_1 = torch.where(truth == 1,pred,torch.ones_like(pred))
pt_0 = torch.where(truth == 0,torch.zeros_like(pred))
pt_1 = torch.clamp(pt_1,eps,1. - eps)
pt_0 = torch.clamp(pt_0,1. - eps)
out1 = -torch.mean(alpha * torch.pow(1. - pt_1,gamma) * torch.log(pt_1))
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))
return out1 + out0
我不明白的部分是out0的计算。由于论文将 y =/= 1 时的 Pt 值定义为 1-p,因此我希望 out0 改为这样写:
out0 = -torch.mean((1 - alpha) * torch.pow((1 - (1 - pt_0)),gamma) * torch.log(1. - pt_0))
代替
out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))
谁能帮我解释一下?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。