微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

什么时候使用 torch.no_grad() 在前向传播中是安全的?为什么它会严重伤害我的模型?

如何解决什么时候使用 torch.no_grad() 在前向传播中是安全的?为什么它会严重伤害我的模型?

我训练了一个 CNN 模型,它的前向传播是这样的:

*Part1*: learnable preprocess​
*Part2*: Mixup which does not need to calculate gradient
*Part3*: CNN backbone and classifier head

part1part3 都需要计算梯度并在反向传播时需要更新权重,但是 part2 只是一个简单的混合和不需要梯度,所以我尝试用 torch.no_grad() 包裹这个 Mixup 以节省计算资源并加快训练速度,这确实大大加快了我的训练速度,但模型的预测准确度下降了很多。

我想知道Mixup是否不需要计算梯度,为什么用torch.no_grad()包裹它如此伤害模型的能力,是不是由于Part1,或者诸如打破Part1Part2间的链条之类的东西?


编辑:

感谢@Ivan 的回复,听起来很合理,我也有同样的想法,但不知道如何证明。

在我的实验中,当我在 Part2 上应用 torch.no_grad() 时,GPU 内存消耗下降了很多,而且训练速度要快得多,所以我猜这个 Part2即使没有可学习的参数,仍然需要梯度。

那么我们是否可以得出结论,torch.no_grad() 不应应用在 2 个或更多可学习块之间,否则会降低 no_grad() 部分之前的块的学习能力?

解决方法

但是part2只是简单的混合,不需要渐变

确实如此!为了计算梯度流并成功反向传播到模型的 part1(根据您的说法,这是可以学习的),您还需要计算 part2 上的梯度。即使模型的 part2 上没有可学习的参数。

当您在 part2 上应用 torch.no_grad() 时,我假设只有 part3 的模型能够学习而 part1 保持原状。


编辑

那么我们是否可以得出结论,torch.no_grad() 不应该应用在 2 个或更多可学习块之间,否则会降低 no_grad() 部分之前块的学习能力?

推理很简单:要计算第 1 部分的梯度,您需要计算中间结果的梯度,而不管您不会使用这些梯度来更新第 2 部分的张量。所以确实,你是对的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。