架构视角:理解LLM的设计哲学
大语言模型(LLM)的崛起标志着自然语言处理进入了一个新时代。从GPT到LLaMA,从Claude到Qwen,这些模型虽然在规模和训练数据上有所不同,但其底层架构都建立在Transformer的基础之上。本文将从架构师视角,深入剖析LLM的核心组件与设计原理。
LLM架构核心设计原则
- 可扩展性:架构支持从百万到千亿参数的平滑扩展
- 并行化:摒弃序列依赖,支持高效的并行训练
- 上下文感知:通过注意力机制捕捉长距离依赖关系
- 通用性:统一架构适用于多种任务和模态
Transformer:LLM的基石
注意力机制:核心创新
注意力机制是Transformer的灵魂,它让模型能够动态地关注输入序列的不同部分:
import torch
import torch.nn as nn
import math
class ScaledDotProductAttention(nn.Module):
"""
缩放点积注意力机制
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
"""
def __init__(self, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
Args:
query: (batch, seq_len, d_k)
key: (batch, seq_len, d_k)
value: (batch, seq_len, d_v)
mask: (batch, seq_len, seq_len) 用于屏蔽某些位置
Returns:
output: (batch, seq_len, d_v)
attention_weights: (batch, seq_len, seq_len)
"""
d_k = query.size(-1)
# 1. 计算注意力分数: Q * K^T
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# scores: (batch, seq_len, seq_len)
# 2. 应用掩码(如因果掩码)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 3. Softmax归一化
attention_weights = torch.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# 4. 加权求和: attention_weights * V
output = torch.matmul(attention_weights, value)
return output, attention_weights
class MultiHeadAttention(nn.Module):
"""
多头注意力机制
将注意力机制并行执行多次,捕捉不同子空间的信息
"""
def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x):
"""
将张量分割成多个头
(batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
"""
合并多个头的输出
(batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
# 1. 线性投影并分头
Q = self.split_heads(self.W_q(query))
K = self.split_heads(self.W_k(key))
V = self.split_heads(self.W_v(value))
# 2. 执行注意力计算
attn_output, attn_weights = self.attention(Q, K, V, mask)
# 3. 合并头并线性投影
output = self.W_o(self.combine_heads(attn_output))
return self.dropout(output), attn_weights
多头注意力的优势
- 多视角表示:不同头可以关注不同的特征子空间
- 并行计算:各头计算相互独立,可高效并行
- 表达能力:多个头的组合增强了模型的表达能力
位置编码:注入序列顺序信息
class PositionalEncoding(nn.Module):
"""
正弦位置编码
为模型提供序列中每个位置的绝对位置信息
"""
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# 预计算位置编码
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
# 使用不同频率的正弦和余弦函数
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# 注册为buffer(不参与训练)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class RotaryPositionalEmbedding(nn.Module):
"""
RoPE (Rotary Position Embedding)
LLaMA等现代模型使用的位置编码方式
通过旋转矩阵编码相对位置信息
"""
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
# 计算旋转角度
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# 预计算位置编码
t = torch.arange(max_seq_len)
freqs = torch.einsum('i,j->ij', t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, None, :, :])
self.register_buffer('sin_cached', emb.sin()[None, None, :, :])
def rotate_half(self, x):
"""旋转张量的一半维度"""
x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q, k, seq_len):
"""
应用旋转位置编码到Q和K
"""
cos = self.cos_cached[:, :, :seq_len, :]
sin = self.sin_cached[:, :, :seq_len, :]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed
LLM架构变体对比
三种主要架构范式
| 架构类型 | 代表模型 | 注意力方向 | 主要用途 |
|---|---|---|---|
| Encoder-only | BERT, RoBERTa, DeBERTa | 双向 | 理解任务(分类、NER) |
| Decoder-only | GPT, LLaMA, Claude | 因果(单向) | 生成任务(文本生成) |
| Encoder-Decoder | T5, BART, UL2 | 编码器双向,解码器因果 | 翻译、摘要、转换 |
Decoder-only架构详解(GPT系列)
class DecoderBlock(nn.Module):
"""
Transformer解码器块
包含掩码自注意力、交叉注意力(可选)和前馈网络
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
# 掩码自注意力(因果注意力)
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
# 前馈网络
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# 自注意力子层(带残差连接和层归一化)
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + attn_output)
# 前馈子层
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output)
return x
class GPTModel(nn.Module):
"""
GPT风格的大语言模型
Decoder-only架构,自回归生成
"""
def __init__(
self,
vocab_size: int,
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
d_ff: int = 3072,
max_seq_len: int = 1024,
dropout: float = 0.1,
use_rope: bool = True
):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
# 词嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
if use_rope:
self.rope = RotaryPositionalEmbedding(d_model // num_heads, max_seq_len)
self.pos_encoding = None
else:
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
self.rope = None
# Transformer解码器层堆叠
self.layers = nn.ModuleList([
DecoderBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# 最终层归一化
self.norm = nn.LayerNorm(d_model)
# 语言模型头
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# 权重共享(输入嵌入与输出投影共享权重)
self.token_embedding.weight = self.lm_head.weight
self.dropout = nn.Dropout(dropout)
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def create_causal_mask(self, seq_len):
"""
创建因果(三角)掩码
确保模型只能看到当前位置及之前的位置
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
return ~mask # 返回可参加注意力的位置
def forward(self, input_ids, labels=None):
batch_size, seq_len = input_ids.shape
# 词嵌入
x = self.token_embedding(input_ids)
x = self.dropout(x)
# 位置编码
if self.pos_encoding:
x = self.pos_encoding(x)
# 创建因果掩码
mask = self.create_causal_mask(seq_len).to(input_ids.device)
# 通过Transformer层
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
# 语言模型预测
logits = self.lm_head(x)
loss = None
if labels is not None:
# 计算交叉熵损失
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return {'logits': logits, 'loss': loss}
现代LLM的关键优化技术
1. Layer Normalization改进
class PreNormDecoderBlock(nn.Module):
"""
Pre-LayerNorm架构
相比Post-LayerNorm,训练更稳定,可以使用更大的学习率
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
# Pre-Norm:先归一化,再计算
x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)[0]
x = x + self.feed_forward(self.norm2(x))
return x
class RMSNorm(nn.Module):
"""
RMSNorm (Root Mean Square Layer Normalization)
LLaMA使用的归一化方式,相比LayerNorm计算更高效
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# RMSNorm: x / sqrt(mean(x^2) + eps) * weight
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return norm * self.weight
2. 激活函数:SwiGLU
class SwiGLU(nn.Module):
"""
SwiGLU激活函数
PaLM、LLaMA 2等模型使用,相比ReLU/GELU表现更好
"""
def __init__(self, dim: int):
super().__init__()
# SwiGLU需要2/3倍的中间维度
self.w1 = nn.Linear(dim, dim * 8 // 3, bias=False)
self.w2 = nn.Linear(dim, dim * 8 // 3, bias=False)
self.w3 = nn.Linear(dim * 8 // 3, dim, bias=False)
def forward(self, x):
# SwiGLU(x) = Swish(xW) ⊙ (xV) W2
# 其中 Swish(x) = x * sigmoid(βx),通常β=1
return self.w3(nn.functional.silu(self.w1(x)) * self.w2(x))
class FeedForwardSwiGLU(nn.Module):
"""
使用SwiGLU的前馈网络
"""
def __init__(self, d_model: int, hidden_dim: int = None, dropout: float = 0.1):
super().__init__()
if hidden_dim is None:
# SwiGLU使用不同的维度计算
hidden_dim = int(2 / 3 * 4 * d_model)
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w3(nn.functional.silu(self.w1(x)) * self.w2(x)))
3. 分组查询注意力(GQA)
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA)
LLaMA 2使用,平衡多查询注意力(MQA)的效率和多头注意力的质量
"""
def __init__(self, d_model: int, num_heads: int, num_kv_heads: int, dropout: float = 0.1):
super().__init__()
assert num_heads % num_kv_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = num_heads // num_kv_heads
self.d_k = d_model // num_heads
# Q投影保持num_heads个头
self.W_q = nn.Linear(d_model, d_model)
# K、V投影使用较少的头
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, _ = query.size()
# 投影
Q = self.W_q(query).view(batch_size, seq_len, self.num_heads, self.d_k)
K = self.W_k(key).view(batch_size, -1, self.num_kv_heads, self.d_k)
V = self.W_v(value).view(batch_size, -1, self.num_kv_heads, self.d_k)
# 扩展K、V以匹配Q的头数
K = K.repeat_interleave(self.num_queries_per_kv, dim=2)
V = V.repeat_interleave(self.num_queries_per_kv, dim=2)
# 转置为(batch, num_heads, seq_len, d_k)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# 注意力计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(output), attn_weights
推理优化技术
KV缓存
class KVCache:
"""
键值缓存优化
避免在自回归生成中重复计算已生成token的K、V
"""
def __init__(self, max_batch_size: int, max_seq_len: int, num_heads: int, head_dim: int):
self.k_cache = torch.zeros(max_batch_size, num_heads, max_seq_len, head_dim)
self.v_cache = torch.zeros(max_batch_size, num_heads, max_seq_len, head_dim)
self.current_len = 0
def update(self, k: torch.Tensor, v: torch.Tensor):
"""
更新缓存
k, v shape: (batch, num_heads, seq_len, head_dim)
"""
seq_len = k.size(2)
self.k_cache[:, :, self.current_len:self.current_len+seq_len] = k
self.v_cache[:, :, self.current_len:self.current_len+seq_len] = v
self.current_len += seq_len
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
"""获取缓存的K、V"""
return self.k_cache[:, :, :self.current_len], self.v_cache[:, :, :self.current_len]
class OptimizedDecoderBlock(nn.Module):
"""
带KV缓存优化的解码器块
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.ff = FeedForwardSwiGLU(d_model, d_ff, dropout)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None, kv_cache=None, use_cache=False):
# 注意力计算
normed = self.norm1(x)
attn_output, new_kv_cache = self.self_attn(
normed, normed, normed,
mask=mask,
kv_cache=kv_cache,
use_cache=use_cache
)
x = x + attn_output
# 前馈网络
x = x + self.ff(self.norm2(x))
return x, new_kv_cache
LLM架构设计注意事项
- ❌ 忽视内存带宽:大模型推理受限于内存带宽,而非计算
- ❌ 位置编码错误:外推长度超过训练长度时性能急剧下降
- ❌ 数值稳定性:大模型训练需要 careful 的初始化和平滑技术
- ❌ 忽略量化影响:低精度量化可能影响特定任务表现
架构决策总结
| 组件 | 传统方案 | 现代推荐 | 理由 |
|---|---|---|---|
| 归一化 | Post-LayerNorm | Pre-RMSNorm | 训练更稳定,计算更快 |
| 激活函数 | ReLU/GELU | SwiGLU | 表达能力更强 |
| 位置编码 | 绝对位置编码 | RoPE | 更好的长度外推 |
| 注意力 | MHA | GQA | 平衡效率与质量 |
| 词表 | BPE 32K | SentencePiece 128K+ | 多语言支持更好 |
总结
LLM架构的设计是一门平衡的艺术。从Transformer的基础设计,到Pre-Norm、RoPE、SwiGLU、GQA等现代优化,每一步改进都在追求更高的效率、更好的性能和更强的能力。
理解这些架构细节,不仅能帮助我们更好地使用和优化现有模型,也为未来设计下一代架构提供了基础。随着硬件的发展和理论的突破,LLM架构还将继续演进,但注意力机制、并行化设计、可扩展性这些核心原则将长期适用。