微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pytorch 为 nn.Module 函数添加自定义向后传递

如何解决Pytorch 为 nn.Module 函数添加自定义向后传递

我正在重新实现 Invertible Residual Networks 架构。

class iresnetBlock(nn.Module):
    def __init__(self,input_size,hidden_size):
        self.bottleneck = nn.Sequential(
            LinearContraction(input_size,hidden_size),LinearContraction(hidden_size,input_size),nn.ReLU(),)
        
    def forward(self,x):
        return x + self.bottleneck(x)
    
    def inverse(self,y):
        x = y.clone()

        while not converged:
            # fixed point iteration
            x = y - self.bottleneck(x)
   
        return x

我想向 inverse 函数添加自定义向后传递。由于是不动点迭代,因此可以利用隐函数定理来避免循环展开,而是通过求解线性系统来计算梯度。例如,这是在 Deep Equilibrium Models 架构中完成的。

    def inverse(self,y):
        with torch.no_grad():
            x = y.clone()
            while not converged:
                # fixed point iteration
                x = y - self.bottleneck(x)
   
            return x

    def custom_backward_inverse(self,grad_output):
        pass

我如何为这个函数注册我的自定义向后传递?我希望这样,当我稍后定义一些损失(例如 r = loss(y,model.inverse(other_model(model(x)))))时,r.backwards() 会正确使用我的自定义梯度进行反向调用

理想情况下,解决方案应与 torchscript 兼容。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。