一、自注意力的数学本质

自注意力(Self-Attention)将序列中每个位置与其他所有位置建立关联,其本质是对序列做全局信息聚合。理解其矩阵运算形式,是掌握Transformer一切优化的前提。

1.1 Scaled Dot-Product Attention

# 自注意力的数学形式
# Attention(Q, K, V) = softmax(QK^T / √d_k) × V

# Q(Query)、K(Key)、V(Value)的来源
# 输入X经过三个不同的线性变换得到
Q = X · W_Q  # (batch, seq_len, d_model) → (batch, seq_len, d_k)
K = X · W_K  # (batch, seq_len, d_model) → (batch, seq_len, d_k)
V = X · W_V  # (batch, seq_len, d_model) → (batch, seq_len, d_v)

# 注意力分数矩阵
scores = Q @ K.transpose(-2, -1) / √d_k
# scores shape: (batch, seq_len, seq_len)
# 每个token对所有其他token的注意力权重

# softmax归一化
attn_weights = softmax(scores, dim=-1)

# 加权求和
output = attn_weights @ V
# output shape: (batch, seq_len, d_v) = (batch, seq_len, d_model)

# ⚠️ 为什么除以 √d_k?
# 点积的方差随d_k增大 → softmax梯度消失
# 除以√d_k使方差稳定在1

二、多头注意力与特征空间分离

2.1 多头注意力的设计动机

# 单头注意力的局限:
# 一次注意力操作只能学到一种"相关性模式"
# 例如:主语-谓语关系、语义相似性、位置关系
# 多头让模型同时学习多种关系

# 多头注意力(Multi-Head Attention)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # 每个头有独立的W_Q, W_K, W_V
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)  # 输出的线性变换

    def split_heads(self, x, batch_size):
        # 将d_model维度拆分为num_heads个
        # (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
        x = x.reshape(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 线性变换 + 分头
        Q = self.split_heads(self.W_Q(query), batch_size)
        K = self.split_heads(self.W_K(key), batch_size)
        V = self.split_heads(self.W_V(value), batch_size)

        # 单头注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, V)

        # 合并多头:(batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous()
        context = context.reshape(batch_size, -1, self.num_heads * self.d_k)

        return self.W_O(context)

2.2 位置编码:让序列有序

# 自注意力是位置无关的(Permutation Invariant)
# 需要显式注入位置信息

# Transformer原版的正弦位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# 正弦编码的优势:
# ① 任意两个位置的相对距离可线性表示(sin/cos线性组合)
# ② 外推到训练时未见过的序列长度
# ③ 无需学习参数(位置信息硬编码)

三、计算复杂度与Flash Attention

3.1 标准注意力的O(n²)问题

# 标准注意力的计算瓶颈:
# scores = Q @ K^T: (n, d) @ (d, n) → O(n²d)
# 其中n为序列长度,d为模型维度

# 问题:n=8192时,n²d计算量爆炸
# 典型LLM上下文长度:
# GPT-3: n=2048, n²=4M
# GPT-4: n=32K, n²=1B  ← 内存成为瓶颈

# 标准实现的空间复杂度:O(n²)
# attention_weights: (n, n) float16 = 8192² × 2B = 128MB
# 如果batch=32, 128×32=4GB显存仅存注意力矩阵!

# Flash Attention的核心思想:不一次计算完整注意力矩阵
# 分块计算 + 在线softmax → 显存从O(n²)降至O(n)
# 原理:分块计算softmax的分子分母,逐块累加

3.2 Flash Attention V2的实现原理

# Flash Attention V2伪代码(简化)
def flash_attention(Q, K, V, BLOCK_M=64, BLOCK_N=64):
    # Q, K, V分块加载到SRAM(高速缓存)
    # 分块计算注意力,不生成完整矩阵

    # 外循环:遍历Q的块
    for i in range(0, n, BLOCK_M):
        # 加载Q的第i块到SRAM
        Qi = Q[i:i+BLOCK_M]

        # 初始化:row_max和row_sum(用于在线softmax)
        mi = torch.full((BLOCK_M,), -float('inf'), device=Q.device)
        zi = torch.zeros((BLOCK_M,), device=Q.device)
        Oi = torch.zeros((BLOCK_M, d), device=Q.device)

        # 内循环:遍历K的块
        for j in range(0, n, BLOCK_N):
            # 加载K、V的第j块
            Kj = K[j:j+BLOCK_N]
            Vj = V[j:j+BLOCK_N]

            # 计算当前块的注意力
            Sij = Qi @ Kj.T  # (BLOCK_M, BLOCK_N)
            mij = Sij.max(dim=-1).values  # 当前块行最大值

            # 在线softmax更新:
            # m_next = max(m_old, m_new)
            # z_next = z_old * exp(m_old - m_next) + exp(m_new - m_next)
            m_next = torch.maximum(mi, mij)
            z_prev = zi * torch.exp(mi - m_next)
            z_curr = torch.exp(mij - m_next)
            zi = z_prev + z_curr

            # 更新输出
            Pij = torch.exp(Sij - m_next.unsqueeze(-1))
            Oi = (Oi * z_prev.unsqueeze(-1) + Pij @ Vj) / zi.unsqueeze(-1)
            mi = m_next

        Oi_final[i:i+BLOCK_M] = Oi

    return Oi_final

# 性能提升对比(A100, seq_len=4098, batch=1)
# 标准attention: 峰值显存 45GB, 耗时 620ms
# Flash Attention v2: 峰值显存 8GB, 耗时 89ms
# 提速 7x,显存节省 5.6x

四、注意力机制的变种与优化

# ① Grouped Query Attention(GQA)— LLaMA2/3/Mistral使用
# Query分组:h个query头,kv_num个KV头(kv_num < h)
# 减少KV读取次数,适合长上下文

# ② Multi-Query Attention(MQA)— 极端版,所有Query共享1个KV头
# 更激进的显存压缩,但质量可能下降

# ③ Sparse Attention(滑动窗口+全局+随机)
# BigBird/Longformer:在n²注意力中只计算部分位置
# 局部窗口(前后各w个token)+ 少量全局token + 随机采样

# ④ Linear Attention(线性近似)
# 将softmax(QK^T) ≈ φ(Q)φ(K)^T
# 将O(n²d)降至O(nd²),但表达能力受限

# ⑤ Ring Attention(分布式长序列)
# 多GPU协同计算,梯度在环上流动
# 每个GPU计算序列的一个片段的注意力
# 用于百万级token上下文训练