ReFT: Reasoning with REinforced Fine-Tuning

最近 openai 发布圣诞系列的第一弹,就强调了强化微调,基于这个,可以让小模型结合行业数据,做到比大模型更强的推理效果

然后研究了下字节之前发过的类似的一篇论文:https://arxiv.org/pdf/2401.08967

 

1. 背景

1.1. 传统的 CoT 训练方法

虚线之上是传统的 CoT 训练方法,就是使用数据集(x, e, y)不断的训练基础模型,让基础模型获得推理能力

比如 gsm8k 数据集:https://huggingface.co/datasets/openai/gsm8k/viewer/main/train?p=1&row=167

这个数据集里面每一行就是一条训练数据,包括3部分:

x就是问题

e就是解决这个问题的思路

y是答案

但是这种训练方法,模型推理的泛华能力是比较弱的,因为它只能学习到一种解题思路,就是数据集中的思路

1.2. ReFT:Reinforced Fine-Tuning

训练思路和 SFT 很不一样同一条数据集,SFT 会反复训练多次,让模型在数据集上误差最小。这样训练出来的模型,对于解决数据集中的问题肯定是没问题的,但是对于解决数据集的其他问题,就不一定是最佳的了。这个时候回答问题的质量取决于数据集的质量和规模ReFT 只需要1~2次预热,得到一个基础的模型,然后通过强化学习,让模型主动去探索不同的解题路径,这样得到的模型,泛化能力是最强的

如上图,ReFT 有2个核心阶段:

  1. warm-up:xx
  2. 强化学习阶段:specifically Proximal Policy Optimization (PPO)

特别注意:ReFT 并不依赖额外的训练数据集通过这个方法,论文使 CodeLLAMA 和 Galactica 模型在GSM8K、MathQA、SVAMP数据集上,泛化能力得到了显著提高

2. 方法

核心算法和前面所述一样,2个阶段:

  1. warm-up
  2. 强化学习阶段

2.1. warm-up state

使用原始数据,训练 W 次,让模型获得 “初步” 的解题能力

warm-up 阶段定义的损失函数是:

其中:

  1. θ是个超参,用来控制 πθ 策略的
  2. L是CoT长度,也就是e的长度
  3. a_t 表示 e 的第 t 个 token,e = [a1, a2, …, aL−1, aL],每次 decode 一个 token
  4. s_t 表示第 t 个时刻的环境状态,这是增强学习领域的基本术语,s_(t+1) 表示在 s_t 时刻选择 a_t 后的下一个状态,当 t = 0 时,也就是初始时刻,s_0 就是输入

  1. πθ是一个环境策略,πθ(a_t|s_t) 表示在超参θ下,给定状态s_t下选择a_t的概率

疑问:

  1. 看起来这个策略函数πθ,非常关键,这个环境策略怎么设计的?

2.2. Reinforcement Learning

上述 warm-up 阶段只是让模型具备基础的推理能力,还远不够。强化学习需要通过环境反馈(奖励),让模型学习到最佳的策略

这一阶段,论文定义了一个新的损失函数 $L_{RL}(θ, \phi)$,这个损失函数,是根据 δ_t,A_t,R_t 计算出来的

2.2.1. 奖励模型

首先定义下基本的奖励模型,如下:

  1. 设计奖励(环境反馈)
    1. 正确的结果(如果算出y)要奖励
    2. 错误的结果(非y)不奖励
    3. 近似的结果(如果是个数字),部分奖励。不过论文后面也实验证明,即使去掉这部分,对效果影响也不大,所以这一步不是关键部分
  2. 计算 Kullback-Leibler 离散程度防止过于发散,这个值衡量的是模型当前的策略与初始策略之间的差异。这样做的目的是鼓励模型不要偏离初始策略太远,同时也能学到有效的策略

2.2.2. 优势估计 – A

a_t 表示我们在第 t 时刻选择生成的 token

但我们怎么知道选择 a_t 是不是一个最佳的决策?这个时候要引入一个叫优势估计的函数

其中:

  1. λ ∈ (0, 1] is the discount factor for rewards, and
  2. γ ∈ [0, 1] is the discount factor for TD
  3. δ_t:这是一个时间差分因子,表示实际观测值和预测值之间的差距。这个非常重要

因此,A 计算的是 t 时刻之后未来所有时间差分 δ_t 的折现值,t 越大,参考价值就越小(这2个因子都是 < 1,因此 l 次方一定是越来越小)

A 越大,表示当前这个动作在未来就是平均最优的(不管后面你选啥 token),从而帮助模型做出更优的决策

再把 δ_t 展开理解一下:

其中 VΦ 表示 s_t’ 状态下的预测价值因此,δ_t 计算的是 t 时刻实际观测到的奖励值 r_total,和预测的状态价值 VΦ 之间的差距

参考:

  1. 时间差分法(TD方法)
  2. OpenAI o1的真正前世竟来自字节?ReFT技术超越传统的数学微调能力,让GPT实现进化_字节跳动reft csdn-CSDN博客

2.2.3. 总回报 – R

基于上面的奖励模型和优势估计,我们很容易得到总回报的函数定义

这里表示的是预期的总汇报?当前状态 s_t 下的预测的价值 VΦ 和选择 a_t 的所带来的未来的贴现值 A

疑问:

  1. VΦ 是当前状态下的预测价值,为啥不用实际观测到的价值?

2.2.3. 损失函数 – L

这里面分为策略损失和价值损失

todo(fangdong):先知道这么算吧,还不太理解为什么这么设计。论文也没细讲,可能有些增强学习的背景知识在这里面

最后总的损失:

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注