Bahdanau 注意力
原始 Seq2Seq 的问题
在没有注意力的时候:
输入一句话 → 编码器 → 一个固定向量 → 解码器 → 输出一句话
👉 问题:所有信息被压成一个向量(信息瓶颈)
🧠 d2l里的关键一句话总结:
“不是所有输入词对当前输出都有用”
Bahdanau注意力在做什么?
👉 核心思想:每生成一个词,都重新去“看一遍输入句子”
模型训练代码
# -*- coding: utf-8 -*-
import torch
import logging
import os
from torch import nn
from d2l import torch as d2l
# 配置基础日志。若外部项目已经配置日志,这里不会重复覆盖。
if not logging.getLogger().handlers:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
def setup_writable_d2l_data_dir():
"""将 d2l 默认数据目录重定向到当前脚本目录下的可写路径。"""
data_dir = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(data_dir, exist_ok=True)
original_download = d2l.download
def download_to_local(url, folder=data_dir, sha1_hash=None):
# 统一把下载目录固定到 data_dir,避免写入只读目录 ../data。
return original_download(url, folder=folder, sha1_hash=sha1_hash)
d2l.download = download_to_local
logger.info("d2l 数据目录已重定向到: %s", data_dir)
#@save
class AttentionDecoder(d2l.Decoder):
"""带有注意力机制解码器的基本接口"""
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
@property
def attention_weights(self):
raise NotImplementedError
class Seq2SeqAttentionDecoder(AttentionDecoder):
"""带加性注意力(Additive Attention)的Seq2Seq解码器。
主要流程:
1. 先将输入token id做embedding;
2. 每个时间步使用上一时刻隐藏状态作为query,和编码器输出做注意力计算;
3. 将注意力上下文向量与当前输入embedding拼接后送入GRU;
4. 通过全连接层映射到词表维度,得到每一步的预测分布。
"""
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
# 兼容不同版本 d2l 的 AdditiveAttention 构造函数:
# 新版常见签名: (num_hiddens, dropout)
# 旧版常见签名: (key_size, query_size, num_hiddens, dropout)
try:
self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
except TypeError:
self.attention = d2l.AdditiveAttention(
num_hiddens, num_hiddens, num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(
embed_size + num_hiddens, num_hiddens, num_layers,
dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
# 是否打印 init_state 的详细日志(训练时建议关闭,避免每个 batch 都打印)
self.enable_state_logging = False
# 是否打印解码器每一步的详细 shape 日志(训练时建议关闭,避免日志过多)
self.enable_step_logging = False
# 是否打印“首轮首个 batch”的向量变化(默认关闭,需要时手动打开)
self.enable_first_epoch_vector_logging = False
self._has_logged_first_train_batch = False
logger.info(
"初始化解码器: vocab_size=%d, embed_size=%d, num_hiddens=%d, num_layers=%d, dropout=%s",
vocab_size, embed_size, num_hiddens, num_layers, dropout
)
def init_state(self, enc_outputs, enc_valid_lens, *_args):
# 编码器会返回两部分:
# 1) 每个时间步的输出 outputs,原始形状是 (num_steps, batch_size, num_hiddens)
# 2) 最后一层(及各层)的隐藏状态 hidden_state,形状是 (num_layers, batch_size, num_hiddens)
# 保留可变参数是为了兼容上层接口,此处不使用。
_ = _args
outputs, hidden_state = enc_outputs
if self.enable_state_logging:
logger.info(
"init_state输入: enc_outputs=%s, hidden_state=%s, enc_valid_lens=%s",
tuple(outputs.shape), tuple(hidden_state.shape), enc_valid_lens
)
# 后续注意力计算希望编码器输出按 batch 放在第一维,
# 所以把 outputs 从 (num_steps, batch_size, num_hiddens)
# 调整成 (batch_size, num_steps, num_hiddens)。
state = (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
if self.enable_state_logging:
logger.info(
"init_state输出: permuted_enc_outputs=%s, hidden_state=%s",
tuple(state[0].shape), tuple(state[1].shape)
)
return state
def forward(self, X, state):
"""单次前向传播。
参数:
X: 解码器输入token id,形状(batch_size, num_steps)
state: (enc_outputs, hidden_state, enc_valid_lens)
返回:
outputs: 词表预测,形状(batch_size, num_steps, vocab_size)
new_state: 更新后的解码器状态
"""
# state 中包含:
# enc_outputs: 编码器所有时间步输出,形状 (batch_size, num_steps, num_hiddens)
# hidden_state: 当前解码器隐藏状态,形状 (num_layers, batch_size, num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
if self.enable_step_logging:
logger.info(
"forward开始: X=%s, enc_outputs=%s, hidden_state=%s",
tuple(X.shape), tuple(enc_outputs.shape), tuple(hidden_state.shape)
)
# 先把 token id 映射成词向量,再转成“按时间步遍历”更方便的形状:
# (batch_size, num_steps) -> embedding -> (batch_size, num_steps, embed_size)
# -> permute -> (num_steps, batch_size, embed_size)
X = self.embedding(X).permute(1, 0, 2)
if self.enable_step_logging:
logger.info("embedding后X=%s", tuple(X.shape))
# 只在训练开始阶段打印一次“向量变化”,帮助理解训练中的数值流动。
trace_vectors = (
self.training
and self.enable_first_epoch_vector_logging
and not self._has_logged_first_train_batch
)
if trace_vectors:
logger.info("首轮首batch向量追踪开始(仅打印一次)")
outputs, self._attention_weights = [], []
for step, x in enumerate(X):
token_embed = x
# 用“当前时刻的解码器隐藏状态”作为 query 去做注意力检索。
# 这里只取最后一层隐藏状态 hidden_state[-1],并扩维成 (batch_size, 1, num_hiddens)。
query = torch.unsqueeze(hidden_state[-1], dim=1)
prev_h = hidden_state[-1]
# context 是从编码器输出中“加权汇总”得到的上下文向量,
# 形状为 (batch_size, 1, num_hiddens)。
context = self.attention(
query, enc_outputs, enc_outputs, enc_valid_lens)
# 把上下文向量和当前输入词向量拼接,作为 GRU 的输入特征。
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# GRU 期望输入是 (num_steps, batch_size, input_size),
# 这里单步输入,所以 num_steps=1,形状变为 (1, batch_size, embed_size + num_hiddens)。
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
if self.enable_step_logging:
logger.info(
"step=%d: query=%s, context=%s, rnn_out=%s, hidden_state=%s",
step, tuple(query.shape), tuple(context.shape),
tuple(out.shape), tuple(hidden_state.shape)
)
if trace_vectors:
embed_sample = token_embed[0].detach()
context_sample = context[0, 0].detach()
prev_h_sample = prev_h[0].detach()
new_h_sample = hidden_state[-1, 0].detach()
logger.info(
(
"vector_step=%d | "
"embed_norm=%.4f, context_norm=%.4f, "
"h_prev_norm=%.4f, h_new_norm=%.4f, "
"delta_h_norm=%.4f | "
"embed_head=%s | context_head=%s | h_new_head=%s"
),
step,
embed_sample.norm().item(),
context_sample.norm().item(),
prev_h_sample.norm().item(),
new_h_sample.norm().item(),
(new_h_sample - prev_h_sample).norm().item(),
embed_sample[:4].cpu().tolist(),
context_sample[:4].cpu().tolist(),
new_h_sample[:4].cpu().tolist(),
)
# 把每个时间步的 GRU 输出拼起来,再通过线性层映射到词表大小,
# 得到每个时间步对每个词的打分 logits,形状是 (num_steps, batch_size, vocab_size)。
outputs = self.dense(torch.cat(outputs, dim=0))
final_outputs = outputs.permute(1, 0, 2)
if self.enable_step_logging:
logger.info(
"forward结束: logits=%s, final_outputs=%s, attention_steps=%d",
tuple(outputs.shape), tuple(final_outputs.shape),
len(self._attention_weights)
)
if trace_vectors:
self._has_logged_first_train_batch = True
logger.info("首轮首batch向量追踪结束")
return final_outputs, [enc_outputs, hidden_state, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
class TrainSeq2SeqNetAdapter(nn.Module):
"""兼容不同 d2l 版本返回值差异的训练适配器。
某些版本的 EncoderDecoder.forward 只返回 Y_hat,
但 train_seq2seq 期望拿到 (Y_hat, state)。
这里统一转换为二元组,避免训练阶段解包报错。
"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, enc_X, dec_X, *args):
out = self.model(enc_X, dec_X, *args)
if isinstance(out, tuple):
return out
return out, None
# ===== 训练配置 =====
# embed_size: 词向量维度;num_hiddens: 隐藏层维度;num_layers: RNN层数;dropout: 随机失活比例
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
# batch_size: 每次喂给模型多少样本;num_steps: 每个样本截断/填充后的时间步长度
batch_size, num_steps = 64, 10
# lr: 学习率;num_epochs: 训练轮数;device: 自动选择 GPU(可用时) 否则 CPU
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
logger.info(
"训练配置: embed_size=%d, num_hiddens=%d, num_layers=%d, dropout=%.3f, "
"batch_size=%d, num_steps=%d, lr=%.4f, num_epochs=%d, device=%s",
embed_size, num_hiddens, num_layers, dropout,
batch_size, num_steps, lr, num_epochs, device
)
setup_writable_d2l_data_dir()
# 读取机器翻译数据集,返回:
# train_iter: 训练数据迭代器(按 batch 提供 src/tgt)
# src_vocab: 源语言词表;tgt_vocab: 目标语言词表
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
logger.info("数据集加载完成: src_vocab_size=%d, tgt_vocab_size=%d", len(src_vocab), len(tgt_vocab))
sample_batch = next(iter(train_iter))
logger.info(
"首个batch形状: src=%s, src_valid_len=%s, tgt=%s, tgt_valid_len=%s",
tuple(sample_batch[0].shape), tuple(sample_batch[1].shape),
tuple(sample_batch[2].shape), tuple(sample_batch[3].shape)
)
# 编码器负责把源语言序列编码成上下文表示。
encoder = d2l.Seq2SeqEncoder(
len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
# 解码器基于注意力机制,按时间步生成目标语言序列。
decoder = Seq2SeqAttentionDecoder(
len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
# 需要观察 state 初始化细节时可打开(训练时默认关闭,避免日志过于庞大):
# decoder.enable_state_logging = True
# 需要观察单步注意力细节时可打开(训练时默认关闭,避免日志过于庞大):
# decoder.enable_step_logging = True
# 打开“首轮首个 batch”向量变化打印(推荐:理解训练过程时开启)
decoder.enable_first_epoch_vector_logging = True
# 将编码器和解码器封装成完整的 Encoder-Decoder 网络。
net = d2l.EncoderDecoder(encoder, decoder)
train_net = TrainSeq2SeqNetAdapter(net)
# 某些模块(如 LazyModule)在首个 forward 前参数尚未初始化,直接 numel() 会报错。
# 这里做“安全统计”:只统计已初始化参数,并把未初始化参数个数单独打日志。
num_params = 0
uninitialized_count = 0
for p in train_net.parameters():
if not p.requires_grad:
continue
try:
num_params += p.numel()
except ValueError:
uninitialized_count += 1
logger.info(
"模型构建完成: initialized_trainable_params=%d, uninitialized_params=%d",
num_params, uninitialized_count
)
# 启动训练:内部会完成前向、损失计算、反向传播和参数更新。
logger.info("开始训练...")
d2l.train_seq2seq(train_net, train_iter, lr, num_epochs, tgt_vocab, device)
d2l.plt.show()
logger.info("训练结束。")
时间步理解
时间步 t=0 t=1 t=2 t=3
--------------------------------------------------
输入 <bos> 我 爱 你
│ │ │ │
▼ ▼ ▼ ▼
emb emb emb emb
│ │ │ │
├────┐ ├────┐ ├────┐ ├────┐
▼ │ ▼ │ ▼ │ ▼ │
Attention Attention Attention Attention
│ │ │ │
context₀ context₁ context₂ context₃
│ │ │ │
└──concat──┴──concat─┴──concat─┴──concat
│
GRU(共享参数)
│
hidden₁ → hidden₂ → hidden₃ → hidden₄
│
dense
│
输出 我 爱 你 <eos>
对比没有注意力的seq2seq
| 维度 | 无注意力 Seq2Seq | 带注意力(Bahdanau)Seq2Seq |
|---|---|---|
| 信息来源 | 只用 encoder 最后一个 hidden | 每个时间步动态读取所有 encoder 输出 |
| 上下文(context) | 固定一个向量 | 每一步重新计算 |
| Query | ❌ 没有 | decoder 当前 hidden |
| Key/Value | ❌ 没有 | encoder 全部输出 |
| 是否有对齐能力 | ❌ 没有 | ✅ 有(attention权重) |
| 长句效果 | ❌ 容易崩 | ✅ 明显更好 |
| 计算量 | ✅ 低 | ❌ 略高 |
| 参数量 | 少 | 稍多(attention MLP) |
| 工业使用 | ❌ 基本淘汰 | ⚠️ 过渡方案(之后是 Transformer) |
Attention 的本质,就是让模型从“记忆驱动”变成“检索驱动”
保存训练结束后的模型
# 保存到 checkpoint 的训练/模型配置,后续可用于复现实验或重建模型结构。
config = {
"embed_size": embed_size, # 词向量维度
"num_hiddens": num_hiddens, # 隐藏层维度
"num_layers": num_layers, # RNN层数
"dropout": dropout, # 随机失活比例
"batch_size": batch_size, # 每次喂给模型多少样本
"num_steps": num_steps, # 每个样本截断/填充后的时间步长度
"lr": lr, # 学习率
"num_epochs": num_epochs, # 训练轮数
"device": str(device), # 设备类型(GPU 或 CPU)
}
# 训练完成后保存 checkpoint,便于后续继续训练或直接做推理。
torch.save(
{
# 仅保存参数权重(推荐做法),加载时需先构建同结构模型再 load_state_dict。
"model": net.state_dict(),
# 同时保存源语言和目标语言词表,保证推理时 token/id 映射一致。
"src_vocab": src_vocab,
"tgt_vocab": tgt_vocab,
# 保存关键超参数配置,方便恢复训练环境与模型定义。
"config": config,
},
# checkpoint 文件名;可按需改成带时间戳或轮次的名字。
"checkpoint.pth",
)
logger.info("模型已保存到 checkpoint.pth")
使用保存后的模型推理
# -*- coding: utf-8 -*-
import argparse
import logging
from typing import List
import torch
from torch import nn
from d2l import torch as d2l
if not logging.getLogger().handlers:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)
class AttentionDecoder(d2l.Decoder):
"""带有注意力机制解码器的基础接口。"""
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
@property
def attention_weights(self):
raise NotImplementedError
class Seq2SeqAttentionDecoder(AttentionDecoder):
"""与训练脚本一致的注意力解码器定义,用于加载 checkpoint。"""
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
# 兼容不同版本 d2l 的 AdditiveAttention 构造参数。
try:
self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
except TypeError:
self.attention = d2l.AdditiveAttention(
num_hiddens, num_hiddens, num_hiddens, dropout
)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(
embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout
)
self.dense = nn.Linear(num_hiddens, vocab_size)
self._attention_weights = []
def init_state(self, enc_outputs, enc_valid_lens, *_args):
_ = _args
outputs, hidden_state = enc_outputs
return outputs.permute(1, 0, 2), hidden_state, enc_valid_lens
def forward(self, X, state):
enc_outputs, hidden_state, enc_valid_lens = state
X = self.embedding(X).permute(1, 0, 2)
outputs, self._attention_weights = [], []
for x in X:
query = torch.unsqueeze(hidden_state[-1], dim=1)
context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
def load_model_from_checkpoint(ckpt_path: str, device: torch.device):
"""从 checkpoint 加载词表、配置和模型参数。"""
# 兼容 PyTorch 2.6+ 默认 weights_only=True 的行为变化。
# 该 checkpoint 来自本地训练流程,属于可信来源,因此允许完整反序列化。
try:
checkpoint = torch.load(ckpt_path, map_location=device)
except Exception:
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
config = checkpoint["config"]
src_vocab = checkpoint["src_vocab"]
tgt_vocab = checkpoint["tgt_vocab"]
encoder = d2l.Seq2SeqEncoder(
len(src_vocab),
config["embed_size"],
config["num_hiddens"],
config["num_layers"],
config["dropout"],
)
decoder = Seq2SeqAttentionDecoder(
len(tgt_vocab),
config["embed_size"],
config["num_hiddens"],
config["num_layers"],
config["dropout"],
)
net = d2l.EncoderDecoder(encoder, decoder)
net.load_state_dict(checkpoint["model"])
net.to(device)
net.eval()
logger.info("已加载模型: %s", ckpt_path)
return net, src_vocab, tgt_vocab, config
def run_inference(
ckpt_path: str,
sentences: List[str],
device: torch.device,
num_steps: int = None,
):
net, src_vocab, tgt_vocab, model_config = load_model_from_checkpoint(ckpt_path, device)
infer_num_steps = num_steps if num_steps is not None else model_config["num_steps"]
logger.info("checkpoint配置: num_steps=%s", model_config.get("num_steps"))
logger.info("开始推理: device=%s, num_steps=%d, 样本数=%d", device, infer_num_steps, len(sentences))
for text in sentences:
translation, _ = d2l.predict_seq2seq(
net, text, src_vocab, tgt_vocab, infer_num_steps, device, True
)
print(f"{text} => {translation}")
def parse_args():
parser = argparse.ArgumentParser(description="加载 checkpoint.pth 并执行翻译推理")
parser.add_argument(
"--ckpt",
default="/home/lxg/code/AI/code/20260423/checkpoint.pth",
help="checkpoint 文件路径",
)
parser.add_argument(
"--num_steps",
type=int,
default=None,
help="推理时最大时间步;不传则使用 checkpoint 中保存的配置",
)
parser.add_argument(
"--text",
nargs="*",
default=None,
help="待翻译英文句子,可传多个;不传则使用内置样例",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
device = d2l.try_gpu()
texts = args.text if args.text else ["go .", "i lost .", "he's calm .", "i'm home ."]
run_inference(args.ckpt, texts, device, args.num_steps)
运行结果
2026-04-23 14:53:39,874 | INFO | __main__ | 已加载模型: /home/lxg/code/AI/code/20260423/checkpoint.pth
2026-04-23 14:53:39,874 | INFO | __main__ | checkpoint配置: num_steps=10
2026-04-23 14:53:39,874 | INFO | __main__ | 开始推理: device=cuda:0, num_steps=10, 样本数=4
go . => va !
i lost . => j'ai perdu .
he's calm . => il est riche .
i'm home . => je suis chez moi .
最小推理依赖清单
project/
│
├── checkpoint.pth ✅ 模型权重+词表+配置
├── model.py ✅ 模型结构(Encoder + Decoder)
├── inference.py ✅ 推理脚本(你这份)
│
└── requirements.txt
开源模型开源的是什么
| 模块 | 是否必须开源 | 作用 |
|---|---|---|
| model | ⭐ 必须 | 定义网络结构 |
| weights | ⭐ 必须 | 模型参数 |
| tokenizer | ⭐ 必须 | 文本处理 |
| config | ⭐ 必须 | 重建模型 |
| inference | ⭐ 建议 | 快速使用 |
| training | 🟡 可选 | 复现训练 |
| evaluation | 🟡 可选 | 测试模型 |
| data | 🔴 通常不 | 数据太大/版权 |
| export | 🟡 可选 | 部署 |
| README | ⭐ 必须 | 使用说明 |
| requirements | ⭐ 必须 | 环境 |
多头注意力
多头注意力 = 同一段话,让多个“不同专家”同时去理解,然后把他们的理解合起来
每个专家都在看同一句话,但关注点不同:
| 专家 | 关注点 |
|---|---|
| 专家1 | 谁爱谁(语义) |
| 专家2 | 主谓关系(语法) |
| 专家3 | 从句结构(who is singing) |
| 专家4 | 远距离依赖 |
| 专家5 | 局部词关系 |
👉 最后:把所有人的理解拼起来 → 得到更完整的理解
代码
import math
import torch
from torch import nn
from d2l import torch as d2l
#@save
class MultiHeadAttention(nn.Module):
"""多头注意力"""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
# 头数h。多头注意力的核心思想:把同一个表示空间拆成h个子空间并行计算,
# 让不同头关注不同位置/关系,再把结果拼回去。
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
# 下面四个线性层的作用:
# 1) W_q/W_k/W_v:把输入映射到同一隐藏维num_hiddens,方便做点积注意力
# 2) W_o:把多头拼接后的结果再做一次融合映射
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# queries,keys,values的形状:
# (batch_size,查询或者“键-值”对的个数,num_hiddens)
# valid_lens 的形状:
# (batch_size,)或(batch_size,查询的个数)
#
# 下面三行先做线性投影,再把“多头维”拆出来:
# 原始: (B, N, H)
# 拆头后: (B*h, N, H/h)
# 这样可以把每个头当作一个独立样本,直接并行送进同一个点积注意力模块。
# 经过变换后,输出的queries,keys,values 的形状:
# (batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# 在轴0,将第一项(标量或者矢量)复制num_heads次,
# 然后如此复制第二项,然后诸如此类。
# 原因:我们把batch维从B扩成了B*h,每个样本对应h个头,
# 所以mask长度也要复制h份,保证每个头用到同样的有效长度约束。
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# output的形状:(batch_size*num_heads,查询的个数,
# num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# output_concat的形状:(batch_size,查询的个数,num_hiddens)
# 将各头结果从(B*h, N, H/h)还原并拼接到最后一维,得到(B, N, H)。
output_concat = transpose_output(output, self.num_heads)
# 最后再过W_o,让不同头的信息进行线性融合。
return self.W_o(output_concat)
#@save
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 记号说明:
# B = batch_size, N = 序列长度(查询数或键值对数), H = num_hiddens, h = num_heads
# 目标:把(B, N, H)变成(B*h, N, H/h),便于并行计算每个头。
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
# 这一步把最后一维H拆成(h, H/h)。
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
# 交换维度,把“头维”提前,方便下一步与batch合并。
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
# 合并(B, h) -> (B*h),让每个头都像一个独立样本进入attention。
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
# 输入X: (B*h, N, H/h)
# 先还原成(B, h, N, H/h)...
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
# ...再交换成(B, N, h, H/h)...
X = X.permute(0, 2, 1, 3)
# ...最后拼接头维,得到(B, N, H)。
return X.reshape(X.shape[0], X.shape[1], -1)
#@save
def transpose_qkv(X, num_heads):
"""为了多注意力头的并行计算而变换形状"""
# 与上方同名函数一致:将(B, N, H)拆成(B*h, N, H/h)以并行计算每个头。
# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
# num_hiddens/num_heads)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3)
# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
# num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""逆转transpose_qkv函数的操作"""
# 与上方同名函数一致:把(B*h, N, H/h)还原为(B, N, H)。
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# 这里是“自注意力”形式:Q来自X,K/V都来自Y(若X==Y则是标准自注意力)。
# 输出形状应为(B, 查询数, H) => (2, 4, 100)
attention(X, Y, Y, valid_lens).shape
多头注意力参数
| 参数 | 单头 | 多头 |
|---|---|---|
| W_q | d×d | d×d |
| W_k | d×d | d×d |
| W_v | d×d | d×d |
| W_o | ❌ | d×d |
| 总计 | 3个矩阵 | 4个矩阵 |
代码里确实只有一个 MultiHeadAttention 对象,但“多头”是通过张量变形(reshape)实现的,而不是多个对象
理论和实现的差异
| 维度 | 理论多头(概念层) | 实际代码实现(你这段) | 关键意义 |
|---|---|---|---|
| Head数量 | h 个独立 attention | 1 个 attention + reshape | 用“数据维度”模拟多个头 |
| Q/K/V 投影 | 每个 head 一套 (W_q^i, W_k^i, W_v^i) | 一套大矩阵 (W_q, W_k, W_v),再切分 | 参数共享 + 分块 |
| 输入数据 | 每个 head 输入同一 X | 同一 X → 线性变换 → reshape | 同源不同视角 |
| 数据拆分 | 手动分给每个 head | reshape + permute |
自动并行拆分 |
| attention计算 | h 次独立计算(for循环) | 一次计算(batch×head) | 🚀 并行加速 |
| head并行方式 | 逻辑并行(概念上) | 物理并行(矩阵运算/GPU) | 真正利用硬件 |
| 中间shape | h 个:(batch, seq, d/h) | (batch×h, seq, d/h) | 核心技巧 |
| 结果合并 | concat(h₁,…,hₕ) | transpose_output |
恢复原维度 |
| 输出融合 | 可能有融合层 | 必有 (W_o) | 融合多头信息 |
| 参数规模 | 看起来是 h 倍 | 实际 ≈ 1 倍(+W_o) | 高效设计 |
| 计算复杂度 | h × Attention | ≈ 单次大矩阵运算 | 更高吞吐 |
| 内存访问 | 多次调用 | 连续内存块操作 | cache友好 |
| 代码复杂度 | 高(多模块) | 低(统一模块) | 易实现 |
| 可扩展性 | 增head需加模块 | 改 num_heads 即可 | 灵活 |
| GPU利用率 | 低(循环) | 高(并行) | 核心优势 |
自注意力和位置编码
自注意力
自注意力(Self-Attention)= 句子里的每个词,都会去“看”同一句子里的其他词,决定该关注谁
import math
import torch
from torch import nn
from d2l import torch as d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape
这段代码其实就是一个标准“多头自注意力(Multi-Head Self-Attention)”的最小示例。
为什么这是“自注意力”?
关键在这里👇
attention(X, X, X, valid_lens)
❗Q、K、V 全是 X
Q = X
K = X
V = X
👉 含义:
每个词都在“看自己这句话里的所有词”
这段代码在干嘛
🧾 输入
X = torch.ones((2, 4, 100))
👉 表示:
batch = 2(两句话)
每句 4 个词
每个词 100维
🧠 自注意力在做什么?
对每一句话里的每个词:
“我应该关注句子里的谁?”
👉 举个直觉例子:
第1个词:
看第1个词 ✔
看第2个词 ✔
看第3个词 ✔
看第4个词 ✔
👉 然后加权融合
shape流动
Step1️⃣ 输入
(2, 4, 100)
Step2️⃣ 多头拆分(内部发生)
num_heads = 5
👉 变成:
(2 × 5, 4, 20) = (10, 4, 20)
👉 含义:
5个头 → 变成10个“伪batch”
Step3️⃣ 每个 head 做 attention
(10, 4, 20) → (10, 4, 20)
Step4️⃣ 拼回来
(2, 4, 100)
自注意力 (Self-Attention)
- 代码特征:attention(X, X, X)
- 信息源:自己看自己。
- 目的:为了让词与词之间发生关系。比如处理“苹果”时,看看句子后面有没有出现“好吃”或“手机”。
- 场景:Transformer 的编码器(Encoder)内部,或者解码器(Decoder)的开头。
自注意力学到的参数是什么含义
| 参数 | 学到的“现实意义” |
|---|---|
| W_q | 我想找什么信息 |
| W_k | 我提供什么信息 |
| W_v | 信息内容本身 |
| attention权重 | 谁应该关注谁 |
| W_o | 如何整合多种理解 |
比较卷积神经网络、循环神经网络和自注意力
用一句话理解三者差异
🟩 CNN : “我只看你附近发生了什么”
🟦 RNN : “我按时间顺序,一步一步记住过去”
🟨 自注意力 : “我直接看全局,决定谁重要”
案例
示例句子: The animal didn’t cross the street because it was tired
CNN(局部窗口滑动)
[The animal] → 提取局部特征
↓
[animal didn't]
↓
[didn't cross]
↓
[cross the]
↓
[the street]
↓
[street because]
↓
[because it]
↓
[it was]
↓
[was tired]
🧠 信息流特点 局部 → 局部 → 局部 → … → 全局(需要多层)
👉 ❗信息是“慢慢扩散”的 👉 ❗远距离关系很难直接捕捉
RNN(按时间顺序流动)
The → animal → didn't → cross → the → street → because → it → was → tired
↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
h1 → h2 → h3 → h4 → h5 → h6 → h7 → h8 → h9
🧠 信息流特点 前一个状态 → 传给后一个
👉 ❗信息是“链式传递” 👉 ❗距离越远,信息越弱(梯度消失)
自注意力(全连接关系)
┌───────────────┐
The ──────▶│ │◀────── animal
│ │
animal ────▶│ │◀────── didn't
│ Attention │
didn't ────▶│ Matrix │◀────── cross
│ │
cross ─────▶│ │◀────── street
│ │
street ────▶│ │◀────── because
│ │
because ───▶│ │◀────── it
│ │
it ────────▶│ │◀────── was
│ │
was ───────▶│ │◀────── tired
└───────────────┘
或者更直观一点👇
每个词 ↔ 所有词(全连接)
it → The
it → animal ✅(重点)
it → didn't
it → cross
it → street
it → because
it → was
it → tired
🧠 信息流特点 任意词 ↔ 任意词(一步完成)
👉 ✅ 全局感知 👉 ✅ 无距离限制 👉 ✅ 并行计算
理解
| 模型类型 | 核心比喻 | 它是怎么看世界的? | 优点 | 局限 |
|---|---|---|---|---|
| 🟩 卷积神经网络(CNN) | 扫描仪 | 拿一个小窗口在局部滑动,只关注“附近的信息”,逐层扩大视野 | ✅ 高效 ✅ 擅长局部模式(边缘、纹理、n-gram) ✅ 并行性强 |
❌ 看不远(需要很多层) ❌ 长距离关系弱 |
| 🟦 循环神经网络(RNN) | 传声筒 | 按顺序读,每次记一点,把“过去的信息”传给未来 | ✅ 顺序建模强 ✅ 适合时间序列 |
❌ 不能并行 ❌ 长距离信息容易丢(梯度消失) |
| 🟨 自注意力(Self-Attention) | 雷达 | 一次扫描全局,直接计算“谁和谁相关” | ✅ 全局建模 ✅ 长距离依赖强 ✅ 并行计算 |
❌ 计算量大(O(n²)) ❌ 长序列成本高 |
自注意力机制通过一种“空间换时间”的策略:它牺牲了更多的计算内存(用来存储所有词两两相关的矩阵),换取了极强的并行能力和近乎完美的长程记忆。这正是大规模数据训练(如 GPT 系列)所最需要的特质。
当上下文长度 $n$ 增加时,自注意力的算力和内存开销是按平方比例($O(n^2)$)增长的。这也就是为什么很多 AI 模型(如早期的 GPT)会有“上下文窗口限制”(比如 2k、4k 或 8k 个 token)的根本原因。
为什么开销会呈“平方”增长?
自注意力的核心是计算注意力权重矩阵(Attention Matrix)。
- 如果你有 10 个词,你需要算一个 $10 \times 10 = 100$ 的关系矩阵。
- 如果你有 10,000 个词,这个矩阵就会变成 $10,000 \times 10,000 = 1 亿$ 个元素。
- 后果:即便是一个不算太长的文档,这个巨大的矩阵也会迅速吃光显卡的显存(VRAM)。
算力开销(算不动)
- 为了填满这个 $n \times n$ 的矩阵,模型必须让每一个词去和所有词做一次点积运算。
- 后果:序列长度翻倍,计算量会变成原来的 4 倍。这会导致推理(生成回答)的速度变慢,训练时间也会大幅拉长。
训练 vs 推理:谁更压力大?
训练(Training):由于训练时需要并行处理整个序列,并且要保存所有中间状态用于反向传播,显存压力极大。为了训练超长文本,通常需要动用成百上千块高性能显卡(如 H100)。推理(Inference):虽然推理时可以利用缓存(KV Cache)来避免重复计算之前的词,但处理超长上下文时,KV 缓存本身也会占用巨大的显存。如果你发现模型回复越来越慢,或者直接报错“Out of Memory”,通常就是因为上下文太长,显存撑不住了。
大模型训练之前需要先定义上下文大小
大模型在设计阶段就“固定了最大上下文长度(context window)”,训练和推理都必须遵守这个上限。