微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

尝试修改就地操作错误:梯度计算所需的变量之一已被就地操作修改:

如何解决尝试修改就地操作错误:梯度计算所需的变量之一已被就地操作修改:

我有两个代理a和b:他们使用相同的网络结构。我尝试使用replaybuffer来更新网络的参数。如果我更新a,那么更新b会带来错误:所需的变量之一梯度计算已被原位操作修改。所以我试图找到就地操作。但我做不到。

    while total_timesteps < episodes:
    total_timesteps = total_timesteps + 1
    episodes_reward = []
    state = env.reset()
    state.to(device)
    for j in range(max_steps):
        action = torch.zeros(env.num_agents,env.dim_world)
        action[0] = policy_agent1.select_action(state)
        action[1] = policy_agent2.select_action(state)
        action_agent0 = action[0]
        action_agent1 = action[1]
        next_state,reward,done = env.step(action)
        next_state.to(device)
        action.to(device)
        reward.to(device)
        reward_agent = reward[0]
        reward_agent1 = reward[1]
        replay_buffer1.add((state,next_state,action_agent0,reward_agent,done))
        replay_buffer2.add((state,action_agent1,reward_agent1,done))
        # print(next_state,action)
        state = next_state
        if done:
            continue
    policy_agent1.train(replay_buffer1.sample(batch_size),gamma)
    policy_agent2.train(replay_buffer2.sample(batch_size),gamma)

和火车部分:

    def train(self,replay_buffer,gamma):
    state,action,done = replay_buffer
    q = torch.zeros(len(reward)).to(device)
    q_ = torch.zeros(len(reward)).to(device)
    q_target = torch.zeros(len(reward)).to(device)
    done = torch.Tensor(done).to(device)

    for j,r in enumerate(reward):
        q1_target = torch.zeros(len(reward)).to(device)
        q_[j] = self.critic_network(torch.transpose(next_state[j].to(device),1),self.actor_network(torch.transpose(next_state[j].to(device),1)).view(1,1))
        q_target[j] = r.to(device) + (done[j] * gamma * q_[j])
        q1_target = q1_target + q_target
        # q_target[j] = r.to(device) + (done[j] * gamma * q_[j]).detach().clone()
        q1 = torch.zeros(len(reward)).to(device)
        q1[j] = self.critic_network(torch.transpose(state[j].to(device),action[j].view(1,1).to(device))
        q = q + q1
    loss_critic = F.mse_loss(q,q1_target)
    self.critic_optimizer.zero_grad()
    loss_critic.backward(retain_graph=True)
    self.critic_optimizer.step()

    b = torch.zeros(len(reward)).to(device)
    for j,_ in enumerate(reward):
        b1 = torch.zeros(len(reward)).to(device)
        b1[j] =  self.critic_network(torch.transpose(state[j].to(device),self.actor_network(torch.transpose(state[j].to(device),1))
        b = b + b1
        # b[j] = self.critic_network(torch.transpose(state[j].to(device),1))
    loss_actor = -torch.mean(b)
    self.actor_optimizer.zero_grad()
    loss_actor.backward(retain_graph=True)
    self.actor_optimizer.step()

我也改变了一些我认为有问题的地方。所以它看起来很冗长。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。