Skip to content

Transformer 架构详解

Transformer 是大语言模型(LLM)的核心架构。本文深入解析其原理和实现。

架构概览

Transformer 由 Vaswani 等人在 2017 年提出,完全基于注意力机制,摒弃了传统的 RNN 和 CNN。

主要组件

  • Self-Attention(自注意力)
  • Multi-Head Attention(多头注意力)
  • Positional Encoding(位置编码)
  • Feed-Forward Network(前馈网络)
  • Layer Normalization(层归一化)
  • Residual Connection(残差连接)

Self-Attention 机制

核心思想

Self-Attention 允许序列中的每个位置关注序列中的所有其他位置,计算上下文相关的表示。

数学公式

Attention(Q, K, V) = softmax(QK^T / √d_k) V

其中:

  • Q (Query): 查询向量
  • K (Key): 键向量
  • V (Value): 值向量
  • d_k: Key 的维度(用于缩放)

计算步骤

  1. 计算注意力分数:Q 与每个 K 做点积
  2. 缩放并归一化:除以 √d_k 后经过 softmax
  3. 加权求和:用注意力权重对 V 加权

Multi-Head Attention

为什么需要多头?

单个注意力头可能只关注一种模式。多头允许模型同时关注不同类型的关系:

  • 语法关系
  • 语义关系
  • 指代关系
  • 等等

实现

python
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # Q, K, V 投影层
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # 投影并分割多头
        Q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(scores, dim=-1)

        # 加权求和
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        return self.out_proj(out)

Positional Encoding

Transformer 没有循环结构,必须显式注入位置信息。

正弦位置编码

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中 pos 是位置,i 是维度索引。

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Transformer 架构

Encoder Block

python
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # 残差 + 层归一化 + 注意力
        attn_out = self.attn(self.norm1(x), mask)
        x = x + self.dropout(attn_out)

        # 残差 + 层归一化 + 前馈网络
        ff_out = self.ffn(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x

Decoder Block

在 Encoder 基础上增加:

  • Encoder-Decoder Attention(交叉注意力)
  • Masked Self-Attention(防止看到未来信息)

关键特性

  1. 并行计算:所有位置同时处理,训练高效
  2. 长程依赖:自注意力可直接建模任意距离的依赖
  3. 可解释性:注意力权重可视化显示模型关注点
  4. 可扩展性:更大模型 = 更好性能(数据充足前提下)

变体与演进

  • BERT:仅 Encoder,双向预训练
  • GPT:仅 Decoder,单向自回归
  • T5:Encoder-Decoder,文本到文本统一框架
  • Vision Transformer (ViT):图像分类应用 Transformer
  • Whisper:音频转文本

性能与复杂度

  • 时间复杂度:O(n²·d),n 为序列长度,d 为特征维度
  • 空间复杂度:O(n²) 存储注意力矩阵
  • 优化技术
    • Sparse Attention(稀疏注意力)
    • Linear Attention(线性注意力)
    • FlashAttention(内存优化)

实践建议

  • 序列长度较长时考虑稀疏或线性注意力
  • 残差连接使用 Pre-LN 或 Post-LN(Pre-LN 更稳定)
  • 初始化使用 Xavier 或正态分布
  • 学习率预热(warmup)对训练稳定很重要

Transformer 是当今 AI 的基石,深入理解它是成为 AI 专家的必经之路!