广义优势估计 GAE

GAE 提出了一种优势函数的一般形式,综合考虑了多步真实采样的累积回报和价值函数的估计,权衡偏差和方差,是目前 PPO 等策略梯度类方法的主流选择。本文将首先梳理策略梯度中 Ψ t \Psi_t Ψt 的多种不同选择,理解它们各自的偏差/方差特性,然后介绍 GAE 提出的优势函数的一般形式。

前置:

强化学习策略梯度

强化学习Q-Learning/DQN

基于累积回报的 Ψ t \Psi_t Ψt

在之前,我们介绍了策略梯度的一般形式:
∇ θ J ( π θ ) = E τ ∼ π θ [ ∑ t = 0 T Ψ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] \nabla_\theta J(\pi_\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\Psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] \notag \\ θJ(πθ)=Eτπθ[t=0TΨtθlogπθ(atst)]
其中 Ψ t \Psi_t Ψt 用来表示状态 s t s_t st 下,动作 a t a_t at 有多 “好”。策略梯度的目标就是要增大 “好” 动作出现的概率,减小 “不好” 动作出现的概率。

在之前策略梯度最简形式的推导中,我们一开始是简单地采用整个 rollout 的累积回报作为 Ψ t \Psi_t Ψt
Ψ t = R ( τ ) = ∑ t = 0 T r t (1) \Psi_t=R(\tau)=\sum_{t=0}^T r_t \tag{1} \\ Ψt=R(τ)=t=0Trt(1)
后来我们对取 Ψ t = R ( τ ) \Psi_t=R(\tau) Ψt=R(τ) 的几个问题进行了分析,得到了两个优化版本的 Ψ t \Psi_t Ψt

首先,动作 a t a_t at 有多好,应该与其之前获得的奖励无关,而只与其被采取之后获得的奖励有关。从而得到了 reward-to-go 形式的累积回报:
Ψ t = R ′ ( τ ) = ∑ t ′ = t T r t ′ (2) \Psi_t=R'(\tau)=\sum_{t'=t}^Tr_{t'} \tag{2} \\ Ψt=R(τ)=t=tTrt(2)
第二,全正的 Ψ t \Psi_t Ψt 可能会导致训练前期恰好没有被采样到的高回报动作一直没有办法被探索到,因此我们进一步减去 baseline b ( s t ) b(s_t) b(st),即取
Ψ t = R ′ ′ ( τ ) = ∑ t ′ = t T r t ′ − b ( s t ) (3) \Psi_t=R''(\tau)=\sum_{t'=t}^T r_{t'}-b(s_t) \tag{3} \\ Ψt=R′′(τ)=t=tTrtb(st)(3)
这样就能使得 Ψ t \Psi_t Ψt 有正有负,鼓励好的动作,同时打压不够好的动作。EGLP 定理显示,加入只与状态 s t s_t st 有关的 b ( s t ) b(s_t) b(st),不会影响整体的期望,因此加入 baseline 后整体还是一个无偏估计。

基于价值函数的 Ψ t \Psi_t Ψt

以上是之前在策略梯度中介绍过的内容,这三种 Ψ t = R ( τ ) \Psi_t=R(\tau) Ψt=R(τ) 的形式都是在 policy 与 environment 的实际互动采样中收集的真实数据。从 bias 的角度来看,它当然是无偏的,但是由于需要进行多步采样,policy 本身有随机性,environment 的状态转移也有随机性,方差在多步采样中逐渐积累,导致最终的方差非常大。无偏但是方差大,如果我们每次都能采样足够多次,那当然没有问题。但实际中我们的采样次数是很有限的,因此方差是一个不得不考虑的问题,大方差的训练会很不稳定。

那么,我们怎么才能降低 Ψ t \Psi_t Ψt 的方差呢?在这里,我们想避免多步采样带来的累积方差,那就要直接一步估计出 R R R 的期望值,而不是逐步采样到最后,这样虽然可能会引入一些 bias(模型的预测可能会不准),但是能够降低方差。这里的 biase/variance 权衡,与我们之前介绍过的 value-based 训练目标中,MC 方法与 TD 方法的差异是一个道理。

“累积回报的期望值”,这听起来非常熟悉,它正是我们之前介绍的 value-based 方法中价值函数的定义。

回顾一下:

状态价值函数,$V^\pi(s)=\mathbb{E}_{\tau\sim\pi}[R(\tau)|s_0=s] \notag \$,是指在状态 s s s,接下来按照特定的策略 π \pi π 来进行动作,最终累积回报的期望。

动作-状态价值函数,$Q^\pi(s,a)=\mathbb{E}_{\tau\sim\pi}[R(\tau)|s_0=s,a_0=a] \notag \$,是指在状态 s s s,先强制采取动作 a a a(这个动作不一定符合当前策略 π \pi π,是人为采取的),接下来按照特定的策略 π \pi π 来进行动作,最终累积回报的期望。

状态价值函数 V V V 是动作-价值函数 Q Q Q 在策略 π \pi π 下的期望 V π ( s ) = E a ∼ π [ Q π ( s , a ) ] V^\pi(s)=\mathbb{E}_{a\sim\pi}[Q^\pi(s,a)] Vπ(s)=Eaπ[Qπ(s,a)]

Ψ t \Psi_t Ψt 的目标就是要表达出 “在状态 s t s_t st 下,动作 a t a_t at 有多好”,即需要的是 “状态 s t s_t st 下采取动作 a t a_t at 的期望回报”,而这正是动作-状态价值函数 Q π ( s t , a t ) Q^\pi(s_t,a_t) Qπ(st,at) 的定义。因此,将 Q Q Q 函数作为 Ψ t \Psi_t Ψt 是一个合理的选择:
Ψ t = Q π ( s t , a t ) (4) \Psi_t=Q^\pi(s_t,a_t) \tag{4} \\ Ψt=Qπ(st,at)(4)

Actor-Criric 方法:用价值函数来表示 Ψ t \Psi_t Ψt,同 value-based 方法中一样,我们需要训练一个价值网络来拟合价值函数,这样就将 policy-based 方法(策略梯度)和 value-based 方法(价值网络)结合起来了,这就是所谓的 actor-critic 方法。其中 actor 就是 policy network,即策略网络;critic 就是 value network,是价值网络,用来估计策略对应的价值函数。

更进一步地,如果再将 baseline 考虑在内, Ψ t \Psi_t Ψt 要表达的就是:“在状态 s t s_t st 下,动作 a t a_t at 相比于该状态下其他动作(的期望)要好多少”,而 “该状态下其他动作的期望”,正是状态价值函数 V π ( s t , a t ) V^\pi(s_t,a_t) Vπ(st,at) 的定义。因此,可以在 Q Q Q 的基础上,再将 V V V 作为 baseline,减去 baseline,即取:
Ψ t = A π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) (5) \Psi_t=A^\pi(s_t,a_t)=Q^\pi(s_t,a_t)-V^\pi(s_t) \tag{5} \\ Ψt=Aπ(st,at)=Qπ(st,at)Vπ(st)(5)
这里的 A π ( s t , a t ) A^\pi(s_t,a_t) Aπ(st,at) 就是所谓的优势函数。

但是在优势函数 A A A 中,我们需要估计估计 Q Q Q V V V 两个价值函数,这样就需要训练两个神经网络。这样的训练起来误差会比较大。我们再进行一些变换。首先将 V V V 写成 TD 残差的形式 V π ( s t ) = r t + V π ( s t + 1 ) V^\pi(s_t)=r_t+V^\pi(s_{t+1}) Vπ(st)=rt+Vπ(st+1)

我们知道, Q Q Q V V V 的期望:
Q π ( s t , a t ) = E [ V π ( s t ) ] = E [ r t + V π ( s t + 1 ) ] Q^\pi(s_t,a_t)=\mathbb{E}[V^\pi(s_t)]=\mathbb{E}[r_t+V^\pi(s_{t+1})] \notag \\ Qπ(st,at)=E[Vπ(st)]=E[rt+Vπ(st+1)]
把期望拿掉,把 Q Q Q 写成 TD 残差的形式并代入:
A π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) = r t + V π ( s t + 1 ) − V π ( s t ) A^\pi(s_t,a_t)=Q^\pi(s_t,a_t)-V^\pi(s_t)=r_t+V^\pi(s_{t+1})-V^\pi(s_t) \notag \\ Aπ(st,at)=Qπ(st,at)Vπ(st)=rt+Vπ(st+1)Vπ(st)
这样我们就得到了优势函数 A A A 的 TD 残差形式,它也可以作为 Ψ t \Psi_t Ψt
Ψ t = r t + V π ( s t + 1 ) − V π ( s t ) (6) \Psi_t=r_t+V^\pi(s_{t+1})-V^\pi(s_t) \tag{6} \\ Ψt=rt+Vπ(st+1)Vπ(st)(6)
我们拿掉了期望,会引入一些 bias,但是这样优势函数的估计就只需要训练一个状态价值网络 V V V 就行了。

广义优势估计 GAE

至此,我们已经介绍了基于累积回报的 Ψ t \Psi_t Ψt (式 1-3),它们的特点是无偏但是高方差;基于价值函数函数的 Ψ t \Psi_t Ψt(式 4-6),它们的特点是有偏但是低方差。这正是 GAE 论文 Introduction 中总结的 6 种 Ψ t \Psi_t Ψt 的选择,即下图中的 1-6。

在这里插入图片描述

那么 GAE 是做了个什么事儿呢?我们上面看到

  • 基于累积回报的 Ψ t \Psi_t Ψt 是一直采样到最后,获取真实的累积奖励,是 MC 方法,无偏但是多次采样方差很大;
  • 基于价值函数的 Ψ t \Psi_t Ψt 是用一个价值网络估计动作的价值(或优势),是 TD 方法,TD 残差形式下只有一步采样,方差小,但是毕竟网络估计可能不准,因此是有偏的估计

现在要解决的问题就是 偏差/方差 权衡的问题。

先补充说明一件事,我们之前都是用的无折扣的累积回报,将未来所有回报都直接加到 a t a_t at 的累积回报中。但是实际情况下,应当是更近获得的奖励受到当前动作 a t a_t at 的影响较大,更远获得的奖励则受到 a t a_t at 影响较小,因此我们一般会对未来获得的奖励累乘一个小于 1 的折扣因子 γ \gamma γ,因此上面的几种 Ψ t \Psi_t Ψt 形式都要改写。举例来说,reward-to-go 累积回报(式 2)应该写为:
R ′ ( τ ) = ∑ l = 0 ∞ γ l r t (7) R'(\tau)=\sum_{l=0}^\infty \gamma^lr_t \tag{7} \\ R(τ)=l=0γlrt(7)
而优势函数的 TD 残差形式(式 6) 应该写为:
r t + γ V π ( s t + 1 ) − V π ( s t ) = 记作 δ t (8) r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t)\overset{记作}{=}\delta_t \tag{8} \\ rt+γVπ(st+1)Vπ(st)=记作δt(8)
我们现在将其记作 δ t \delta_t δt

GAE 怎么权衡偏差和方差呢?

首先将 k k k 个这样的 δ \delta δ 项加起来,记为 A ^ t ( k ) \hat{A}^{(k)}_t A^t(k)
A ^ t ( 1 ) : = δ t = − V ( s t ) + r t + γ V ( s t + 1 ) A ^ t ( 2 ) : = δ t + γ δ t + 1 = − V ( s t ) + r t + γ r t + 1 + γ 2 V ( s t + 2 ) A ^ t ( 3 ) : = δ t + γ δ t + 1 + γ 2 δ t + 2 = − V ( s t ) + r t + γ r t + 1 + γ 2 r t + 2 + γ 3 V ( s t + 3 ) . . . . . . A ^ t ( k ) : = ∑ l = 0 k − 1 γ l δ t + l = − V ( s t ) + r t + γ r t + 1 + γ k − 1 r t + k − 1 + γ k V ( s t + k ) \begin{aligned} \hat{A}^{(1)}_t&:=\delta_t&&=-V(s_t)+r_t+\gamma V(s_{t+1}) \\ \hat{A}^{(2)}_t&:=\delta_t+\gamma \delta_{t+1}&&=-V(s_t)+r_t+\gamma r_{t+1}+\gamma^2V(s_{t+2}) \\ \hat{A}^{(3)}_t&:=\delta_t+\gamma \delta_{t+1}+\gamma^2\delta_{t+2}&&=-V(s_t)+r_t+\gamma r_{t+1}+\gamma^2r_{t+2}+\gamma^3V(s_{t+3}) \\ ... &&... \\ \hat{A}^{(k)}_t&:=\sum_{l=0}^{k-1}\gamma^l\delta_{t+l}&&=-V(s_t)+r_t+\gamma r_{t+1}+\gamma^{k-1}r_{t+k-1}+\gamma^kV(s_{t+k}) \\ \end{aligned} \notag \\ A^t(1)A^t(2)A^t(3)...A^t(k):=δt:=δt+γδt+1:=δt+γδt+1+γ2δt+2:=l=0k1γlδt+l...=V(st)+rt+γV(st+1)=V(st)+rt+γrt+1+γ2V(st+2)=V(st)+rt+γrt+1+γ2rt+2+γ3V(st+3)=V(st)+rt+γrt+1+γk1rt+k1+γkV(st+k)
可以看到, A ^ t ( k ) \hat{A}^{(k)}_t A^t(k) 相当于是先采样 k k k 步,得到 k k k 步的真实回报,然后估计剩余的价值函数 γ V ( s t + k ) \gamma V(s_{t+k}) γV(st+k),再减去 baseline V ( s t ) V(s_t) V(st)。从 A ^ t ( 1 ) = δ t \hat{A}^{(1)}_t=\delta_t A^t(1)=δt 推广( δ t \delta_t δt 是优势函数的 TD 残差形式的估计),我们可以将所有 A ^ t ( k ) \hat{A}^{(k)}_t A^t(k) 都看作是对优势函数的一种估计,

  • k k k 越大,真实采样的步数越多,对价值函数估计的依赖越小,偏差越小,方差越大;
  • k k k 越小,真实采样的步数越少,对价值函数估计的依赖越大,偏差越大,方差越小;

特别的,当 k → ∞ k\rightarrow\infty k 时,有
A ^ t ( ∞ ) = ∑ l = 0 ∞ γ l δ t + l = − V ( s t ) + ∑ l = 0 ∞ γ l r t + l \hat{A}^{(\infty)}_t=\sum_{l=0}^\infty\gamma^l\delta_{t+l}=-V(s_t)+\sum_{l=0}^\infty\gamma^lr_{t+l} \notag \\ A^t()=l=0γlδt+l=V(st)+l=0γlrt+l
这样就是一直采样到结束,相当于又成了基于累积回报的 Ψ t \Psi_t Ψt(式 7),只是 baseline 采用了价值函数的估计 V ( s t ) V(s_t) V(st)

GAE,就是对这一系列不同的对优势函数的估计 A ^ t ( k ) \hat{A}^{(k)}_t A^t(k) 的指数加权平均
A ^ t GAE : = ( 1 − λ ) ( A ^ t ( 1 ) + λ A ^ t ( 2 ) + λ 2 A ^ t ( 3 ) + . . . ) = ( 1 − λ ) ( δ t + λ ( δ t + γ δ t + 1 ) + λ 2 ( δ t + 1 + γ δ t + 1 + γ 2 δ t + 2 ) + . . . ) = ( 1 − λ ) ( δ t ( 1 + λ + λ 2 + . . . ) + γ δ t + 1 ( λ + λ 2 + λ 3 + . . . ) + γ 2 δ t + 2 ( λ 2 + λ 3 + λ 4 + . . . ) + . . . ) = ( 1 − λ ) ( δ t 1 1 − λ + γ δ t + 1 λ 1 − λ + γ 2 δ t + 2 λ 2 1 − λ + . . . ) = ∑ l = 0 ∞ ( γ λ ) l δ t + l \begin{aligned} \hat{A}_t^\text{GAE}&:=(1-\lambda)(\hat{A}^{(1)}_t+\lambda\hat{A}^{(2)}_t+\lambda^2\hat{A}^{(3)}_t+...) \\ &=(1-\lambda)(\delta_t+\lambda(\delta_t+\gamma\delta_{t+1})+\lambda^2(\delta_{t+1}+\gamma\delta_{t+1}+\gamma^2\delta_{t+2})+...) \\ &=(1-\lambda)(\delta_t(1+\lambda+\lambda^2+...)+\gamma\delta_{t+1}(\lambda+\lambda^2+\lambda^3+...)+\gamma^2\delta_{t+2}(\lambda^2+\lambda^3+\lambda^4+...)+...) \\ &=(1-\lambda)(\delta_t\frac{1}{1-\lambda}+\gamma\delta_{t+1}\frac{\lambda}{1-\lambda}+\gamma^2\delta_{t+2}\frac{\lambda^2}{1-\lambda}+...) \\ &=\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l} \end{aligned} \notag \\ A^tGAE:=(1λ)(A^t(1)+λA^t(2)+λ2A^t(3)+...)=(1λ)(δt+λ(δt+γδt+1)+λ2(δt+1+γδt+1+γ2δt+2)+...)=(1λ)(δt(1+λ+λ2+...)+γδt+1(λ+λ2+λ3+...)+γ2δt+2(λ2+λ3+λ4+...)+...)=(1λ)(δt1λ1+γδt+11λλ+γ2δt+21λλ2+...)=l=0(γλ)lδt+l
(这里乘上个 ( 1 − λ ) (1-\lambda) (1λ) 是为了化简形式?)可以看到,我们引入一个超参数 λ \lambda λ,对所有的 A ^ t ( k ) \hat{A}_t^{(k)} A^t(k) 进行加权表示,得到了一个一般形式的对优势函数的估计,也就是 “广义优势估计 Generalized Advantage Estimation”,一般直接简记作 A ^ t \hat{A}_t A^t。这样,我们就可以通过调节超参数 λ \lambda λ,来对偏差和方差进行 trade-off。

  • λ \lambda λ 越小, k k k 值高的项 A ^ t ( k ) \hat{A}_t^{(k)} A^t(k) 衰减得越厉害,价值网络越主导,偏差越大,方差越小;
  • λ \lambda λ 越大, k k k 值高的项 A ^ t ( k ) \hat{A}_t^{(k)} A^t(k) 衰减得越少,真实采样越主导,偏差越小,方差越大;

特别地,当 λ = 0 \lambda=0 λ=0 λ = 1 \lambda=1 λ=1 时,有:
λ = 0 , A ^ t : = δ t = r t + γ V π ( s t + 1 ) − V π ( s t ) λ = 1 , A ^ t : = ∑ l = 0 ∞ γ l δ t + l = ∑ l = 0 ∞ γ l r t + l − V ( s t ) \begin{aligned} \lambda=0,\quad\hat{A}_t&:=\delta_t=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) \\ \lambda=1,\quad \hat{A}_t&:=\sum_{l=0}^\infty\gamma^l\delta_{t+l}=\sum_{l=0}^\infty\gamma^lr_{t+l}-V(s_t) \end{aligned} \notag \\ λ=0,A^tλ=1,A^t:=δt=rt+γVπ(st+1)Vπ(st):=l=0γlδt+l=l=0γlrt+lV(st)
分别对应了只一次真实采样,偏差最大、方差最小的情况,和全都真实采样,方差最大、偏差最小的情况。

顺便提一句,GAE 这种构造形式和 TD ( λ ) \text{TD}(\lambda) TD(λ) 很像,但是 TD ( λ ) \text{TD}(\lambda) TD(λ) 是对价值函数的估计,而 GAE 是对优势函数的估计。

最终,我们在策略梯度中,就可以将 GAE A ^ t \hat{A}_t A^t 作为 Ψ t \Psi_t Ψt

∇ θ J ( π θ ) = E τ ∼ π θ [ ∑ t = 0 T A ^ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] A ^ t = ∑ l = 0 ∞ ( γ λ ) l δ t + l , δ t = r t + γ V π ( s t + 1 ) − V π ( s t ) \nabla_\theta J(\pi_\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\hat{A}_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] \\ \hat{A}_t=\sum_{l=0}^\infty(\gamma\lambda)^l\delta_{t+l},\quad \delta_t=r_t+\gamma V^\pi(s_{t+1})-V^\pi(s_t) \notag \\ θJ(πθ)=Eτπθ[t=0TA^tθlogπθ(atst)]A^t=l=0(γλ)lδt+l,δt=rt+γVπ(st+1)Vπ(st)

总结

本文中,我们首先回顾了之前介绍过的基于累积回报的 Ψ t \Psi_t Ψt 选择,这类选择虽然是无偏的,但是由于需要多次采样,会造成累积方差很大。为了减小方差,我们考虑了基于价值函数的 Ψ t \Psi_t Ψt 选择,但是在价值网络估计的价值函数不准时,这类选择的偏差又比较大。GAE 提出了一种一般形式,将多项对优势函数的估计值进行加权表示,通过调节超参数 λ \lambda λ,实现了偏差和方差之间进行权衡,目前已经成为策略梯度类算法中优势函数形式的主流选择。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐