Stable and low-precision training for large-scale vision-language models
文章标题: 稳定且低精度的大规模视觉语言模型训练
作者/机构: Mitchell Wortsman1, Tim Dettmers1, Luke Zettlemoyer12, Ari Morcos†2, Ali Farhadi†1, Ludwig Schmidt†134
A1 主要贡献
本文旨在解决大规模语言-视觉模型训练中的两个关键问题:训练速度和训练稳定性,并以对比语言-图像预训练(CLIP)模型为研究背景。
1. 提升训练速度(快速训练):
- 核心贡献: 提出了 SwitchBack,一个用于 int8 量化训练的线性层。
- 性能表现: 在训练10亿参数的 CLIP ViT-Huge 模型时,SwitchBack 相比 bfloat16 训练的性能差距在0.1个百分点以内,但端到端训练速度提升了13-25%。这是迄今为止规模最大的 int8 训练。
- 设计原理: SwitchBack 的设计基于一个观察:矩阵乘法中的量化噪声会随着内积维度的增大而增长。在CLIP训练中,权重梯度的计算涉及一个非常大的内积维度(批次大小 × 序列长度)。因此,SwitchBack 对前向传播和输入梯度计算使用 int8 矩阵乘法,而对权重梯度计算则“切换回”(Switch Back)16位精度,从而在保持精度的同时获得加速。该方法性能显著优于 LLM.int8()。
- Float8(fp8)训练探索: 尽管当前支持fp8的硬件稀少,作者通过模拟研究了fp8训练。SwitchBack 对fp8同样有效。此外,研究发现,如果通过特定方法(如使用零初始化的 layer-scale)抑制网络中特征的幅度,标准的fp8量化技术也能成功训练,避免了基线方法在大模型上发散的问题。
2. 提升训练稳定性(稳定训练):
- 核心观察: 分析发现,训练过程中的损失尖峰(loss spikes)总是在 AdamW 优化器的二阶矩估计量低估了梯度的平方值之后的1-8次迭代内发生。这种情况在 patch embedding 层尤为明显。
- 核心贡献: 提出了一种 AdamW-AdaFactor 混合优化器,命名为 StableAdamW。
- 设计原理: StableAdamW 借鉴了 AdaFactor 中的“更新裁剪”(update clipping)技术。该技术会追踪梯度平方与二阶矩估计量之间的平均比率,当该比率过大时(意味着二阶矩估计量已过时),便会降低学习率,从而避免因更新步长过大导致的损失尖峰。
- 性能表现: 在 CLIP ViT-Huge 模型的训练中,StableAdamW 成功消除了不稳定性,并且在测试的规模上性能优于梯度裁剪(gradient clipping)方法。
A3 背景知识与关键观察
2.1 预备知识与相关工作
16位训练格式与混合精度。当今神经网络训练通常使用16位操作,格式为 float16 或 bfloat16 【39,【63,Bfloat16: The secret to high performance on cloud tpus,2019,https://cloud.google. com/blog/products/ai-machine-learning/ bfloat16-the-secret-to-high-performance-on-c】】。浮点格式使用一部分位表示指数,其余表示小数(尾数)。float16 使用5位指数,而 bfloat16 使用8位,因此 bfloat16 的表示范围更广。多数浮点格式还支持非规格化数,以实现“软下溢”。为防止下溢,float16 混合精度训练【39,【39,Mixed precision training,2017,arXiv:1710.03740】】被开发出来:将一个小批次的损失乘以一个损失缩放因子(loss scalar),反向传播后使梯度落入 fp16 的可表示范围内;在优化器更新 fp32 主权重之前,通过重新缩放权重梯度来撤销此操作。在 PyTorch 【43,【43,Pytorch: An imperative style, highperformance deep learning library,2019,NeurIPS】】中,损失缩放因子初始值为65536,每次遇到 Inf/NaN 时更新会跳过且缩放因子减半,若连续2k次迭代无 Inf/NaN,则缩放因子加倍。
16位训练的局限性。在 float16 训练中,当损失缩放因子变得过低时,损失会缓慢发散。Cherti 等人【9,【9,Reproducible scaling laws for contrastive language-image learning,2022,arXiv:2212.07143】】在训练 ViT-Huge CLIP 模型时观察到此现象,并通过切换到 bfloat16 解决了问题。类似地,OPT【74,【74,Opt: Open pre-trained transformer language models,2022,arXiv:2205.01068】】和 BLOOM【52,【52,Bloom: A 176b-parameter openaccess multilingual language model,2022,arXiv:2211.05100】】模型的大规模训练也遇到了 float16 的问题,而 bfloat16 则表现稳定。然而,bfloat16 的快速支持仅限于 TPU 或 NVIDIA Ampere 系列(2021年或之后)的 GPU。
8位训练的现状。虽然16位训练是当前标准,但支持8位操作的硬件日益普遍。Hopper GPU 支持 float8 (fp8) 【40,【40,Fp8 formats for deep learning,2022,arXiv:2209.05433】】,Ampere GPU 支持 int8。但目前(2023年)获取 Hopper GPU 非常困难。尽管 int8 和 int4 已用于推理【17,【66,Smoothquant: Accurate and efficient post-training quantization for large language models,2022,arXiv:2211.10438】,【16,The case for 4-bit precision: k-bit inference scaling laws,2022,arXiv:2212.09720】】,并且已有早期工作探索了在卷积网络上进行8位训练【61,【79,Towards unified int8 training for convolutional neural network,2020,CVPR】,【10,Dkm: Differentiable k-means clustering layer for neural network compression,2021,arXiv:2108.12659】】,但这些格式在 Transformer 模型的大规模训练中并不常用。本文训练的 CLIP ViT-Huge 模型拥有10亿参数,比标准 ResNet-50(2300万参数)【28,【28,Deep residual learning for image recognition,2016,CVPR】】大40倍,而对大张量进行量化更具挑战性【17,【17,Llm. int8 (): 8-bit matrix multiplication for transformers at scale,2022,arXiv:2208.07339】】。更多关于大规模模型量化和低精度训练的相关工作见附录A。
量化技术。在8位精度的矩阵乘法中,本文使用了量化技术。作者发现,i) 对输入和梯度使用行式量化(row-wise quantization)【31,【31,Fbgemm: Enabling high-performance low-precision deep learning inference,2021,arXiv:2101.05615】】,以及 ii) 对权重使用张量式量化(tensor-wise quantization)是在简单性和性能之间的最佳权衡。以 int8(表示范围-127到127)为例,对于一个矩阵 $X$,其行为 $x_1, ..., x_b$,行式量化 $Q_{row}$ 定义如下:
而张量式量化 $Q_{tensor}$ 定义为:
其中 absmax
是绝对值的最大值。在应用 $Q_{row}$ 时,会保存每行的绝对值最大值用于后续的反量化,这被称为量化状态,即 $state_{row}(X) = [\text{absmax}(x_1), ..., \text{absmax}(x_b)]^T \in R^{b \times 1}$。对于张量式量化,只需存储整个张量的绝对值最大值,$state_{tensor}(X) = \text{absmax}(X) \in R$。
3.1 稳定性问题的预备知识与相关工作
损失尖峰问题。在扩大模型规模时,可能会出现损失尖峰【8,【25,Intriguing properties of transformer training instabilities,To appear】,【14,Scaling vision transformers to 22 billion parameters,2023,arXiv:2302.05442】,【69,Stabilizing transformer training by preventing attention entropy collapse,2023,arXiv:2303.06296】,【71,Sigmoid loss for language image pre-training,2023,arXiv:2303.15343】,【54,Adafactor: Adaptive learning rates with sublinear memory cost,2018,ICML】,【74,Opt: Open pre-trained transformer language models,2022,arXiv:2205.01068】】。这些不稳定性会减慢学习速度,甚至使训练完全失稳。已有的解决方案包括冻结嵌入层【8,【8,An empirical study of training self-supervised vision transformers,2021,ICCV】】、添加额外的层归一化【14,【25,Intriguing properties of transformer training instabilities,To appear】】或重参数化权重【69,【69,Stabilizing transformer training by preventing attention entropy collapse,2023,arXiv:2303.06296】】。
本文研究的稳定性问题。本文研究的是 CLIP 训练期间出现的不稳定性,与【14,【69,Stabilizing transformer training by preventing attention entropy collapse,2023,arXiv:2303.06296】】中导致缓慢发散的不稳定性不同,本文关注的是快速的损失尖峰。研究结果表明,这些尖峰是由于早期层的二阶矩估计器过时所致。
与AdaFactor的关系。本文的分析和方法直接建立在 Shazeer 和 Stern【54,【54,Adafactor: Adaptive learning rates with sublinear memory cost,2018,ICML】】 (AdaFactor) 的工作之上,但也存在重要差异。与他们只在没有预热(warmup)时观察到不稳定性不同,本文即使在有很长的预热期时仍观察到不稳定性。此外,本文发现二阶矩估计器过时主要影响(patch)嵌入层,并量化了该事件对损失尖峰的预测能力。最后,作者指出,研究人员已不再将原始形式的 AdaFactor 用于大规模训练【47,【11,Palm: Scaling language modeling with pathways,2022,arXiv:2204.02311】,【70,Scaling vision transformers,2021,https://arxiv.org/abs/2106.04560】】,因其性能不及 AdamW【47,【47,Scaling language models: Methods, analysis & insights from training gopher,2021,arXiv:2112.11446】】。这可能是由于其分解式二阶矩或缺少一阶矩。因此,本文重点关注当前 transformer 训练的标准优化器 AdamW【37,【37,Decoupled weight decay regularization,2019,ICLR】】。
A2 方法细节
2.2 SwitchBack
2.2.1 方法
SwitchBack概览。一个标准的线性层包含三次矩阵乘法:一次在前向传播中计算输出,两次在反向传播中计算输入和权重的梯度。SwitchBack 层对前两次矩阵乘法使用8位精度,但对权重梯度的计算则“切换回”更高的精度。
高精度计算权重梯度的原因。之所以在计算权重梯度时使用更高精度,是因为这次矩阵乘法涉及的向量点积长度为“批次大小 × 序列长度”。由于 CLIP 训练需要大批次【46,【44,Combined scaling for zeroshot transfer learning,2021,https://arxiv.org/ abs/2111.10050】】,这个内积维度远大于其他矩阵乘法。如附录C所示,量化引入的方差会随内积维度的增大而增加。这一修改是 SwitchBack 与 LLM.int8()【17,【17,Llm. int8 (): 8-bit matrix multiplication for transformers at scale,2022,arXiv:2208.07339】】的关键区别,使得 SwitchBack 能够匹配 bfloat16 基线的性能(图1)。
class SwitchBackMatmul(autograd.Function):
@staticmethod
def forward(ctx, X, W):
# X [b, n] 输入
# W [n, m] 权重
# 在ctx中保存张量
ctx.save_for_backward = X, W
X_int8, state_X = row-wise_quantize(X)
W_int8, state_W = tensor-wise_quantize(W)
# 返回输出
return matmul_int8_and_dequanitze(
X_int8, W_int8.t(), state_X, state_W
)
@staticmethod
def backward(ctx, G):
# G [b, m] 输出的梯度
# ...
# 仅对 X_gradient 使用 8bit matmul
X_gradient = matmul_int8_and_dequanitze(
G_int8, W_int8.t(), state_X, state_W
)
W_gradient = matmul_fp16(G.t(), X)
return X_gradient, W_gradient
class SwitchBackLinear(nn.Linear):
def forward(self, X):
return SwitchBackMatmul.apply(X, self.weight)
符号表示。一个标准线性层由输入 $X \in R^{b \times n}$,权重 $W \in R^{m \times n}$ 和输出 $Y \in R^{b \times m}$ 组成。前向传播中,输出计算为 $Y = XW^T$。反向传播中,该层接收损失对 $Y$ 的梯度,记为 $\dot{Y}$。然后,对输入的梯度 $\dot{X}$ 通过 $\dot{X} = \dot{Y}W$ 计算,对权重的梯度 $\dot{W}$ 通过 $\dot{W} = \dot{Y}^T X$ 计算。对于 Transformer【60,【60,Attention is all you need,2017,NeurIPS】】中的线性层,$b$ 是批次大小乘以序列长度,而 $n$ 和 $m$ 是嵌入维度的较小倍数。
量化与反量化细节。由于只有矩阵乘法在 int8 精度下进行,需要将输出反量化回原始的浮点精度。因此,带量化和反量化的前向传播过程变为:
其中 $*$ 表示逐元素乘法,此处使用广播机制,使得矩阵 $Q_{row}(X)Q_{tensor}(W)^T$ 的第 $i$ 行与 $state_{row}(X)$ 的第 $i$ 个元素相乘。如前所述,本文对输入和梯度使用行式量化,对权重使用张量式量化。作者发现对两个矩阵都使用行式量化会增加复杂性,但性能提升可忽略不计。
硬件相关的实现细节。本文使用的 NVIDIA GPU 不直接实现 int8/float8 的 $AB$ 矩阵乘法,只实现了 $AB^T$。因此,在反向传播中需要转置权重矩阵。为了减少转置和量化的开销,作者将这两个操作融合,即一次性将所需数据从慢速的 DRAM 加载到快速的 SRAM/共享内存,然后在这个缓存的内存中执行这两个操作,这对实现加速至关重要。作者称此操作为“张量式量化转置”(tensor-wise quantize transpose)。综合以上各点,最终得到算法1。
SwitchBack的变体。除了算法1所描述的最直接的版本,本文还提出了两个替代版本——SwitchBackM 和 SwitchBackQ,并通过 bitsandbytes 库【18,【18,8-bit optimizers via block-wise quantization,2022,ICLR】】发布了所有三个版本的 Triton【57,【57,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】】实现。附录B包含伪代码。SwitchBackM(算法3)是 SwitchBack 的内存高效版本,它只为反向传播保存8位张量,推荐在内存受限时使用,其微小缺点是在反向传播中需要额外的反量化操作,增加了运行时开销。对于 CLIP ViT-Huge,SwitchBack 和 SwitchBackM 之间的精度差异可忽略。此外,还提出了 SwitchBackQ(算法4),它对权重使用行式和列式量化而非张量式量化。尽管在本文考虑的规模下未观察到精度提升,但在更大规模下可能表现更优。SwitchBackQ 的前向传播公式为:
其中 $*$ 是逐元素乘积。同样,算法4中函数名后附加的 "transpose" 表示该操作与转置融合。
Float8(fp8)的实现。尽管以上解释以 int8 为例,但 SwitchBack 的 float8 (fp8) 代码几乎完全相同。唯一的修改是将 round(127x/absmax(x))
等操作替换为 float8cast(x/absmax(x))
。作者通过 bitsandbytes 库模拟 float8cast
,将值舍入到 float8 数据类型的精确值。这种模拟改进了【40,【40,Fp8 formats for deep learning,2022,arXiv:2209.05433】】的模拟方法,后者仅将输入张量裁剪到 float8 的可表示范围内。理论上,这种模拟与真实的 float8 训练一致,但由于缺乏支持 float8 运算的硬件,作者无法进行真实训练,而是在16位精度下使用精确的 float8 值进行算术运算。对于 int8 实验,作者在 A100 GPU 上进行了真实的 int8 乘法,没有任何模拟。
2.3 通过减小特征幅度进行Float8训练
不使用SwitchBack的FP8训练方法。研究发现,SwitchBack 对于高精度的 int8 训练是必要的。然而,本节开发了其他干预措施,使得在不使用 SwitchBack 的情况下也能进行 float8 训练。研究表明,只要网络在初始化和训练时被设计成抑制大的特征幅度,就可以通过对输入、权重和梯度都使用张量式量化来实现高精度的 float8 训练。作者通过使用初始化为零的 layer-scale【58,【58,Going deeper with image transformers,2021,ICCV】】来实现这一点。
模拟FP8训练的设置。作者使用 bitsandbytes 库【18,【18,8-bit optimizers via block-wise quantization,2022,ICLR】】来模拟 float8 训练,采用了 Micikevicius 等人【40,【40,Fp8 formats for deep learning,2022,arXiv:2209.05433】】的 fp8 类型。对输入、权重和梯度均使用张量式量化,使得所有操作都在模拟的 float8 中进行。在模拟中,每个值仅用 float8 可表示的精确值来表示,但计算在 float16 精度下进行。作者认为张量式量化近似于完全移除量化操作,因为如附录B.2(图14)所示,这些张量的最大值倾向于平滑演变,因此使用移动平均值作为最大值直接在矩阵乘法中进行除法,与张量式量化相似。
Layer-scale方法。Layer-scale 由 Touvron 等人【58,【58,Going deeper with image transformers,2021,ICCV】】提出,它将每个自注意力块和 MLP 块输出的隐藏状态乘以一个形状为 embed_dim
的可学习向量。一个带有 layer-scale 张量 $\gamma_1$ 和 $\gamma_2$ 的 pre-norm Transformer 块定义为:
其中 $*$ 是广播的逐元素乘法。通常,层被初始化以近似保持其输入的方差,且输入近似为单位方差【26,【27,Delving deep into rectifiers: Surpassing human-level performance on imagenet classification,2015,ICCV】】。然而,当与残差连接结合时,这可能导致更深网络中范数更高。因此,研究人员提出了解决此问题的初始化和缩放方案【1,【72,Fixup initialization: Residual learning without normalization,2019,arXiv:1901.09321】,【4,High-performance large-scale image recognition without normalization,2021,ICML】,【19,Effective theory of transformers at initialization,2023,arXiv:2304.02034】】。初始化为0的 Layer-scale 就是其中一种方案——在初始化时,Transformer 是一个恒等函数。虽然 $\gamma_1, \gamma_2$ 通常被初始化为 $10^{-4}$ 或 $10^{-6}$ 的向量,但为简单起见,本文使用0。
FP8训练结果。模拟 fp8 训练的结果如图5(左)所示(ViT-Large)。研究发现,除了使用初始化为零的 layer-scale 的情况外,所有 fp8 运行都发散了。具体来说,图5比较了:i) 使用 bfloat16 训练的基线;ii) 使用带张量式量化的 fp8 且无其他修改,其缓慢发散;iii) 在 ii) 的基础上添加梯度裁剪,也发散;iv) 在 ii) 的基础上添加 KQ layernorm【14,【14,Scaling vision transformers to 22 billion parameters,2023,arXiv:2302.05442】】,也发散;v) 使用零初始化的 layerscale,训练未发散。尽管 fp8 和 bfloat16 训练之间仍有差异,但这主要是由于 layerscale 本身。此外,作者相信通过超参数调整,layerscale 可以在精度上匹配标准训练。
3.3 损失尖峰随模型尺寸、批量和学习率的增加而增多
关键观察。本节研究了损失尖峰的出现如何随模型尺寸、批量大小和学习率的变化而变化。研究发现,当增加模型尺寸(图6)、批量大小(图7)或学习率(图8)时,损失尖峰会增多。然而,通过降低 AdamW 中的 $\beta_2$ 超参数可以避免损失尖峰。另一方面,如果 $\beta_2$ 降得太低,学习速度会减慢,导致性能下降【50,【50,On the convergence of adam and beyond,2019,arXiv:1904.09237】】。
3.4 关于 $\beta_2$ 和过时的二阶矩估计器
“卡在过去”的情景。像 AdaGrad【22,【22,Adaptive subgradient methods for online learning and stochastic optimization,2011,JMLR】】、Adam【33,【33,Adam: A method for stochastic optimization,2014,ICLR】】或 AdaFactor【54,【54,Adafactor: Adaptive learning rates with sublinear memory cost,2018,ICML】】这样的自适应优化器,会为每个参数不同地缩放更新量,这通常被概念化为逐参数的学习率。例如,在 Adam/AdamW 中,逐参数的更新量会通过平方梯度的指数移动平均值的逆平方根进行缩放。这种自适应性在加速训练方面非常有用,但当学习信号改变时也可能引发问题。具体来说,指数移动平均值可能会变得过时,导致更新被一个过大的值缩放。如算法2所示,$u_t$ 表示平方梯度 $g_t^2$ 的指数移动平均(EMA),更新步长按 $1/\sqrt{u_t+\epsilon}$ 缩放。当学习信号改变时,如果某些参数的历史梯度幅度很小,导致 $1/\sqrt{u_t+\epsilon}$ 很大,而这些参数在迭代 $t$ 时突然接收到大的梯度信号,更新步长可能会灾难性地增大。这个场景被称为“卡在过去”(stuck-in-the-past)。如果 $\beta_2$ 太小,收敛可能变慢;如果 $\beta_2$ 太大,$u_t$ 可能会过时,导致过大的参数更新。
量化“卡在过去”的情景。为了量化上述情景并寻找其与损失尖峰的预测关系,本文遵循 Shazeer 和 Stern【54,【54,Adafactor: Adaptive learning rates with sublinear memory cost,2018,ICML】】的方法,测量均方根量 $RMS_t = \sqrt{E[g_t^2/u_t]}$。如果 $u_t$ 是 $g_t^2$ 的一个好估计,那么 $RMS_t$ 会在1附近。$RMS_t \gg 1$ 则对应“卡在过去”的情景。
预测关系。图9展示了视觉 Transformer 的 patch 嵌入层(visual.conv1.weight)的 $RMS_t$。图中揭示了几个重要发现:i) 损失尖峰倾向于在 RMS 尖峰出现后1-8次迭代发生;ii) 损失尖峰会减慢学习,因为需要恢复时间;iii) 对于较低的 $\beta_2$,$RMS_t$ 始终保持在1附近。这表明 RMS 尖峰可以预测损失尖峰。附录D中的图16至21进一步详细探讨了这种预测关系,发现在检测到的30个损失尖峰中,有28个是在 RMS 尖峰后1-8次迭代发生的,而随机发生的概率仅为1%。
3.5 StableAdamW:带有AdaFactor更新裁剪的AdamW
方法。为了稳定训练,AdaFactor 优化器将第 $t$ 次迭代的学习率除以 $1/\max(RMS_t, 1)$,这种方法被称为“更新裁剪”(update clipping)。其效果是在 $u_t$ 不再是 $g_t^2$ 的良好估计时减慢训练速度。鉴于本文的稳定性问题也源于过时的 $u_t$,作者将更新裁剪技术移植到标准的 AdamW 优化器中,并将得到的 AdamW-AdaFactor 混合体称为 StableAdamW(算法2)。一个修改是,为了实现方便,为每个张量独立计算 $RMS_t$ 并除以 $\max(RMS_t, 1)$。
v0, u0 = 0
for t = 1 to T do
gt = ∇f (θt)
// 应用校正项以去偏移动平均
βˆ1 = β1 · (1−βt−1_1) / (1 − βt_1)
βˆ2 = β2 · (1−βt−1_2) / (1−βt_2)
// 更新移动平均
vt = βˆ1*vt−1 + (1 − βˆ1)*gt
ut = βˆ2*ut−1 + (1 − βˆ2)*g2t
// 为实现方便,以下步骤对每个张量独立进行
RMSt = sqrt(E [g2t /ut])
// 更新参数
ηt = αt/max (1, RMSt)
θt = θt−1 − ηt*λ*θt−1 − ηt*vt / (sqrt(ut) + ϵ)
结果。图10展示了 StableAdamW 与其他稳定性干预措施(如梯度裁剪或降低 $\beta_2$)的比较结果。对于 CLIP ViT-Huge,StableAdamW(即 AdamW + 更新裁剪)的性能优于其他干预措施。虽然梯度裁剪和更新裁剪都消除了不稳定性,但更新裁剪在 ImageNet 零样本准确率方面表现更佳。使用更新或梯度裁剪时,较高的 $\beta_2$(如0.99)往往表现更好。
A7 补充细节
3.6 损失尖峰和损失缩放因子
低精度训练与稳定性的联系。本节将低精度训练的结果与对稳定性的研究联系起来。研究发现,损失尖峰可能与大的激活值和梯度同时发生,而大的激活值和梯度由于表示范围有限,可能在低精度训练中引发问题。因此,减少损失尖峰是成功进行低精度训练的重要一步。
数据支持。图11展示了支持数据,其中一个 RMS 尖峰先于一个损失尖峰,后者又与激活值(特征)和梯度的尖峰同时发生。正如之前所见(图5),高特征幅度对低精度训练构成挑战。此外,梯度的尖峰非常大,以至于出现 Inf/NaN 值,导致损失缩放因子【40,【40,Fp8 formats for deep learning,2022,arXiv:2209.05433】】多次下降。
结论与启示。
1. 减少损失尖峰对实现低精度训练至关重要。
2. 梯度幅度的尖峰可能是暂时的,因此我们可能过于频繁地调整损失缩放因子。
3. Inf/NaN 经常发生在特定层(如 patch 嵌入层)。如果因为单个层的 Inf/NaN 而跳过整个网络的更新,效率会很低。
本文使用的损失缩放器。基于以上观察,本文在需要时(除图11外)使用了一个自定义的损失缩放器:i) 在单个张量级别检查 Inf/NaN 并仅跳过该张量的更新,而非全局跳过;ii) 在其初始值上保持固定。这个缩放器使得 CLIP 模型在 ViT-Huge 规模上能够进行 fp16 混合精度训练,而之前的方法中缩放器会变得过低导致训练发散【9,【9,Reproducible scaling laws for contrastive language-image learning,2022,arXiv:2212.07143】】。
A4 实验环境
- 数据集:
- 训练: LAION-2B 【53,【53,Laion-5b: An open largescale dataset for training next generation image-text models,2022,arXiv:2210.08402】】。为便于多次实验,仅使用了3.8亿图像样本进行训练。
- 评估: ImageNet 【15,【15,Imagenet: A large-scale hierarchical image database,2009,CVPR】】(零样本评估),使用 CLIP 的80个提示模板。
- 模型架构:
- CLIP 【46,【46,Learning transferable visual models from natural language supervision,2021,ICML】】 视觉 Transformer (ViT) 【20,【20,An image is worth 16x16 words: Transformers for image recognition at scale,2021,ICLR】】模型,规模包括 ViT-Base、ViT-Large 和 ViT-Huge(10亿参数)。
- 硬件配置:
- GPU: NVIDIA A100 GPU。端到端速度测试在单节点(4x A100)上进行。
- 软件配置:
- 框架与库: PyTorch 【43,【43,Pytorch: An imperative style, highperformance deep learning library,2019,NeurIPS】】, OpenCLIP 【29,【29,Openclip,2021,https: //doi.org/10.5281/zenodo.5143773】】, Triton 【57,【57,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】】 (用于自定义内核), bitsandbytes 【18,【18,8-bit optimizers via block-wise quantization,2022,ICLR】】。
- 训练超参数:
- 批量大小: 16384 (单GPU 256)。
- 迭代次数: 20k次 (前5k次线性预热,后15k次余弦衰减)。
- 学习率: 2e-3。
- 权重衰减: 0.2。
- 优化器: StableAdamW (Section 3.5)。
- 正则化: Patch-dropout 0.5 【35,【35,Scaling languageimage pre-training via masking,2022,arXiv:2212.00794】】。
A4 实验结果
1. SwitchBack 精度测试 (int8 & float8)
- 实验内容: 训练不同规模的 CLIP ViT 模型,比较 SwitchBack、bfloat16 基线以及其他低精度方法(LLM.int8() 或 tensor-wise fp8)的零样本 ImageNet 准确率。
- 实验结果:
- int8 (图1, 2): 对于 ViT-Huge 模型,SwitchBack 的准确率与 bfloat16 基线仅相差0.1个百分点。而 LLM.int8() 的准确率下降了5.9个百分点。
- float8 (图1, 2): 对于 ViT-Huge 模型,SwitchBack 同样与 bfloat16 性能相当(-0.1个百分点),而使用张量级量化的基线方法则训练发散。
- 分析结论: SwitchBack 能够在 int8 和 fp8 两种低精度设置下,成功地以接近16位基线的精度训练大规模 CLIP 模型,显著优于现有方法。
2. SwitchBack 速度测试
- 实验内容: 对 SwitchBack 线性层的各个操作(矩阵乘法、量化)进行性能剖析,并测量替换所有 Transformer 线性层后完整的端到端训练速度。
- 实验结果:
- 操作剖析 (图3): int8 矩阵乘法(已融合反量化)的耗时仅为标准 fp16 矩阵乘法的一半多。量化操作的耗时比矩阵乘法低一个数量级。
- 端到端速度 (图4): 对于 ViT-Huge,SwitchBack 相比标准的 PyTorch nn.Linear
层实现了13%的端到端加速,相比自定义的 autograd
线性层实现了25%的加速。
- 与LLM.int8()对比 (图13): 在测试的规模下,LLM.int8() 未提供加速,而 SwitchBack 提供了显著加速。
- 分析结论: SwitchBack 能够提供显著的实际训练加速,量化开销占比较小且随维度增大而减小。
3. 通过减小特征幅度进行 Float8 训练
- 实验内容: 在 ViT-Large 模型上测试模拟 fp8 训练,比较不同干预措施(无、梯度裁剪、KQ layernorm、零初始化 layerscale)的效果。
- 实验结果 (图5): 只有使用零初始化 layerscale 的方法成功训练且不发散。图5(右)显示,该方法有效抑制了模型深层特征幅度(激活值)的增长。
- 分析结论: 对于 fp8 训练,通过零初始化 layerscale 控制特征幅度是一种有效的替代方案,可以避免训练发散,这为低精度训练提供了除 SwitchBack 之外的另一条路径。
4. 训练稳定性分析
- 实验内容:
- 研究模型大小、批量大小、学习率和 AdamW $\beta_2$ 对损失尖峰的影响。
- 追踪 patch 嵌入层的 $RMS_t$ 指标与损失尖峰的关联。
- 比较 StableAdamW、梯度裁剪和降低 $\beta_2$ 对 ViT-Huge 训练稳定性和最终性能的影响。
- 实验结果:
- 影响因素 (图6, 7, 8): 损失尖峰随着模型大小、批量大小和学习率的增加而增多。降低 $\beta_2$ 可以缓解尖峰,但过低会损害性能。
- 预测关系 (图9, 16-21): 损失尖峰大概率在 patch 嵌入层的 $RMS_t$ 出现尖峰后的1-8次迭代内发生,表明二阶矩估计器过时是损失尖峰的直接前兆。
- 干预措施比较 (图10): StableAdamW 不仅能消除不稳定性,而且在零样本准确率上优于梯度裁剪。
- 分析结论: 损失尖峰的根本原因在于 AdamW 的二阶矩估计器过时,通过引入 AdaFactor 的更新裁剪机制(即 StableAdamW)是解决该问题的有效且性能更优的方法。
A5 结论
本文分享了在加速和稳定大规模多模态模型训练方面的实验,这些经验对社区具有实用价值。作者发布了 Triton 内核等资源,以便于他人在此工作基础上进行构建和改进。
主要局限性:
1. 工作并非详尽无遗,例如,float8 训练仅通过模拟进行。
2. 未研究初始化和训练超参数的宽度缩放(width scalings)对稳定性的影响。
3. 由于计算预算有限,训练运行时间较短,因此生成的模型检查点最终准确率不高。
尽管存在局限性,一个积极的方面是,这些早期的探索为训练出迄今为止准确率最高的开源 CLIP 模型【64,【64,Reaching 80% accuracy with openclip,2023,https://laion.ai/blog/ giant-openclip/】】提供了超参数的参考。
A6 附录
B. 附加代码和图表
B.1 附加代码。本节提供了额外的伪代码,包括:
- 算法3:SwitchBack 的内存高效变体。
- 算法4:对权重使用行式和列式量化的 SwitchBack 变体。
- 算法5:使用 torch.autograd
实现的标准线性层。
这些实现可在 bitsandbytes
库中找到,并提供了在 OpenCLIP 中使用的命令行参数。
B.2 附加图表。本节提供了附加图表:
- 图12:图3的更细粒度版本,展示了各操作的性能剖析。
- 图13:比较了 SwitchBack 和 LLM.int8() 的端到端训练加速效果。
- 图14:展示了训练过程中 MLP 权重梯度和 Transformer 块输出的均值和最大值。
- 图15:表明使用 $1 - \text{iteration}^{-\lambda}$ 形式的 $\beta_2$ 调度并未提高准确率。
C. 分析
量化误差与内积维度的关系。本节分析表明,量化导致的误差会随着矩阵乘法 $UV$($U \in R^{n \times k}, V \in R^{k \times m}$)的内积维度 $k$ 的增大而增加。这解释了为什么 SwitchBack 能获得高精度,因为它避免了对 $k$ 非常大的矩阵乘法(即权重梯度计算)进行量化。
C.1 量化内积的方差分析。
- 模型假设:将两个长度为 $k$ 的向量 $u$ 和 $v$ 的量化版本建模为 $\hat{u}_i = u_i + \epsilon_i$ 和 $\hat{v}_i = v_i + \xi_i$,其中 $\epsilon_i, \xi_i$ 是均值为0、方差为 $\sigma_q^2$ 的量化误差。
- 推导:通过计算,量化后内积 $\langle \hat{u}, \hat{v} \rangle$ 的方差为:
这个结果表明,由量化引起的方差随内积维度 $k$ 线性增长。
C.2 结论。由于量化导致的方差随内积维度 $k$ 增加,这会使得输出单元/特征的噪声更大。噪声在网络中逐层累积,最终会淹没有效信号,导致学习失效。
C.3 为什么LLM.int8()会失败:LLM vs CLIP 模型。
- 核心假设:基于上述分析,不同模型的量化失败模式可能不同。
- LLM vs CLIP 的关键差异:
- LLM(如GPT-3, LLaMA):权重矩阵相对较大,而批次大小相对较小。最嘈杂的操作是前向传播 $XW^T$ 和层间反向传播 $\dot{Y}_k W_k$,其内积维度由嵌入维度决定(例如 LLaMA 65B 中为8192或32768)。权重梯度的内积维度由单GPU批次大小决定(LLaMA 中为2048)。
- CLIP:为了获得最佳性能,批次大小非常大。对于 ViT-Huge,单GPU的有效批次大小(patches数量)可达65536。权重矩阵维度为1280×5120。
- 分析:因此,在 CLIP 训练中,权重梯度计算(内积维度为65536)是噪声最大的操作,其噪声比前向传播(内积维度为5120)大得多。相比之下,LLM 中最嘈杂的操作是前向传播。这解释了为什么 LLM.int8()(对所有矩阵乘法都使用int8)在 CLIP 训练中失败,因为其最关键的权重梯度计算被高度噪声污染,使得 SGD 无法有效学习。SwitchBack 通过对该特定操作使用高精度来解决此问题。
D. RMS 尖峰先于损失尖峰
分析方法。本节通过定义一个启发式规则来量化和分析 RMS 尖峰与损失尖峰之间的预测关系。
- RMS 尖峰定义: $RMS_t \ge 2.3$。
- 损失尖峰定义: 时刻 $t$ 的损失超过其滑动平均值3.2倍的滑动标准差。
- 数据去重: 10次迭代内的多个尖峰只计为一次。
分析结果。
- ViT-Huge (图16): 在15个检测到的损失尖峰中,有14个是在 patch 嵌入层 RMS 尖峰出现后的1-8次迭代内发生的。这种情况偶然发生的概率小于1%。
- ViT-Large (图17): 在15个损失尖峰中,有13个遵循此模式,偶然发生概率为1.0%。
- 控制实验 (图21): 当检查 Transformer 中间某个随机层的 RMS 时,这种预测关系消失了,没有一个损失尖峰遵循 RMS 尖峰。
- 失败案例分析 (图19, 20): 未被预测的损失尖峰大多是由于启发式规则识别错误(假阳性)导致的。
结论。这些数据有力地证明了,在 patch 嵌入层的二阶矩估计器过时(表现为 RMS 尖峰)是导致后续损失尖峰的直接原因。
E. StableAdamW 续
E.1 问答。
- 为什么不直接用 AdaFactor? 因为社区发现 AdaFactor 在大规模训练上性能不及 AdamW【47,【11,Palm: Scaling language modeling with pathways,2022,arXiv:2204.02311】】,这可能是由于其分解式矩估计。本文旨在倡导使用混合方法。
- 为什么不用 Lion【7,【7,Symbolic discovery of optimization algorithms,2023,arXiv:2302.06675】】这类对更新不做除法缩放的优化器? Lion 是一个有前景的方向,但在本文的 ViT-Huge 实验中,其性能仍略逊于 AdamW。
- 为什么计算 $RMS_t$ 时用 $g_t^2$ 而不是 $v_t^2$? 作者尝试过用 $v_t^2$,发现性能更差。
E.2 实现注意事项。
为了防止计算 $RMS_t$ 时出现除以零的问题,实际计算公式为 $RMS_t = \sqrt{E[g_t^2 / \text{maximum}(u_t, \epsilon^2)]}$,其中 $\epsilon$ 是 AdamW 的超参数(本文使用1e-6),maximum 是逐元素最大值操作。
💬 评论讨论
欢迎在这里分享您的想法和见解!