如何解决Pytorch 为 nn.Module 函数添加自定义向后传递
我正在重新实现 Invertible Residual Networks 架构。
class iresnetBlock(nn.Module):
def __init__(self,input_size,hidden_size):
self.bottleneck = nn.Sequential(
LinearContraction(input_size,hidden_size),LinearContraction(hidden_size,input_size),nn.ReLU(),)
def forward(self,x):
return x + self.bottleneck(x)
def inverse(self,y):
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
我想向 inverse
函数添加自定义向后传递。由于是不动点迭代,因此可以利用隐函数定理来避免循环展开,而是通过求解线性系统来计算梯度。例如,这是在 Deep Equilibrium Models 架构中完成的。
def inverse(self,y):
with torch.no_grad():
x = y.clone()
while not converged:
# fixed point iteration
x = y - self.bottleneck(x)
return x
def custom_backward_inverse(self,grad_output):
pass
我如何为这个函数注册我的自定义向后传递?我希望这样,当我稍后定义一些损失(例如 r = loss(y,model.inverse(other_model(model(x))))
)时,r.backwards()
会正确使用我的自定义梯度进行反向调用。
理想情况下,解决方案应与 torchscript
兼容。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。