When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch

1. Measuring the Mismatch: The vllm-kl Metric

\small{\mathbb{E}_{s\sim d_{\textcolor{red}{\pi^\text{vllm}_\theta}}}\left[\text{KL}\left(\textcolor{red}{\pi^\text{vllm}_\theta}\left(\cdot|s\right),\textcolor{blue}{\pi^\text{fsdp}_\theta}\left(\cdot|s\right)\right)\right] = \mathbb{E}_{s\sim d_{\textcolor{red}{\pi^\text{vllm}_\theta}},a\sim {\textcolor{red}{\pi^\text{vllm}_\theta}\left(\cdot|s\right)}} \left[\log\left(\frac{\textcolor{red}{\pi^\text{vllm}_\theta}(a|s)}{\textcolor{blue}{\pi^\text{fsdp}_\theta}(a|s)}\right)\right],}

但是代码用的是 kl3 散度

kl = \frac{\pi_1}{\pi_2} - log(\frac{\pi_1}{\pi_2}) - 1
rollout_log_probs = batch.batch["rollout_log_probs"] # pi_vllm
actor_old_log_probs = batch.batch["old_log_probs"] # pi_fsdp
response_mask = batch.batch["response_mask"]
log_ratio = actor_old_log_probs - rollout_log_probs
vllm_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
vllm_k3_kl = masked_mean(vllm_k3_kl_matrix,response_mask)

2. The Smoking Gun: The Low-Probability Token Pitfall

这个洞察意义其实不大,这个相关性不需要通过数学分析来找到这种洞察,vllm-kl很小,从数学上就能直接推导出 \frac{\pi_{fsdp}}{\pi_{vllm}} 会很发散

不过这里有一个重要的洞察是:当熵很大的时候,fsdp的log_probs要比vllm的log_probs小


An Introduction to Reinforcement Learning for llm

最近在研究大模型的强化学习,整理下相关的算法基础,了解一下当前主流强化学习算法的一些改进思路

1. 目标函数

1.1. 预训练

自回归模型的目标是预测序列中的下一个词元,也是使用交叉熵损失

其中:

  • T 是序列长度。
  • w_t 是时间步 t 的真实词元。
  • w_{<t} 是时间步 t 之前的所有词元。
  • P(w_t | w_{<t}) 是模型预测给定历史词元后,下一个词元为 w_t 的概率。

1.2. 强化学习

最大化目标函数(最大化优势) = 最小化 -目标函数

所以强化学习把 loss 定义为 = - si(θ) * adv

典型目标函数:

(1)PPO

(2)GRPO

其中:

  1. G 表示同一个 prompt 的G条样本答案
  2. y_i表示第几条样本,|y_i|表示向量的长度的,也就是 token 数,但是实际上这里面loss是求平均还是求和,有很多算法:token-mean,seq-mean-token-sum,seq-sum-token-mean

GRPO 计算的是 token 粒度的平均 loss

(3)GSPO

优势计算和 GRPO 一样

GSPO 计算的是 seq 粒度的平均 loss

(4)DAPO

xx

(5)BAPO

xx

和预训练不一样,强化学习是基于 reward 而不是交叉熵,token-level-reward

上面公式,除了clip的目的主要是为了防止模型训崩,loss计算的核心是重要性系数adv优势