Expected Sarsa 与 n-步 Sarsa

在上一篇文章中,我们学习了 Sarsa 算法。虽然它效果不错,但它有两个小软肋:

  1. 太依赖运气:下一步动作 At+1A_{t+1}At+1 是随机选出来的(特别是 ϵ\epsilonϵ-greedy 策略下),这导致训练过程波动较大。
  2. 视野太窄:只看一步(只看 Rt+1R_{t+1}Rt+1 和下一刻的 QQQ),就像近视眼开车,容易陷入局部最优。

今天,我们通过两个变体算法来解决这两个问题。


一、 Expected Sarsa:与其赌运气,不如算概率

1.1 核心思想

在标准 Sarsa 中,我们的更新公式依赖于下一个实际发生的动作 at+1a_{t+1}at+1
Target=rt+1+γQ(st+1,at+1) \text{Target} = r_{t+1} + \gamma Q(s_{t+1}, \mathbf{a_{t+1}}) Target=rt+1+γQ(st+1,at+1)
这就好像你预估明天的花费,完全取决于你明天随机抽到的那张优惠券。如果抽到了大奖,你觉得明天很省钱;没抽到,你觉得很费钱。这种随机性导致了方差(Variance)很大。

Expected Sarsa 的想法是:既然我已经知道策略 π\piπ 在下一状态选各个动作的概率,为什么不直接算一个**平均值(期望)**呢?

1.2 算法公式

Q(st,at)←Q(st,at)+α[rt+1+γ∑a′π(a′∣st+1)Q(st+1,a′)⏟Expected TD Target−Q(st,at)] Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[ \underbrace{r_{t+1} + \gamma \sum_{a'} \pi(a'|s_{t+1}) Q(s_{t+1}, a')}_{\text{Expected TD Target}} - Q(s_t, a_t) \right] Q(st,at)Q(st,at)+α Expected TD Target rt+1+γaπ(ast+1)Q(st+1,a)Q(st,at)

这里的核心变化在于 TD Target 的第二部分:

  • Sarsa: 直接用 Q(st+1,at+1)Q(s_{t+1}, a_{t+1})Q(st+1,at+1)
  • Expected Sarsa: 计算期望值 E[Q]=∑aπ(a∣st+1)×Q(st+1,a)\mathbb{E}[Q] = \sum_{a} \pi(a|s_{t+1}) \times Q(s_{t+1}, a)E[Q]=aπ(ast+1)×Q(st+1,a)

与 Sarsa 相比:

  • TD 目标发生了变化:在 Sarsa 中为 rt+1+γqt(st+1,at+1)r_{t+1} + \gamma q_t(s_{t+1}, a_{t+1})rt+1+γqt(st+1,at+1),而在 Expected Sarsa 中为 rt+1+γE[qt(st+1,A)]r_{t+1} + \gamma \mathbb{E}[q_t(s_{t+1}, A)]rt+1+γE[qt(st+1,A)]
  • 需要更多的计算量。但它是有益的,因为它减少了估计方差,因为它将 Sarsa 中的随机变量从 {st,at,rt+1,st+1,at+1}\{s_t, a_t, r_{t+1}, s_{t+1}, a_{t+1}\}{st,at,rt+1,st+1,at+1} 减少到了 {st,at,rt+1,st+1}\{s_t, a_t, r_{t+1}, s_{t+1}\}{st,at,rt+1,st+1}

理解
Sarsa 说:“我下一步正好走到了左边,所以用左边的价值更新。”
Expected Sarsa 说:“我下一步有 90% 概率向左,10% 概率向右,所以我用 (0.9×左 + 0.1×右) 的加权平均值来更新。”

1.3 Expected Sarsa 的优缺点

优点 缺点
方差小:消除了 At+1A_{t+1}At+1 选择带来的随机性,训练曲线更平滑。 计算量大:每次更新都要把下一状态的所有动作遍历一遍求和。
收敛稳:通常比 Sarsa 收敛得稍快一些。 代码略繁:需要显式地计算策略分布。

1.4 代码实现对比

# Sarsa 的 Target 计算
next_action = agent.take_action(next_state) # 真的选了一个动作
td_target = reward + gamma * Q_table[next_state, next_action]

# Expected Sarsa 的 Target 计算 (假设是 epsilon-greedy)
# 1. 计算最大 Q 值动作的概率 (1 - epsilon + epsilon/n_actions)
# 2. 计算其他动作的概率 (epsilon/n_actions)
# 3. 加权求和
expected_value = 0
for action in range(n_actions):
    prob = get_probability(next_state, action) # 获取策略概率
    expected_value += prob * Q_table[next_state, action]
    
td_target = reward + gamma * expected_value

二、 n-步 Sarsa (n-step Sarsa):统一 Sarsa 与 MC

2.1 视野的差异

我们在之前的文章中对比过 TD 和 MC:

  • TD(0) / Sarsa:走一步,看一步。利用 Rt+1+γQnextR_{t+1} + \gamma Q_{next}Rt+1+γQnext 更新。
    • 特点:偏差大(盲目信任 Q),方差小。
  • Monte Carlo (MC):走到终点,算总账。利用 GtG_tGt 更新。
    • 特点:无偏差(全是真金白银的奖励),方差大(路途遥远,变数多)。

如果不只走一步,也不走到终点,而是走 n 步呢? 这就是 n-step Sarsa

2.2 回报 (Return) 的谱系

让我们看看 nnn 的取值如何改变我们的“视野”:

  • 1-step (Sarsa): Gt(1)=Rt+1+γQ(St+1,At+1)G_t^{(1)} = R_{t+1} + \gamma Q(S_{t+1}, A_{t+1})Gt(1)=Rt+1+γQ(St+1,At+1)
  • 2-step: Gt(2)=Rt+1+γRt+2+γ2Q(St+2,At+2)G_t^{(2)} = R_{t+1} + \gamma R_{t+2} + \gamma^2 Q(S_{t+2}, A_{t+2})Gt(2)=Rt+1+γRt+2+γ2Q(St+2,At+2)
  • n-step: Gt(n)=Rt+1+γRt+2+⋯+γn−1Rt+n+γnQ(St+n,At+n)G_t^{(n)} = R_{t+1} + \gamma R_{t+2} + \dots + \gamma^{n-1} R_{t+n} + \gamma^n Q(S_{t+n}, A_{t+n})Gt(n)=Rt+1+γRt+2++γn1Rt+n+γnQ(St+n,At+n)
  • ∞\infty-step (MC): Gt(∞)=Rt+1+γRt+2+⋯+γT−1RTG_t^{(\infty)} = R_{t+1} + \gamma R_{t+2} + \dots + \gamma^{T-1} R_TGt()=Rt+1+γRt+2++γT1RT

需要注意的是,Gt=Gt(1)=Gt(2)=Gt(n)=Gt(∞)G_t = G_t^{(1)} = G_t^{(2)} = G_t^{(n)} = G_t^{(\infty)}Gt=Gt(1)=Gt(2)=Gt(n)=Gt(),其中上标仅表示 GtG_tGt 的不同分解结构。

2.3 n-step Sarsa 的更新机制

这就带来了一个问题:延迟更新
为了计算 Gt(n)G_t^{(n)}Gt(n),我们必须等到时间步 t+nt+nt+n 真的发生后,拿到了 Rt+nR_{t+n}Rt+nSt+nS_{t+n}St+n,才能回头去更新时间步 ttt 的 Q 值。

更新公式为:
Q(st,at)←Q(st,at)+α[Gt(n)−Q(st,at)] Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha [ G_t^{(n)} - Q(s_t, a_t) ] Q(st,at)Q(st,at)+α[Gt(n)Q(st,at)]

  • Sarsa 旨在求解:
    qπ(s,a)=E[Gt(1)∣s,a]=E[Rt+1+γqπ(St+1,At+1)∣s,a]. q_\pi(s, a) = \mathbb{E}[G_t^{(1)} | s, a] = \mathbb{E}[R_{t+1} + \gamma q_\pi(S_{t+1}, A_{t+1}) | s, a]. qπ(s,a)=E[Gt(1)s,a]=E[Rt+1+γqπ(St+1,At+1)s,a].

  • MC 学习旨在求解:
    qπ(s,a)=E[Gt(∞)∣s,a]=E[Rt+1+γRt+2+γ2Rt+3+…∣s,a]. q_\pi(s, a) = \mathbb{E}[G_t^{(\infty)} | s, a] = \mathbb{E}[R_{t+1} + \gamma R_{t+2} + \gamma^2 R_{t+3} + \dots | s, a]. qπ(s,a)=E[Gt()s,a]=E[Rt+1+γRt+2+γ2Rt+3+s,a].

  • 一种称为 n步 Sarsa 的中间算法旨在求解:
    qπ(s,a)=E[Gt(n)∣s,a]=E[Rt+1+γRt+2+⋯+γnqπ(St+n,At+n)∣s,a]. q_\pi(s, a) = \mathbb{E}[G_t^{(n)} | s, a] = \mathbb{E}[R_{t+1} + \gamma R_{t+2} + \dots + \gamma^n q_\pi(S_{t+n}, A_{t+n}) | s, a]. qπ(s,a)=E[Gt(n)s,a]=E[Rt+1+γRt+2++γnqπ(St+n,At+n)s,a].

  • n步 Sarsa 的算法是:

    qt+1(st,at)=qt(st,at)−αt(st,at)[qt(st,at)−[rt+1+γrt+2+⋯+γnqt(st+n,at+n)]] \begin{aligned} q_{t+1}(s_t, a_t) = q_t(s_t, a_t) & \\ & - \alpha_t(s_t, a_t) \left[ q_t(s_t, a_t) - \left[ r_{t+1} + \gamma r_{t+2} + \dots + \gamma^n q_t(s_{t+n}, a_{t+n}) \right] \right] \end{aligned} qt+1(st,at)=qt(st,at)αt(st,at)[qt(st,at)[rt+1+γrt+2++γnqt(st+n,at+n)]]

    n步 Sarsa 更具一般性,因为当 n=1n=1n=1 时,它变成了(一步)Sarsa 算法;当 n=∞n=\inftyn= 时,它变成了 MC 学习算法。


  • n步 Sarsa 需要 (st,at,rt+1,st+1,at+1,…,rt+n,st+n,at+n)(s_t, a_t, r_{t+1}, s_{t+1}, a_{t+1}, \dots, r_{t+n}, s_{t+n}, a_{t+n})(st,at,rt+1,st+1,at+1,,rt+n,st+n,at+n)

  • 由于在时间 ttt 时尚未收集到 (rt+n,st+n,at+n)(r_{t+n}, s_{t+n}, a_{t+n})(rt+n,st+n,at+n),我们无法在步骤 ttt 实施 n步 Sarsa 来更新 (st,at)(s_t, a_t)(st,at) 的 q 值。然而,我们可以等到时间 t+nt+nt+n 再进行更新:

    qt+n(st,at)=qt+n−1(st,at)−αt+n−1(st,at)[qt+n−1(st,at)−[rt+1+γrt+2+⋯+γnqt+n−1(st+n,at+n)]] \begin{aligned} q_{t+n}(s_t, a_t) = q_{t+n-1}(s_t, a_t) & \\ & - \alpha_{t+n-1}(s_t, a_t) \left[ q_{t+n-1}(s_t, a_t) - \left[ r_{t+1} + \gamma r_{t+2} + \dots + \gamma^n q_{t+n-1}(s_{t+n}, a_{t+n}) \right] \right] \end{aligned} qt+n(st,at)=qt+n1(st,at)αt+n1(st,at)[qt+n1(st,at)[rt+1+γrt+2++γnqt+n1(st+n,at+n)]]

  • 由于 n步 Sarsa 包含了 Sarsa 和 MC 学习作为两个极端情况,其性能是 Sarsa 和 MC 学习的混合:

    • 如果 nnn 很大,其性能接近 MC 学习,因此具有较大的方差但较小的偏差。
    • 如果 nnn 很小,其性能接近 Sarsa,因此具有相对较大的偏差(源于初始猜测)和相对较低的方差。
  • 最后,n步 Sarsa 也可用于策略评估。它可以与策略改进步骤结合,以搜索最优策略。

2.4 为什么要用 n-step?

n-step Sarsa 提供了一个调节 偏差(Bias)方差(Variance) 的旋钮:

  1. 当 n 较小时(如 n=1):
    • 利用了 Bootstrapping(用估计更新估计),方差小,学习快,但容易受到初始 Q 值不准确的影响(偏差大)。
  2. 当 n 较大时(接近 MC):
    • 利用了更多的真实奖励,偏差小,最终结果准。但因为链路长,中间任何一步的随机性都会累积,导致方差大,学习过程抖动。
  3. 折中方案
    • 通常选择 中间值 往往能取得最好的效果,既比单步看得远,又不至于像 MC 那样等到花儿都谢了。

三、 总结与对比

我们现在学习了三种主要的无模型控制算法,它们的关系如下表所示:

算法 核心目标 (Target) 特点 适用场景
Sarsa (1-step) R+γQ(s′,a′)R + \gamma Q(s', a')R+γQ(s,a) 简单,在线更新,方差略大 入门首选,快速验证
Expected Sarsa R+γE[Q(s′,⋅)]R + \gamma \mathbb{E}[Q(s', \cdot)]R+γE[Q(s,)] 消除动作随机性,更稳定 想要减少震荡,算力充足时
n-step Sarsa R+⋯+γnQ(st+n,⋅)R + \dots + \gamma^n Q(s_{t+n}, \cdot)R++γnQ(st+n,) 平衡偏差与方差,视野更广 追求极致性能,处理复杂回报延迟
Monte Carlo GtG_tGt (累计总回报) 无偏差,方差极大,离线 只有回合制任务,且对准确性要求高

核心结论

  • Expected Sarsa 改进了 Target 的计算方式(由点到面)。
  • n-step Sarsa 改进了 Target 的计算深度(由近及远)。
  • 它们本质上都是为了更准确地逼近真实的动作价值 qπ(s,a)q_\pi(s, a)qπ(s,a),从而找到最优策略。

更多推荐