微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

无法加载保存的策略TF 代理

如何解决无法加载保存的策略TF 代理

我使用策略保护程序保存了训练有素的策略,如下所示:

  tf_policy_saver = policy_saver.PolicySaver(agent.policy)
  tf_policy_saver.save(policy_dir)

我想使用保存的策略继续训练。所以我尝试用保存的策略初始化训练,这导致了一些错误

agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),tf_env.action_spec(),q_network=q_net,optimizer=optimizer,td_errors_loss_fn=common.element_wise_squared_loss,train_step_counter=train_step_counter)

agent.initialize()

agent.policy=tf.compat.v2.saved_model.load(policy_dir)

错误

  File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py",line 172,in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')


File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py",line 92,in __setattr__
    super(AutoTrackable,self).__setattr__(name,value)
AttributeError: can't set attribute

我只是想节省每次从第一次重新训练开始的时间。如何加载保存的策略并继续训练?

提前致谢

解决方法

是的,如前所述,您应该使用检查指针来执行此操作,请查看下面的示例代码。

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',policy=policy)

... # Train the agent

# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())

当您以后想要重新加载策略时,您只需运行相同的初始化代码。

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1,possibly Y1==Y depending on agent class you are using,if it's DQN
#               then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',policy=policy)
# Policy --> X

创建后,policy_checkpointer 将自动识别是否存在任何预先存在的检查点。如果有,它会在创建时自动更新它正在跟踪的变量的值。

需要做的几个注意事项:

  1. 使用检查指针可以节省的不仅仅是策略,我确实建议这样做。 TF-Agent 的 Checkpointer 对象非常灵活,例如:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,agent=tf_agent,# tf_agent.TFAgent
                                         train_step=train_step,# tf.Variable
                                         epoch_counter=epoch_counter,# tf.Variable
                                         metrics=metric_utils.MetricsGroup(
                                                 train_metrics,'train_metrics'))

policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,policy=agent.policy)

rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,max_to_keep=1,replay_buffer=replay_buffer  # TFUniformReplayBuffer
                                      )
  1. 请注意,在 DqnAgent 的情况下,agent.policyagent.collect_policy 本质上是 QNetwork 的包装器。其含义如下面的代码所示(查看关于策略变量状态的注释)
agent = DqnAgent(...)
policy = agent.policy      # Random initial policy ---> X

dataset = replay_buffer.as_dataset(...)
for data in dataset:
   experience,_ = data
   loss_agent_info = agent.train(experience=experience)

# policy variable stores a trained Policy object ---> Y

发生这种情况是因为 TF 中的张量在您的运行时共享。因此,当您使用 QNetwork 更新代理的 agent.train 权重时,这些相同的权重也会隐式更新到您的 policy 变量的 QNetwork 中。实际上,并不是 policy 的张量得到更新,而是它们与您的 agent 中的张量相同。

,

您应该为此目的查看 Checkpointer。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。