如何解决为什么在尝试为 next_states 提取 QValues 时我的维度是我们的范围?
我用静态方法 .get_next() 定义了一个类 QValues,以便提取下一个状态的 q_values。但是,当我调用 .get_next() 时,我收到一条错误消息,指出我的尺寸超出范围,我非常困惑。为什么会发生这种情况?
这是一些代码:
运行 Cartpole-V0 环境
episode_durations = []
for episode in range(num_episodes):
em.reset()
state = em.get_state()
for timestep in count():
action = agent.select_action(state,policy_net)
reward = em.take_action(action)
next_state = em.get_state()
memory.push(Experience(state,action,reward,next_state))
if memory.can_provide_sample(batch_size):
experiences = memory.sample(batch_size)
states,actions,rewards,next_states = extract_tensors(experiences)
current_q_values = QValues.get_current(policy_net,states,actions)
next_q_values = QValues.get_next(target_net,next_states)
这是我得到的错误:
<ipython-input-91-9001413ab21d> in <module>()
17
18 current_q_values = QValues.get_current(policy_net,actions) # get current return q-values for any given state action pairs predicted by the policy network
---> 19 next_q_values = QValues.get_next(target_net,next_states) # maximum q-values for the for the next states for the best corresponding actions
20 target_q_values = (next_q_values * gamma) + rewards # bellman equation basically
21
<ipython-input-78-18dcfcc8c5d6> in get_next(target_net,next_states)
8 @staticmethod
9 def get_next(target_net,next_states):
---> 10 final_state_locations = next_states.flatten(start_dim = 1).unsqueeze(1).max(dim=1).eq(0).type(torch.bool) # final state == True; non-final state == False
11 non_final_state_locations = (final_state_locations == False) # true for non-final state; false for final state
12 non_final_states = next_states[non_final_state_locations]
IndexError: Dimension out of range (expected to be in range of [-1,0],but got 1)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。