多头潜在注意力论文:
- deepseek-v2: https://arxiv.org/pdf/2405.04434
- deepseek-v3: https://arxiv.org/pdf/2412.19437
MLA 最核心的理念就是低秩转换
我们回顾一下最基本的 attention 计算,这里直接省略各种 MHA,MQA,GQA,因为这些 attention 变种并没有本质的改变了 attention 的计算公式,只是简单的共享
Q = W_q x\\ K = W_k x\\ V = W_v x\\ A=softmax(Q^TK / \sqrt{d}) \\ O = AV \\ Y = W_o O \\ MLP其中 x 是 [1, h] 矩阵,W_k 是 [h, h] 矩阵,因此,QKV 都是 [1, h] 矩阵,推理过程中的 KV 显存占用 sizeof(fp16) * 2 * b * l * h * s = 4bhls,即使使用最先进的 GQA,显存占用也是 4bhls / 8
有没有一种无损的方法,降低 KV 缓存服用,又不影响模型的效果?
deepseek-v3 探索出了一种新思路
1. W 权重矩阵的低秩转换
大模型推理过程中保存的是 KV,计算的 A,W_k也只在这2个地方会用到
W_k = W_{UK} W_{DK}其中 DK 表示 down-projection matrix 降维矩阵,UK 表示 up-projection matrix 升维矩阵
K = W_{UK} W_{DK} * x\\ A=softmax((W_q * x)^T * (W_{UK} W_{DK} * x) / \sqrt{d}) \\到这里就很关键了,虽然矩阵乘法不支持交换律,但是矩阵乘法还遵循一个定理(AB)^T = B^T A^T
因此有:
K = W_{UK} W_{DK} * x\\ A=softmax((x^T * W^T_q * (W_{UK} W_{DK} * x) / \sqrt{d}) \\如果只是简单的做 W_k的低秩转换,每次计算 A 的时候,K 还需要乘上一个 W_{UK},虽然节省了显存,但是增加了计算量,不是好事
亮点来了,这个计算 A 的时候,W_{UK}是可以被 W^T_q吸收掉的!
于是公式就变成了:
C^{KV} = W_{DK} * x \\ W^{new}_q = W^T_q * W_{UK} \\ A = softmax(x^T W^{new}_q C^{KV})实际上在作者的论文里面:
- KV 2个缓存直接合并成1个缓存 C^{KV}
- W_{UK}被吸收到W_q里面,W_{UV}被吸收到了W_o里面
- Q 也做了同样的低秩转换,节省了训练的内存
如下所示:
2. RoPE 解耦
但是现在有一个最大的问题就是,现在基本所有的大预言模型都会启用 RoPE,RoPE 本质是一个对角矩阵,目的是为了在计算 attention 的时候引入 token 之间的位置信息,这样能够更好的理解上下文和生成质量
Q,K 都需要应用 RoPE,V不需要
原始的 RoPE:
Q^T_mK_n = (R^d_{\Theta,m} W_q x_m)^T(R^d_{\Theta,n} W_k x_n) = x^T_m W^T_q R^d_{\Theta,n-m} W_k x_n低秩转换之后:
这就导致x^T * W^T_q * (W_{UK} W_{DK} * x)之间对会多一个 R^d_{\Theta,n-m}
x^T * W^T_q * R^d_{\Theta,n-m} * (W_{UK} W_{DK} * x)这样W_{UK}就没法简单的被W^T_q吸收了,因为矩阵运算不支持交换律,如果没法吸收,那就必须在推理的时候,增加一次计算,这样又不划算了
对于这个问题,目前唯一可借鉴的方法,就是放弃 RoPE,使用其他位置编码算法如ALIBI,但 DeepSeek 的实验显示其他现有方法的效果都不太行
最后论文使用了一个折中的算法:Q、K新增d_r个维度用来保存RoPE信息,计算 Attention 的时候,计算 Attention 的时候,矩阵拼接就直接变成加法了
相当于原有的 Q^TK加上 q^R_t k^R_t,再做 softmax,这样 Q^TK可以继续使用之前的 MLA 算法
疑问:
- 能想到这个方法,并且效果不失真,怎么找到的?
3. 推理的性能提升
参考 https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json 的数据
d_c = 512 对应配置 kv_lora_rank
d^R_h = 64 对应配置 qk_rope_head_dim
性能提升:
-
对比 MHA:kv 大小相当于从 MHA 的 128 * 128 = 16384 优化到了 512 + 64 = 576,显存降低 2*16384/57=56 倍,相当于bs提升56倍
-
对比 GQA = 8:也相当于bs提升了 2 * 8 * 128 / 576 = 3.6 倍
-
对比 GQA = 16:由于这个模型更大,同等参数量,GQA可能得是 g = 16 才能保证效果?bs提升 7.2 倍
4. 源码实现
v2 的实现:https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py
v3 的实现:https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py