AI

注意力机制 Transformer

动手深度学习

Posted by LXG on April 23, 2026

Bahdanau 注意力

原始 Seq2Seq 的问题

在没有注意力的时候:

输入一句话 → 编码器 → 一个固定向量 → 解码器 → 输出一句话

👉 问题:所有信息被压成一个向量(信息瓶颈)

🧠 d2l里的关键一句话总结:

“不是所有输入词对当前输出都有用”

Bahdanau注意力在做什么?

👉 核心思想:每生成一个词,都重新去“看一遍输入句子”

模型训练代码


# -*- coding: utf-8 -*-
import torch
import logging
import os
from torch import nn
from d2l import torch as d2l

# 配置基础日志。若外部项目已经配置日志,这里不会重复覆盖。
if not logging.getLogger().handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    )
logger = logging.getLogger(__name__)


def setup_writable_d2l_data_dir():
    """将 d2l 默认数据目录重定向到当前脚本目录下的可写路径。"""
    data_dir = os.path.join(os.path.dirname(__file__), "data")
    os.makedirs(data_dir, exist_ok=True)

    original_download = d2l.download

    def download_to_local(url, folder=data_dir, sha1_hash=None):
        # 统一把下载目录固定到 data_dir,避免写入只读目录 ../data。
        return original_download(url, folder=folder, sha1_hash=sha1_hash)

    d2l.download = download_to_local
    logger.info("d2l 数据目录已重定向到: %s", data_dir)

#@save
class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基本接口"""
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self):
        raise NotImplementedError
    
class Seq2SeqAttentionDecoder(AttentionDecoder):
    """带加性注意力(Additive Attention)的Seq2Seq解码器。

    主要流程:
    1. 先将输入token id做embedding;
    2. 每个时间步使用上一时刻隐藏状态作为query,和编码器输出做注意力计算;
    3. 将注意力上下文向量与当前输入embedding拼接后送入GRU;
    4. 通过全连接层映射到词表维度,得到每一步的预测分布。
    """
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
                 dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        # 兼容不同版本 d2l 的 AdditiveAttention 构造函数:
        # 新版常见签名: (num_hiddens, dropout)
        # 旧版常见签名: (key_size, query_size, num_hiddens, dropout)
        try:
            self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
        except TypeError:
            self.attention = d2l.AdditiveAttention(
                num_hiddens, num_hiddens, num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers,
            dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        # 是否打印 init_state 的详细日志(训练时建议关闭,避免每个 batch 都打印)
        self.enable_state_logging = False
        # 是否打印解码器每一步的详细 shape 日志(训练时建议关闭,避免日志过多)
        self.enable_step_logging = False
        # 是否打印“首轮首个 batch”的向量变化(默认关闭,需要时手动打开)
        self.enable_first_epoch_vector_logging = False
        self._has_logged_first_train_batch = False
        logger.info(
            "初始化解码器: vocab_size=%d, embed_size=%d, num_hiddens=%d, num_layers=%d, dropout=%s",
            vocab_size, embed_size, num_hiddens, num_layers, dropout
        )

    def init_state(self, enc_outputs, enc_valid_lens, *_args):
        # 编码器会返回两部分:
        # 1) 每个时间步的输出 outputs,原始形状是 (num_steps, batch_size, num_hiddens)
        # 2) 最后一层(及各层)的隐藏状态 hidden_state,形状是 (num_layers, batch_size, num_hiddens)
        # 保留可变参数是为了兼容上层接口,此处不使用。
        _ = _args
        outputs, hidden_state = enc_outputs
        if self.enable_state_logging:
            logger.info(
                "init_state输入: enc_outputs=%s, hidden_state=%s, enc_valid_lens=%s",
                tuple(outputs.shape), tuple(hidden_state.shape), enc_valid_lens
            )
        # 后续注意力计算希望编码器输出按 batch 放在第一维,
        # 所以把 outputs 从 (num_steps, batch_size, num_hiddens)
        # 调整成 (batch_size, num_steps, num_hiddens)。
        state = (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
        if self.enable_state_logging:
            logger.info(
                "init_state输出: permuted_enc_outputs=%s, hidden_state=%s",
                tuple(state[0].shape), tuple(state[1].shape)
            )
        return state

    def forward(self, X, state):
        """单次前向传播。

        参数:
            X: 解码器输入token id,形状(batch_size, num_steps)
            state: (enc_outputs, hidden_state, enc_valid_lens)
        返回:
            outputs: 词表预测,形状(batch_size, num_steps, vocab_size)
            new_state: 更新后的解码器状态
        """
        # state 中包含:
        # enc_outputs: 编码器所有时间步输出,形状 (batch_size, num_steps, num_hiddens)
        # hidden_state: 当前解码器隐藏状态,形状 (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        if self.enable_step_logging:
            logger.info(
                "forward开始: X=%s, enc_outputs=%s, hidden_state=%s",
                tuple(X.shape), tuple(enc_outputs.shape), tuple(hidden_state.shape)
            )
        # 先把 token id 映射成词向量,再转成“按时间步遍历”更方便的形状:
        # (batch_size, num_steps) -> embedding -> (batch_size, num_steps, embed_size)
        # -> permute -> (num_steps, batch_size, embed_size)
        X = self.embedding(X).permute(1, 0, 2)
        if self.enable_step_logging:
            logger.info("embedding后X=%s", tuple(X.shape))

        # 只在训练开始阶段打印一次“向量变化”,帮助理解训练中的数值流动。
        trace_vectors = (
            self.training
            and self.enable_first_epoch_vector_logging
            and not self._has_logged_first_train_batch
        )
        if trace_vectors:
            logger.info("首轮首batch向量追踪开始(仅打印一次)")
        outputs, self._attention_weights = [], []
        for step, x in enumerate(X):
            token_embed = x
            # 用“当前时刻的解码器隐藏状态”作为 query 去做注意力检索。
            # 这里只取最后一层隐藏状态 hidden_state[-1],并扩维成 (batch_size, 1, num_hiddens)。
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            prev_h = hidden_state[-1]
            # context 是从编码器输出中“加权汇总”得到的上下文向量,
            # 形状为 (batch_size, 1, num_hiddens)。
            context = self.attention(
                query, enc_outputs, enc_outputs, enc_valid_lens)
            # 把上下文向量和当前输入词向量拼接,作为 GRU 的输入特征。
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # GRU 期望输入是 (num_steps, batch_size, input_size),
            # 这里单步输入,所以 num_steps=1,形状变为 (1, batch_size, embed_size + num_hiddens)。
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
            if self.enable_step_logging:
                logger.info(
                    "step=%d: query=%s, context=%s, rnn_out=%s, hidden_state=%s",
                    step, tuple(query.shape), tuple(context.shape),
                    tuple(out.shape), tuple(hidden_state.shape)
                )
            if trace_vectors:
                embed_sample = token_embed[0].detach()
                context_sample = context[0, 0].detach()
                prev_h_sample = prev_h[0].detach()
                new_h_sample = hidden_state[-1, 0].detach()
                logger.info(
                    (
                        "vector_step=%d | "
                        "embed_norm=%.4f, context_norm=%.4f, "
                        "h_prev_norm=%.4f, h_new_norm=%.4f, "
                        "delta_h_norm=%.4f | "
                        "embed_head=%s | context_head=%s | h_new_head=%s"
                    ),
                    step,
                    embed_sample.norm().item(),
                    context_sample.norm().item(),
                    prev_h_sample.norm().item(),
                    new_h_sample.norm().item(),
                    (new_h_sample - prev_h_sample).norm().item(),
                    embed_sample[:4].cpu().tolist(),
                    context_sample[:4].cpu().tolist(),
                    new_h_sample[:4].cpu().tolist(),
                )
        # 把每个时间步的 GRU 输出拼起来,再通过线性层映射到词表大小,
        # 得到每个时间步对每个词的打分 logits,形状是 (num_steps, batch_size, vocab_size)。
        outputs = self.dense(torch.cat(outputs, dim=0))
        final_outputs = outputs.permute(1, 0, 2)
        if self.enable_step_logging:
            logger.info(
                "forward结束: logits=%s, final_outputs=%s, attention_steps=%d",
                tuple(outputs.shape), tuple(final_outputs.shape),
                len(self._attention_weights)
            )
        if trace_vectors:
            self._has_logged_first_train_batch = True
            logger.info("首轮首batch向量追踪结束")
        return final_outputs, [enc_outputs, hidden_state, enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights


class TrainSeq2SeqNetAdapter(nn.Module):
    """兼容不同 d2l 版本返回值差异的训练适配器。

    某些版本的 EncoderDecoder.forward 只返回 Y_hat,
    但 train_seq2seq 期望拿到 (Y_hat, state)。
    这里统一转换为二元组,避免训练阶段解包报错。
    """

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, enc_X, dec_X, *args):
        out = self.model(enc_X, dec_X, *args)
        if isinstance(out, tuple):
            return out
        return out, None
    
# ===== 训练配置 =====
# embed_size: 词向量维度;num_hiddens: 隐藏层维度;num_layers: RNN层数;dropout: 随机失活比例
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
# batch_size: 每次喂给模型多少样本;num_steps: 每个样本截断/填充后的时间步长度
batch_size, num_steps = 64, 10
# lr: 学习率;num_epochs: 训练轮数;device: 自动选择 GPU(可用时) 否则 CPU
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
logger.info(
    "训练配置: embed_size=%d, num_hiddens=%d, num_layers=%d, dropout=%.3f, "
    "batch_size=%d, num_steps=%d, lr=%.4f, num_epochs=%d, device=%s",
    embed_size, num_hiddens, num_layers, dropout,
    batch_size, num_steps, lr, num_epochs, device
)
setup_writable_d2l_data_dir()

# 读取机器翻译数据集,返回:
# train_iter: 训练数据迭代器(按 batch 提供 src/tgt)
# src_vocab: 源语言词表;tgt_vocab: 目标语言词表
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
logger.info("数据集加载完成: src_vocab_size=%d, tgt_vocab_size=%d", len(src_vocab), len(tgt_vocab))
sample_batch = next(iter(train_iter))
logger.info(
    "首个batch形状: src=%s, src_valid_len=%s, tgt=%s, tgt_valid_len=%s",
    tuple(sample_batch[0].shape), tuple(sample_batch[1].shape),
    tuple(sample_batch[2].shape), tuple(sample_batch[3].shape)
)

# 编码器负责把源语言序列编码成上下文表示。
encoder = d2l.Seq2SeqEncoder(
    len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
# 解码器基于注意力机制,按时间步生成目标语言序列。
decoder = Seq2SeqAttentionDecoder(
    len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
# 需要观察 state 初始化细节时可打开(训练时默认关闭,避免日志过于庞大):
# decoder.enable_state_logging = True
# 需要观察单步注意力细节时可打开(训练时默认关闭,避免日志过于庞大):
# decoder.enable_step_logging = True
# 打开“首轮首个 batch”向量变化打印(推荐:理解训练过程时开启)
decoder.enable_first_epoch_vector_logging = True

# 将编码器和解码器封装成完整的 Encoder-Decoder 网络。
net = d2l.EncoderDecoder(encoder, decoder)
train_net = TrainSeq2SeqNetAdapter(net)
# 某些模块(如 LazyModule)在首个 forward 前参数尚未初始化,直接 numel() 会报错。
# 这里做“安全统计”:只统计已初始化参数,并把未初始化参数个数单独打日志。
num_params = 0
uninitialized_count = 0
for p in train_net.parameters():
    if not p.requires_grad:
        continue
    try:
        num_params += p.numel()
    except ValueError:
        uninitialized_count += 1
logger.info(
    "模型构建完成: initialized_trainable_params=%d, uninitialized_params=%d",
    num_params, uninitialized_count
)
# 启动训练:内部会完成前向、损失计算、反向传播和参数更新。
logger.info("开始训练...")
d2l.train_seq2seq(train_net, train_iter, lr, num_epochs, tgt_vocab, device)
d2l.plt.show()
logger.info("训练结束。")

时间步理解


时间步 t=0        t=1        t=2        t=3
--------------------------------------------------
输入    <bos>      我         爱         你
        │          │         │         │
        ▼          ▼         ▼         ▼
      emb        emb       emb       emb
        │          │         │         │
        ├────┐     ├────┐    ├────┐    ├────┐
        ▼    │     ▼    │    ▼    │    ▼    │
     Attention   Attention   Attention   Attention
        │          │         │         │
     context₀    context₁   context₂   context₃
        │          │         │         │
        └──concat──┴──concat─┴──concat─┴──concat
                    │
                   GRU(共享参数)
                    │
          hidden₁ → hidden₂ → hidden₃ → hidden₄
                    │
                  dense
                    │
输出     我         爱         你       <eos>

对比没有注意力的seq2seq

维度 无注意力 Seq2Seq 带注意力(Bahdanau)Seq2Seq
信息来源 只用 encoder 最后一个 hidden 每个时间步动态读取所有 encoder 输出
上下文(context) 固定一个向量 每一步重新计算
Query ❌ 没有 decoder 当前 hidden
Key/Value ❌ 没有 encoder 全部输出
是否有对齐能力 ❌ 没有 ✅ 有(attention权重)
长句效果 ❌ 容易崩 ✅ 明显更好
计算量 ✅ 低 ❌ 略高
参数量 稍多(attention MLP)
工业使用 ❌ 基本淘汰 ⚠️ 过渡方案(之后是 Transformer)

Attention 的本质,就是让模型从“记忆驱动”变成“检索驱动”

保存训练结束后的模型


# 保存到 checkpoint 的训练/模型配置,后续可用于复现实验或重建模型结构。
config = {
    "embed_size": embed_size,  # 词向量维度
    "num_hiddens": num_hiddens, # 隐藏层维度
    "num_layers": num_layers,   # RNN层数
    "dropout": dropout,          # 随机失活比例
    "batch_size": batch_size,    # 每次喂给模型多少样本
    "num_steps": num_steps,      # 每个样本截断/填充后的时间步长度
    "lr": lr,                    # 学习率
    "num_epochs": num_epochs,    # 训练轮数
    "device": str(device),       # 设备类型(GPU 或 CPU)
}

# 训练完成后保存 checkpoint,便于后续继续训练或直接做推理。
torch.save(
    {
        # 仅保存参数权重(推荐做法),加载时需先构建同结构模型再 load_state_dict。
        "model": net.state_dict(),
        # 同时保存源语言和目标语言词表,保证推理时 token/id 映射一致。
        "src_vocab": src_vocab,
        "tgt_vocab": tgt_vocab,
        # 保存关键超参数配置,方便恢复训练环境与模型定义。
        "config": config,
    },
    # checkpoint 文件名;可按需改成带时间戳或轮次的名字。
    "checkpoint.pth",
)
logger.info("模型已保存到 checkpoint.pth")

使用保存后的模型推理


# -*- coding: utf-8 -*-
import argparse
import logging
from typing import List

import torch
from torch import nn
from d2l import torch as d2l


if not logging.getLogger().handlers:
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    )
logger = logging.getLogger(__name__)


class AttentionDecoder(d2l.Decoder):
    """带有注意力机制解码器的基础接口。"""

    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    @property
    def attention_weights(self):
        raise NotImplementedError


class Seq2SeqAttentionDecoder(AttentionDecoder):
    """与训练脚本一致的注意力解码器定义,用于加载 checkpoint。"""

    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
        # 兼容不同版本 d2l 的 AdditiveAttention 构造参数。
        try:
            self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
        except TypeError:
            self.attention = d2l.AdditiveAttention(
                num_hiddens, num_hiddens, num_hiddens, dropout
            )
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(
            embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout
        )
        self.dense = nn.Linear(num_hiddens, vocab_size)
        self._attention_weights = []

    def init_state(self, enc_outputs, enc_valid_lens, *_args):
        _ = _args
        outputs, hidden_state = enc_outputs
        return outputs.permute(1, 0, 2), hidden_state, enc_valid_lens

    def forward(self, X, state):
        enc_outputs, hidden_state, enc_valid_lens = state
        X = self.embedding(X).permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1)
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            outputs.append(out)
            self._attention_weights.append(self.attention.attention_weights)
        outputs = self.dense(torch.cat(outputs, dim=0))
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]

    @property
    def attention_weights(self):
        return self._attention_weights


def load_model_from_checkpoint(ckpt_path: str, device: torch.device):
    """从 checkpoint 加载词表、配置和模型参数。"""
    # 兼容 PyTorch 2.6+ 默认 weights_only=True 的行为变化。
    # 该 checkpoint 来自本地训练流程,属于可信来源,因此允许完整反序列化。
    try:
        checkpoint = torch.load(ckpt_path, map_location=device)
    except Exception:
        checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
    config = checkpoint["config"]
    src_vocab = checkpoint["src_vocab"]
    tgt_vocab = checkpoint["tgt_vocab"]

    encoder = d2l.Seq2SeqEncoder(
        len(src_vocab),
        config["embed_size"],
        config["num_hiddens"],
        config["num_layers"],
        config["dropout"],
    )
    decoder = Seq2SeqAttentionDecoder(
        len(tgt_vocab),
        config["embed_size"],
        config["num_hiddens"],
        config["num_layers"],
        config["dropout"],
    )
    net = d2l.EncoderDecoder(encoder, decoder)
    net.load_state_dict(checkpoint["model"])
    net.to(device)
    net.eval()
    logger.info("已加载模型: %s", ckpt_path)
    return net, src_vocab, tgt_vocab, config


def run_inference(
    ckpt_path: str,
    sentences: List[str],
    device: torch.device,
    num_steps: int = None,
):
    net, src_vocab, tgt_vocab, model_config = load_model_from_checkpoint(ckpt_path, device)
    infer_num_steps = num_steps if num_steps is not None else model_config["num_steps"]
    logger.info("checkpoint配置: num_steps=%s", model_config.get("num_steps"))
    logger.info("开始推理: device=%s, num_steps=%d, 样本数=%d", device, infer_num_steps, len(sentences))

    for text in sentences:
        translation, _ = d2l.predict_seq2seq(
            net, text, src_vocab, tgt_vocab, infer_num_steps, device, True
        )
        print(f"{text} => {translation}")


def parse_args():
    parser = argparse.ArgumentParser(description="加载 checkpoint.pth 并执行翻译推理")
    parser.add_argument(
        "--ckpt",
        default="/home/lxg/code/AI/code/20260423/checkpoint.pth",
        help="checkpoint 文件路径",
    )
    parser.add_argument(
        "--num_steps",
        type=int,
        default=None,
        help="推理时最大时间步;不传则使用 checkpoint 中保存的配置",
    )
    parser.add_argument(
        "--text",
        nargs="*",
        default=None,
        help="待翻译英文句子,可传多个;不传则使用内置样例",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    device = d2l.try_gpu()
    texts = args.text if args.text else ["go .", "i lost .", "he's calm .", "i'm home ."]
    run_inference(args.ckpt, texts, device, args.num_steps)

运行结果


2026-04-23 14:53:39,874 | INFO | __main__ | 已加载模型: /home/lxg/code/AI/code/20260423/checkpoint.pth
2026-04-23 14:53:39,874 | INFO | __main__ | checkpoint配置: num_steps=10
2026-04-23 14:53:39,874 | INFO | __main__ | 开始推理: device=cuda:0, num_steps=10, 样本数=4
go . => va !
i lost . => j'ai perdu .
he's calm . => il est riche .
i'm home . => je suis chez moi .

最小推理依赖清单


project/
│
├── checkpoint.pth        ✅ 模型权重+词表+配置
├── model.py              ✅ 模型结构(Encoder + Decoder)
├── inference.py          ✅ 推理脚本(你这份)
│
└── requirements.txt

开源模型开源的是什么

模块 是否必须开源 作用
model ⭐ 必须 定义网络结构
weights ⭐ 必须 模型参数
tokenizer ⭐ 必须 文本处理
config ⭐ 必须 重建模型
inference ⭐ 建议 快速使用
training 🟡 可选 复现训练
evaluation 🟡 可选 测试模型
data 🔴 通常不 数据太大/版权
export 🟡 可选 部署
README ⭐ 必须 使用说明
requirements ⭐ 必须 环境

多头注意力

多头注意力 = 同一段话,让多个“不同专家”同时去理解,然后把他们的理解合起来

每个专家都在看同一句话,但关注点不同:

专家 关注点
专家1 谁爱谁(语义)
专家2 主谓关系(语法)
专家3 从句结构(who is singing)
专家4 远距离依赖
专家5 局部词关系

👉 最后:把所有人的理解拼起来 → 得到更完整的理解

代码


import math
import torch
from torch import nn
from d2l import torch as d2l

#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        # 头数h。多头注意力的核心思想:把同一个表示空间拆成h个子空间并行计算,
        # 让不同头关注不同位置/关系,再把结果拼回去。
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        # 下面四个线性层的作用:
        # 1) W_q/W_k/W_v:把输入映射到同一隐藏维num_hiddens,方便做点积注意力
        # 2) W_o:把多头拼接后的结果再做一次融合映射
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        #
        # 下面三行先做线性投影,再把“多头维”拆出来:
        # 原始: (B, N, H)
        # 拆头后: (B*h, N, H/h)
        # 这样可以把每个头当作一个独立样本,直接并行送进同一个点积注意力模块。
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            # 原因:我们把batch维从B扩成了B*h,每个样本对应h个头,
            # 所以mask长度也要复制h份,保证每个头用到同样的有效长度约束。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        # 将各头结果从(B*h, N, H/h)还原并拼接到最后一维,得到(B, N, H)。
        output_concat = transpose_output(output, self.num_heads)
        # 最后再过W_o,让不同头的信息进行线性融合。
        return self.W_o(output_concat)

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 记号说明:
    # B = batch_size, N = 序列长度(查询数或键值对数), H = num_hiddens, h = num_heads
    # 目标:把(B, N, H)变成(B*h, N, H/h),便于并行计算每个头。

    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    # 这一步把最后一维H拆成(h, H/h)。
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    # 交换维度,把“头维”提前,方便下一步与batch合并。
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    # 合并(B, h) -> (B*h),让每个头都像一个独立样本进入attention。
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # 输入X: (B*h, N, H/h)
    # 先还原成(B, h, N, H/h)...
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    # ...再交换成(B, N, h, H/h)...
    X = X.permute(0, 2, 1, 3)
    # ...最后拼接头维,得到(B, N, H)。
    return X.reshape(X.shape[0], X.shape[1], -1)


#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 与上方同名函数一致:将(B, N, H)拆成(B*h, N, H/h)以并行计算每个头。
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # 与上方同名函数一致:把(B*h, N, H/h)还原为(B, N, H)。
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
# 这里是“自注意力”形式:Q来自X,K/V都来自Y(若X==Y则是标准自注意力)。
# 输出形状应为(B, 查询数, H) => (2, 4, 100)
attention(X, Y, Y, valid_lens).shape

多头注意力参数

参数 单头 多头
W_q d×d d×d
W_k d×d d×d
W_v d×d d×d
W_o d×d
总计 3个矩阵 4个矩阵

代码里确实只有一个 MultiHeadAttention 对象,但“多头”是通过张量变形(reshape)实现的,而不是多个对象

理论和实现的差异

维度 理论多头(概念层) 实际代码实现(你这段) 关键意义
Head数量 h 个独立 attention 1 个 attention + reshape 用“数据维度”模拟多个头
Q/K/V 投影 每个 head 一套 (W_q^i, W_k^i, W_v^i) 一套大矩阵 (W_q, W_k, W_v),再切分 参数共享 + 分块
输入数据 每个 head 输入同一 X 同一 X → 线性变换 → reshape 同源不同视角
数据拆分 手动分给每个 head reshape + permute 自动并行拆分
attention计算 h 次独立计算(for循环) 一次计算(batch×head) 🚀 并行加速
head并行方式 逻辑并行(概念上) 物理并行(矩阵运算/GPU) 真正利用硬件
中间shape h 个:(batch, seq, d/h) (batch×h, seq, d/h) 核心技巧
结果合并 concat(h₁,…,hₕ) transpose_output 恢复原维度
输出融合 可能有融合层 必有 (W_o) 融合多头信息
参数规模 看起来是 h 倍 实际 ≈ 1 倍(+W_o) 高效设计
计算复杂度 h × Attention ≈ 单次大矩阵运算 更高吞吐
内存访问 多次调用 连续内存块操作 cache友好
代码复杂度 高(多模块) 低(统一模块) 易实现
可扩展性 增head需加模块 改 num_heads 即可 灵活
GPU利用率 低(循环) 高(并行) 核心优势

自注意力和位置编码

自注意力

自注意力(Self-Attention)= 句子里的每个词,都会去“看”同一句子里的其他词,决定该关注谁


import math
import torch
from torch import nn
from d2l import torch as d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
attention(X, X, X, valid_lens).shape

这段代码其实就是一个标准“多头自注意力(Multi-Head Self-Attention)”的最小示例。

为什么这是“自注意力”?


关键在这里👇

attention(X, X, X, valid_lens)

❗Q、K、V 全是 X
Q = X  
K = X  
V = X

👉 含义:

每个词都在“看自己这句话里的所有词”

这段代码在干嘛


🧾 输入
X = torch.ones((2, 4, 100))

👉 表示:

batch = 2(两句话)
每句 4 个词
每个词 100维
🧠 自注意力在做什么?

对每一句话里的每个词:

“我应该关注句子里的谁?”

👉 举个直觉例子:

第1个词:

看第1个词 ✔  
看第2个词 ✔  
看第3个词 ✔  
看第4个词 ✔  

👉 然后加权融合

shape流动


Step1️⃣ 输入
(2, 4, 100)

Step2️⃣ 多头拆分(内部发生)
num_heads = 5

👉 变成:

(2 × 5, 4, 20) = (10, 4, 20)

👉 含义:

5个头 → 变成10个“伪batch”
Step3️⃣ 每个 head 做 attention
(10, 4, 20) → (10, 4, 20)
Step4️⃣ 拼回来
(2, 4, 100)

自注意力 (Self-Attention)

  • 代码特征:attention(X, X, X)
  • 信息源:自己看自己。
  • 目的:为了让词与词之间发生关系。比如处理“苹果”时,看看句子后面有没有出现“好吃”或“手机”。
  • 场景:Transformer 的编码器(Encoder)内部,或者解码器(Decoder)的开头。

自注意力学到的参数是什么含义

参数 学到的“现实意义”
W_q 我想找什么信息
W_k 我提供什么信息
W_v 信息内容本身
attention权重 谁应该关注谁
W_o 如何整合多种理解

比较卷积神经网络、循环神经网络和自注意力

用一句话理解三者差异

🟩 CNN : “我只看你附近发生了什么”

🟦 RNN : “我按时间顺序,一步一步记住过去”

🟨 自注意力 : “我直接看全局,决定谁重要”

cnn-rnn-self-attention

案例

示例句子: The animal didn’t cross the street because it was tired

CNN(局部窗口滑动)


[The animal] → 提取局部特征
     ↓
   [animal didn't]
        ↓
     [didn't cross]
           ↓
        [cross the]
              ↓
           [the street]
                ↓
         [street because]
                ↓
           [because it]
                ↓
            [it was]
                ↓
           [was tired]

🧠 信息流特点 局部 → 局部 → 局部 → … → 全局(需要多层)

👉 ❗信息是“慢慢扩散”的 👉 ❗远距离关系很难直接捕捉

RNN(按时间顺序流动)


The → animal → didn't → cross → the → street → because → it → was → tired
  ↓      ↓        ↓        ↓       ↓       ↓         ↓       ↓     ↓
 h1 →   h2 →     h3 →     h4 →    h5 →    h6 →      h7 →    h8 →  h9

🧠 信息流特点 前一个状态 → 传给后一个

👉 ❗信息是“链式传递” 👉 ❗距离越远,信息越弱(梯度消失)

自注意力(全连接关系)


            ┌───────────────┐
The  ──────▶│               │◀────── animal
            │               │
animal ────▶│               │◀────── didn't
            │   Attention   │
didn't ────▶│   Matrix      │◀────── cross
            │               │
cross ─────▶│               │◀────── street
            │               │
street ────▶│               │◀────── because
            │               │
because ───▶│               │◀────── it
            │               │
it ────────▶│               │◀────── was
            │               │
was ───────▶│               │◀────── tired
            └───────────────┘


或者更直观一点👇

每个词 ↔ 所有词(全连接)

it → The
it → animal   ✅(重点)
it → didn't
it → cross
it → street
it → because
it → was
it → tired

🧠 信息流特点 任意词 ↔ 任意词(一步完成)

👉 ✅ 全局感知 👉 ✅ 无距离限制 👉 ✅ 并行计算

理解

模型类型 核心比喻 它是怎么看世界的? 优点 局限
🟩 卷积神经网络(CNN) 扫描仪 拿一个小窗口在局部滑动,只关注“附近的信息”,逐层扩大视野 ✅ 高效
✅ 擅长局部模式(边缘、纹理、n-gram)
✅ 并行性强
❌ 看不远(需要很多层)
❌ 长距离关系弱
🟦 循环神经网络(RNN) 传声筒 按顺序读,每次记一点,把“过去的信息”传给未来 ✅ 顺序建模强
✅ 适合时间序列
❌ 不能并行
❌ 长距离信息容易丢(梯度消失)
🟨 自注意力(Self-Attention) 雷达 一次扫描全局,直接计算“谁和谁相关” ✅ 全局建模
✅ 长距离依赖强
✅ 并行计算
❌ 计算量大(O(n²))
❌ 长序列成本高

自注意力机制通过一种“空间换时间”的策略:它牺牲了更多的计算内存(用来存储所有词两两相关的矩阵),换取了极强的并行能力和近乎完美的长程记忆。这正是大规模数据训练(如 GPT 系列)所最需要的特质。

当上下文长度 $n$ 增加时,自注意力的算力和内存开销是按平方比例($O(n^2)$)增长的。这也就是为什么很多 AI 模型(如早期的 GPT)会有“上下文窗口限制”(比如 2k、4k 或 8k 个 token)的根本原因。

为什么开销会呈“平方”增长?

自注意力的核心是计算注意力权重矩阵(Attention Matrix)。

  • 如果你有 10 个词,你需要算一个 $10 \times 10 = 100$ 的关系矩阵。
  • 如果你有 10,000 个词,这个矩阵就会变成 $10,000 \times 10,000 = 1 亿$ 个元素。
  • 后果:即便是一个不算太长的文档,这个巨大的矩阵也会迅速吃光显卡的显存(VRAM)。

算力开销(算不动)

  • 为了填满这个 $n \times n$ 的矩阵,模型必须让每一个词去和所有词做一次点积运算。
  • 后果:序列长度翻倍,计算量会变成原来的 4 倍。这会导致推理(生成回答)的速度变慢,训练时间也会大幅拉长。

训练 vs 推理:谁更压力大?

  • 训练(Training):由于训练时需要并行处理整个序列,并且要保存所有中间状态用于反向传播,显存压力极大。为了训练超长文本,通常需要动用成百上千块高性能显卡(如 H100)。
  • 推理(Inference):虽然推理时可以利用缓存(KV Cache)来避免重复计算之前的词,但处理超长上下文时,KV 缓存本身也会占用巨大的显存。如果你发现模型回复越来越慢,或者直接报错“Out of Memory”,通常就是因为上下文太长,显存撑不住了。

大模型训练之前需要先定义上下文大小

大模型在设计阶段就“固定了最大上下文长度(context window)”,训练和推理都必须遵守这个上限。

位置编码

核心问题:Self-Attention 根本不懂顺序


Self-Attention 本身是“无序”的(permutation invariant)

也就是说:

输入:我 爱 你
输入:你 爱 我

在 self-attention 眼里,这两句话本质一样(只是一个集合)

因为它做的是:

所有 token 两两计算关系(Q·K)
完全不关心谁在前谁在后

👉 这就出大问题:语言是有顺序的!

经典位置编码(正弦/余弦)

核心直观:把位置映射到圆周上

想象一个半径为 1 的圆。在三角函数中,一对 $(sin(\theta), cos(\theta))$ 实际上代表了圆周上的一个点(或者说是一个角度)。

  • 第 0 个位置:对应角度 $0^\circ$。
  • 第 1 个位置:我们将角度旋转一个固定的步长(比如 $30^\circ$)。
  • 第 2 个位置:再旋转 $30^\circ$。

为什么这比单纯用数字 $1, 2, 3$ 好?因为圆周运动是有界的。无论你的句子多长,坐标永远在 $[-1, 1]$ 之间波动,不会像直线增长的数字那样让神经元的数值溢出或失去平衡。

相对位置的几何本质:旋转

这是最精妙的地方。在几何上,从位置 $i$ 移动到位置 $i+\delta$,本质上就是旋转了一个固定的角度。

对于模型来说,判断两个词的距离,不再是做减法($100 - 98 = 2$),而是看这两个词在圆周上的夹角。

由于旋转操作在数学上可以通过“线性变换”(旋转矩阵)来实现,Transformer 的注意力机制只需要通过简单的矩阵运算,就能感知到:“哦,这两个词之间的相对夹角很小,说明它们离得很近。”

多维度的“齿轮组”

如果只用一个圆,当句子很长时,圆周上的点会挤在一起分不清。所以 Transformer 用了一组圆(就像机械表里的齿轮组):

  • 快齿轮(低维度):角度随位置变化极快。第一个词在 $0^\circ$,第二个词可能就转到了 $180^\circ$。它负责捕捉相邻词的微小差异。
  • 慢齿轮(高维度):角度随位置变化极慢。可能句子过了一百个词,它才转了 $1^\circ$。

它负责捕捉长距离的宏观位置信息。当你把这些快慢不一的“齿轮状态”组合在一起时,就为句子中的每一个位置生成了一个唯一的、具有几何规律的坐标。

为什么是“加”在词向量上?

从几何上理解,词向量(Embedding)是高维空间里的一个方向(代表词义)。 位置编码则是给这个方向微调了一个旋转偏置。

相加之后,向量在空间里的指向变了。模型在后续处理时,既能识别出它的“主方向”(词义:苹果),又能识别出它被微调的“偏移量”(位置:句首)。

有无位置编码对比

❌ 没有位置编码


        我   爱   你   吗
我      0.25 0.25 0.25 0.25
爱      0.25 0.25 0.25 0.25
你      0.25 0.25 0.25 0.25
吗      0.25 0.25 0.25 0.25


👉 解释
每个词对所有词一视同仁
完全不知道:
谁在前 ❌
谁在后 ❌
谁更近 ❌

👉 本质:

Self-Attention 退化成“全连接平均池化”

✅ 加了 sin/cos 位置编码


        我   爱   你   吗
我      0.4  0.3  0.2  0.1
爱      0.3  0.4  0.2  0.1
你      0.1  0.2  0.4  0.3
吗      0.05 0.1  0.3  0.55

👉 解释

你能看到明显规律:

对角线强(自己最重要)
邻近词权重大
远距离权重低

位置编码源码


import math
import torch
from torch import nn
from d2l import torch as d2l


def log_tensor_info(name, tensor, max_items=8):
    """打印张量的形状、dtype和前几个元素,便于跟踪中间结果。"""
    flat = tensor.reshape(-1)
    preview = flat[:max_items].detach().cpu().tolist()
    print(f"[LOG] {name}: shape={tuple(tensor.shape)}, dtype={tensor.dtype}")
    print(f"[LOG] {name}{len(preview)}个元素: {preview}")


#@save
class PositionalEncoding(nn.Module):
    """位置编码"""
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        # dropout作用:训练时对“token表示 + 位置编码”的结果做随机失活,
        # 防止过拟合;推理模式下(eval)会自动关闭。
        self.dropout = nn.Dropout(dropout)

        # 创建一个足够长的位置编码表 P,形状为 (1, max_len, num_hiddens)。
        # 第0维保留batch维,方便后续与输入X做广播相加。
        self.P = torch.zeros((1, max_len, num_hiddens))

        # X中每个元素对应公式里的 pos / (10000^(2i/num_hiddens))。
        # 其中:
        # - 行索引是位置pos(0 ~ max_len-1)
        # - 列索引是偶数维对应的频率索引i
        # 得到X形状:(max_len, num_hiddens/2)
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)

        # 偶数维用sin,奇数维用cos。
        # 这样每个位置都对应一组周期不同的三角函数值,模型可据此感知相对/绝对位置信息。
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

        # 初始化日志:帮助你确认位置编码表已正确构建。
        print("[LOG] PositionalEncoding 初始化完成")
        print(f"[LOG] num_hiddens={num_hiddens}, dropout={dropout}, max_len={max_len}")
        log_tensor_info("self.P", self.P, max_items=10)

    def forward(self, X):
        # X形状通常为 (batch_size, num_steps, num_hiddens)。
        print("\n[LOG] ===== 进入 PositionalEncoding.forward =====")
        log_tensor_info("输入X", X)

        # 截取与当前序列长度匹配的位置编码:(1, num_steps, num_hiddens)
        pos_slice = self.P[:, :X.shape[1], :].to(X.device)
        log_tensor_info("截取的位置编码pos_slice", pos_slice)

        # 把位置编码加到输入上(按batch维广播)。
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        log_tensor_info("相加后的X", X)

        # 经过dropout后返回。eval()模式下这里不会随机置零。
        out = self.dropout(X)
        log_tensor_info("dropout后的输出", out)
        print("[LOG] ===== 退出 PositionalEncoding.forward =====\n")
        return out
    
# 示例参数
encoding_dim, num_steps = 32, 60

# 构造位置编码模块并切到推理模式(便于稳定观察数值)
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()

# 用全零输入,只保留“纯位置编码”效果,便于可视化观察
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]

# 打印关键张量信息,帮助理解形状流转
log_tensor_info("示例输出X", X)
log_tensor_info("用于绘图的P", P)

# 画部分维度(第6~9维)随位置变化的曲线
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

d2l.plt.show()

# 打印前8个位置编号的二进制表示:
# 常用于类比“不同频率基函数组合后可编码位置信息”这一直觉。
for i in range(8):
    print(f'{i}的二进制是:{i:>03b}')

# d2l.show_heatmaps 需要四维输入:(batch, heads, query_len, key_len)
# 这里把二维位置编码扩成(1, 1, num_steps, encoding_dim)做热力图展示。
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
log_tensor_info("heatmap输入P", P)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
d2l.plt.show()

Step 0:你最终想得到什么?


目标其实很简单:

👉 给每个位置(pos),生成一个 32维向量

比如:

position	向量
0	[0.0, 1.0, 0.0, 1.0, ...]
1	[0.84, 0.54, 0.01, 0.99, ...]
2	[...]

Step 1:初始化位置编码表


self.P = torch.zeros((1, max_len, num_hiddens))

假设:

max_len = 1000
num_hiddens = 32

👉 那就是:

P.shape = (1, 1000, 32)

理解成:

batch维(占位)
   ↓
[ 
  [pos0的32维向量],
  [pos1的32维向量],
  ...
]

Step 2:生成“位置矩阵”


torch.arange(max_len).reshape(-1, 1)

得到:

[[0],
 [1],
 [2],
 ...
 [999]]

👉 shape:

(1000, 1)

👉 含义:

每一行 = 一个位置 pos

Step 3:生成“频率”


torch.arange(0, num_hiddens, 2)

👉 得到:

[0, 2, 4, 6, ..., 30]

👉 一共 16 个(因为步长是2)

然后:

torch.pow(10000, index / num_hiddens)

👉 举个例子:

index	计算结果
0	10000^(0/32)=1
2	10000^(2/32)
30	很大

👉 关键结论:

维度	分母大小	变化速度
前面维度	小	变化快(高频)
后面维度	大	变化慢(低频)

Step 4:组合成 X


X = pos / frequency

👉 shape:

(1000, 16)

👉 每个元素:

X[pos, i] = pos / (10000^(2i/d))


👉 shape / frequency

(1000, 1)  / (1, 16)  = (1000, 16)

X 形状: torch.Size([1000, 16])
tensor([[    0.0000,     0.0000,     0.0000,  ...,     0.0000,     0.0000,  0.0000],
        [    1.0000,     0.5623,     0.3162,  ...,     0.0006,     0.0003,  0.0002],
        [    2.0000,     1.1247,     0.6325,  ...,     0.0011,     0.0006,  0.0004],
        ...,
        [  997.0000,   560.6543,   315.2791,  ...,     0.5607,     0.3153,  0.1773],
        [  998.0000,   561.2166,   315.5953,  ...,     0.5612,     0.3156,  0.1775],
        [  999.0000,   561.7790,   315.9115,  ...,     0.5618,     0.3159,  0.1777]])

看行(位置 pos):随着行数增加(从 0 到 999),数字在变大。这意味着位置越往后的词,旋转的角度越大。

看列(维度 dim):

  • 左边的列(低维):数字增加得非常快。比如第二行(pos=1)第一列是 1.0000,而最后一行(pos=999)第一列变成了 999.0000。这说明“秒针”转得飞快。
  • 右边的列(高维):数字增加得非常慢。比如最后一行最后两列才 0.1777。这说明“时针”几乎没怎么动。

Step 5:应用 sin / cos


self.P[:, :, 0::2] = sin(X)
self.P[:, :, 1::2] = cos(X)

👉 拆开看:

维度	内容
0	sin
1	cos
2	sin
3	cos

👉 举个具体例子(pos=0):

sin(0) = 0
cos(0) = 1

👉 所以:

pos=0 → [0,1,0,1,0,1,...]

👉 pos=1:

sin(1/1) ≈ 0.84
cos(1/1) ≈ 0.54
sin(1/100) ≈ 很小


初始Position位置编码: torch.Size([1, 1000, 32])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

Position位置编码后: torch.Size([1, 1000, 32])
tensor([[[     0.0000,      1.0000,      0.0000,  ...,      1.0000,  0.0000,      1.0000],
         [     0.8415,      0.5403,      0.5332,  ...,      1.0000,  0.0002,      1.0000],
         [     0.9093,     -0.4161,      0.9021,  ...,      1.0000,  0.0004,      1.0000],
         ...,
         [    -0.8980,     -0.4401,      0.9928,  ...,      0.9507,  0.1764,      0.9843],
         [    -0.8555,      0.5178,      0.9038,  ...,      0.9506,  0.1765,      0.9843],
         [    -0.0265,      0.9996,      0.5363,  ...,      0.9505,  0.1767,      0.9843]]])

想象你面前有一个非常复杂的表,它不是只有时针、分针、秒针,而是有 16 根针。

理解坐标的概念

位置编码的理解:圆周(坐标)当我们把 $sin$ 和 $cos$ 配对时,我们实际上是把位置映射到了一个单位圆上。

  • $sin(\theta)$:对应圆上点的 横坐标 (x)。
  • $cos(\theta)$:对应圆上点的 纵坐标 (y)。

每一对 $(sin, cos)$ 就是圆周上的一个“点的坐标”:

  • 位置 0:角度是 0,坐标是 $(0, 1)$。
  • 位置 1:角度变大,坐标沿着圆周移动到了 $(sin(1), cos(1))$。
  • 位置 $n$:坐标继续旋转。

为什么说它是“独一无二”的?(坐标的概念)

在你的 Position编码后 矩阵中,每一行对应一个单词。每一对 (sin, cos) 其实就是一根指针在表盘上的位置。

  • 位置 0:所有 16 根针都指在正北方(12点钟方向)。
  • 位置 1:所有针都开始转动,但速度不一样。
  • 位置 999:所有针都转了很多圈了。

虽然每根针都在转,但由于它们转速不同,在 1000 个位置里,绝对不会出现两次“所有针指向完全相同”的情况。这就是“独一无二”的坐标。

当你把这组数字(位置编码)加到词向量上时:

  • 低维度波形:给词向量加上了细微的、快速变化的“指纹”,用来分清谁是左邻右舍。
  • 高维度波形:给词向量加上了宏观的、缓慢变化的“背景色”,用来感知词在整句话中的大体位置。

这就是几何直观: 位置编码把每一个位置,都变成了一个由“16个不同转速齿轮状态”组成的精密密码。模型只需要看一眼这个密码,就能算出来这两个词到底离了多远。

position_encode

👉 类比(非常关键)

就像:音频合成

  • 一个音 = 多个频率叠加
  • 不同组合 → 不同音

Step 6:forward 时发生了什么?


X = X + pos_encoding

👉 关键点:

embedding(语义)
position(位置)

👉 合在一起:

👉 “这个词 + 它在哪”

多个钟表动画理解位置编码

动画源码


import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.gridspec as gridspec
import matplotlib.font_manager as fm


def configure_matplotlib_chinese_font():
    """自动选择可用中文字体,避免图中中文显示异常。"""
    candidates = [
        "Noto Sans CJK SC",
        "Noto Sans CJK",
        "WenQuanYi Micro Hei",
        "WenQuanYi Zen Hei",
        "SimHei",
        "Microsoft YaHei",
        "PingFang SC",
        "Source Han Sans CN",
        "Arial Unicode MS",
    ]
    installed = {f.name for f in fm.fontManager.ttflist}
    for name in candidates:
        if name in installed:
            plt.rcParams["font.family"] = "sans-serif"
            plt.rcParams["font.sans-serif"] = [name, "DejaVu Sans"]
            break
    # 避免负号显示为方块
    plt.rcParams["axes.unicode_minus"] = False


configure_matplotlib_chinese_font()

# ===============================
# 1. 位置编码生成函数
# ===============================
def get_pe(seq_len, dim):
    pe = np.zeros((seq_len, dim))
    position = np.arange(seq_len)[:, np.newaxis]
    # 频率指数:低维度对应高频率,高维度对应低频率
    div_term = np.exp(np.arange(0, dim, 2) * -(np.log(10000.0) / dim))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe

num_steps = 200    # 足够长的步数
encoding_dim = 32  # 编码总维度
pe_matrix = get_pe(num_steps, encoding_dim)

# ===============================
# 2. 画布初始化(使用 GridSpec 布局)
# ===============================
fig = plt.figure(figsize=(10, 12))
# 创建 3x2 的布局,最后一行用来展示向量
gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 0.8])

# A. 四个圆周表盘
axes_circles = []
dims_to_show = [0, 8, 16, 30] # 选择四个观察的维度起始点
colors = ['red', 'blue', 'green', 'orange']
pointers = []

for i in range(4):
    row, col = i // 2, i % 2
    ax = fig.add_subplot(gs[row, col])
    d = dims_to_show[i]
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.set_aspect('equal')
    ax.grid(True, linestyle='--', alpha=0.5)
    
    # 绘制背景圆圈
    circle = plt.Circle((0, 0), 1, fill=False, color='gray', alpha=0.3)
    ax.add_artist(circle)
    
    # 设置标题和速度描述
    speed_desc = ["Fast (区分邻居)", "Medium", "Medium-Slow", "Slow (感知大局)"][i]
    ax.set_title(f"Compass {i+1}: Dim {d}-{d+1}\nSpeed: {speed_desc}")
    
    # 初始化指针
    p, = ax.plot([], [], 'o-', lw=3, color=colors[i], markersize=8)
    pointers.append(p)
    axes_circles.append(ax)

# B. 位置向量可视化区域 (占据最后一整行)
ax_vec = fig.add_subplot(gs[2, :])
ax_vec.set_xlim(-0.5, encoding_dim - 0.5)
ax_vec.set_ylim(-1.1, 1.1)
ax_vec.set_title(f"Current Position Vector P (Dim 0-{encoding_dim-1})")
ax_vec.set_xticks(np.arange(encoding_dim))
ax_vec.set_xticklabels([f"D{i}" for i in range(encoding_dim)], rotation=45, fontsize=8)
ax_vec.set_ylabel("PE Value")
ax_vec.grid(axis='y', linestyle='--', alpha=0.5)

# 初始化向量柱状图
bars = ax_vec.bar(np.arange(encoding_dim), np.zeros(encoding_dim), color='purple', alpha=0.7)

# 全局位置和文本显示
info_text = fig.text(0.5, 0.015, '', ha='center', fontsize=12, fontweight='bold', color='purple')

# ===============================
# 3. 动画更新
# ===============================
def update(frame):
    artists = []
    current_pe_vector = pe_matrix[frame, :]
    
    # 3.1 更新四个圆周表盘的指针
    for i, p in enumerate(pointers):
        d = dims_to_show[i]
        # x 为 sin, y 为 cos
        x = current_pe_vector[d]
        y = current_pe_vector[d+1]
        p.set_data([0, x], [0, y])
        artists.append(p)
    
    # 3.2 更新向量柱状图
    for i, bar in enumerate(bars):
        val = current_pe_vector[i]
        bar.set_height(val)
        # 根据数值给柱子点色(正数绿色,负数红色)
        bar.set_color('green' if val > 0 else 'red')
        artists.append(bar)
    
    # 3.3 更新文本信息
    info_text.set_text(f"--- Position Index: {frame} ---")
    artists.append(info_text)
    
    return artists

ani = FuncAnimation(fig, update, frames=num_steps, interval=60, blit=True)

# 保存为GIF(保存在当前脚本同目录)
gif_path = os.path.join(os.path.dirname(__file__), "demo_07_animation.gif")
ani.save(gif_path, writer=PillowWriter(fps=15))
print(f"GIF已保存: {gif_path}")

plt.tight_layout(rect=[0, 0.04, 1, 0.96]) # 留出文本空间
plt.show()

location_encode_animation

音频合成角度理解位置编码

源码


import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.font_manager as fm
import os


def configure_matplotlib_chinese_font():
    """自动选择可用中文字体,避免中文显示异常。"""
    candidates = [
        "Noto Sans CJK SC",
        "Noto Sans CJK",
        "WenQuanYi Micro Hei",
        "WenQuanYi Zen Hei",
        "SimHei",
        "Microsoft YaHei",
        "PingFang SC",
        "Source Han Sans CN",
        "Arial Unicode MS",
    ]
    installed = {f.name for f in fm.fontManager.ttflist}
    for name in candidates:
        if name in installed:
            plt.rcParams["font.family"] = "sans-serif"
            plt.rcParams["font.sans-serif"] = [name, "DejaVu Sans"]
            break
    plt.rcParams["axes.unicode_minus"] = False


configure_matplotlib_chinese_font()

# ===============================
# 1. 模拟参数
# ===============================
num_positions = 100   # 序列长度(相当于时间轴)
dim = 32              # 向量维度(位置矩阵每行32维)
x_axis = np.linspace(0, 10, dim)
# 位置纹理由多个正弦分量叠加而成(可视化重点)
# 第3个分量频率不宜过高,否则在32个采样点下会出现“看起来不像正弦”的锯齿感。
frequencies = [1.0, 2.5, 4.0]   # 不同频率(平滑且层次分明)
amplitude = 0.3                 # 每个分量的振幅

# 模拟一个“词向量”:比如它是一个平滑的低频波
def get_word_vector():
    # 模拟“猫”这个词的语义信号(相对稳定)
    return 0.5 * np.sin(x_axis * 0.8) + 0.2 * np.cos(x_axis * 0.3)

# 模拟位置编码:随位置变化的三角函数
def get_pe_components(pos):
    """返回当前位置下的各正弦分量列表。"""
    components = []
    for f in frequencies:
        # 相位随位置变化:位置越靠后,波形整体相位越向前推进
        comp = amplitude * np.sin(x_axis * f + pos * (f / 5.0))
        components.append(comp)
    return components


def get_pe_signal(pos, dim):
    # 位置纹理信号 = 多个正弦分量逐维相加
    components = get_pe_components(pos)
    return np.sum(components, axis=0)

word_vec = get_word_vector()

# ===============================
# 2. 画布初始化
# ===============================
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
    4, 1, figsize=(10, 14), gridspec_kw={"height_ratios": [1, 1, 1, 0.9]}
)
# 增加子图间距与底部留白,避免第2图图例与第3图重叠。
plt.subplots_adjust(hspace=0.72, bottom=0.10, top=0.95)

# 词向量图
line1, = ax1.plot(x_axis, word_vec, lw=2, color='blue')
ax1.set_title("1. 原始语义信号 (Word Embedding: 'CAT')", fontsize=12)
ax1.set_ylim(-1.5, 1.5)
ax1.grid(True, alpha=0.3)

# 位置编码拆解图:多个正弦分量 + 总叠加曲线
component_colors = ['tab:blue', 'tab:green', 'tab:red']
component_lines = []
for idx, f in enumerate(frequencies):
    comp_line, = ax2.plot(
        x_axis,
        np.zeros(dim),
        lw=1.5,
        color=component_colors[idx % len(component_colors)],
        alpha=0.9,
        linestyle='--',
        label=f"分量{idx+1}: sin(freq={f})",
    )
    component_lines.append(comp_line)

line2, = ax2.plot(x_axis, np.zeros(dim), lw=2.6, color='orange', label="分量叠加结果")
ax2.set_title("2. 位置纹理信号拆解:多个正弦波叠加", fontsize=12)
ax2.set_ylim(-1.5, 1.5)
ax2.grid(True, alpha=0.3)
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.16), ncol=4, frameon=False, fontsize=8)

# 叠加结果图
line3, = ax3.plot(x_axis, np.zeros(dim), lw=3, color='purple')
ax3.set_title("3. 最终合成信号 (Embedding + PE)", fontsize=12)
ax3.set_ylim(-1.5, 1.5)
ax3.grid(True, alpha=0.3)

# 位置矩阵行向量图(参考 demo_07)
# 含义:在当前位置frame,取位置矩阵P的第frame行(即当前PE向量)并显示各维度数值。
ax4.set_title("4. 当前位置 PE 向量(位置矩阵的一行)", fontsize=12)
ax4.set_xlim(-0.5, dim - 0.5)
ax4.set_ylim(-1.5, 1.5)
ax4.set_ylabel("PE值")
ax4.set_xticks(np.arange(dim))
ax4.set_xticklabels([f"D{i}" for i in range(dim)], rotation=45, fontsize=8)
ax4.grid(axis='y', linestyle='--', alpha=0.3)
bars = ax4.bar(np.arange(dim), np.zeros(dim), color='purple', alpha=0.7)

# 文字提示(使用整张图坐标,而不是某个子图坐标)
pos_text = fig.text(0.5, 0.03, '', ha='center', fontsize=16, fontweight='bold', color='darkred')

# ===============================
# 3. 动画更新
# ===============================
def update(frame):
    # 获取当前位置的PE信号
    components = get_pe_components(frame)
    current_pe = np.sum(components, axis=0)
    
    # 更新PE拆解图(分量 + 总和)
    for comp_line, comp in zip(component_lines, components):
        comp_line.set_ydata(comp)
    line2.set_ydata(current_pe)
    
    # 更新叠加图(相加操作)
    combined = word_vec + current_pe
    line3.set_ydata(combined)

    # 更新“当前位置PE向量”柱状图
    for i, bar in enumerate(bars):
        val = current_pe[i]
        bar.set_height(val)
        bar.set_color('green' if val > 0 else 'red')
    
    pos_text.set_text(f"当前句子位置 (Position Index): {frame}")

    return [*component_lines, line2, line3, pos_text, *bars]

ani = FuncAnimation(fig, update, frames=num_positions, interval=100, blit=True)

# 保存为GIF(保存在当前脚本同目录)
gif_path = os.path.join(os.path.dirname(__file__), "demo_08_animation.gif")
ani.save(gif_path, writer=PillowWriter(fps=10))
print(f"GIF已保存: {gif_path}")

plt.show()

location_encode_aduio

在 标准 Sin/Cos 位置编码 里:

  • d_model = 32 时,一共是 32 个维度分量 。
  • 其中按 (sin, cos) 成对出现,所以是 16 对 。
  • 每一对对应一个频率尺度,因此可理解为 16 组频率分量 。

也就是说:

  • 从“维度数”看: 32 个分量
  • 从“频率对(sin/cos对)”看: 16 个分量组

图形2中只画出了三个频率分量,理论上有16个