微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

关于焦点损失函数实现的问题

如何解决关于焦点损失函数实现的问题

在介绍焦点损失的 paper 中,他们声明损失函数的公式如下:

enter image description here

哪里

enter image description here

我在另一个作者的 Github 页面上找到了它的实现,他在他们的 paper 中使用了它。我在我拥有的分割问题数据集上尝试了该功能,它似乎工作得很好,但对我来说实现起来似乎很奇怪。

以下是实现:

def binary_focal_loss(pred,truth,gamma=2.,alpha=.25):
    eps = 1e-8
    pred = nn.softmax(1)(pred)
    truth = F.one_hot(truth,num_classes = pred.shape[1]).permute(0,3,1,2).contiguous()

    pt_1 = torch.where(truth == 1,pred,torch.ones_like(pred))
    pt_0 = torch.where(truth == 0,torch.zeros_like(pred))

    pt_1 = torch.clamp(pt_1,eps,1. - eps)
    pt_0 = torch.clamp(pt_0,1. - eps)

    out1 = -torch.mean(alpha * torch.pow(1. - pt_1,gamma) * torch.log(pt_1)) 
    out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))

    return out1 + out0

我不明白的部分是out0的计算。由于论文将 y =/= 1 时的 Pt 值定义为 1-p,因此我希望 out0 改为这样写:

out0 = -torch.mean((1 - alpha) * torch.pow((1 - (1 - pt_0)),gamma) * torch.log(1. - pt_0))

代替

out0 = -torch.mean((1 - alpha) * torch.pow(pt_0,gamma) * torch.log(1. - pt_0))

谁能帮我解释一下?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。