CS294踩坑记录(二)

主要是CS294 Lecture 4 部分以及homework 2的内容。好像Lecture 5还会介绍一些PG的内容,到时候再补充。

PG好难啊,推导式子多,代码trick也多,自己写的代码跑的效果烂的一笔,对着别人的代码调真难受….

放学习资料:资料。这个大佬写的东西挺详细的,而且感觉补充了一些课上没讲的内容。至于代码我还是用的别人提供的pytorch版本。

尽可能减少讲义部分,不写公式推导直接放ppt了。

上面这张ppt是如何把一个对采样的轨迹的数学期望形式转换成对于log_prob乘以reward的期望形式。而且这个转log或者转回去在后面推式子也经常用到。

上面是进一步展开轨迹那个式子,就可以得到最原始的PG式子。数学期望我们可以通过采样n次轨迹来实现。既然是DRL,那么就需要用神经网络来拟合那个π。所以整个PG可以被包含在REINFORCE框架下:

实际上上述梯度的式子和常见的最大似然估计的梯度式子相似,其实我们可以认为PG在干的事情是通过采样得到的一系列动作action作为label,然后针对Policy Network进行supervised learning,增加好的奖励的概率,降低不好的奖励的概率。而这其实也是写代码的时候loss函数。

但是PG毕竟是通过采样n次得到的东西,按照学习资料里说的:不同的action可能导致相同的Expected Reward,这种不确定性导致原始的PG式子具有高方差的特性。

降方差的方法:

(1)Causality:就是如果我们把梯度式子的括号打开我们会发现,(t-1)时刻的奖励会和t乃至之后时刻的梯度相乘,这其实从因果来说是过去的奖励不会影响到未来的梯度的。所以就可以t时刻的梯度只和之后的奖励相乘,从而减少项数降低方差(我是觉得这是能减少方差的原因了,但是其实感觉这种方法也不是很说的通,因为是不是可以直接减去一个数值也可以让值变小啊。。。)但是写代码的时候就发现好像挺自然地,使用reward to go的方式计算q_n,会在每个时刻的q_n都不一样,但是如果原始的那样的话则是每个时刻相同的,这点来看q_n的确不合理。而且实验效果也显示这种方式更加稳定方差小,且学习速率更快,效果更好。学习资料给了一个证明。。

(2)baseline:引入baseline感觉可以增大高于baseline奖励的部分。

其实在作业代码部分中,用的是value function。其实在作业里还让证明其无偏性,也类似于上述的证明过程。这部分实际上也需要一个神经网络来拟合value function。

PG是on-policy的,就是之前的采样不能直接用在之后的梯度计算中,每次都得重新采。当然也有一些变种使得其off-policy。

上面我们说的采样,在连续动作空间的时候,我们实际上得到的是一个均值和一个std(如果假设一个多维正态分布的话)。实际上mean就是Policy Network的输出(或者加上一个tanh之类的激活函数,因为有些env的决策空间在一定范围内需要clip掉,但我在一个CartPole任务上做了反而效果更差了,可能是梯度消失等问题。),std是一个可学习的参数。但是我们会发现这个采样过程是不可导的,所以需要reparameterization这样的trick,来把mean和std参数引入。所以关于均值的梯度就可以回传到Policy Network了。

有了value function 和 Q function 之差,叫做advantage  function。需要说的是这个东西计算完后,代码提示我们要归一化,我发现加上Advantage Normalization这个trick的确效果好了。

考虑如何训练拟合value function,我们实际上希望输入一个状态,知道这个情况下的q_n。所以可以两个网络一起训练。这里还有个trick就是把value function的分布会归一化到q_n的分布上。这里摘一段学习资料:为什么要对价值函数神经网络的输出归一化到 Q 的分布上?然后再在训练此网络的时候,将作为 target 的 Q 归一化到标准正态分布上?

答案就是我们希望网络的输出,从头到尾,从开始训练到训练的后期,都满足标准正态分布的形式。这样,我们就可以避免了作为 Target 的 Q 值,在训练过程中的分布发生变化所带来的问题。也就是说,因为 Policy Network 渐渐的训练而引起的 Q 值的变化,不会影响价值网络的训练。这是一种广为使用的,在有多个网络互相交互的时候,避免互相影响的办法。

其实还是不太懂为啥这么搞。。。。

感觉还是应该往下继续看看才会对这个PG有更深的理解吧。为什么师兄们都说这个简单啊,看了好几天都不太会啊。。。。ORZ。。。

 

发表评论

电子邮件地址不会被公开。 必填项已用*标注