Transformer简介

为什么需要Transformer模型

背景与动机

随着自然语言处理(NLP)技术的快速发展,传统的序列模型逐渐暴露出明显的局限性:

传统RNN/LSTM的瓶颈

  • 计算效率低:必须按时间步顺序处理,无法并行计算,训练速度慢
  • 长期依赖问题:虽然LSTM通过门控机制缓解了梯度消失,但对于超长序列仍然难以捕捉远距离依赖关系
  • 信息传递路径长:信息需要逐步传递,容易造成信息损失

Transformer的创新突破

2017年,Google团队在论文《Attention is All You Need》中提出了Transformer架构,完全抛弃了循环结构,转而采用:

  • 自注意力机制(Self-Attention):允许模型直接建模序列中任意两个位置之间的关系
  • 多头注意力(Multi-Head Attention):从多个角度并行捕捉不同的语义信息
  • 并行计算:所有位置可以同时处理,大幅提升训练效率

Transformer模型的核心特点

89a67ca3e34e8114dcf669f7c33fb3dd_720.png

Transformer模型的核心特点

整体架构

Transformer架构图

Transformer采用经典的编码器-解码器(Encoder-Decoder)架构:

  • 编码器(Encoder):负责理解输入序列,提取语义特征
  • 解码器(Decoder):基于编码器的输出,生成目标序列
  • 层数设计:原始论文中编码器和解码器各包含6层相同结构的子层

三大核心特性

1. 自注意力机制(Self-Attention)

这是Transformer最核心的创新。传统模型只能关注局部上下文或固定窗口内的信息,而自注意力机制允许:

  • 每个位置可以直接关注到序列中的所有其他位置
  • 动态计算不同位置之间的相关性权重
  • 无视距离远近,直接建模长距离依赖关系

2. 并行计算能力

与RNN/LSTM的顺序处理不同,Transformer具有天然的并行性:

  • 所有位置的表示可以同时计算,无需等待前一时刻的结果
  • 充分利用现代GPU的并行计算能力
  • 训练速度相比RNN提升数倍甚至数十倍

3. 灵活的任务适应性

Transformer的架构设计极具通用性:

  • 可变长度输入:通过位置编码支持任意长度的序列
  • 可扩展性强:通过增加层数和注意力头数提升模型容量
  • 多任务适用:不仅限于机器翻译,还广泛应用于文本生成、分类、问答、摘要等任务
  • 迁移学习基础:成为BERT、GPT等预训练模型的基础架构

内部机制概述

Transformer的整体工作流程

Transformer工作流程图1

上图展示了Transformer的基本输入输出流程。输入序列经过嵌入层和位置编码后,进入编码器进行特征提取。

编码器内部结构

编码器通过多头注意力机制捕捉序列内部的依赖关系,然后通过前馈网络进行特征变换。

解码器工作机制

解码器在生成目标序列时,不仅关注自身已生成的内容(自注意力),还需要关注编码器的输出(交叉注意力)。

注意力机制可视化

注意力权重的可视化展示了模型如何在不同位置之间建立关联。颜色越深表示注意力权重越大。

多层堆叠效果

通过堆叠多层编码器和解码器,模型能够学习到从低级到高级的特征表示。

完整架构总览

这是Transformer的完整架构图,清晰展示了数据流动的全过程:从输入嵌入、位置编码,经过多层编码器和解码器,最终输出预测结果。

Encoder(编码器)

编码器的组成与作用

编码器由多头注意力机制(Multi-Head Attention)前馈神经网络(Feed Forward Network)两个核心模块构成。原始Transformer使用6层编码器堆叠,每层都包含相同的结构但参数独立。

分层特征提取

  • 底层:关注词汇的基本关系和短期依赖,识别基本的语法模式
  • 中间层:识别更长范围的依赖关系,捕捉词语之间的语义关联
  • 高层:关注整个句子的结构和深层语义,如句子级别的语法关系或情感倾向

通过多层结构,Transformer在每一层中对信息进行渐进的抽象和加工,最终获得高层次的、能够适应各种任务的表示。

编码器结构示意图1

编码器数据流动

编码器层级结构

编码器输出特征

Decoder(解码器)

解码器的结构特点

解码器同样由6层堆叠而成,但每层包含三个子模块(比编码器多一个):

  1. 带掩码的多头注意力层:采用Masked操作,确保生成过程的自回归特性
  2. 编码器-解码器注意力层:K和V矩阵来自编码器输出,Q矩阵来自解码器自身
  3. 前馈神经网络层:与编码器中的结构相同

掩码机制(Mask)

Transformer中使用两种掩码:

  • Padding Mask:在较短序列后填充0,避免模型关注填充位置
  • Sequence Mask:确保解码时只能依赖当前时刻之前的输出,不能”看到未来”

编码器输出矩阵C

编码器的输出是一个 n×d 的矩阵,其中:

  • n:输入序列长度
  • d:特征维度(由隐藏层维度决定)

该矩阵包含了输入序列每个元素的上下文嵌入表示,作为解码器交叉注意力的键值对。

解码器结构图

Transformer模型实现

本节将详细介绍Transformer各个组件的PyTorch实现,包括输入嵌入、位置编码、注意力机制等核心模块。

输入嵌入(Input Embeddings)

功能说明

将离散的token ID转换为连续的向量表示。例如,将句子”Your cat is a lovely cat”转换为512维向量序列。

实现细节

  • 使用PyTorch的nn.Embedding层实现token到向量的映射
  • 每个token ID对应一个固定的512维向量(可学习参数)
  • 输出向量乘以√d_model进行缩放(论文3.4节要求)

缩放的数学原理

缩放操作使嵌入向量的L2范数与维度无关,避免在后续与位置编码相加时某个分量过大。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class InputEmbeddings(nn.Module):
"""词嵌入层:将token ID转换为连续向量"""

def __init__(self, d_model: int, vocab_size: int) -> None:
"""
参数:
d_model: 嵌入向量维度(如512)
vocab_size: 词汇表大小(如10000)
"""
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)

def forward(self, x):
"""
前向传播
输入: (batch, seq_len) - token IDs
输出: (batch, seq_len, d_model) - 嵌入向量
"""
return self.embedding(x) * math.sqrt(self.d_model)

位置编码(Positional Encoding)

为什么需要位置编码

Transformer的注意力机制本身是位置无关的(permutation invariant),无法区分”I ate an apple”和”An apple ate me”这类语序不同的句子。因此需要显式地为模型注入位置信息。

实现方式

使用预定义的数学公式生成位置编码,而非可学习参数:

位置编码公式

其中:

  • pos:单词在序列中的位置(0-based索引)
  • i:维度索引(0 ≤ i < d_model/2)
  • 偶数维度使用正弦函数,奇数维度使用余弦函数

融合方式

位置编码与词嵌入逐元素相加

1
最终表示 = 词嵌入 + 位置编码

关键特性

  • 相对位置感知:正弦/余弦函数的周期性使模型能捕捉相对位置关系
  • 可扩展性:允许处理比训练时更长的序列
  • 数值平衡:确保位置编码值域与词嵌入值域相匹配

可视化示例

1
2
3
4
5
6
7
8
9
10
# 假设 d_model=4, seq_len=3
词嵌入 = [[0.2, 1.1, -0.5, 0.8], # "Hello"
[0.7, -0.3, 1.2, 0.4], # "World"
[0.9, 0.5, -0.1, 1.0]] # "!"

位置编码 = [[0.0, 1.0, 0.0, 1.0], # 位置0
[0.84, 0.54, 0.002, 1.0], # 位置1
[0.91, -0.42, 0.003, 0.99]] # 位置2

最终表示 = 词嵌入 + 位置编码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class PositionalEncoding(nn.Module):
"""位置编码层(基于Attention is All You Need论文3.5节)"""

def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
"""
参数:
d_model: 模型维度(必须与词嵌入维度相同)
seq_len: 预设的最大序列长度
dropout: dropout概率
"""
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)

# 预计算位置编码矩阵 (seq_len, d_model)
pe = torch.zeros(seq_len, d_model)

# 生成位置索引 [0, 1, 2, ..., seq_len-1]
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)

# 计算频率项分母(对数空间计算,数值更稳定)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

# 偶数维度使用正弦,奇数维度使用余弦
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

# 添加batch维度 (1, seq_len, d_model)
pe = pe.unsqueeze(0)

# 注册为buffer(随模型保存但不参与梯度计算)
self.register_buffer('pe', pe)

def forward(self, x):
"""
前向传播
输入/输出: (batch_size, seq_len, d_model)
"""
# 截取与输入长度匹配的位置编码并相加
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
return self.dropout(x)

层归一化(Layer Normalization)

核心原理

层归一化对每个样本独立进行标准化处理,分为三步:

  1. 计算统计量:为每个样本单独计算所有特征的均值和方差
  2. 标准化:将特征值转换为均值为0、方差为1的标准分布
  3. 可学习变换:通过gamma(缩放)和beta(偏移)参数调整输出

关键特性

  • 处理不同长度文本时更稳定
  • 与Transformer的残差连接配合良好
  • 训练和推理时行为一致(不依赖batch统计量)

数值稳定性

使用epsilon(ε = 10⁻⁶)避免除零错误,确保计算稳定性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class LayerNormalization(nn.Module):
"""层归一化(参考论文《Layer Normalization》)"""

def __init__(self, features: int, eps: float=10**-6) -> None:
"""
参数:
features: 输入特征维度(对应d_model)
eps: 防止除零的小常数
"""
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(features)) # 缩放参数γ
self.bias = nn.Parameter(torch.zeros(features)) # 偏移参数β

def forward(self, x):
"""
前向传播: (x-μ)/σ * α + β
输入/输出: (batch_size, seq_len, features)
"""
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.alpha * (x - mean) / (std + self.eps) + self.bias

前馈神经网络(Feed Forward Network)

结构说明

前馈网络由两个线性层组成,中间使用ReLU激活函数:

1
FFN(x) = ReLU(xW₁ + b₁)W₂ + b₂

维度变换

  • 第一层:d_model → d_ff(扩展,通常d_ff = 4 × d_model)
  • 第二层:d_ff → d_model(压缩回原维度)

论文中使用 d_model=512,d_ff=2048。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class FeedForwardBlock(nn.Module):
"""前馈神经网络块"""

def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
"""
参数:
d_model: 模型维度(输入输出维度)
d_ff: 中间层扩展维度(通常为d_model的4倍)
dropout: 随机失活概率
"""
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)

def forward(self, x):
"""
前向传播: (batch, seq_len, d_model) → (batch, seq_len, d_ff) → (batch, seq_len, d_model)
"""
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

多头注意力机制(Multi-Head Attention)

工作原理

多头注意力机制将输入通过三个线性变换生成Q(查询)、K(键)、V(值)矩阵,然后:

  1. 线性投影:输入通过W_Q、W_K、W_V三个权重矩阵生成Q、K、V
  2. 分头:将Q、K、V沿特征维度切分为h个头(如512维切分为8个64维的头)
  3. 并行注意力:每个头独立计算缩放点积注意力
  4. 合并:拼接所有头的输出,通过W_O线性层融合

核心流程

多头注意力流程图

分头策略

  • 沿特征维度而非序列维度切分
  • 每个头访问完整句子,但只看到嵌入的不同部分
  • 例如:d_model=512, h=8 → 每个头的维度d_k=64

关键特性

  • 多头设计使模型同时关注不同语义关系(语法/语义/指代等)
  • 输入输出维度一致(都是d_model),便于堆叠
  • 编码器中Q=K=V(自注意力),解码器中K、V来自编码器(交叉注意力)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class MultiHeadAttentionBlock(nn.Module):
"""多头注意力机制(论文3.2.2节)"""

def __init__(self, d_model: int, h: int, dropout: float) -> None:
"""
参数:
d_model: 模型维度(必须能被h整除)
h: 注意力头数量
dropout: dropout概率
"""
super().__init__()
self.d_model = d_model
self.h = h

# 确保d_model能被h整除
assert d_model % h == 0, "d_model必须能被h整除"

self.d_k = d_model // h # 每个头的维度

# 定义Q、K、V和输出的线性变换
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)

self.dropout = nn.Dropout(dropout)

@staticmethod
def attention(query, key, value, mask, dropout: nn.Dropout):
"""
计算缩放点积注意力(论文3.2.1节)
返回: (注意力输出, 注意力权重)
"""
d_k = query.shape[-1]

# 计算注意力分数: Q·K^T / √d_k
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)

# 应用掩码(将无效位置设为极小值)
if mask is not None:
attention_scores.masked_fill_(mask == 0, -1e9)

# Softmax归一化
attention_scores = attention_scores.softmax(dim=-1)

# 应用dropout
if dropout is not None:
attention_scores = dropout(attention_scores)

# 注意力加权求和
return (attention_scores @ value), attention_scores

def forward(self, q, k, v, mask):
"""
前向传播
输入/输出: (batch_size, seq_len, d_model)
"""
# 线性投影
query = self.w_q(q) # (batch, seq_len, d_model)
key = self.w_k(k)
value = self.w_v(v)

# 分头: (batch, seq_len, d_model) → (batch, h, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

# 计算注意力
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

# 合并头: (batch, h, seq_len, d_k) → (batch, seq_len, d_model)
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

# 最终线性变换
return self.w_o(x)

掩码机制说明

掩码用于控制注意力的可见范围:

  • 在Softmax之前将不希望关注的位置设为-∞(实际使用-1e9)
  • Softmax后这些位置的权重接近0
  • 用途:隐藏padding token、防止解码器看到未来信息

残差连接(Residual Connection)

功能说明

残差连接(也称跳跃连接)将子层的输入直接加到输出上,缓解深层网络的梯度消失问题。

实现细节

1
输出 = LayerNorm(输入) → 子层 → Dropout → + 输入

注意:这里采用Pre-LN结构(先归一化再计算),与原论文的Post-LN略有不同,但训练更稳定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class ResidualConnection(nn.Module):
"""残差连接 + 层归一化 + Dropout"""

def __init__(self, features: int, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization(features)

def forward(self, x, sublayer):
"""
参数:
x: 输入
sublayer: 子层函数(如注意力层或前馈层)
"""
return x + self.dropout(sublayer(self.norm(x)))

编码器块(Encoder Block)

结构组成

每个编码器块包含两个子层:

  1. 多头自注意力 + 残差连接 + 层归一化
  2. 前馈网络 + 残差连接 + 层归一化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class EncoderBlock(nn.Module):
"""单个编码器层"""

def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock,
feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

def forward(self, x, src_mask):
"""
前向传播
参数:
x: 输入 (batch, seq_len, d_model)
src_mask: 源序列掩码(隐藏padding)
"""
# 自注意力子层(Q=K=V=x)
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
# 前馈子层
x = self.residual_connections[1](x, self.feed_forward_block)
return x

完整编码器(Encoder)

多层堆叠

将N个编码器块堆叠,前一层的输出作为下一层的输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
class Encoder(nn.Module):
"""完整编码器(N层编码器块堆叠)"""

def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)

def forward(self, x, mask):
"""逐层处理输入"""
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)

解码器块(Decoder Block)

结构组成

解码器块包含三个子层(比编码器多一个交叉注意力层):

  1. 掩码自注意力:Q=K=V来自解码器自身,使用目标掩码
  2. 交叉注意力:Q来自解码器,K和V来自编码器输出
  3. 前馈网络:与编码器相同的结构

自注意力 vs 交叉注意力

  • 自注意力:同一句子内的词相互关注(Q=K=V)
  • 交叉注意力:解码器关注编码器的输出(Q≠K,V)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class DecoderBlock(nn.Module):
"""单个解码器层"""

def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock,
cross_attention_block: MultiHeadAttentionBlock,
feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

def forward(self, x, encoder_output, src_mask, tgt_mask):
"""
前向传播
参数:
x: 解码器输入
encoder_output: 编码器输出
src_mask: 源序列掩码
tgt_mask: 目标序列掩码(因果掩码)
"""
# 掩码自注意力(Q=K=V=x)
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
# 交叉注意力(Q=x, K=V=encoder_output)
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
# 前馈网络
x = self.residual_connections[2](x, self.feed_forward_block)
return x

完整解码器(Decoder)

1
2
3
4
5
6
7
8
9
10
11
12
13
class Decoder(nn.Module):
"""完整解码器(N层解码器块堆叠)"""

def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features)

def forward(self, x, encoder_output, src_mask, tgt_mask):
"""逐层处理输入"""
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)

投影层(Projection Layer)

功能说明

将解码器的输出(d_model维)投影到词汇表空间(vocab_size维),用于预测下一个token。

1
2
3
4
5
6
7
8
9
10
11
12
13
class ProjectionLayer(nn.Module):
"""线性投影层:d_model → vocab_size"""

def __init__(self, d_model, vocab_size) -> None:
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)

def forward(self, x):
"""
输入: (batch, seq_len, d_model)
输出: (batch, seq_len, vocab_size)
"""
return self.proj(x)

完整Transformer模型

模型组装

将所有组件组装成完整的Transformer模型,包括编码器、解码器、嵌入层、位置编码和投影层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Transformer(nn.Module):
"""完整的Transformer模型"""

def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings,
tgt_embed: InputEmbeddings, src_pos: PositionalEncoding,
tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer

def encode(self, src, src_mask):
"""编码源序列"""
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src, src_mask)

def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor,
tgt: torch.Tensor, tgt_mask: torch.Tensor):
"""解码目标序列"""
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

def project(self, x):
"""投影到词汇表空间"""
return self.projection_layer(x)

模型构建函数

功能说明

根据超参数构建完整的Transformer模型,并使用Xavier初始化参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, 
tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8,
dropout: float=0.1, d_ff: int=2048) -> Transformer:
"""
构建Transformer模型

参数:
src_vocab_size: 源语言词汇表大小
tgt_vocab_size: 目标语言词汇表大小
src_seq_len: 源序列最大长度
tgt_seq_len: 目标序列最大长度
d_model: 模型维度(默认512)
N: 编码器/解码器层数(默认6)
h: 注意力头数(默认8)
dropout: dropout概率(默认0.1)
d_ff: 前馈网络中间层维度(默认2048)
"""

# 创建嵌入层
src_embed = InputEmbeddings(d_model, src_vocab_size)
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

# 创建位置编码层
src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

# 创建编码器块
encoder_blocks = []
for _ in range(N):
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
encoder_blocks.append(encoder_block)

# 创建解码器块
decoder_blocks = []
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)

# 创建编码器和解码器
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

# 创建投影层
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

# 组装Transformer
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

# 使用Xavier均匀初始化参数
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

return transformer

Transformer训练实现

本节介绍如何训练Transformer模型,包括数据加载、训练循环、验证和推理。

导入依赖库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 自定义模块
from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path, latest_weights_file_path

# PyTorch核心库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

# 数据处理
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# 工具库
import warnings
from tqdm import tqdm
import os
from pathlib import Path
import torchmetrics
from torch.utils.tensorboard import SummaryWriter

贪心解码函数

功能说明

在推理阶段使用贪心算法逐个生成目标序列的token,每次选择概率最高的token。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
"""
贪心解码:逐token生成目标序列

参数:
model: 训练好的Transformer模型
source: 源序列 (已编码)
source_mask: 源序列掩码
tokenizer_src/tgt: 源/目标语言分词器
max_len: 生成序列最大长度
device: 计算设备

返回:
生成的目标序列 (token IDs)
"""
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]')

# 预计算编码器输出(只需计算一次)
encoder_output = model.encode(source, source_mask)

# 初始化解码器输入为[SOS]
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

# 自回归生成
while True:
if decoder_input.size(1) == max_len:
break

# 创建因果掩码
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

# 解码
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

# 投影到词汇表并选择概率最高的token
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)

# 将新token添加到序列
decoder_input = torch.cat([
decoder_input,
torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
], dim=1)

# 遇到EOS则停止
if next_word == eos_idx:
break

return decoder_input.squeeze(0)

验证函数

功能说明

在验证集上评估模型性能,计算CER、WER和BLEU等指标。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, 
print_msg, global_step, writer, num_examples=2):
"""
运行模型验证

参数:
model: Transformer模型
validation_ds: 验证数据集
tokenizer_src/tgt: 分词器
max_len: 最大序列长度
device: 计算设备
print_msg: 打印函数
global_step: 当前训练步数
writer: TensorBoard写入器
num_examples: 验证样本数量
"""
model.eval()
count = 0

source_texts = []
expected = []
predicted = []

try:
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)
except:
console_width = 80

with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device)
encoder_mask = batch["encoder_mask"].to(device)

assert encoder_input.size(0) == 1, "验证时batch_size必须为1"

# 贪心解码生成预测
model_out = greedy_decode(model, encoder_input, encoder_mask,
tokenizer_src, tokenizer_tgt, max_len, device)

source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text)

# 打印示例
print_msg('-' * console_width)
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

if count == num_examples:
print_msg('-' * console_width)
break

# 计算评估指标
if writer:
# 字符错误率
metric = torchmetrics.CharErrorRate()
cer = metric(predicted, expected)
writer.add_scalar('validation cer', cer, global_step)
writer.flush()

# 词错误率
metric = torchmetrics.WordErrorRate()
wer = metric(predicted, expected)
writer.add_scalar('validation wer', wer, global_step)
writer.flush()

# BLEU分数
metric = torchmetrics.BLEUScore()
bleu = metric(predicted, expected)
writer.add_scalar('validation BLEU', bleu, global_step)
writer.flush()

数据处理函数

获取句子生成器

1
2
3
4
5
6
7
8
9
10
def get_all_sentences(ds, lang):
"""
从数据集提取指定语言的所有句子(生成器)

参数:
ds: 数据集对象
lang: 语言代码(如'en', 'fr')
"""
for item in ds:
yield item['translation'][lang]

构建或加载分词器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def get_or_build_tokenizer(config, ds, lang):
"""
获取或构建指定语言的分词器

参数:
config: 配置字典
ds: 数据集(用于训练分词器)
lang: 语言代码

返回:
Tokenizer对象
"""
tokenizer_path = Path(config['tokenizer_file'].format(lang))

if not Path.exists(tokenizer_path):
# 训练新分词器
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(
special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"],
min_frequency=2
)
tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
# 加载已有分词器
tokenizer = Tokenizer.from_file(str(tokenizer_path))

return tokenizer

加载并准备数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def get_ds(config):
"""
加载数据集并创建数据加载器

返回:
train_dataloader: 训练集加载器
val_dataloader: 验证集加载器
tokenizer_src: 源语言分词器
tokenizer_tgt: 目标语言分词器
"""
# 加载原始数据集
ds_raw = load_dataset(
f"{config['datasource']}",
f"{config['lang_src']}-{config['lang_tgt']}",
split='train'
)

# 构建分词器
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

# 划分训练集和验证集(90%/10%)
train_ds_size = int(0.9 * len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

# 创建BilingualDataset实例
train_ds = BilingualDataset(
train_ds_raw, tokenizer_src, tokenizer_tgt,
config['lang_src'], config['lang_tgt'], config['seq_len']
)
val_ds = BilingualDataset(
val_ds_raw, tokenizer_src, tokenizer_tgt,
config['lang_src'], config['lang_tgt'], config['seq_len']
)

# 统计最大句子长度
max_len_src = 0
max_len_tgt = 0
for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids))

print(f'源语言最大长度: {max_len_src}')
print(f'目标语言最大长度: {max_len_tgt}')

# 创建数据加载器
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

构建模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def get_model(config, vocab_src_len, vocab_tgt_len):
"""
构建Transformer模型

参数:
config: 配置字典
vocab_src_len: 源语言词汇表大小
vocab_tgt_len: 目标语言词汇表大小
"""
model = build_transformer(
vocab_src_len, vocab_tgt_len,
config["seq_len"], config['seq_len'],
d_model=config['d_model']
)
return model

训练主函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def train_model(config):
"""
Transformer模型训练主函数

参数:
config: 配置字典,包含所有训练超参数
"""
# 选择设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("使用设备:", device)

if device == 'cuda':
print(f"设备名称: {torch.cuda.get_device_name(device.index)}")
print(f"显存大小: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")

device = torch.device(device)

# 创建权重保存目录
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

# 加载数据集
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)

# 构建模型
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

# TensorBoard记录器
writer = SummaryWriter(config['experiment_name'])

# 优化器
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

# 加载预训练模型(如果有)
initial_epoch = 0
global_step = 0
preload = config['preload']
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None

if model_filename:
print(f'加载模型: {model_filename}')
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']
else:
print('从头开始训练')

# 损失函数(带标签平滑)
loss_fn = nn.CrossEntropyLoss(
ignore_index=tokenizer_src.token_to_id('[PAD]'),
label_smoothing=0.1
).to(device)

# 训练循环
for epoch in range(initial_epoch, config['num_epochs']):
torch.cuda.empty_cache()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch:02d}")

for batch in batch_iterator:
# 准备数据
encoder_input = batch['encoder_input'].to(device)
decoder_input = batch['decoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
decoder_mask = batch['decoder_mask'].to(device)
label = batch['label'].to(device)

# 前向传播
encoder_output = model.encode(encoder_input, encoder_mask)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
proj_output = model.project(decoder_output)

# 计算损失
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

# 记录到TensorBoard
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush()

# 反向传播
loss.backward()

# 更新参数
optimizer.step()
optimizer.zero_grad(set_to_none=True)

global_step += 1

# 每个epoch后运行验证
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt,
config['seq_len'], device, lambda msg: batch_iterator.write(msg),
global_step, writer)

# 保存模型
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename)

总结

本文详细介绍了Transformer模型的原理和PyTorch实现,主要内容包括:

核心组件

  • 输入嵌入与位置编码:将离散token转换为连续向量并注入位置信息
  • 多头注意力机制:从多个角度并行捕捉序列中的依赖关系
  • 前馈神经网络:对每个位置独立进行非线性变换
  • 层归一化与残差连接:稳定训练过程,缓解梯度消失

模型架构

  • 编码器:6层堆叠,每层包含自注意力和前馈网络
  • 解码器:6层堆叠,每层包含自注意力、交叉注意力和前馈网络
  • 投影层:将解码器输出映射到词汇表空间

训练流程

  • 数据处理:分词、批处理、掩码生成
  • 训练循环:前向传播、损失计算、反向传播、参数更新
  • 验证评估:使用CER、WER、BLEU等指标评估模型性能

Transformer的成功在于其完全基于注意力机制的架构设计,摒弃了传统的循环结构,实现了高效的并行计算和长距离依赖建模,为后续的BERT、GPT等预训练模型奠定了基础。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    optimizer.step()

# 清空梯度,节省内存
optimizer.zero_grad(set_to_none=True)

# 更新步数
global_step += 1

# 每个epoch结束后运行验证逻辑
run_validation(
model,
val_dataloader,
tokenizer_src,
tokenizer_tgt,
config['seq_len'],
device,
lambda msg: batch_iterator.write(msg), # 用 tqdm 输出日志信息
global_step,
writer
)

# 每个epoch结束后保存模型状态
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename)
1
2
3
4
if __name__ == '__main__':  # 主程序入口,确保当前脚本被直接运行时才执行以下代码
warnings.filterwarnings("ignore") # 忽略所有警告信息,保持输出界面清爽
config = get_config() # 获取训练配置参数,通常来自配置文件或定义函数
train_model(config) # 调用训练函数,开始模型训练流程

Dataset

1
2
3
4
#处理双语数据集
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import torch:导入 PyTorch 主库,用于张量运算。

import torch.nn as nn:导入神经网络模块并简写为 nn,方便后续如果需要网络层时使用。

from torch.utils.data import Dataset:从PyTorch数据工具中导入 Dataset 基类,用来构建自定义数据集。

1
2
3
4
5
6
7
8
9
10
11
12
#定义一个新的数据集类 BilingualDataset,继承自 PyTorch 的 Dataset,用于加载双语翻译对。
class BilingualDataset(Dataset):
def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
super().__init__()
self.seq_len = seq_len

self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang

init:初始化方法,接受以下参数:

  • ds:原始数据集(如从 HuggingFace Dataset 加载的翻译对)。
  • tokenizer_src / tokenizer_tgt:源语言和目标语言的分词器。
  • src_lang / tgt_lang:字符串,指明在每个数据项里使用哪个语言字段。
  • seq_len:固定的序列长度(包含特殊token)。
  • 将这些参数保存到实例属性,以便后续 __getitem__ 中使用。
1
2
3
self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

从目标分词器中获取特殊符号 [SOS][EOS][PAD] 的ID,并封装成形状为 (1,) 的整型张量,方便后面拼接。

1
2
3
#定义数据集长度
def __len__(self):
return len(self.ds)

len:返回数据集的条目数,使得 DataLoader 能够知道迭代上限。

1
2
3
4
5
#定义获取项目方法
def __getitem__(self, idx):
src_target_pair = self.ds[idx]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]

getitem:根据索引 idx 取出一条翻译对,分别抽取源语言文本 src_text 和目标语言文本 tgt_text

1
2
3
# Transform the text into tokens
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

分别对源文和目标文进行分词,得到 ID 列表 enc_input_tokensdec_input_tokens

1
2
3
4
# Add sos, eos and padding to each sentence
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
# We will only add <s>, and </s> only on the label
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

计算要补多少个 [PAD]

  • 源端要加 [SOS]+[EOS] 共2个特殊符,故剩余长度为 seq_len - 原始长度 - 2
  • 目标端的 decoder 输入只加 [SOS],故剩余长度为 seq_len - 原始长度 - 1
1
2
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")

如果算出来的补齐长度为负,说明句子太长,超过了 seq_len,直接抛错提醒。

1
2
3
4
5
6
7
8
9
10
11
#切割三个张量
# Add <s> and </s> token
encoder_input = torch.cat(
[
self.sos_token,#首先是这个句子的开头标记,
torch.tensor(enc_input_tokens, dtype=torch.int64),#然后是源文本的标记
self.eos_token,#然后是句子的结尾标记
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),#然后是足够的填充标记以达到序列长度
],
dim=0,
)

构造 Encoder 输入序列:

  1. 添加 [SOS]

  2. 源语言分词 ID;

  3. 添加 [EOS]

  4. 补齐若干个 [PAD]

    最后拼成形状 (seq_len,) 的张量。

1
2
3
4
5
6
7
8
9
# Add only <s> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)

构造Decoder输入序列:只在最前面加 [SOS] 和尾部补齐 [PAD]

1
2
3
4
5
6
7
8
9
# Add only </s> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)

构造训练目标(标签)序列:紧跟分词 ID 后面加 [EOS],再补齐。

1
2
3
4
# Double check the size of the tensors to make sure they are all seq_len long
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len

断言三者长度都等于 seq_len保证模型输入输出的一致性。

1
2
3
4
5
6
7
8
9
10
11
return {
"encoder_input": encoder_input, # (seq_len)编码器输入,是越位的序列长度
"decoder_input": decoder_input, # (seq_len)解码器输入,是一个序列长度的标记数
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
#编码器掩码,通过添加填充标记来增加编码器输入句子的大小,但是我们 我们不希望这些填充标记参与自注意力机制,所以我们需要构建一个掩码,不希望这些标记被自注意力机制看到。我们还会挤压以添加此序列维度,稍后还会添加批处理维度。然后我们将其转换为整数,因此这是一个序列长度。
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
#解码器,我们需要一个特殊的掩码,即因果掩码。意味着每个单词只能查看前面的单词,每个单词只能查看未知的填充单词,因此我们不希望填充标记参与自注意力机制,我们只希望真实的单词参与其中,并且我们也不希望每个单词都关注其后面的单词,而只关注其前面的单词,因此我将在这里使用一种称为因果掩码的方法稍后我们会构建它。
"label": label, # (seq_len)
"src_text": src_text,
"tgt_text": tgt_text,
}

返回由编码器输入组成的字典,包含:

  • encoder_inputdecoder_input:前面拼好的整型张量;
  • encoder_mask:对 encoder_input 中非 [PAD] 的位置置 1,shape 为 (1,1,seq_len),用于 self-attention。
  • decoder_mask:先对非 [PAD] 位置置 1,得到 (1,seq_len),再与 causal_mask(下述函数生成的因果遮挡矩阵)按位 AND,得到 (1,seq_len,seq_len),用于 Transformer 解码器的自回归限制。
  • label:训练用的目标序列;
  • src_texttgt_text:原始文本,方便后续打印或调试。
1
2
3
4
def causal_mask(size):
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
#这个方法将返回对角线上方的每个值,其他所有值都将变为零,所以我们想要对角线的一种类型,我们希望它是整数,我们要做的是返回掩码等于零,所以这将返回对角线上方的所有值,对角线下方的所有值都将变为零,我们实际上想要相反的结果,即零应该会在这个表达式中变为真,所有非0的值都会变为假
return mask == 0

causal_mask:生成一个上三角全 1、主对角线以上(diagonal=1)为 1 的张量,然后取反得到下三角及对角线位置为 True,上三角为 False,用于在解码时屏蔽未来位置。

这段代码实现了一个用于序列到序列(sequence‑to‑sequence)机器翻译任务的数据集类 BilingualDataset。它将原始的双语文本对:

  1. 分词 → 得到整数 ID 列表
  2. 添加特殊标记 [SOS], [EOS][PAD] → 统一成固定长度
  3. 构造注意力掩码 → Encoder 掩掉 PAD,Decoder 同时掩掉 PAD 和未来 token
  4. 返回模型所需的输入格式(包括 encoder_inputdecoder_input、注意力掩码、以及训练标签)

从而能够直接喂给基于Transformer的翻译模型进行训练或推理。

Config

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from pathlib import Path
def get_config():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 10**-4,
"seq_len": 350,
"d_model": 512,
"datasource": 'opus_books',
"lang_src": "en",
"lang_tgt": "it",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}

get_config:返回一个字典对象,包含模型训练所需的各种超参数和文件配置。

batch_size:每批次处理8条样本

num_epochs:训练轮数为20

lr:学习率设置为0.0001

seq_len:每个输入序列最大长度为350

d_model:Transformer模型的隐藏维度为512

datasource:数据源名,便于标识不同数据集(此处是 opus_books

lang_srclang_tgt:源语言和目标语言(如从英语翻译到意大利语)

model_folder:保存模型权重的文件夹(如 weights

model_basename:模型文件的前缀名(如 tmodel_5.pt

preload:加载哪个权重(”latest” 代表自动找最新的)

tokenizer_file:分词器的文件名模板

experiment_name:实验记录的路径(如TensorBoard的日志)

1
2
3
4
def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}" # 拼接成目录名
model_filename = f"{config['model_basename']}{epoch}.pt" # 拼接模型文件名
return str(Path('.') / model_folder / model_filename) # 返回完整路径字符串

get_weights_file_path:根据配置和给定的epoch数,生成当前epoch模型文件的完整路径。

1
2
3
4
5
6
7
8
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}" #权重文件所在的目录
model_filename = f"{config['model_basename']}*" #匹配所有模型文件
weights_files = list(Path(model_folder).glob(model_filename)) # 用glob匹配目录中所有模型文件
if len(weights_files) == 0: #如果没有模型文件
return None
weights_files.sort() #按名称排序
return str(weights_files[-1]) #返回最新的那个(排序最后一个)

latest_weights_file_path:查找给定目录下最新(最后一个按名字排序)的模型权重文件的完整路径。如果没有任何权重文件,则返回 None

这段代码提供了训练和管理模型的一套配置信息管理工具。主要实现了以下功能:

get_config() 函数集中定义训练参数(如batch size、学习率、语言设置、模型文件名格式等)。

提供 get_weights_file_path()latest_weights_file_path() 两个函数来动态生成模型权重文件的保存路径或加载路径,支持按照epoch命名和获取最新模型。

这个模块的设计非常适合用于训练循环中管理模型的保存和加载行为,是构建机器学习训练框架的重要一部分。