一、自注意力的数学本质
自注意力(Self-Attention)将序列中每个位置与其他所有位置建立关联,其本质是对序列做全局信息聚合。理解其矩阵运算形式,是掌握Transformer一切优化的前提。
1.1 Scaled Dot-Product Attention
# 自注意力的数学形式
# Attention(Q, K, V) = softmax(QK^T / √d_k) × V
# Q(Query)、K(Key)、V(Value)的来源
# 输入X经过三个不同的线性变换得到
Q = X · W_Q # (batch, seq_len, d_model) → (batch, seq_len, d_k)
K = X · W_K # (batch, seq_len, d_model) → (batch, seq_len, d_k)
V = X · W_V # (batch, seq_len, d_model) → (batch, seq_len, d_v)
# 注意力分数矩阵
scores = Q @ K.transpose(-2, -1) / √d_k
# scores shape: (batch, seq_len, seq_len)
# 每个token对所有其他token的注意力权重
# softmax归一化
attn_weights = softmax(scores, dim=-1)
# 加权求和
output = attn_weights @ V
# output shape: (batch, seq_len, d_v) = (batch, seq_len, d_model)
# ⚠️ 为什么除以 √d_k?
# 点积的方差随d_k增大 → softmax梯度消失
# 除以√d_k使方差稳定在1
二、多头注意力与特征空间分离
2.1 多头注意力的设计动机
# 单头注意力的局限:
# 一次注意力操作只能学到一种"相关性模式"
# 例如:主语-谓语关系、语义相似性、位置关系
# 多头让模型同时学习多种关系
# 多头注意力(Multi-Head Attention)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 每个头有独立的W_Q, W_K, W_V
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model) # 输出的线性变换
def split_heads(self, x, batch_size):
# 将d_model维度拆分为num_heads个
# (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
x = x.reshape(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换 + 分头
Q = self.split_heads(self.W_Q(query), batch_size)
K = self.split_heads(self.W_K(key), batch_size)
V = self.split_heads(self.W_V(value), batch_size)
# 单头注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# 合并多头:(batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
context = context.transpose(1, 2).contiguous()
context = context.reshape(batch_size, -1, self.num_heads * self.d_k)
return self.W_O(context)
2.2 位置编码:让序列有序
# 自注意力是位置无关的(Permutation Invariant)
# 需要显式注入位置信息
# Transformer原版的正弦位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# 正弦编码的优势:
# ① 任意两个位置的相对距离可线性表示(sin/cos线性组合)
# ② 外推到训练时未见过的序列长度
# ③ 无需学习参数(位置信息硬编码)
三、计算复杂度与Flash Attention
3.1 标准注意力的O(n²)问题
# 标准注意力的计算瓶颈:
# scores = Q @ K^T: (n, d) @ (d, n) → O(n²d)
# 其中n为序列长度,d为模型维度
# 问题:n=8192时,n²d计算量爆炸
# 典型LLM上下文长度:
# GPT-3: n=2048, n²=4M
# GPT-4: n=32K, n²=1B ← 内存成为瓶颈
# 标准实现的空间复杂度:O(n²)
# attention_weights: (n, n) float16 = 8192² × 2B = 128MB
# 如果batch=32, 128×32=4GB显存仅存注意力矩阵!
# Flash Attention的核心思想:不一次计算完整注意力矩阵
# 分块计算 + 在线softmax → 显存从O(n²)降至O(n)
# 原理:分块计算softmax的分子分母,逐块累加
3.2 Flash Attention V2的实现原理
# Flash Attention V2伪代码(简化)
def flash_attention(Q, K, V, BLOCK_M=64, BLOCK_N=64):
# Q, K, V分块加载到SRAM(高速缓存)
# 分块计算注意力,不生成完整矩阵
# 外循环:遍历Q的块
for i in range(0, n, BLOCK_M):
# 加载Q的第i块到SRAM
Qi = Q[i:i+BLOCK_M]
# 初始化:row_max和row_sum(用于在线softmax)
mi = torch.full((BLOCK_M,), -float('inf'), device=Q.device)
zi = torch.zeros((BLOCK_M,), device=Q.device)
Oi = torch.zeros((BLOCK_M, d), device=Q.device)
# 内循环:遍历K的块
for j in range(0, n, BLOCK_N):
# 加载K、V的第j块
Kj = K[j:j+BLOCK_N]
Vj = V[j:j+BLOCK_N]
# 计算当前块的注意力
Sij = Qi @ Kj.T # (BLOCK_M, BLOCK_N)
mij = Sij.max(dim=-1).values # 当前块行最大值
# 在线softmax更新:
# m_next = max(m_old, m_new)
# z_next = z_old * exp(m_old - m_next) + exp(m_new - m_next)
m_next = torch.maximum(mi, mij)
z_prev = zi * torch.exp(mi - m_next)
z_curr = torch.exp(mij - m_next)
zi = z_prev + z_curr
# 更新输出
Pij = torch.exp(Sij - m_next.unsqueeze(-1))
Oi = (Oi * z_prev.unsqueeze(-1) + Pij @ Vj) / zi.unsqueeze(-1)
mi = m_next
Oi_final[i:i+BLOCK_M] = Oi
return Oi_final
# 性能提升对比(A100, seq_len=4098, batch=1)
# 标准attention: 峰值显存 45GB, 耗时 620ms
# Flash Attention v2: 峰值显存 8GB, 耗时 89ms
# 提速 7x,显存节省 5.6x
四、注意力机制的变种与优化
# ① Grouped Query Attention(GQA)— LLaMA2/3/Mistral使用
# Query分组:h个query头,kv_num个KV头(kv_num < h)
# 减少KV读取次数,适合长上下文
# ② Multi-Query Attention(MQA)— 极端版,所有Query共享1个KV头
# 更激进的显存压缩,但质量可能下降
# ③ Sparse Attention(滑动窗口+全局+随机)
# BigBird/Longformer:在n²注意力中只计算部分位置
# 局部窗口(前后各w个token)+ 少量全局token + 随机采样
# ④ Linear Attention(线性近似)
# 将softmax(QK^T) ≈ φ(Q)φ(K)^T
# 将O(n²d)降至O(nd²),但表达能力受限
# ⑤ Ring Attention(分布式长序列)
# 多GPU协同计算,梯度在环上流动
# 每个GPU计算序列的一个片段的注意力
# 用于百万级token上下文训练