微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

强化学习系列七--DDPG

DDPG(deep deterministic policy gradient),深度确定性策略梯度算法。

PPO(Proximal Policy Optimization),近端策略优化算法。

DDPG和PPO都是AC框架。

本文主要介绍DDPG。

DDPG

从名字我们也可以看出DDPG就是DPG和DQN的结合。

这篇文章很详细的介绍了三者关系:https://zhuanlan.zhihu.com/p/337976595

我们先先回顾一些算法

DPG--deterministic policy gradient

PG之前已经介绍过,就是通过参数化概率分布来表示策略,选择一个动作,目的是让累计价值最高。其中动作a是根据概率的随机选取,也就是stochastic Policy Gradient。

DPG就是不用概率分布表示policy,而是用一个确定的函数表示。也就是给定了s,选取的a是确定的,也就是deterministic Policy Gradient。

相对于stochastic Policy,deterministic Policy 计算上更高效,缺点也很明显就是缺少了探索性。为了解决探索问题,DPG采用off-policy方法,也就是采样的policy和待优化目标policy是不同策略;采样的policy是随机的,而待优化的策略是确定的。

DQN--deep Q-network

DQN也有Actor和Critic网络,其中Actor网络输出一个动作A,动作和状态都要输出到Critic网络,目标是获取的Q值最大化。

在DQN之前,用神经网络拟合Q函数,但是训练会非常不稳定,主要原因是:

  1. Q函数(待优化的对象)和r+\gamma maxQ(s',a') 优化的target都是来自于同一个Q函数,也就是data和label是来自一个模型,会导致模型很难学习;
  2. Q函数的一小点改动会很大影响策略,从而改变观测数据的分布;
  3. 观测序列数据,后面的状态,动作和reward强烈依赖前面的数据。

QDN提出改进来解决以上问题:

  1. 数据只存储四元组(s_t,a_t,r_t,s_{t+1}),让四元组和四元组之间没有关联性
  2. 把待优化的Q函数和target的Q函数分为不同策略,待优化的Q函数更新一段时间参数后,才将参数赋值给目标Q函数。训练数据的target都是target的Q函数生成的,而待优化的Q函数参数更新对训练数据不会有改变。所以data(待优化的Q生成)和label(target的Q函数生成)之间没有相关性。

DDPG--deep deterministic policy gradient

DDPG是结合了DPG和DQN。

先看下DQN的流程:

DQN流程图

在选择Q值最大的A_{t+1}时,用到了max,所以DQN不能解决连续控制问题。而DPG没有采用随机policy,而是采用的确定policy,不用寻找最大化操作,所以DDPG就将DQN中神经网络拟合Q函数的两个优化点用到DPG中,将DPG中的Q函数一个神经网络预测,但是其中使用了off-policy。

所以DDPG和DPG一样,更新网络和目标网络也是不同的策略,所以属于off_policy。

借鉴https://blog.csdn.net/kenneth_yu/article/details/78478356中流程图,可以比较清晰的了解DDPG的算法。

总结步骤:

代码参考:https://github.com/louisnino/RLcode/blob/master/tutorial_DDPG.py

  1. 迭代探索,每次探索time-steps,每个step,actor网络选取动作,环境执行动作得到新的状态和reward,进行存储。一次探索后可以存储序列数据,对序列数据进行采样。注意这里choose_action和存储数据的来源是online策略网络

核心代码

for i in range(MAX_EPISODES):
    t1 = time.time()
    s = env.reset()
    ep_reward = 0
    for j in range(MAX_EP_STEPS):
        # Add exploration noise
        a = ddpg.choose_action(s)
        # 增加探索性
        a = np.clip(np.random.normal(a, VAR), -2, 2)
        # 与环境进行互动
        s_, r, done, info = env.step(a)
        # 保存s,a,r,s_
        ddpg.store_transition(s, a, r / 10, s_)
        if ddpg.pointer > MEMORY_CAPACITY:
            ddpg.learn()

  1. 保存了一定量数据后,就可以进行learn了。

采样数据N含有(s,a,r,s_{-}), 将s_{-}给到target策略网络actor_target得到s_{-}(a_{-},s_{-})给到target-Q网络critic_target得到q_{-}(a, s)给到待优化online-Q网络critic得到q

所以td-error= r + \gamma * q_{-}。以此更新critic网络。

s给到actor得到a,(s,a)给到critic得到q。Actor的目标就是让q最大化,以此更新Actor网络。

核心代码

def learn(self):
    indices = np.random.choice(MEMORY_CAPACITY, size=BATCH_SIZE)  
    bt = self.memory[indices, :]  # 根据indices,选取数据bt,相当于随机
    bs = bt[:, :self.s_dim]  # 从bt获得数据s
    ba = bt[:, self.s_dim:self.s_dim + self.a_dim]  # 从bt获得数据a
    br = bt[:, -self.s_dim - 1:-self.s_dim]  # 从bt获得数据r
    bs_ = bt[:, -self.s_dim:]  # 从bt获得数据s'

    # Critic:td_error = br + gamma * q_ - q
    with tf.GradientTape() as tape:
        # q_由critic_target预测,critic的a_由actor_target预测
        a_ = self.actor_target(bs_)
        q_ = self.critic_target([bs_, a_])
        y = br + GAMMA * q_
        # q由critic预测
        q = self.critic([bs, ba])
        td_error = tf.losses.mean_squared_error(y, q)
    c_grads = tape.gradient(td_error, self.critic.trainable_weights)
    self.critic_opt.apply_gradients(zip(c_grads, self.critic.trainable_weights))

    # Actor:最大化Q值
    with tf.GradientTape() as tape:
        # 待优化actor,输入是s
        a = self.actor(bs)
        # q值由critic得到
        q = self.critic([bs, a])
        # 最大化q,等价最小化-q
        a_loss = -tf.reduce_mean(q)
    a_grads = tape.gradient(a_loss, self.actor.trainable_weights)
    self.actor_opt.apply_gradients(zip(a_grads, self.actor.trainable_weights))

3. 一次探索学习后,将online网络的critic和actor参数更新到target网络中。

核心代码

def ema_update(self):
    """
    滑动平均更新
    """
    # 其实和之前的硬更新类似,不过在更新赋值之前,用一个ema.average。
    paras = self.actor.trainable_weights + self.critic.trainable_weights  # 获取要更新的参数包括actor和critic的
    self.ema.apply(paras)  # 主要是建立影子参数
    for i, j in zip(self.actor_target.trainable_weights + self.critic_target.trainable_weights, paras):
        i.assign(self.ema.average(j))  # 用滑动平均赋值

参考:

https://zhuanlan.zhihu.com/p/111257402

https://blog.csdn.net/kenneth_yu/article/details/78478356

http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS18.html

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐