如何解决如何消除 Pytorch 中的就地操作错误?
RuntimeError:梯度计算所需的变量之一已被原位操作修改:[torch.DoubleTensor [3]] 为版本 10;而是预期版本 9。
正如所见,代码没有就地操作。
import torch
device = torch.device('cpu')
class MesNet(torch.nn.Module):
def __init__(self):
super(MesNet,self).__init__()
self.cov_lin = torch.nn.Sequential(torch.nn.Linear(6,5)).double()
def forward(self,u):
z_cov = self.cov_lin(u.transpose(0,2).squeeze(-1))
return z_cov
class UpdateModel(torch.nn.Module):
def __init__(self):
torch.nn.Module.__init__(self)
self.P_dim = 18
self.Id3 = torch.eye(3).double()
def run_KF(self):
N = 10
u = torch.randn(N,6).double()
v = torch.zeros(N,3).double()
model = MesNet()
measurements_covs_l = model(u.t().unsqueeze(0))
# remember to remove this afterwards
torch.autograd.set_detect_anomaly(True)
for i in range(1,N):
v[i] = self.update_pos(v[i].detach(),measurements_covs_l[i-1])
criterion = torch.nn.MSELoss(reduction="sum")
targ = torch.rand(10,3).double()
loss = criterion(v,targ)
loss = torch.mean(loss)
loss.backward()
return v,p
def update_pos(self,v,measurement_cov):
Omega = torch.eye(3).double()
H = torch.ones((5,self.P_dim)).double()
R = torch.diag(measurement_cov)
Kt = H.t().mm(torch.inverse(R))
# it is indicating inplace error even with this:
# Kt = H.t().mm(R)
dx = Kt.mv(torch.ones(5).double())
dR = self.trans(dx[:9].clone())
v_up = dR.mv(v)
return v_up
def trans(self,xi):
phi = xi[:3].clone()
angle = torch.norm(phi.clone())
if angle.abs().lt(1e-10):
skew_phi = torch.eye(3).double()
J = self.Id3 + 0.5 * skew_phi
Rot = self.Id3 + skew_phi
else:
axis = phi / angle
skew_axis = torch.eye(3).double()
s = torch.sin(angle)
c = torch.cos(angle)
Rot = c * self.Id3
return Rot
net = UpdateModel()
net.run_KF()
解决方法
我认为问题在于您覆盖了 v[i]
元素。
您可以从循环中构造一个辅助列表 v_
,然后将其转换为张量:
v_ = [v[0]]
for i in range(1,N):
v_.append(self.update_pos(v[i].detach(),measurements_covs_l[i-1]))
v = torch.stack(v_)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。