When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch

1. Measuring the Mismatch: The vllm-kl Metric

\small{\mathbb{E}_{s\sim d_{\textcolor{red}{\pi^\text{vllm}_\theta}}}\left[\text{KL}\left(\textcolor{red}{\pi^\text{vllm}_\theta}\left(\cdot|s\right),\textcolor{blue}{\pi^\text{fsdp}_\theta}\left(\cdot|s\right)\right)\right] = \mathbb{E}_{s\sim d_{\textcolor{red}{\pi^\text{vllm}_\theta}},a\sim {\textcolor{red}{\pi^\text{vllm}_\theta}\left(\cdot|s\right)}} \left[\log\left(\frac{\textcolor{red}{\pi^\text{vllm}_\theta}(a|s)}{\textcolor{blue}{\pi^\text{fsdp}_\theta}(a|s)}\right)\right],}

但是代码用的是 kl3 散度

kl = \frac{\pi_1}{\pi_2} - log(\frac{\pi_1}{\pi_2}) - 1
rollout_log_probs = batch.batch["rollout_log_probs"] # pi_vllm
actor_old_log_probs = batch.batch["old_log_probs"] # pi_fsdp
response_mask = batch.batch["response_mask"]
log_ratio = actor_old_log_probs - rollout_log_probs
vllm_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
vllm_k3_kl = masked_mean(vllm_k3_kl_matrix,response_mask)

2. The Smoking Gun: The Low-Probability Token Pitfall

这个洞察意义其实不大,这个相关性不需要通过数学分析来找到这种洞察,vllm-kl很小,从数学上就能直接推导出 \frac{\pi_{fsdp}}{\pi_{vllm}} 会很发散

不过这里有一个重要的洞察是:当熵很大的时候,fsdp的log_probs要比vllm的log_probs小

3. More Tool Calls, More Training Instability

对于没有见过的上下文,大模型是很难学习到这之间的相关性的

猜测

传统的工具调用,工具生成的结果,在训练阶段会被 mask 掉。这样会导致RL过程中,大模型会强制计算2个不想关片段的相关性,但是可能这个相关性在现有大模型的能力里就是没有的。因此训练很容易崩溃。

工具调用的RL可能需要一个新的范式。

4. The Environmental Factor: The Critical Role of Hardware

xx

5. The Mismatch is Not Static: A Vicious Cycle Driven by Optimization

如果vllm-kl > 0.1,过滤掉当前的batch数据,不更新模型权重。

本质上等价于 gspo

6. Masked Importance Sampling (MIS)

原文 4.2.1 首先从数学上证明了 token-level 的重要性采样是一个有偏估计,原因有2:

  1. 分布的误差是由2方面引入的:状态空间的偏差和动作空间的偏差,token-level IS只修正了动作空间的偏差,没有修正状态空间的偏差(疑问:但是token的概率其实隐含了状态空间的概率,所以理论上也一定程度的修正了状态空间的误差)
  2. 奖励是基于 fsdp 策略空间的,而不是 vllm 动作空间的,所以这里也有误差

作者提出了 MIS 策略:对于 seq 的重要性误差超过 C 的,直接 masked 掉,不参与梯度更新

这个做法和 GSPO 是类似的,和 GSPO 不同的是,MIS保留了GRPO的token-level梯度计算方式,而GSPO的token-level梯度公式不一样了。

如下:

gspo的梯度公式可以看gspo的论文

效果:

7. Top-p Sampling

这个更感觉像是个反例

top-p越大,vllm-kl越大,但是 reward 反而更好了。。。

8. Diable Cascade Attention in vLLM

这个问题根源是 flashattention2 的一个bug,在a100和L20这种型号的GPU上,大batch size下会触发split-kv,这个split-kv操作会导致一个LSE错误。

这个bug已经被修复了:

 

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注