架构视角:理解LLM的设计哲学

大语言模型(LLM)的崛起标志着自然语言处理进入了一个新时代。从GPT到LLaMA,从Claude到Qwen,这些模型虽然在规模和训练数据上有所不同,但其底层架构都建立在Transformer的基础之上。本文将从架构师视角,深入剖析LLM的核心组件与设计原理。

LLM架构核心设计原则

  • 可扩展性:架构支持从百万到千亿参数的平滑扩展
  • 并行化:摒弃序列依赖,支持高效的并行训练
  • 上下文感知:通过注意力机制捕捉长距离依赖关系
  • 通用性:统一架构适用于多种任务和模态

Transformer:LLM的基石

注意力机制:核心创新

注意力机制是Transformer的灵魂,它让模型能够动态地关注输入序列的不同部分:

import torch
import torch.nn as nn
import math

class ScaledDotProductAttention(nn.Module):
    """
    缩放点积注意力机制
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    """
    
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len, d_k)
            key: (batch, seq_len, d_k)
            value: (batch, seq_len, d_v)
            mask: (batch, seq_len, seq_len) 用于屏蔽某些位置
        
        Returns:
            output: (batch, seq_len, d_v)
            attention_weights: (batch, seq_len, seq_len)
        """
        d_k = query.size(-1)
        
        # 1. 计算注意力分数: Q * K^T
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        # scores: (batch, seq_len, seq_len)
        
        # 2. 应用掩码(如因果掩码)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # 3. Softmax归一化
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 4. 加权求和: attention_weights * V
        output = torch.matmul(attention_weights, value)
        
        return output, attention_weights


class MultiHeadAttention(nn.Module):
    """
    多头注意力机制
    将注意力机制并行执行多次,捕捉不同子空间的信息
    """
    
    def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(dropout)
        self.dropout = nn.Dropout(dropout)
    
    def split_heads(self, x):
        """
        将张量分割成多个头
        (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
        """
        batch_size, seq_len, _ = x.size()
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        """
        合并多个头的输出
        (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        """
        batch_size, _, seq_len, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
    
    def forward(self, query, key, value, mask=None):
        # 1. 线性投影并分头
        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))
        
        # 2. 执行注意力计算
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        
        # 3. 合并头并线性投影
        output = self.W_o(self.combine_heads(attn_output))
        
        return self.dropout(output), attn_weights

多头注意力的优势

  • 多视角表示:不同头可以关注不同的特征子空间
  • 并行计算:各头计算相互独立,可高效并行
  • 表达能力:多个头的组合增强了模型的表达能力

位置编码:注入序列顺序信息

class PositionalEncoding(nn.Module):
    """
    正弦位置编码
    为模型提供序列中每个位置的绝对位置信息
    """
    
    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # 预计算位置编码
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        
        # 使用不同频率的正弦和余弦函数
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 注册为buffer(不参与训练)
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)


class RotaryPositionalEmbedding(nn.Module):
    """
    RoPE (Rotary Position Embedding)
    LLaMA等现代模型使用的位置编码方式
    通过旋转矩阵编码相对位置信息
    """
    
    def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
        super().__init__()
        
        # 计算旋转角度
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # 预计算位置编码
        t = torch.arange(max_seq_len)
        freqs = torch.einsum('i,j->ij', t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        
        self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
        self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
    
    def rotate_half(self, x):
        """旋转张量的一半维度"""
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        return torch.cat((-x2, x1), dim=-1)
    
    def forward(self, q, k, seq_len):
        """
        应用旋转位置编码到Q和K
        """
        cos = self.cos_cached[:, :, :seq_len, :]
        sin = self.sin_cached[:, :, :seq_len, :]
        
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        
        return q_embed, k_embed

LLM架构变体对比

三种主要架构范式

架构类型 代表模型 注意力方向 主要用途
Encoder-only BERT, RoBERTa, DeBERTa 双向 理解任务(分类、NER)
Decoder-only GPT, LLaMA, Claude 因果(单向) 生成任务(文本生成)
Encoder-Decoder T5, BART, UL2 编码器双向,解码器因果 翻译、摘要、转换

Decoder-only架构详解(GPT系列)

class DecoderBlock(nn.Module):
    """
    Transformer解码器块
    包含掩码自注意力、交叉注意力(可选)和前馈网络
    """
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        # 掩码自注意力(因果注意力)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        
        # 前馈网络
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None):
        # 自注意力子层(带残差连接和层归一化)
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_output)
        
        # 前馈子层
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        
        return x


class GPTModel(nn.Module):
    """
    GPT风格的大语言模型
    Decoder-only架构,自回归生成
    """
    
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 768,
        num_layers: int = 12,
        num_heads: int = 12,
        d_ff: int = 3072,
        max_seq_len: int = 1024,
        dropout: float = 0.1,
        use_rope: bool = True
    ):
        super().__init__()
        
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        # 词嵌入
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        
        # 位置编码
        if use_rope:
            self.rope = RotaryPositionalEmbedding(d_model // num_heads, max_seq_len)
            self.pos_encoding = None
        else:
            self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
            self.rope = None
        
        # Transformer解码器层堆叠
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # 最终层归一化
        self.norm = nn.LayerNorm(d_model)
        
        # 语言模型头
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # 权重共享(输入嵌入与输出投影共享权重)
        self.token_embedding.weight = self.lm_head.weight
        
        self.dropout = nn.Dropout(dropout)
        
        # 初始化权重
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def create_causal_mask(self, seq_len):
        """
        创建因果(三角)掩码
        确保模型只能看到当前位置及之前的位置
        """
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return ~mask  # 返回可参加注意力的位置
    
    def forward(self, input_ids, labels=None):
        batch_size, seq_len = input_ids.shape
        
        # 词嵌入
        x = self.token_embedding(input_ids)
        x = self.dropout(x)
        
        # 位置编码
        if self.pos_encoding:
            x = self.pos_encoding(x)
        
        # 创建因果掩码
        mask = self.create_causal_mask(seq_len).to(input_ids.device)
        
        # 通过Transformer层
        for layer in self.layers:
            x = layer(x, mask)
        
        x = self.norm(x)
        
        # 语言模型预测
        logits = self.lm_head(x)
        
        loss = None
        if labels is not None:
            # 计算交叉熵损失
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        return {'logits': logits, 'loss': loss}

现代LLM的关键优化技术

1. Layer Normalization改进

class PreNormDecoderBlock(nn.Module):
    """
    Pre-LayerNorm架构
    相比Post-LayerNorm,训练更稳定,可以使用更大的学习率
    """
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        
        self.norm2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, mask=None):
        # Pre-Norm:先归一化,再计算
        x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)[0]
        x = x + self.feed_forward(self.norm2(x))
        return x


class RMSNorm(nn.Module):
    """
    RMSNorm (Root Mean Square Layer Normalization)
    LLaMA使用的归一化方式,相比LayerNorm计算更高效
    """
    
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        # RMSNorm: x / sqrt(mean(x^2) + eps) * weight
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return norm * self.weight

2. 激活函数:SwiGLU

class SwiGLU(nn.Module):
    """
    SwiGLU激活函数
    PaLM、LLaMA 2等模型使用,相比ReLU/GELU表现更好
    """
    
    def __init__(self, dim: int):
        super().__init__()
        # SwiGLU需要2/3倍的中间维度
        self.w1 = nn.Linear(dim, dim * 8 // 3, bias=False)
        self.w2 = nn.Linear(dim, dim * 8 // 3, bias=False)
        self.w3 = nn.Linear(dim * 8 // 3, dim, bias=False)
    
    def forward(self, x):
        # SwiGLU(x) = Swish(xW) ⊙ (xV) W2
        # 其中 Swish(x) = x * sigmoid(βx),通常β=1
        return self.w3(nn.functional.silu(self.w1(x)) * self.w2(x))


class FeedForwardSwiGLU(nn.Module):
    """
    使用SwiGLU的前馈网络
    """
    
    def __init__(self, d_model: int, hidden_dim: int = None, dropout: float = 0.1):
        super().__init__()
        if hidden_dim is None:
            # SwiGLU使用不同的维度计算
            hidden_dim = int(2 / 3 * 4 * d_model)
        
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(self.w3(nn.functional.silu(self.w1(x)) * self.w2(x)))

3. 分组查询注意力(GQA)

class GroupedQueryAttention(nn.Module):
    """
    Grouped Query Attention (GQA)
    LLaMA 2使用,平衡多查询注意力(MQA)的效率和多头注意力的质量
    """
    
    def __init__(self, d_model: int, num_heads: int, num_kv_heads: int, dropout: float = 0.1):
        super().__init__()
        assert num_heads % num_kv_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_queries_per_kv = num_heads // num_kv_heads
        self.d_k = d_model // num_heads
        
        # Q投影保持num_heads个头
        self.W_q = nn.Linear(d_model, d_model)
        # K、V投影使用较少的头
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, _ = query.size()
        
        # 投影
        Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k)
        K = self.W_k(key).view(batch_size, -1, self.num_kv_heads, self.d_k)
        V = self.W_v(value).view(batch_size, -1, self.num_kv_heads, self.d_k)
        
        # 扩展K、V以匹配Q的头数
        K = K.repeat_interleave(self.num_queries_per_kv, dim=2)
        V = V.repeat_interleave(self.num_queries_per_kv, dim=2)
        
        # 转置为(batch, num_heads, seq_len, d_k)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        output = torch.matmul(attn_weights, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(output), attn_weights

推理优化技术

KV缓存

class KVCache:
    """
    键值缓存优化
    避免在自回归生成中重复计算已生成token的K、V
    """
    
    def __init__(self, max_batch_size: int, max_seq_len: int, num_heads: int, head_dim: int):
        self.k_cache = torch.zeros(max_batch_size, num_heads, max_seq_len, head_dim)
        self.v_cache = torch.zeros(max_batch_size, num_heads, max_seq_len, head_dim)
        self.current_len = 0
    
    def update(self, k: torch.Tensor, v: torch.Tensor):
        """
        更新缓存
        k, v shape: (batch, num_heads, seq_len, head_dim)
        """
        seq_len = k.size(2)
        self.k_cache[:, :, self.current_len:self.current_len+seq_len] = k
        self.v_cache[:, :, self.current_len:self.current_len+seq_len] = v
        self.current_len += seq_len
    
    def get(self) -> tuple[torch.Tensor, torch.Tensor]:
        """获取缓存的K、V"""
        return self.k_cache[:, :, :self.current_len], self.v_cache[:, :, :self.current_len]


class OptimizedDecoderBlock(nn.Module):
    """
    带KV缓存优化的解码器块
    """
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = FeedForwardSwiGLU(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, mask=None, kv_cache=None, use_cache=False):
        # 注意力计算
        normed = self.norm1(x)
        attn_output, new_kv_cache = self.self_attn(
            normed, normed, normed, 
            mask=mask,
            kv_cache=kv_cache,
            use_cache=use_cache
        )
        x = x + attn_output
        
        # 前馈网络
        x = x + self.ff(self.norm2(x))
        
        return x, new_kv_cache

LLM架构设计注意事项

  • 忽视内存带宽:大模型推理受限于内存带宽,而非计算
  • 位置编码错误:外推长度超过训练长度时性能急剧下降
  • 数值稳定性:大模型训练需要 careful 的初始化和平滑技术
  • 忽略量化影响:低精度量化可能影响特定任务表现

架构决策总结

组件 传统方案 现代推荐 理由
归一化 Post-LayerNorm Pre-RMSNorm 训练更稳定,计算更快
激活函数 ReLU/GELU SwiGLU 表达能力更强
位置编码 绝对位置编码 RoPE 更好的长度外推
注意力 MHA GQA 平衡效率与质量
词表 BPE 32K SentencePiece 128K+ 多语言支持更好

总结

LLM架构的设计是一门平衡的艺术。从Transformer的基础设计,到Pre-Norm、RoPE、SwiGLU、GQA等现代优化,每一步改进都在追求更高的效率、更好的性能和更强的能力。

理解这些架构细节,不仅能帮助我们更好地使用和优化现有模型,也为未来设计下一代架构提供了基础。随着硬件的发展和理论的突破,LLM架构还将继续演进,但注意力机制、并行化设计、可扩展性这些核心原则将长期适用。