如何解决RuntimeError:梯度计算所需的变量之一已通过就地操作进行了修改:
我看到了类似的话题,但是当我将布尔型self.noisy设置为True时,我的代码为什么包含了inplace操作,我不明白。代码后有更多解释
代码:
const post = (endPoint = '',header = {},body = {}) => {
return axios({method: 'POST',url: `${urls.apiBaseUrl}/${endPoint}/?tenant_id=3`,headers: header,body: body})
}
具有:
l = self.backward()
使用
def backward(self):
transitions = self.replay_memory.sample(self.batch_size)
batch = Transition(*zip(*transitions))
a_batch = torch.cat(batch.action).to(self.device) # [BS x 1]
cs_batch = torch.stack(batch.state).to(self.device) # [BS x state_size]
ns_batch = torch.stack(batch.next_state).to(self.device) # [BS x state_size]
r_batch = torch.tensor(np.expand_dims(np.array(batch.reward),1),dtype=torch.float).to(self.device) # [BS x 1]
# STEP 2: PREDICTIONS
pred = self.calc_pred(cs_batch,a_batch)
# STEP 3: TARGETS
with torch.no_grad():
target = self.calc_target(ns_batch,r_batch)
# STEP 4: LOSS
loss = self.calc_loss(pred,target) # [BS x 1]
loss = loss.mean()
# STEP 5: TRAIN
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss
还有:
def calc_pred(self,cs_batch,a_batch):
if self.noisy:
self.network.sample_noise()
pred = self.network(cs_batch).gather(1,a_batch)
return pred
和损失函数:
def calc_target(self,ns_batch,r_batch):
if self.double:
if self.noisy:
self.network.sample_noise()
self.target_network.sample_noise()
next_Q = self.network(ns_batch)
max_a = next_Q.max(1)[1].unsqueeze(1)
Q_target = self.target_network(ns_batch).gather(1,max_a)
target = r_batch + (self.gamma ** self.multi_step_n) * Q_target
return target
如果不使用嘈杂的网络,则此方法有效。但是,使用嘈杂(self.noisy = True)时,它将返回以下错误:
def calc_loss(self,pred,target):
loss = F.mse_loss(pred,target,reduction='none')
return loss
但是,无论何时使用嘈杂的网络,我都不会更改代码中的任何变量?所以我不明白为什么会给出一个原地错误。我在另一段代码中也使用了嘈杂的网络,但是我没有得到错误,所以我认为错误不在代码的那部分中。作为参考,我还是将其发布:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5]] is at version 1003; expected version 1002 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。