DQN三大改进(一)-Double DQN
1、背景
我们简单回顾一下DQN的过程(这里是2015版的DQN):
DQN中有两个关键的技术,叫做经验回放和双网络结构。
DQN中的损失函数定义为:
其中,yi也被我们称为q-target值,而后面的Q(s,a)我们称为q-eval值,我们希望q-target和q-eval值越接近越好。
q-target如何计算呢?根据下面的公式:
上面的两个公式分别截取自两篇不同的文章,所以可能有些出入。我们之前说到过,我们有经验池存储的历史经验,经验池中每一条的结构是(s,a,r,s’),我们的q-target值根据该轮的奖励r以及将s’输入到target-net网络中得到的Q(s’,a’)的最大值决定。
我们进一步展开我们的q-target计算公式:
也就是说,我们根据状态s’选择动作a’的过程,以及估计Q(s’,a’)使用的是同一张Q值表,或者说使用的同一个网络参数,这可能导致选择过高的估计值,从而导致过于乐观的值估计。为了避免这种情况的出现,我们可以对选择和衡量进行解耦,从而就有了双Q学习,在Double DQN中,q-target的计算基于如下的公式:
我们根据一张Q表或者网络参数来选择我们的动作a’,再用另一张Q值表活着网络参数来衡量Q(s’,a’)的值。
2、代码实现
本文的代码还是根据莫烦大神的代码,它的github地址为:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow
这里我们想要实现的效果类似于寻宝。
其中,红色的方块代表寻宝人,黑色的方块代表陷阱,黄色的方块代表宝藏,我们的目标就是让寻宝人找到最终的宝藏。
这里,我们的状态可以用横纵坐标表示,而动作有上下左右四个动作。使用tkinter来做这样一个动画效果。宝藏的奖励是1,陷阱的奖励是-1,而其他时候的奖励都为0。
接下来,我们重点看一下我们Double-DQN相关的代码。
定义输入
1 | # ------------------------input--------------------------- |
定义双网络结构
这里我们的双网络结构都简单的采用简单的全链接神经网络,包含一个隐藏层。这里我们得到的输出是一个向量,表示该状态才取每个动作可以获得的Q值:
1 | def build_layers(s,c_name,n_l1,w_initializer,b_initializer): |
接下来,我们定义两个网络:
1 | # ------------------ build evaluate_net ------------------ |
定义损失和优化器
接下来,我们定义我们的损失,和DQN一样,我们使用的是平方损失:
1 | with tf.variable_scope('loss'): |
定义我们的经验池
我们使用一个函数定义我们的经验池,经验池每一行的长度为 状态feature * 2 + 2。
1 | def store_transition(self,s,a,r,s_): |
选择action
我们仍然使用的是e-greedy的选择动作策略,即以e的概率选择随机动作,以1-e的概率通过贪心算法选择能得到最多奖励的动作a。
1 | def choose_action(self,observation): |
选择数据batch
我们从经验池中选择我们训练要使用的数据。
1 | if self.memory_counter > self.memory_size: |
更新target-net
这里,每个一定的步数,我们就更新target-net中的参数:
1 | t_params = tf.get_collection('target_net_params') |
更新网络参数
根据Double DQN的做法,我们需要用两个网络的来计算我们的q-target值,同时通过最小化损失来更新网络参数。这里的做法是,根据eval-net的值来选择动作,然后根据target-net的值来计算Q值。
1 | q_next,q_eval4next = self.sess.run([self.q_next, self.q_eval], |
q_next是根据经验池中下一时刻状态输入到target-net计算得到的q值,而q_eval4next是根据经验池中下一时刻状态s’输入到eval-net计算得到的q值,这个q值主要用来选择动作。
下面的动作用来得到我们batch中的实际动作和奖励
1 | batch_index = np.arange(self.batch_size, dtype=np.int32) |
接下来,我们就要来选择动作并计算该动作的q值了,如果是double dqn的话,我们是根据刚刚计算的q_eval4next来选择动作,然后根据q_next来得到q值的。而原始的dqn直接通过最大的q_next来得到q值:
1 | if self.double_q: |
那么我们的q-target值就可以计算得到了:
1 | q_target = q_eval.copy() |
有了q-target值,我们就可以结合eval-net计算的q-eval值来更新网络参数了:
1 | _, self.cost = self.sess.run([self.train_op, self.loss], |