微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何消除 Pytorch 中的就地操作错误?

如何解决如何消除 Pytorch 中的就地操作错误?

我从以下 Pytorch 代码中收到此错误

RuntimeError:梯度计算所需的变量之一已被原位操作修改:[torch.DoubleTensor [3]] 为版本 10;而是预期版本 9。

正如所见,代码没有就地操作。

import torch
device = torch.device('cpu')
class MesNet(torch.nn.Module):
        def __init__(self):
            super(MesNet,self).__init__()
            self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6,5)).double()
        def forward(self,u):
            z_cov = self.cov_lin(u.transpose(0,2).squeeze(-1))
            return z_cov 
class UpdateModel(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.P_dim = 18
        self.Id3 = torch.eye(3).double()
    def run_KF(self):
        N = 10
        u = torch.randn(N,6).double()
        v = torch.zeros(N,3).double()
        model = MesNet()
        measurements_covs_l = model(u.t().unsqueeze(0))
        # remember to remove this afterwards
        torch.autograd.set_detect_anomaly(True)
        for i in range(1,N):
            v[i] = self.update_pos(v[i].detach(),measurements_covs_l[i-1])

        criterion = torch.nn.MSELoss(reduction="sum")
        targ = torch.rand(10,3).double()
        loss = criterion(v,targ)
        loss = torch.mean(loss)
        loss.backward()
        return v,p


    def update_pos(self,v,measurement_cov):
        Omega = torch.eye(3).double() 
        H = torch.ones((5,self.P_dim)).double()
        R = torch.diag(measurement_cov)
        Kt = H.t().mm(torch.inverse(R))
        # it is indicating inplace error even with this: 
        # Kt = H.t().mm(R)
        dx = Kt.mv(torch.ones(5).double())
        dR = self.trans(dx[:9].clone())
        v_up = dR.mv(v)
        return v_up

    def trans(self,xi):
        phi = xi[:3].clone()
        angle = torch.norm(phi.clone())

        if angle.abs().lt(1e-10):

            skew_phi = torch.eye(3).double()
            J = self.Id3 + 0.5 * skew_phi
            Rot = self.Id3 + skew_phi
        else:
            axis = phi / angle
            skew_axis = torch.eye(3).double()
            s = torch.sin(angle)
            c = torch.cos(angle)

            Rot = c * self.Id3
        return Rot
net =  UpdateModel()
net.run_KF()

解决方法

我认为问题在于您覆盖了 v[i] 元素。

您可以从循环中构造一个辅助列表 v_,然后将其转换为张量:

v_ = [v[0]]
for i in range(1,N):
    v_.append(self.update_pos(v[i].detach(),measurements_covs_l[i-1]))
v = torch.stack(v_)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。