AI

Transformer 训练和推理

动手深度学习

Posted by LXG on April 27, 2026

模型训练



# Transformer主干超参数:隐藏维、层数、dropout,以及数据批处理设置
num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
# 优化与训练轮次配置;设备优先使用GPU
lr, num_epochs, device = 0.005, 10, d2l.try_gpu()
# 前馈网络与多头注意力配置
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
# Q/K/V投影维度(当前实现保持一致)
key_size, query_size, value_size = 32, 32, 32
# LayerNorm归一化维度
norm_shape = [32]

# 加载NMT训练数据与词表
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)

# 根据TRAIN_LOG_CFG决定是否启用“终端+文件”双写日志
enable_file_logging_if_needed()

# 构建编码器与解码器
encoder = TransformerEncoder(
    len(src_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
decoder = TransformerDecoder(
    len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    num_layers, dropout)
# 组装Seq2Seq整体模型(Encoder-Decoder)
net = d2l.EncoderDecoder(encoder, decoder)
# 执行训练(内部含分层训练日志)
train_seq2seq_with_planned_logs(
    net, train_iter, lr, num_epochs, src_vocab, tgt_vocab, device)

实现了一个 带完整可观测日志的 Transformer Seq2Seq(机器翻译模型)

整体训练流程(首轮)


======================== 训练开始(Epoch 1 / Batch 1) ========================

【输入数据】
src (源语言token id)                    tgt (目标语言token id)
(64, 10)                               (64, 10)
  │                                        │
  │                                        └──► 右移 + <bos>
  │                                             (作用:构造Decoder输入,
  │                                              实现Teacher Forcing)
  │
  ▼
======================== Encoder(语义理解) ========================

Embedding
  │
  ├──► (64,10) → (64,10,32)
  │     (作用:把离散token映射到连续向量空间)
  │
  ▼
× √d_model
  │
  ├──► (作用:放大embedding数值,
  │         防止位置编码被“淹没”)
  │
  ▼
+ Positional Encoding
  │
  ├──► (作用:注入位置信息,
  │         否则模型不知道词序)
  │
  ▼
X0(带语义+位置的信息表示)
  │
  ▼

-------------------- Encoder Block 0 --------------------

Multi-Head Self Attention(自注意力)
  │
  ├──► Q = XWq, K = XWk, V = XWv
  │     (作用:把输入投影到“查询/键/值”空间)
  │
  ├──► QK^T / √d
  │     (作用:计算“相关性/相似度”)
  │
  ├──► softmax
  │     (作用:变成概率分布=注意力权重)
  │
  ├──► attention_weights @ V
  │     (作用:加权汇总全局信息)
  │
  ▼
attn_out(融合上下文后的表示)

AddNorm1(残差 + 归一化)
  │
  ├──► X + attn_out
  │     (作用:保留原信息,避免信息丢失)
  │
  ├──► LayerNorm
  │     (作用:稳定训练,防止梯度爆炸)
  │
  ▼
Y

FFN(前馈网络)
  │
  ├──► dense1 → ReLU → dense2
  │     (作用:引入非线性,提高表达能力)
  │
  ▼
ffn_out

AddNorm2
  │
  ├──► Y + ffn_out
  │     (作用:再次融合 & 稳定)
  │
  ▼
X1(Block0输出)

-------------------- Encoder Block 1 --------------------

(作用完全相同,但语义更抽象、更高级)

X1 → Self-Attn → AddNorm → FFN → AddNorm → X2

  │
  ▼
X_enc(Encoder最终输出)
(作用:整句的“上下文语义表示”)

==================================================================

======================== Decoder(逐词生成) ========================

Decoder 输入(右移后的tgt)
  │
  ▼
Embedding → √d → +位置编码
  │
  ▼
D0

-------------------- Decoder Block 0 --------------------

Masked Self Attention(带mask的自注意力)
  │
  ├──► 只能看过去(下三角mask)
  │     (作用:防止“偷看未来词”)
  │
  ▼
AddNorm1
  │
  ▼

Cross Attention(编码器-解码器注意力)🔥关键
  │
  ├──► Query = Decoder当前状态
  ├──► Key/Value = Encoder输出
  │
  ├──► (作用:在“源句子”中找当前词对应的信息)
  │
  ▼
AddNorm2

FFN
  │
  ├──► (作用:非线性变换,增强表达)
  │
  ▼
AddNorm3
  │
  ▼
D1

-------------------- Decoder Block 1 --------------------

(同上,进一步 refinement)

D1 → MaskedAttn → CrossAttn → FFN → D2

  │
  ▼
======================== 输出层 ========================

Linear(全连接)
  │
  ├──► (64,10,32) → (64,10,vocab_size)
  │
  ├──► (作用:映射到词表空间,得到每个词的概率)
  │
  ▼
logits

Softmax + Loss
  │
  ├──► MaskedSoftmaxCELoss
  │
  ├──► (作用:计算预测与真实token的差异)
  │
  ▼
loss

==================================================================

======================== 反向传播 ========================

loss
  │
  ├──► 反向传播梯度
  │
  ├──► 更新:
  │     - embedding
  │     - attention权重
  │     - FFN参数
  │
  ▼
模型参数更新(学习完成一步)

==================================================================

热力图

demo_03_attention

训练开始时的热力图

demo_03_attention_first

训练结束时的热力图

demo_03_attention_last

预测结果已经完全正确


pred: j'ai froid . <eos>
target: j'ai froid . <eos>
loss: 0.0318

👉 说明:

✔ translation 已收敛 ✔ sequence-level correct ✔ EOS 处理稳定

attention 变“更尖锐”


“froid”这一行:
几乎完全指向 “am”
权重非常集中(接近 one-hot)

👉 说明:

模型已经从“分布式对齐” → “确定性对齐”

attention ≠ reasoning


❗ attention ≠ reasoning

你现在看到:

attention 很漂亮
对齐很稳定
loss 很低

但本质是:

模型在做 lookup table,而不是理解语言

其他曲线

demo_03_train_diagnostics

左上图:Train Loss per Epoch

  • 纵轴:loss/token(每个 token 的平均损失)
  • 含义:整体训练效果的第一指标。如果曲线持续下降,说明模型正在学习;如果进入平台期或反弹,说明可能过拟合或学习卡住了。
  • 理想状态:单调下降(越陡越好),最后趋于平缓且稳定。

右上图:Probe Attention Entropy

  • 纵轴:entropy(0~1)
  • 含义:固定”I am cold”样本的 cross-attention 熵。熵越小,说明模型对这个样本的注意力越集中;熵越大,说明注意力分散。
  • 理想状态:从高向低递减。早期模型不知道看哪儿,注意力很平均(熵大);训练后逐渐学会集中关注(熵小)。
  • 提示:与 loss 曲线要相关。如果 loss 下降但熵没下降,说明模型是靠”死记”一些高频词;如果熵下降但 loss 反弹,说明”过度集中”导致泛化变差。

左下图:Probe Token Accuracy

  • 纵轴:accuracy(0~1)
  • 含义:固定样本在每个 epoch 的逐 token 命中率(预测=真实标签 的比例)。
  • 理想状态:从 0 上升到接近 1。这是最直观的翻译质量指标。
  • 提示:这条线通常在整个图中涨幅最明显。如果它一直在低位,说明模型可能无法很好地学会这个具体翻译对。

右下图:Probe Edit Distance

  • 纵轴:distance(编辑距离,取值为整数)
  • 含义:预测序列与真实序列的 Levenshtein 距离(最少需要多少次插入/删除/替换才能把预测改成真实)。
  • 理想状态:从高向低递减,最后降到 0(预测 = 真实)。
  • 提示:这比逐 token 准确率更宽松:允许位置偏移、少个几个词等,但整体衡量句子级别差距。

模型保存

文件/内容 是否必须 文件形式示例 作用 推理阶段用在哪 备注
模型权重(state_dict) ⭐⭐⭐⭐⭐ 必须 model_weights.pth 存储所有层参数 load_state_dict() 不能缺
模型结构配置(hparams) ⭐⭐⭐⭐⭐ 必须 dict / json 重建模型结构 初始化 Encoder/Decoder 层数/hidden/head
源语言词表(src_vocab) ⭐⭐⭐⭐⭐ 必须 pickle / torch 文本 → token id 输入编码 不可缺
目标语言词表(tgt_vocab) ⭐⭐⭐⭐⭐ 必须 pickle / torch token id → 文本 输出解码 不可缺
tokenizer规则 ⭐⭐⭐⭐ 建议 json 分词规则 文本预处理 你现在是简单split
特殊token定义 ⭐⭐⭐⭐ 建议 json <bos>/<eos>/<pad> 编码/解码控制 很重要
完整checkpoint ⭐⭐⭐⭐⭐ 推荐 xxx_inference_ckpt.pth 一次性打包所有内容 一步加载 你已经实现 👍
ONNX模型 ⭐⭐⭐(部署) model.onnx 跨平台推理 嵌入式/加速 RK3568用
推理脚本 ⭐⭐⭐⭐ 建议 infer.py 串起整个流程 实际调用 工程必须

模型推理


import argparse
from pathlib import Path

import torch
from d2l import torch as d2l

from demo_04 import TransformerEncoder, TransformerDecoder


def resolve_ckpt_path(user_ckpt_arg):
    """解析checkpoint路径:优先用户指定,其次脚本目录与当前目录兜底。

    路径优先级:
    1) --ckpt 显式传入的路径(最优先)
    2) 推理脚本同目录下的 demo_04_inference_ckpt.pth
    3) 当前工作目录下的 demo_04_inference_ckpt.pth
    """
    if user_ckpt_arg:
        return Path(user_ckpt_arg)

    candidates = [
        Path(__file__).with_name("demo_04_inference_ckpt.pth"),
        Path.cwd() / "demo_04_inference_ckpt.pth",
    ]
    for p in candidates:
        if p.exists():
            return p
    return candidates[0]


def build_model_from_hparams(src_vocab, tgt_vocab, hparams):
    """按checkpoint中的超参数重建网络结构。

    hparams 字段逐项说明(与训练侧保持一致):
    - key_size: 注意力中 Key 的投影维度
    - query_size: 注意力中 Query 的投影维度
    - value_size: 注意力中 Value 的投影维度
    - num_hiddens: Transformer隐藏维度(embedding和各层主通道维度)
    - norm_shape: LayerNorm 的归一化维度,通常是 [num_hiddens]
    - ffn_num_input: PositionWiseFFN 输入维度(一般与 num_hiddens 一致)
    - ffn_num_hiddens: PositionWiseFFN 中间隐藏层维度
    - num_heads: 多头注意力头数
    - num_layers: 编码器/解码器堆叠层数
    - dropout: dropout 概率(推理时虽然eval会关闭dropout,但结构需一致)
    """
    # 1) 构建编码器:超参数必须与训练时一致,否则权重shape无法匹配
    encoder = TransformerEncoder(
        len(src_vocab),
        hparams["key_size"],
        hparams["query_size"],
        hparams["value_size"],
        hparams["num_hiddens"],
        hparams["norm_shape"],
        hparams["ffn_num_input"],
        hparams["ffn_num_hiddens"],
        hparams["num_heads"],
        hparams["num_layers"],
        hparams["dropout"],
    )
    # 2) 构建解码器:同样要求结构完全一致
    decoder = TransformerDecoder(
        len(tgt_vocab),
        hparams["key_size"],
        hparams["query_size"],
        hparams["value_size"],
        hparams["num_hiddens"],
        hparams["norm_shape"],
        hparams["ffn_num_input"],
        hparams["ffn_num_hiddens"],
        hparams["num_heads"],
        hparams["num_layers"],
        hparams["dropout"],
    )
    # 3) 封装为 D2L 的 EncoderDecoder 总体模型
    return d2l.EncoderDecoder(encoder, decoder)


def load_model_and_vocab(ckpt_path, device):
    """加载推理checkpoint,返回模型与词表。

    返回值:
    - net: 已加载权重并切到 eval() 的模型
    - src_vocab: 源语言词表(英文)
    - tgt_vocab: 目标语言词表(法文)
    - hparams: 训练时保存的模型超参数
    """
    # PyTorch 2.6 起 torch.load 默认 weights_only=True。
    # 本地训练生成的checkpoint包含词表对象(如 d2l.Vocab),
    # 推理时需要显式关闭 weights_only 才能反序列化完整内容。
    try:
        ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    except TypeError:
        # 兼容旧版本PyTorch(没有 weights_only 参数)
        ckpt = torch.load(ckpt_path, map_location=device)
    src_vocab = ckpt["src_vocab"]
    tgt_vocab = ckpt["tgt_vocab"]
    hparams = ckpt["hparams"]

    net = build_model_from_hparams(src_vocab, tgt_vocab, hparams)
    net.load_state_dict(ckpt["model_state_dict"])
    net.to(device)
    net.eval()
    return net, src_vocab, tgt_vocab, hparams

    
def run_d2l_demo(ckpt_path, device):
    """按D2L章节示例风格执行多句推理并计算BLEU。

    说明:
    - engs: 待翻译的英文句子(源语言)
    - fras: 对应参考法语(用于计算BLEU,不参与模型推理)
    - d2l.predict_seq2seq:
        输入英文句子后,按自回归方式逐token生成法语结果
    - d2l.bleu(..., k=2):
        计算2-gram BLEU,数值越高表示与参考翻译越接近
    """
    net, src_vocab, tgt_vocab, hparams = load_model_and_vocab(ckpt_path, device)
    engs = ["go .", "i lost .", "he's calm .", "i'm home ."]
    fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."]

    for eng, fra in zip(engs, fras):
        translation, _ = d2l.predict_seq2seq(
            net,
            eng,
            src_vocab,
            tgt_vocab,
            hparams["num_steps"],
            device,
            save_attention_weights=True,
        )
        print(f"{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}")


def print_hparams_guide(hparams):
    """把checkpoint中的超参数逐项打印,便于理解推理模型配置。"""
    desc = {
        "num_hiddens": "隐藏维度/主通道维度(embedding与Transformer内部特征维)",
        "num_layers": "编码器层数与解码器层数(均为该值)",
        "dropout": "dropout概率(训练中生效,推理eval后自动关闭)",
        "batch_size": "训练时batch大小(推理单句时不强依赖)",
        "num_steps": "最大时间步/最大序列长度(推理解码上限)",
        "lr": "训练学习率(仅记录,不影响推理)",
        "num_epochs": "训练轮数(仅记录,不影响推理)",
        "ffn_num_input": "前馈网络输入维度",
        "ffn_num_hiddens": "前馈网络隐藏层维度",
        "num_heads": "多头注意力头数",
        "key_size": "注意力Key投影维度",
        "query_size": "注意力Query投影维度",
        "value_size": "注意力Value投影维度",
        "norm_shape": "LayerNorm归一化维度",
    }
    print("[INFER] ====== checkpoint 超参数说明 ======")
    for k, v in hparams.items():
        meaning = desc.get(k, "(未定义说明)")
        print(f"[INFER] {k} = {v}  # {meaning}")
    print("[INFER] =================================")


def main():
    parser = argparse.ArgumentParser(description="demo_04 推理脚本")
    parser.add_argument(
        "--ckpt",
        type=str,
        default=None,
        help="推理checkpoint路径",
    )
    parser.add_argument(
        "--text",
        type=str,
        default="i am cold .",
        help="待翻译英文句子(建议按训练数据风格:小写+空格分词)",
    )
    parser.add_argument(
        "--demo",
        action="store_true",
        help="使用D2L示例句批量推理并输出BLEU",
    )
    parser.add_argument(
        "--show-hparams",
        action="store_true",
        help="启动时打印checkpoint中的超参数及中文解释",
    )
    args = parser.parse_args()

    ckpt_path = resolve_ckpt_path(args.ckpt)
    if not ckpt_path.exists():
        raise FileNotFoundError(
            f"checkpoint不存在: {ckpt_path};请用 --ckpt 指定,例如 "
            f"/home/lxg/code/AI/code/20260425/demo_04_inference_ckpt.pth"
        )

    device = d2l.try_gpu()
    print(f"[INFER] 使用设备: {device}")
    print(f"[INFER] 加载checkpoint: {ckpt_path}")

    # 可选:先打印超参数说明,帮助理解当前推理模型的配置
    if args.show_hparams:
        _, _, _, hparams = load_model_and_vocab(str(ckpt_path), device)
        print_hparams_guide(hparams)

    run_d2l_demo(str(ckpt_path), device)


if __name__ == "__main__":
    main()

模型推理过程


┌──────────────────────────────────────────────┐
│                输入英文句子                  │
│           "i am cold ."                     │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│           文本预处理(tokenizer)            │
│  lower + split → ["i","am","cold","."]      │
│  + <eos> → ["i","am","cold",".","<eos>"]    │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│         token → id(src_vocab)              │
│  ["i","am","cold","."]                      │
│      ↓                                      │
│  [17, 25, 98, 4, 2]                         │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│              padding + 截断                  │
│  → 固定长度 num_steps (如10)                 │
│  [17,25,98,4,2,0,0,0,0,0]                   │
│  valid_len = 5                              │
└──────────────────────────────────────────────┘
                      │
                      ▼
==================== Encoder ====================

┌──────────────────────────────────────────────┐
│           Embedding(词嵌入)                │
│  (1,10) → (1,10,32)                         │
│  每个token变成向量                           │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│         加位置编码 PositionalEncoding        │
│  加入位置信息(顺序感)                      │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│      Encoder Block × N(比如2层)            │
│                                              │
│  每一层:                                    │
│   1️⃣ 自注意力(Self-Attention)              │
│      👉 每个词看所有词                        │
│   2️⃣ Add & Norm                             │
│   3️⃣ FFN(逐位置MLP)                       │
│   4️⃣ Add & Norm                             │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│            Encoder输出 enc_outputs           │
│         shape: (1,10,32)                     │
└──────────────────────────────────────────────┘

==================== Decoder ====================

初始输入:
┌──────────────────────────────────────────────┐
│  dec_input = ["<bos>"]                      │
└──────────────────────────────────────────────┘
                      │
                      ▼

🔁 逐token生成(循环 num_steps 次)

┌──────────────────────────────────────────────┐
│           Embedding + 位置编码               │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│       Decoder Block × N                      │
│                                              │
│  每一层:                                    │
│                                              │
│  1️⃣ Masked Self-Attention                   │
│     👉 只能看“已经生成的词”                   │
│                                              │
│  2️⃣ Cross Attention(重点!)               │
│     👉 decoder 看 encoder 输出               │
│     👉 学习翻译对齐关系                      │
│                                              │
│  3️⃣ FFN                                     │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│           线性层(投影到词表)               │
│      (1, t, 32) → (1, t, vocab_size)        │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│            Softmax + argmax                  │
│         选出概率最大token                    │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│        输出一个token(例如 "j'")           │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│   拼接到输入 → ["<bos>", "j'"]              │
│   继续下一轮生成                             │
└──────────────────────────────────────────────┘

🔁 重复直到:

┌──────────────────────────────────────────────┐
│         遇到 <eos> 或 达到最大长度           │
└──────────────────────────────────────────────┘

==================== 输出 ====================

┌──────────────────────────────────────────────┐
│      token id → 文本(tgt_vocab)           │
│  [5, 88, 23, 2] → "j'ai froid ."            │
└──────────────────────────────────────────────┘
                      │
                      ▼
┌──────────────────────────────────────────────┐
│                最终翻译输出                  │
│              "j'ai froid ."                 │
└──────────────────────────────────────────────┘