微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

PyTorch 使用外部库保留梯度

如何解决PyTorch 使用外部库保留梯度

我有一个返回预测 (n,n) 的 GAN。为了指导这个网络,我有一个损失函数,它是二元交叉熵损失 (torch.tensor) 和 Wasserstein 距离的总和。但是,为了计算 Wasserstein 距离,我使用了 bceloss 库中的 scipy.stats.wasserstein_distance 函数。您可能知道,此函数需要两个 SciPy 数组作为输入。所以,为了使用这个函数,我将我的预测张量和地面实况张量转换为 NumPy 数组,如下所示

NumPy

然后,将pred_np = pred_tensor.detach().cpu().clone().numpy().ravel() target_np = target_tensor.detach().cpu().clone().numpy().ravel() W_loss = wasserstein_distance(pred_np,target_np) W_loss相加得到总损失。我现在展示这部分是因为它有点不必要并且与我的问题无关。

我担心的是我正在分离梯度,所以我想在优化和更新模型参数时它不会考虑 bceloss。我是个新手,所以我希望我的问题很清楚,并感谢您提前回答。

解决方法

将一个不是张量的对象添加到您的损失中本质上是添加一个常量。常数的导数为零,所以这个增加的项对您的网络的权重没有任何影响。

tl;博士: 您需要在 pytorch 中重写损失计算(或者只是找到一个现有的实现,互联网上有很多)。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。