软件调优(四):避免训练不稳定

为了避免训练不稳定,需要着重注意以下要点。内容整理自 Stas Bekman 的 Machine Learning Engineering by Stas Bekman


术语表

训练相关内容中常包含大量缩写和专有名词。以下是本文涉及的一些重要概念:

  • BS:Batch Size(批次大小)—— 在这里通常指每个 GPU 上一次处理的样本数,也常称为 MBS(Micro Batch Size,微批次大小)。

  • GBS:Global Batch Size(全局批次大小)—— 指一次完整迭代(iteration)中实际参与参数更新的总样本数,通常等于所有 GPU、所有数据并行副本以及梯度累积步数共同作用后的总 batch size。

  • GAS:Gradient Accumulation Steps(梯度累积步数)—— 指在执行一次参数更新之前,需要累计多少次前向传播和反向传播的梯度。

  • TFLOPs:每秒万亿次浮点运算,用于衡量训练过程中的计算吞吐量。

  • PP:Pipeline Parallelism(流水线并行)—— 将模型的不同层切分到不同设备上,以流水线方式并行执行训练。


全局批次大小逐步提升(Global Batch Size Ramp Up)

如果计划使用非常大的 GBS 进行训练,例如 1024、2048,甚至更大,那么在训练一开始就直接使用如此大的 batch size 往往并不划算。训练初期模型参数基本仍处于随机初始化状态,对数据中的细微信息尚不敏感,因此过早使用大 batch 可能无法带来相应收益,反而会浪费数据和计算资源。为提高训练效率,实践中通常会在训练初期的一段时间内逐步增大全局批次大小

不过,初始 GBS 也不能设置得过小。过小的 GBS 会降低硬件利用率,使计算吞吐量(TFLOPs)偏低,从而拖慢整体训练速度。这一点在使用 流水线并行(PP) 时尤其明显:PP 的效率很大程度上取决于能否减少流水线中的 GPU 空闲气泡(bubble)。GBS 越小,可用于填充流水线的 micro-batch 数越少,气泡占比就越高,GPU 空闲时间也就越多。

以 BLOOM-176B 的训练为例:该模型训练使用了流水线并行。经过吞吐量基准测试后发现,如果从 GBS = 16 开始训练,速度会非常慢,吞吐量只有约 8 TFLOPs。因此,最终选择从 GBS = 192 起步,此时吞吐量约为 73 TFLOPs;随后再逐步提升到 GBS = 2048,最终吞吐量可达到约 150 TFLOPs。具体策略是:每处理 9,765,625 个样本,就将 GBS 增加 16


权重初始化标准差(std)的选择

权重初始化的标准差(std)不是固定的,必须根据隐藏层维度(hidden dimension)调整。 选错了,模型在训练初期就会崩溃。

背景故事

BLOOM 团队在训练 1040 亿参数 的预实验模型时,遇到了严重的训练不稳定问题:

  • 使用 Megatron-LM 框架的 默认初始化 std = 0.02
  • 结果:训练几千步后就崩溃(loss 爆炸、梯度消失/爆炸)
  • 排查后发现:0.02 对这个规模的模型来说太大了

两种初始化公式对比

来源 论文 公式 等价形式
Transformers without Tears arXiv:1910.05895 sqrt(2 / (NHIDDEN * 5)) sqrt(0.4000 / NHIDDEN)
530B 模型训练实践 arXiv:2201.11990 sqrt(1 / (NHIDDEN * 3)) sqrt(0.3333 / NHIDDEN)

BLOOM 选择了第二个(530B 的公式),因为它给出的初始化值更小、更保守。其中, NHIDDEN 是隐藏层维度。

因此,当 NHIDDEN = 14336 时,计算结果为 $\sqrt{1/(14336 \times 3)} = 0.00482$,这就是实际采用的标准差的值。作者认为,这当然不是 BLOOM-176B 训练过程中没有出现稳定性问题的唯一原因,但它是关键因素之一。

结论

虽然这不是 BLOOM-176B 训练稳定的唯一原因,但作者认为它是关键因素之一

这本质上是在控制初始权重的幅度:太大的初始化会让前向传播初期的激活值爆炸,太小则可能导致梯度消失。大模型的深度和宽度放大了这个问题。


避免数值溢出

在 FP16/BF16 混合精度训练中,Q @ K^T 可能先产生很大的中间结果;如果缩放因子 1 / self.norm_factor[1] 是在矩阵乘法之后才乘上去,那么中间结果可能已经溢出成 inf 或产生不稳定值,后面的缩放已经救不回来。

所以,纠正做法是:不要等 Q @ K^T 算完后再整体乘缩放因子,而是把缩放因子提前拆到 QK 上,让它们先变小,再做矩阵乘法。即:

(Q / sqrt(norm_factor)) @ (K^T / sqrt(norm_factor))
= (Q @ K^T) / norm_factor

这样可以有效避免数值溢出。

评论