Mix Precision Training 混合精度训练

这是一篇百度和 nvidia 合作的论文
实际上在 MPT 之前,已经有一些论文在做减少精度训练相关的工作了,比如:
  1. Binaryconnect: Training deep neural networks withbinary weights during propagations @2015
  2. Binarized neural networks. InAdvances in Neural Information Processing Systems @2016

1. 背景

随着现在深度神经网络模型越来越大,万亿参数,百万亿参数,深度学习训练所需要的GPU内存越来越多,内存优化迫在眉睫
Year
Name
Param
From
2018
110M
OpenAI
2018
349M
Google
2020
175B
OpenAI
2022
540B
Google
传统的神经网络训练里面,都是用FP32来保存权重、梯度等数据。但是FP32在大模型场景下(万亿参数)内存开销太大了
为了解决这个问题,MPT 论文提出了一种使用 FP16 精度来训练神经网络,同时基本不影响训练效果的方法

2. 实现

使用FP16会有很多好处:
  1. 大幅减少内存消耗:50%?
  2. 运算加速,FP16肯定要比FP32要快
但是随之带来的问题是:精度损失
IEEE标准中的FP16格式如下:
0
取值范围是5.96× 10−8 ~ 65504,而FP32则是1.4×10-45 ~ 3.4×1038
精度损失在神经网络训练里面可是比较致命的,可能会训练出截然相反的结果
为了解决这个问题,论文提出了3种方法

2.1. FP32 master copy of weights

为每个权重保留一份FP32的副本
使用FP16训练神经网络的基本流程是:
0
简单来说就是:只有权重是FP32的,前向计算和反向计算都用FP16,在运行时进行转换
比如执行前向计算时,权重从FP32转成FP16,得到loss之后,用FP16计算梯度,再转成FP32更新到FP32的权重上
用FP32保存权重不是必须的(optional),主要是为了避免溢出:
  1. 一种是梯度的更新值太小,FP16没法表示小于2^-24的值,会直接变为了0
  2. 二是FP16表示权重的话,和梯度的计算结果也有可能变成0
实验表明,用FP16保存权重会造成80%的精度损失。

2.2. Loss Scaling

作者通过分析训练过程中产生的所有 gradient values,发现一些有趣的特征:
  1. 绝大部分值都集中在一小部分区域内
  2. 大部分值没法直接转换成FP16来保存(损失精度)
如下:
0
0
怎么能够把这个值,用FP16保存起来同时又不损失精度呢?
作者想了个方法,就是 scaling
前向计算得到FP32值之后,向右移动(乘以一个 scaling factor),然后进行反向传播,更新的时候再缩小回FP32保存
One efficient way to shift the gradient values into FP16-representable range is to scale the loss valuecomputed in the forward pass, prior to starting back-propagation
scaling factor 怎么设定?2个方法:
  1. 直接设定一个常数值,经过测试,作者发现大部分神经网络使用 8-32k 的scaling factor 都是可以的
  2. 如果你有梯队的统计数据的话,直接取 scalling factor = 65535 / 梯度最大值 即可

2.3. Arithmetic precision

对于一些模型来说,FP16精度向量的点积运算,需要使用FP32来执行累加运算然后保存到FP16结果里,否则会出现精度损失
对于这个问题,解法是通过nvidia硬件提供的FP16&FP32混合运算来解决
Therefore, in mix precision training, it is required to useNVIDIA Volta GPUs which introduce Tensor Cores to multiply FP16 input matrices and accumulate products into either FP16 orFP32 outputs [65] whereas previous GPUs supported only FP16operation.

3. 性能评估

0
可以看到,使用混合精度训练,性能基本不损失
发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注