全文核心

MoR = 同一权重循环复用(省参数)+ token 级动态深度(省算力)+ 精打细算 KV(省显存)。模型学会“难题多想、简单题少想”。

1 为什么经典 Transformer 会浪费?

经典 Transformer 从 2017 年诞生至今架构几乎没变:同样的 24 层(或 32、48 层)串行堆叠,每个 token 必须完整穿过所有层。这就像一条 24 站的装配线,不论零件多简单,都强制走完整流程,浪费主要体现在三方面:

浪费点 详细原因 直观后果
算力冗余 复杂度固定 ≈ 层数 × 序列长度。简单 token 在后期层几乎不再增益,却仍消耗 FLOPs。 A100 上 2048 token 单次推理≈ 0.8 s;如能提前退出理想可降到 <0.4 s。
显存暴涨 每层都需存 Key‑Value (KV) 对。d=1280、L=2048 时单层 KV≈16 MB,24 层≈384 MB。 对话长一点就 OOM,即便 40 GB A100 也会爆显存。
延迟不均 所有 token 同速 → 必须等最慢的长句跑完 24 层。 响应时间抖动,GPU 利用率低。

小结:固定深度 = 多算、全存、慢响应。MoR 要做的就是“让每个 token 拿到刚刚够用的计算预算”。


2 MoR 的核心思路

MoR 由 递归块 Recursion BlockRouter选择性 KV 缓存 组成,形成“按需深度 × 参数共享 × 显存精简”闭环。

2.1 递归块 Recursion Block

关键词 设计逻辑 直接收益
多层打包 把 4–6 层合成函数 fθ,一次定义反复用 权重只存一份,参数 ↓4–6×
循环调用 每个 token 最多跑 Dmax 圈: h→fθ(h)→… 复杂词能深入,简单词早退
Middle‑Cycle 共享 只共享中间几圈,首尾层保持独立 验证困惑度最低、收敛快

数学一眼看懂

h_i(0)  =  输入嵌入
h_i(d)  =  fθ( h_i(d‑1) )
              d = 1 … Dmax

2.2 Router — 给 token 发“深度配额”

Router 是一层或两层 MLP,参数量<0.1 %。它读首圈隐藏状态,输出概率向量 p(d)。

路由模式 决策时机 典型场景 稳定训练秘笈
Token‑choice 开局一次性给出 d 在线对话、低延迟 Balancing Loss + 温度退火
Expert‑choice 每圈重新挑 top‑k % 难词 离线大批量推理 路由辅助损 + Gumbel‑Softmax

常见坑:Router 过热→全部浅层;解决:温度逐步降 & 熵正则。

2.3 选择性 KV 缓存

技巧 峰值显存节省 典型场景 性能影响
递归级缓存 理论 (Nr+1)/(2Nr) 普通推理 无精度损
Recursive Sharing ≈50 % 64k+ 长上下文 PPL 升 0.1,可忽略

实现关键:已经“毕业”的 token 不再占用 KV 内存;在 Flash‑Attention 里先看一张“活跃名单”,名单外的 token 直接跳过、完全不算。


3 递归块:参数极致复用

  1. 参数对比:以 360 M baseline(24×15 M/层)为例,改为 4 层/块 ×4 圈,唯一权重≈ 4×15 M=60 M,参数直接砍 >70 %。
  2. 梯度截断:最长路径 Dmax 圈;后向传播时只回传到 d≤token 实际深度,梯度爆炸/消失问题明显减弱。
  3. 多尺度共享:共享同一组权重迫使 fθ 同时适应浅语法与深语义两种特征,实测 perplexity 比完全独立层还低约 0.3。

4 Router:思考预算分配

  • 输出维度:Dmax 通常设 3‑4;更大深度收益递减。
  • 平衡损L_balance = (mean_depth − target)^2,把平均深度压到设定预算,例如 1.6 圈。
  • 辅助路由损:(Expert‑choice 专用)soft label 让路由提前“猜”下一圈是否仍被选中,提高稳定性。
  • 示例温度计划:训练前 10 % step τ=0.75 → 中期线性降到 0.5 → 收尾固定。

小贴士:若训练中观察到深度塌缩(全部 d=1),先提高 τ 或增大 Balancing Loss 权重再继续。


5 KV 缓存:显存精打细算

  1. 递归级缓存公式:假设 batch 内 token 平均递归圈数 r̄,显存≈(r̄/Dmax)×原始 100 %。若平均只跑 1.5 圈,24 层模型显存即降到 37 %。
  2. KV Sharing 细节:首圈 KV 在显存中维持,后续圈对同一序列复用查询。需保证块内投影矩阵 weight tying,否则维度不一致。
  3. Prefill 优势:长上下文生成阶段,prefill 占用高峰由 O(layer)→O(active_layer)。MoR‑3 模型在 1M token 上下文可省约 14 GB 显存。

6 优缺点速览

优势 说明
真正三效合一 同时省参数、算力、显存
长上下文友好 缓存减半,窗口可达百万级
边缘端可用 8 GB NPU 跑百兆模型,延迟‑30 %
迁移成本低 在原权重上继续训 3‑5 B token 即可
限制 对策
小模型收益有限 建议 Dmax ≤3
Router 不稳 温度 & Balancing Loss 调参
KV 逻辑需改内核 参考官方 Flash‑Attn 分支

7 未来可探索

  • 层内稀疏 × MoR:MoE 或稀疏注意力塞进递归块,双重稀疏。
  • 连续深度预算:Router 输出实值 budget,实现可微动态计算。
  • 4‑bit KV + MoR:极端压显存,让手机端跑 Llama‑3。
  • 多模态 MoR‑ViT:图像已验证,视频/语音尚在路上。

参考文献

Mixture-of-Recursions: Parameter Sharing for Efficient Token-level Adaptive Computation
Making Transformers More Efficient with MoR
DeepMind’s Mixture‑of‑Recursions could power smaller LLMs

更多推荐