Attention MQA GQA MLA Flash Attention

注意力机制的工程革命:从MHA到MLA

系统解析注意力机制从标准MHA到GQA、MLA、FlashAttention的完整演进路径与工程权衡

一、为什么需要注意力优化

1.1 O(n²)复杂度的本质瓶颈

Attention机制的计算复杂度是O(n² · d),其中n是序列长度,d是隐藏维度。这个二次方复杂度来自注意力分数矩阵——对于每个token,需要计算它与所有其他token的点积,生成一个n×n的注意力权重矩阵。

以LLaMA-2 70B为例,处理2048个token时,Attention层需要约 2,048² × 4,096 × 8 ≈ 131 TFLOPs——而这只是单层的计算量。80层叠加,每层计算量翻倍(Attention QKV投影 + O投影 + 残差),总计算量达到数十PFLOPS的天文数字。

1.2 KV-Cache的显存墙

比计算量更严重的是显存问题。自回归生成需要缓存所有历史token的Key和Value向量——这叫做KV-Cache。每个token的KV向量大小 = 2 × num_heads × head_dim × bytes_per_element。

# KV-Cache 显存计算

# LLaMA-2 70B KV-Cache 配置:
num_heads = 8 × 8 = 64  (GQA with 8 KV heads)
head_dim = 4096 / 64 = 128
bytes_per_element = 2 (BF16)

单个 token 的 KV 大小:
  KV_single = 2 × 64 × 128 × 2 bytes = 32,768 bytes ≈ 32 KB

# 128K 上下文的 KV-Cache:
KV_total = 32KB × 131,072 ≈ 4.2 GB per sample

# 32 层叠加:
KV_layers = 4.2 GB × 32 ≈ 134 GB

# 对比:
# - H100 SXM: 80 GB HBM
# - A100 80GB: 80 GB HBM
# → 128K上下文在单卡上根本装不下!

# 不同注意力变体的 KV-Cache 对比:
MHA (num_kv_heads = num_q_heads = 64):
  KV_single = 2 × 64 × 128 × 2 = 32 KB

GQA (num_kv_heads = 8):
  KV_single = 2 × 8 × 128 × 2 = 4 KB
  压缩率: 8x smaller

MQA (num_kv_heads = 1):
  KV_single = 2 × 1 × 128 × 2 = 0.5 KB
  压缩率: 64x smaller

MLA (latent_dim = 512):
  KV_single = 2 × 512 × 2 = 2 KB
  压缩率: 16x smaller (with low-rank compression)

1.3 多样性驱动优化方向

注意力优化已经发展出多个方向:减少KV-Cache体积(GQA、MQA、MLA)、降低计算复杂度(线性注意力、稀疏注意力)、优化硬件效率(Flash Attention)、扩展上下文长度(NTK-aware Scaling)。每个方向都代表了不同的工程权衡。

二、Multi-Head Attention:基准线

2.1 MHA的原始设计

标准Multi-Head Attention(MHA)将Query、Key、Value分别投影到h个子空间,在每个子空间独立计算注意力,最后拼接输出。隐藏维度d_model被均分为h个head,每个head的维度是d_k = d_v = d_model / h。

# MHA 数学公式

# 单头注意力:
Attention(Q, K, V) = softmax(QK^T / √d_k)V

# 多头版本:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

# 参数分析 (LLaMA-2 70B 配置):
d_model = 4096
num_heads = 32  (每个head 128维)
d_k = d_v = 128

Q/K/V 投影矩阵: W_Q, W_K, W_V ∈ R^(d_model × d_model)
  每个: 4096 × 4096 = 16M 参数
  三个总计: 48M 参数

W_O 投影矩阵: W^O ∈ R^(d_model × d_model)
  16M 参数

总Attention参数量: 64M per layer (仅QKV + O)

# KV-Cache 体积:
num_kv_heads = num_heads = 32
KV_per_token = 2 × 32 × 128 × 2 bytes = 16 KB

2.2 为什么多头有效

Multi-Head Attention的核心价值在于"分而治之":不同的head可以学习不同的注意力模式。有的head专注于局部上下文(语法结构),有的head专注于远距离依赖(指代消解),有的head专注于语义相似性。

这个现象有实验证据:注意力头的可视化显示,不同head确实呈现出明显不同的注意力模式——某些head形成对角线(局部注意力),某些形成垂直线(特定token的注意力)。

三、Group-Query Attention:中间地带

3.1 GQA的设计动机

MHA的KV-Cache是主要瓶颈:每个head都需要独立的K和V。当num_heads=32时,KV-Cache是单头情况(num_kv=1)的32倍。但Query头的数量同样影响计算量——减少Query头会直接影响模型表达能力。

GQA(Group-Query Attention)提出了一个优雅的中间方案:Query头数量保持不变(维持表达能力),但将KV头数量从num_heads减少到num_kv_heads(通常为4~16)。每组Query头共享一对KV。

3.2 GQA的配置与效果

# GQA vs MHA vs MQA

MHA (LLaMA-1):
  num_q_heads = 32, num_kv_heads = 32
  KV-Cache: 16 KB/token
  推理速度: baseline

GQA (LLaMA-2, Mistral):
  num_q_heads = 32, num_kv_heads = 8
  KV-Cache: 4 KB/token  (减少 4x)
  推理速度: ~2x faster  (KV-Cache小了,memory带宽压力小)
  质量损失: 可忽略

MQA (some production systems):
  num_q_heads = 32, num_kv_heads = 1
  KV-Cache: 0.5 KB/token  (减少 32x)
  推理速度: ~4x faster
  质量损失: 可能在某些任务上明显

# LLaMA-2 的 GQA 配置:
# 70B: num_q=80, num_kv=8  (10x reduction)
# 7B: num_q=32, num_kv=8  (4x reduction)

# GQA 的核心公式:
# 每个 KV head 被 G = num_q / num_kv 个 Query 组共享
# Q 被分成 num_kv 组,每组 G 个 Query heads
# 组内的 Query heads 共享同一对 K, V

# 计算时:
# K_expanded = K.expand(num_q_heads)  # 复制到每组
# V_expanded = V.expand(num_q_heads)  # 复制到每组
# 然后按标准 MHA 计算

3.3 LLaMA-2的选择:为什么是8个KV头

LLaMA-2 70B选择了num_kv_heads=8,相比LLaMA-1(无GQA)的num_kv_heads=32,这是一个8x的KV-Cache压缩。LLaMA-3进一步扩展到128K上下文,继续使用GQA(num_kv_heads=8)。

选择8而非1的原因是:质量与效率的帕累托最优。实验表明,num_kv_heads=1(MQA)在某些任务(特别是需要细粒度位置感知的任务)上有质量损失,而num_kv_heads=8几乎可以在所有任务上保持MHA的质量水平。

四、Multi-Query Attention:极致效率

4.1 MQA的极致压缩

MQA(Multi-Query Attention)将GQA推向极端:num_kv_heads=1,即所有Query头共享同一对Key和Value。这是1997年以来注意力机制中最重要的单变量优化之一(Shazeer 2019)。

MQA的理论计算量与GQA相同(因为Query头数量不变),但KV-Cache压缩到了极致:32x压缩(相对MHA的32头)。这使得MQA在高并发推理场景(如大批量在线服务)中具有显著优势。

4.2 MQA的局限性

MQA并非没有代价。在训练阶段,MQA的KV共享可能影响模型的表达能力——多个Query头原本可以从不同的KV子空间获取信息,MQA强制它们共享相同的信息源。

实践中的发现:MQA在推理时质量通常接近MHA,但在训练时可能收敛到次优解。因此,MQA更多用于推理优化(如FasterTransformer),而非训练新模型。

五、Multi-Head Latent Attention:DeepSeek的突破

5.1 低秩压缩的核心思想

MLA(Multi-Head Latent Attention,DeepSeek-V2提出)引入了完全不同的优化思路:不是减少KV头的数量,而是对KV进行低秩压缩(Low-Rank Compression)。

其核心洞察是:KV矩阵的信息量其实远低于其维度所暗示的量。Transformer中的KV存储了大量冗余信息——不同位置token的Key向量之间存在高度相关性。通过学习一个低秩的隐空间,可以将KV压缩到远小于原始维度,同时保留绝大部分有效信息。

5.2 MLA的数学实现

# MLA 数学公式

# 标准 MHA/GQA:
# K, V ∈ R^(num_kv_heads × head_dim) per token
# KV-Cache 体积 = num_kv_heads × head_dim × 2 bytes

# MLA: 低秩 KV 压缩
# c_KV = 低秩维度 (latent dim, 如 512)
# W_KV ∈ R^(c_KV × d_model)  # 压缩矩阵
# W_K, W_V ∈ R^(d_model × c_KV)  # 解压矩阵

# 压缩:
c_KV = down_proj(KV) = W_KV @ x  # [seq, c_KV]
# c_KV = 512 << d_model = 4096

# 解压(推理时):
K = W_K @ c_KV  # [seq, num_kv_heads × head_dim]
V = W_V @ c_KV

# KV-Cache 体积对比 (LLaMA-2 70B):
# MHA: 2 × 32 × 128 × 2 = 16 KB/token
# GQA (8 KV heads): 2 × 8 × 128 × 2 = 4 KB/token
# MLA (latent=512): 2 × 512 × 2 = 2 KB/token

# MLA 的优势:
# 1. KV-Cache 比 GQA 更小 (2KB vs 4KB)
# 2. 信息保留比 MQA 更完整 (低秩 vs 硬共享)
# 3. 所有 head 共享压缩表示,信息冗余最小

5.3 MLA vs GQA:工程权衡

MLA和GQA代表了两种不同的KV优化路线。GQA通过减少KV头数量来压缩,思路直观、实现简单,但压缩率受限于num_kv_heads的最小值。MLA通过低秩压缩来优化,思路更抽象,但可以在更细粒度的维度上控制压缩率。

DeepSeek-V2(236B总参数,21B激活)使用MLA + DeepSeekMoE,KV-Cache降低到极致,同时在MMLU上达到77.8%,证明了MLA路线的工程可行性。

六、线性注意力:理论突破

6.1 核函数近似的数学框架

线性注意力的核心思想来自一个数学技巧:如果将softmax近似为一个核函数,就可以将注意力计算重新排列,从O(n²)降为O(n)。

# 线性注意力的核近似

# 标准 softmax attention:
A_ij = softmax(QK^T)_ij = exp(Q_i · K_j) / Σ_j exp(Q_i · K_j)

# 核函数近似: exp(Q_i · K_j) ≈ φ(Q_i)^T ψ(K_j)
# 其中 φ, ψ 是特征映射函数

# 线性注意力:
Attention_linear(Q, K, V) = φ(Q) @ (ψ(K)^T @ V)
                           = φ(Q) @ S  # O(nd)
其中 S = ψ(K)^T @ V 可以递归计算

# 关键性质:
# 1. 计算顺序从 (QK^T)V 变为 Q(K^TV)
# 2. K^TV 可以预计算并递归更新
# 3. 复杂度从 O(n²d) 降为 O(nd)

# 常见核函数:
# Katharopoulos (2020): φ(x) = elu(x)  (exponential linear unit)
# Performer: φ(x) = exp(x)  (需要随机投影保证准确性)
# Linformer: 额外对 K,V 做降维 (K_hat = W_K @ K)

# 线性注意力的局限:
# 1. 无法精确还原 softmax attention 的结果
# 2. 在需要"全局搜索"的任务上表现较差
# 3. 无法动态学习"关注哪些位置"(核函数是固定的)

6.2 线性注意力的适用场景

线性注意力在以下场景表现出色:长序列任务(音频、视频处理,序列长度>>d_model时线性注意力优势明显)、实时流式处理(递归形式适合在线推理)、资源受限场景(边缘设备、手机端侧部署)。

但在需要精确全局注意力的任务(如精确检索、代码生成)中,标准注意力仍然是首选。线性注意力与标准注意力的差距在高选择性任务上尤为明显。

七、Flash Attention:硬件革命

7.1 IO复杂度分析

Flash Attention(Dao et al., 2022, 2023)的核心贡献不是改变注意力算法,而是重新排列计算顺序以优化硬件IO。

在现代GPU上,HBM(High Bandwidth Memory)的读写速度远慢于SRAM(On-chip Shared Memory)的计算速度。标准Attention需要反复读写HBM——对于每个token,需要将Q、K、V矩阵从HBM读到SRAM,计算注意力,再写回HBM。这使得计算单元大部分时间在等待内存IO。

7.2 Tiling算法

Flash Attention通过Tiling(分块)解决了这个问题:将Q、K、V矩阵分成小块(tiles),每次只将一块从HBM读到SRAM,在SRAM上完成该块的注意力计算,然后通过Online Softmax累加到最终结果。

# Flash Attention 2 的分块计算

# 标准 Attention (需要 O(n²) HBM 读写):
for i in range(n):
    for j in range(n):
        read Q[i], K[j], V[j] from HBM  ← IO瓶颈
        compute exp(QK^T) / Σexp(QK^T)
        write partial_result to HBM

# Flash Attention (只读写 O(n) 数据):
# Block size = 64~128 (SRAM 能容纳的大小)
for block_i in blocks(Q):
    for block_j in blocks(K, V):
        read Q_block, K_block, V_block from HBM  ← 一次性读入
        # 在 SRAM 上计算该块的 attention
        online_softmax_update(Q_block, K_block, V_block, state)
        # state 只维护 m(x), l(x), O 三个向量 (长度 n)
        write state to HBM  ← 只有 O(n) 数据

# IO 量对比 (n=4096, d=64, block=64):
# 标准: O(n²d) = 4096² × 64 ≈ 1G elements HBM 读写
# Flash: O(n²d / block_size) ≈ 1G / 64 = 16M elements  ← 约 60x 减少

# Flash Attention 3 (2024) 新增优化:
# 1. Warp Specialization: 不同 warp 处理不同 block,并行度更高
# 2. FP8 Support: 在 H100 FP8 单元上实现
# 3. 交错式 Tiling: 进一步减少 warp 间的同步开销

7.3 Flash Attention的革命性影响

Flash Attention不仅加速了标准Attention(通常2-4x),更重要的是,它使得在有限显存下训练更长序列成为可能。通过减少HBM访问量,Flash Attention可以将序列长度扩展2-4倍(取决于具体配置)。

八、稀疏注意力:选择性注意

8.1 稀疏注意力的统计发现

研究表明,Transformer的注意力模式具有高度的结构性。分析显示,大多数注意力分数集中在少数几个位置上——远距离token之间的注意力权重往往接近零。

这个发现催生了稀疏注意力:与其计算所有位置的注意力分数,不如只计算最有价值的那些位置。关键是确定哪些位置是"有价值的"。

8.2 BigBird与Longformer的方案

BigBird(Google, 2020)和Longformer(AllenAI, 2020)提出了三种稀疏注意力模式的组合:

  • 滑动窗口注意力(Window):每个token只attend到其相邻的w个token(w=512或1024)。符合局部性假设——大多数语言任务依赖局部上下文。
  • 全局注意力(Global):某些特殊token(如[CLS]或[SEP])attend到所有位置。这确保模型不会丢失全局信息。
  • 随机注意力(Random):每个token随机attend到r个位置(r=3~5)。引入随机连接以增强远距离信息传递。
# BigBird 稀疏注意力配置

# 对于每个 token,注意力目标包括:
# 1. 滑动窗口: ±w/2 个相邻 tokens (w = 512)
# 2. 全局 tokens: [CLS] + 随机 r 个 (r = 3)
# 3. 随机 tokens: 随机选择的 r 个 tokens

# 总注意力连接数:
# 窗口: w + 全局: 2r + 随机: r ≈ w + 3r
# 相比全连接的 n: 从 O(n²) 降到 O(n·w)

# Longformer 的 Dilated Attention:
# 滑动窗口之间留有空隙 (dilations)
# layer 0: dilation=1 (每个token attend到相邻token)
# layer 1: dilation=2 (每个token attend到间隔1的token)
# layer 2: dilation=4 (每个token attend到间隔3的token)
# → receptive field 指数增长: 窗口w=512时,3层可达 512×2³ = 4096

九、滑动窗口注意力

9.1 Mistral的实现

Mistral 7B(2023)是最早大规模采用滑动窗口注意力的开源LLM。其核心设计是:使用滑动窗口注意力(Window Attention)处理局部上下文,同时每隔一层插入一个全局注意力层来处理远距离依赖。

9.2 Attention Sink现象

StreamingLLM(Xiao et al., 2023)揭示了一个重要现象:LLM的注意力存在"注意力接收器"(Attention Sink)——某些token(通常是第一个token或初始token)获得了异常高的注意力权重。

这可能是因为第一个token在位置编码上占据特殊地位(作为全局位置信息的锚点),或者是因为它参与了大量的语义结构(如句子开头)。StreamingLLM利用这一发现,只缓存4个"锚点token"(第一个token + 最近的4个token),就可以维持模型质量。

十、膨胀注意力

10.1 膨胀注意力与卷积的类比

膨胀注意力(Dilated Attention)借鉴了膨胀卷积的思想:在标准滑动窗口之间引入空隙(dilations),使单个层的感受野指数级增长,同时保持计算量线性。

与Longformer的Dilated Attention不同,膨胀注意力通常应用于线性注意力框架(如RFormer、Hyena架构),追求理论最优的O(n log n)复杂度。

10.2 Hyena架构的多元多项式路线

Hyena(Manning et al., 2023)提出了另一种绕过注意力二次复杂度的路线:用多层Hankel矩阵乘法来近似长程依赖。这本质上是在频域(Fourier空间)处理信息,比标准注意力更高效。

十一、硬件感知的注意力

11.1 FlashDecoding

FlashDecoding(FlashAttention团队2023年提出)针对推理场景做了专门优化。核心洞察是:在推理时,KV-Cache是已知的,只需要为每个Query向量计算注意力。FlashDecoding将KV分成更大的块,利用并行性加速。

11.2 PagedAttention与vLLM

PagedAttention(vLLM)从操作系统的分页内存管理中汲取灵感,将KV-Cache分成固定大小的"页"(pages),动态分配给不同请求。这解决了传统方案中的内存碎片化问题。

# PagedAttention 的分页管理

# 传统 KV-Cache:
# 每个请求预分配完整的 seq_len 空间
# → 大量内部碎片(实际使用 < 50%)
# → 无法同时运行更多请求

# PagedAttention:
# KV-Cache 被分成固定大小的 blocks (如 16 tokens/block)
# 每个请求动态分配所需数量的 blocks
# blocks 可以非连续存储(通过 block table 映射)

# 示例:
# 请求1: 生成 500 tokens → 分配 32 blocks (16×16=512 capacity)
# 请求2: 生成 300 tokens → 分配 19 blocks (16×19=304 capacity)
# 总 KV-Cache: (32+19) × block_size × kv_dim
# vs 传统: max(500,300) × kv_dim (大量浪费)

# 效果:
# 在相同显存下,vLLM 可以服务的并发请求数提升 2-4x
# 这对在线推理服务至关重要

11.3 Tensor并行与注意力

对于超大规模模型,单GPU无法容纳所有参数。Tensor并行(Megatron-LM)将注意力矩阵按维度切分到多GPU。这带来了新的注意力优化机会:每个GPU只持有部分head的QKV,可以通过AllReduce高效地实现分布式注意力计算。

十二、变体综合对比

12.1 核心指标对比

理解各种注意力变体在不同维度上的权衡,是做出正确架构决策的基础。

变体KV-Cache复杂度计算复杂度质量工程成本适用场景
MHAO(n·h·d)O(n²·d)最高高显存研究基准
GQAO(n·g·d)O(n²·d)≈MHA中等LLaMA-2/3, Mistral
MQAO(n·d)O(n²·d)略低低显存极致推理优化
MLAO(n·c)O(n²·d)≈MHA中等DeepSeek-V2
FlashAttention同上游变体同上游变体无损需CUDA优化所有场景
线性注意力O(n·d)O(n·d)任务相关需特殊实现超长序列
稀疏注意力O(n·w)O(n·w·n)任务相关需稀疏实现长文档处理

注:h=注意力头数, g=KV头数(GQA), c=隐维度(MLA), w=窗口大小, d=隐藏维度, n=序列长度。

12.2 架构选择的决策树

基于以上分析,一个实用的注意力变体选择框架:

  • 通用场景:GQA(num_kv=8)是当前最优选择——LLaMA-2/3、Mistral、Qwen-2等主流模型的一致选择
  • 极致推理效率:MQA或极低num_kv_heads(如2~4),需验证质量损失可接受
  • 超长上下文(>100K):稀疏注意力或线性注意力 + RoPE扩展
  • 训练效率:Flash Attention是必选,与具体变体无关
  • 工程团队能力:MLA需要更复杂的实现,评估团队是否具备CUDA内核开发能力