多头注意力变体¶
🎯 本节目标¶
理解从MHA到MLA的技术演进,掌握不同注意力机制的优化原理和应用场景。
📖 阅读材料¶
核心技术文章¶
- Transformer的Attention及其各种变体 - 详细对比分析
- 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA - 科学空间深度解析
📝 知识总结¶
技术演进路径¶
各变体详细对比¶
变体 | KV Cache需求 | 计算复杂度 | 性能表现 | 主要应用 |
---|---|---|---|---|
MHA | O(h×d×L) | O(H x N^2 x D) | 基准性能 | 标准Transformer |
MQA | O(d×L) | O(N^2×D) | 轻微下降 | 资源受限场景 |
GQA | O(g×d×L) | O(G×N^2×D) | 平衡优秀 | 主流大模型 |
MLA | 最优化 | O(R×N^2×D) | 接近MHA | 长上下文 |
其中:h=头数,d=维度,L=序列长度,g=组数
核心技术细节¶
1. MHA (Multi-Head Attention)¶
# 每个头都有独立的Q、K、V
for i in range(num_heads):
Q_i = input @ W_Q_i # 每个头独立的查询矩阵
K_i = input @ W_K_i # 每个头独立的键矩阵
V_i = input @ W_V_i # 每个头独立的值矩阵
head_i = attention(Q_i, K_i, V_i)
2. MQA (Multi-Query Attention)¶
# 所有头共享K、V,只有Q独立
K_shared = input @ W_K # 共享的键矩阵
V_shared = input @ W_V # 共享的值矩阵
for i in range(num_heads):
Q_i = input @ W_Q_i # 每个头独立的查询矩阵
head_i = attention(Q_i, K_shared, V_shared)
3. GQA (Grouped-Query Attention)¶
# 分组共享:每组内共享K、V
num_groups = num_heads // group_size
for g in range(num_groups):
K_g = input @ W_K_g # 组共享的键矩阵
V_g = input @ W_V_g # 组共享的值矩阵
for i in range(group_size):
Q_i = input @ W_Q_i
head_i = attention(Q_i, K_g, V_g)
💬 面试问题解答¶
Q1: MHA、MQA、GQA、MLA都是什么?¶
简洁回答: 这是Transformer注意力机制的四个演进阶段,主要优化KV Cache的存储需求:
- MHA: 标准多头注意力,每个头独立QKV
- MQA: 多查询注意力,所有头共享KV
- GQA: 分组查询注意力,分组内共享KV
- MLA: 多头潜在注意力,通过低秩分解优化
技术细节:
MHA问题: KV Cache随头数线性增长,内存开销大
MQA解决方案: 共享KV矩阵,内存需求降低h倍
GQA平衡方案: 分组共享,兼顾性能和效率
MLA终极优化: 潜在空间投影,最小化KV Cache
🔬 MLA技术深度解析¶
MLA核心创新¶
Multi-head Latent Attention (MLA) 是DeepSeek团队提出的革命性注意力机制,通过三大创新显著降低KV Cache内存需求:
1. 低秩KV联合压缩¶
核心思想: 将高维的Key和Value矩阵联合压缩到低维潜在空间
# 传统方式:每个头独立存储KV
traditional_kv_cache = num_heads × head_dim × seq_len × 2 # K和V
# MLA方式:压缩后的潜在向量
mla_kv_cache = compressed_dim × seq_len + rope_dim × seq_len
压缩过程: \(\(c_t^{KV} = x_t W^{DKV}\)\)
其中 \(W^{DKV} \in \mathbb{R}^{d \times d_c}\),\(d_c \ll h \cdot d_h\)
2. RoPE解耦机制¶
问题: 位置编码与压缩机制的冲突 - 传统RoPE需要在原始QK空间中应用 - 压缩破坏了位置信息的正确传递
解决方案: 将Query和Key分为两部分 - 语义部分 (\(q^C, k^C\)): 携带主要语义信息,可以压缩 - 位置部分 (\(q^R, k^R\)): 携带位置信息,保持原维度
def mla_with_rope_decoupling(x, position):
# 1. 生成潜在向量
c_kv = x @ W_down_kv # 压缩
# 2. Query分离
q = x @ W_q
q_c, q_r = q[:, :d_c], q[:, d_c:] # 语义 + 位置
# 3. Key分离和恢复
k_c = c_kv @ W_up_k # 从潜在空间恢复语义Key
k_r = x @ W_k_r # 直接生成位置Key
# 4. 分别应用RoPE
q_r = apply_rope(q_r, position)
k_r = apply_rope(k_r, position)
# 5. 组合计算
q_combined = concat([q_c, q_r])
k_combined = concat([k_c, k_r])
v = c_kv @ W_up_v
return attention(q_combined, k_combined, v)
3. 权重吸收优化¶
目标: 减少推理时的矩阵乘法操作
技术: 利用矩阵乘法结合律,预先合并权重矩阵
# 原始计算:两次矩阵乘法
c_kv = x @ W_down_kv
k = c_kv @ W_up_k
# 权重吸收:合并为一次乘法
W_combined = W_down_kv @ W_up_k
k = x @ W_combined
内存效率对比¶
方法 | KV Cache大小 | 压缩比 |
---|---|---|
MHA | \(2 h \cdot d_h \cdot L\) | 1.0× (基准) |
MQA | \(2 d_h \cdot L\) | \(h\)× |
GQA | \(2 g \cdot d_h \cdot L\) | \(h/g\)× |
MLA | \((d_c + d_h^R) \cdot L\) | ~10-20× |
具体例子 (LLaMA-7B规模): - 原始MHA: 32头 × 128维 × 2 = 8192维/token - MLA压缩: 512维 + 128维 = 640维/token - 压缩比: 12.8×
性能保持机制¶
尽管大幅压缩,MLA通过巧妙设计保持了接近MHA的性能:
1. 表达能力保持¶
- 低秩假设:大部分注意力模式可以用低秩矩阵近似
- 关键信息保留:位置信息通过解耦机制完整保留
- 渐进恢复:多层堆叠逐步恢复完整信息
2. 训练稳定性¶
# 残差连接确保训练稳定
def mla_block(x):
# MLA注意力
attn_out = mla_attention(x)
x = x + attn_out # 残差连接
# FFN
ffn_out = feed_forward(x)
x = x + ffn_out # 残差连接
return x
3. 位置敏感性¶
- RoPE解耦确保位置信息不丢失
- 位置编码维度可以根据任务需求调整
- 长序列外推能力得到保持
Q2: 为什么需要这些优化?¶
核心动机:
- 内存瓶颈
- 长序列推理时KV Cache占用大量显存
-
限制了模型的部署和扩展能力
-
推理速度
- 减少内存访问,提高计算效率
-
支持更大的batch size
-
成本考虑
- 降低硬件要求
- 提高服务并发能力
Q3: 各变体的优缺点对比?¶
维度 | MHA | MQA | GQA | MLA |
---|---|---|---|---|
性能 | 🟢 基准最好 | 🟡 轻微下降 | 🟢 接近MHA | 🟢 超越MHA |
内存 | 🔴 需求最高 | 🟢 显著降低 | 🟡 适中 | 🟢 最优 |
速度 | 🟡 标准 | 🟢 最快 | 🟢 较快 | 🟢 优秀 |
实现 | 🟢 简单 | 🟢 简单 | 🟡 中等 | 🔴 复杂 |
Q4: 如何选择合适的注意力机制?¶
选择策略:
if 资源充足 and 追求最佳性能:
选择 MHA
elif 资源严重受限 and 可接受性能损失:
选择 MQA
elif 需要平衡性能和效率:
选择 GQA # 主流选择
elif 长上下文 and 内存敏感:
选择 MLA
实际考虑因素:
- 硬件内存限制
- 序列长度需求
- 延迟要求
- 开发复杂度
💻 代码实现¶
练习1: 实现Softmax函数¶
平台: Deep-ML Softmax
练习2: MHA到GQA的适配¶
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_groups):
super().__init__()
assert num_heads % num_groups == 0
self.d_model = d_model
self.num_heads = num_heads
self.num_groups = num_groups
self.heads_per_group = num_heads // num_groups
self.d_k = d_model // num_heads
# Q矩阵:每个头独立
self.W_q = nn.Linear(d_model, d_model)
# K,V矩阵:按组共享
self.W_k = nn.Linear(d_model, num_groups * self.d_k)
self.W_v = nn.Linear(d_model, num_groups * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.shape
# 生成Q:每个头独立
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k)
# 生成K,V:按组共享
K = self.W_k(x).view(batch_size, seq_len, self.num_groups, self.d_k)
V = self.W_v(x).view(batch_size, seq_len, self.num_groups, self.d_k)
# 重复K,V以匹配Q的头数
K = K.repeat_interleave(self.heads_per_group, dim=2)
V = V.repeat_interleave(self.heads_per_group, dim=2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
attn_weights = F.softmax(scores, dim=-1)
out = torch.matmul(attn_weights, V)
# 合并多头输出
out = out.view(batch_size, seq_len, d_model)
return self.W_o(out)
练习3: KV Cache实现预览¶
class KVCache:
def __init__(self, max_seq_len, num_heads, d_k):
self.max_seq_len = max_seq_len
self.cache_k = torch.zeros(max_seq_len, num_heads, d_k)
self.cache_v = torch.zeros(max_seq_len, num_heads, d_k)
self.current_len = 0
def update(self, new_k, new_v):
"""更新缓存并返回完整的K,V"""
seq_len = new_k.size(0)
# 存储新的K,V
self.cache_k[self.current_len:self.current_len+seq_len] = new_k
self.cache_v[self.current_len:self.current_len+seq_len] = new_v
self.current_len += seq_len
# 返回到目前为止的完整K,V
return (self.cache_k[:self.current_len],
self.cache_v[:self.current_len])
✅ 学习检验¶
- [ ] 能解释各变体的核心区别
- [ ] 理解KV Cache优化的原理
- [ ] 完成GQA代码实现
- [ ] 能根据场景选择合适的注意力机制