Scaling FP8 Training to Trillion-Token LLMs
Scaling FP8 Training to Trillion-Token LLMs
作者/机构: Maxim Fishman, Brian Chmiel, Ron Banner (Intel, Israel); Daniel Soudry (Technion, Haifa, Israel)
A1 主要贡献
本文在FP8训练大型语言模型(LLM)方面取得了显著进展,成功将训练数据集规模扩展至2万亿(trillion)词元,这比以往的研究(最多1000亿词元)提高了20倍。通过这种大规模、长时间的训练,研究揭示了在早期短时研究中未能观察到的FP8训练中的关键不稳定性问题。
核心问题与发现:
* 问题识别:研究发现,在使用FP8精度进行长时间训练时,Llama2 7B模型在处理约2200亿词元后,训练损失会出现急剧发散(见Fig. 2(a))。
* 根源分析:通过深入分析,本文将这种不稳定性追溯到一个先前未被发现的现象:SwiGLU激活函数(【索引 Shazeer, 2020a, Glu variants improve transformer】)对异常值的放大作用,这种放大效应在长时间训练后变得尤为突出。从理论和实证上,研究将这种异常值放大与训练过程中发生的权重对齐现象联系起来,尤其是在使用ℓ2正则化且输入较大的SwiGLU神经元中。
创新点与解决方案:
* Smooth-SwiGLU:为解决上述问题,本文提出了一种名为Smooth-SwiGLU的新型激活函数。它通过对SwiGLU的线性分支应用一个缩放因子,并在后续的线性层之后将其重新缩放回来,从而在不改变函数行为的前提下,有效抑制了异常值的放大,确保了FP8训练在长时间内的稳定性。
* FP8优化器:本文首次成功地将Adam优化器的两个动量(moments)都量化到FP8精度,进一步减少了训练过程中的内存使用,提升了大规模LLM开发的效率。
* 大规模验证:结合Smooth-SwiGLU和FP8优化器动量量化这两项创新,本文在256个Intel Gaudi2加速器上成功训练了一个7B参数模型。实验结果表明,该方法在下游任务上取得了与BF16基线相当的性能,同时吞吐量提升高达约34%。
主要贡献概览:
* 首次成功在高达2万亿词元的数据集上使用FP8训练LLM,揭示了长期训练中的关键不稳定性。
* 识别出不稳定的根本原因是SwiGLU激活函数在长期训练中对异常值的放大。
* 从理论和实证上将异常值放大与ℓ2正则化下的权重对齐过程联系起来。
* 提出了Smooth-SwiGLU,一种新颖的激活函数,可在不改变模型行为的情况下确保FP8训练的稳定性。
* 首次实现了对Adam优化器两个动量的FP8量化,进一步优化了内存使用。
* 在Intel Gaudi2加速器上,实现了与BF16基线相当的结果,并显著提升了吞吐量,证明了该方法的实际可行性。
A3 背景知识与关键观察
2 FP8训练的背景与挑战
FP8的出现与挑战。大型语言模型(LLM)的计算需求推动了从传统的FP32向低精度格式的转变。虽然FP16和BF16已成为许多训练任务的标准(【索引 Micikevicius et al., 2017, Mixed precision training】; 【索引 Scao et al., 2022, Bloom: A 176b-parameter open-access multilingual language model】; 【索引 Smith et al., 2022, Using deepspeed and megatron to train megatron-turing nlg 530b, a large-scale generative language model】),但FP8是向更低精度发展的下一步。Micikevicius等人(【索引 Micikevicius et al., 2022, Fp8 formats for deep learning】)为深度学习标准化了两种FP8格式:E4M3(4位指数,3位尾数),优化用于权重和激活值;以及E5M2(5位指数,2位尾数),适用于梯度。FP8在大规模训练中显示出巨大潜力,尤其是在现代硬件如NVIDIA的H100和Intel的Gaudi2的支持下。然而,其有限的数值范围要求必须采用精细的缩放技术来维持数值稳定性和模型性能。
FP8的缩放技术。FP8训练LLM的主要挑战源于其有限的动态范围。为了解决这个问题,研究人员开发了各种缩放技术。全局损失缩放(Global loss scaling)(【索引 Micikevicius et al., 2017, Mixed precision training】)将整个损失乘以一个常数因子,以防止反向传播过程中的梯度下溢。逐张量缩放(Per-tensor scaling)(【索引 Sun et al., 2019, Hybrid 8-bit floating point (hfp8) training and inference for deep neural networks】)采用更细粒度的方法,根据每个张量的特定数值范围进行独立缩放。这些技术有助于更好地利用FP8格式有限的范围。然而,Lee等人(【索引 Lee et al., 2024, To fp8 and back again: Quantifying the effects of reducing precision on llm training stability】)的研究表明,即使使用了这些技术,FP8训练仍可能导致显著的不稳定性,凸显了低精度LLM训练中持续存在的挑战。
缩放技术的实现方式。这些缩放技术的实现通常遵循两种方法之一:即时缩放(just-in-time scaling)或延迟缩放(delayed scaling)。即时缩放根据当前数据分布动态调整缩放因子。然而,由于需要多次数据传递,这在FP8训练中通常不切实际,因为它可能会抵消使用FP8带来的性能优势。另一方面,延迟缩放根据前几次迭代的数据分布来选择缩放因子。虽然这种方法更实用,但它假设了迭代间的统计特性是一致的,这使得它容易受到异常值的影响,这些异常值会破坏这种一致性,并可能破坏训练过程的稳定性。
先前工作的局限性与本文的扩展。最近,Peng等人(【索引 Peng et al., 2023, Fp8-lm: Training fp8 large language models】)展示了使用FP8格式训练LLM的首个实证结果,数据集规模高达1000亿词元,验证了FP8在大规模训练中的潜力,但在通常需要更大规模训练的真实场景中适用性有限。我们的研究扩展了这项工作,成功地将FP8训练规模扩展到高达2万亿词元的数据集。我们引入了新的技术来应对FP8有限动态范围的挑战以及在长期训练场景中出现的稳定性问题。我们的方法不仅克服了这些限制,还在内存使用和训练速度上取得了显著提升,证明了FP8训练在真正大规模LLM开发中的可行性。
3 大规模FP8训练中的异常值放大
训练后期出现激活值异常值。许多研究都观察到了异常值的存在(【索引 Yang et al., 2024, Mitigating quantization errors due to activation spikes in glu-based llms】; 【索引 Bondarenko et al., 2023, Quantizable transformers: Removing outliers by helping attention heads do nothing】),尤其是在推理过程中的激活值中。先前为应对推理中的这些异常值而提出的解决方案之一,是对应激活值进行旋转(【索引 Liu et al., 2024, Spinquant: Llm quantization with learned rotations】)。这些异常值会引入极值,在像FP8这样的低精度格式有限的动态范围内难以处理,从而严重影响模型的稳定性和性能。我们的工作揭示,在使用大规模数据集训练大型语言模型(LLM)的后期阶段,这些异常值变得尤为突出。
异常值的出现时间点与挑战。如图1所示,这些异常值仅在训练过程中处理了大约2000亿词元后才出现。这种现象对维持数值稳定性和模型性能构成了重大挑战,尤其是在使用那些假设迭代间一致性的方法时。这些对模型性能至关重要的异常值的突然出现(【索引 Sun et al., 2024, Massive activations in large language models】),破坏了FP8训练技术所依赖的统计假设,可能导致训练过程的不稳定或发散。
FP8有限动态范围的问题。这些在训练后期出现的异常值对于FP8格式尤其成问题,因为其动态范围有限。与FP32甚至BF16等更高精度的格式不同,FP8的可表示值范围要窄得多。当异常值超出这个范围时,它们可能导致上溢或下溢,从而造成关键信息的丢失,并可能破坏整个训练过程的稳定性。
异常值的零星性质。此外,从图1b中可以明显看出,这些异常值的零星性质使其难以预测和管理。传统的缩放技术依赖于迭代间一致的统计特性,难以适应这些突然出现的极值。这种不可预测性进一步加剧了在FP8训练中维持数值稳定性的任务难度,特别是随着数据集规模和训练时长的增加。
A2 方法细节
4 SWIGLU与异常值放大
虽然前一节强调了大规模FP8训练中出现异常值的普遍问题,但我们的调查显示,SwiGLU(Swish Gated Linear Unit)激活函数在放大这些异常值中扮演了关键角色。本节将探讨SwiGLU的结构,并展示其独特属性如何导致异常值的产生和放大。
4.1 SWIGLU结构
SwiGLU在Transformer架构中的应用。作为现代LLM基础的Transformer架构(【索引 Vaswani et al., 2017, Attention is all you need】)经历了多次修改以提升性能和效率。一个显著的例子是在像LLaMA(【索引 Touvron et al., 2023, Llama 2: Open foundation and fine-tuned chat models】)和PaLM(【索引 Chowdhery et al., 2022, Palm: Scaling language modeling with pathways】)等模型中引入了SwiGLU(Swish Gated Linear Unit)(【索引 Shazeer, 2020b, Glu variants improve transformer】)激活函数。
SwiGLU的数学定义。设 $x \in R^d$ 是来自前一层的输入向量。对于两个权重向量 $w_1, w_2 \in R^d$,SwiGLU神经元的定义如下:
$$\operatorname{SwiGLU}_{\mathbf{w}_1, \mathbf{w}_2}(\mathbf{x}) \triangleq\left(\mathbf{x}^{\top} \mathbf{w}_1\right) \operatorname{Swish}\left(\mathbf{x}^{\top} \mathbf{w}_2\right) \triangleq\left(\mathbf{x}^{\top} \mathbf{w}_1\right)\left(\mathbf{x}^{\top} \mathbf{w}_2\right) \sigma\left(\mathbf{x}^{\top} \mathbf{w}_2\right),$$其中 $\sigma(z) \triangleq 1/(1 + e^{-z})$ 是sigmoid函数。
SwiGLU的二次方增长特性。虽然其他标准神经元类型,如ReLU、GeLU和Swish,在输入量级较大时至多是线性的(即 $\lim_{u \to \pm\infty} |f(u)/u| \le 1$),但SwiGLU激活函数是二次的,如果 $w_1$ 和 $w_2$ 充分对齐(例如,如果 $w_1 = w_2$ 且 $w_1^T x = 1$),则可以达到更大的值(并导致非常强的异常值),即 $\lim_{c \to \infty} \text{SwiGLU}_{w_1,w_2}(cx)/c^2 = 1$)。正如我们接下来将展示的,对于具有足够大输入的神经元,在训练期间恰好会发生这种对齐。
4.2 SWIGLU中权重相关性的理论分析
理论分析目标。接下来,我们分析SwiGLU神经元在训练期间的行为,并证明如果其输入的大小增加到某个阈值以上,其权重向量会趋于完全对齐。这导致SwiGLU的输出大小在训练期间显著增加,可能产生异常值。
训练设置与L2正则化。为了证明这一点,我们假设SwiGLU神经元嵌入在一个具有k个参数的神经网络中。网络中其余的参数表示为 $\theta \in R^{k-2d}$。我们使用某种ℓ2正则化来训练该神经网络:
$$\min _{\mathbf{w}_{1}, \mathbf{w}_{2}, \theta} \sum_{n=1}^{N} \ell_{n}\left(\operatorname{SwiGLU}_{\mathbf{w}_{1}, \mathbf{w}_{2}}\left(\mathbf{x}_{n}(\theta)\right), \theta\right)+\frac{\mu}{2}\left(\left\|\mathbf{w}_{1}\right\|^{2}+\left\|\mathbf{w}_{2}\right\|^{2}+\|\theta\|^{2}\right),$$其中 $\mu > 0$ 是正则化强度,N是训练样本数,而 $l_n (u_n, \theta)$ 是关于SwiGLU输出和神经网络其余参数的单样本损失。
定理1:权重对齐的条件。我们发现:定理1。假设我们收敛到损失函数的一个驻点 $(w_1, w_2, \theta)$,并且对于所有样本n,$\sigma'(x_n^T (\theta) w_2) \to 0$。那么,在这个驻点上,有 $w_1 \to w_2$ 或 $w_1 \to -w_2$。
定理证明过程。证明。在一个驻点 $(w_1, w_2, \theta)$ 处,我们有对于所有 $i \in \{1, 2\}$:
$$\sum_{n=1}^{N} \nabla_{\mathbf{w}_{i}} \ell_{n}\left(\operatorname{SwiGLU}_{\mathbf{w}_{1}, \mathbf{w}_{2}}\left(\mathbf{x}_{n}(\theta)\right), \theta\right)+\mu \mathbf{w}_{i}=0 .$$使用链式法则,我们得到以下两个方程:
$$\begin{aligned} \begin{aligned} 0 & = \sum _ { n } \mathbf { x } _ { n } \mathbf { x } _ { n } ^ { \top } \mathbf { w } _ { 2 } \sigma ( \mathbf { w } _ { 2 } ^ { \top } \mathbf { x } _ { n } ) \delta _ { n } + \mu \mathbf { w } _ { 1 } \\ 0 & = \sum _ { n } \mathbf { x } _ { n } \mathbf { x } _ { n } ^ { \top } \mathbf { w } _ { 1 } ( \sigma ( \mathbf { w } _ { 2 } ^ { \top } \mathbf { x } _ { n } ) + ( \mathbf { x } _ { n } ^ { \top } \mathbf { w } _ { 2 } ) \sigma ^ { \prime } ( \mathbf { x } _ { n } ^ { \top } \mathbf { w } _ { 2 } ) ) \delta _ { n } + \mu \mathbf { w } _ { 2 } \end{aligned} \end{aligned}$$这里我们定义了 $\delta_n(w_1, w_2, \theta) \triangleq \frac{\partial l_n(u_n, \theta)}{\partial u}$,并且为了简化符号,我们用 $u_n = \text{SwiGLU}_{w_1,w_2} (x_n(\theta))$ 来表示,省略了对 $(w_1, w_2, \theta)$ 和 $x_n(\theta)$ 对 $\theta$ 的依赖。给定假设:
$$\forall n : \sigma'(\mathbf{x}_n^\top \mathbf{w}_2) \to 0$$在这个极限下,我们得到:
$$0=\sum_{n} \mathbf{x}_{n} \mathbf{x}_{n}^{\top} \mathbf{w}_{2} \sigma\left(\mathbf{w}_{2}^{\top} \mathbf{x}_{n}\right) \delta_{n}+\mu \mathbf{w}_{1} \quad ; \quad 0=\sum_{n} \mathbf{x}_{n} \mathbf{x}_{n}^{\top} \mathbf{w}_{1} \sigma\left(\mathbf{w}_{2}^{\top} \mathbf{x}_{n}\right) \delta_{n}+\mu \mathbf{w}_{2} .$$现在,定义 $\lambda_n = -\mu^{-1}\delta_n\sigma(w_2^T x_n)$,上述方程变为:
$$\mathbf{w}_{1}=\sum_{n} \lambda_{n} \mathbf{x}_{n} \mathbf{x}_{n}^{\top} \mathbf{w}_{2}=A \mathbf{w}_{2} \quad ; \quad \mathbf{w}_{2}=\sum_{n} \lambda_{n} \mathbf{x}_{n} \mathbf{x}_{n}^{\top} \mathbf{w}_{1}=A \mathbf{w}_{1} .$$其中 $A = \sum_n \lambda_n x_n x_n^T$ 是一个对称矩阵。这意味着:
$$\mathbf{w}_{1}=A \mathbf{w}_{2}=A^{2} \mathbf{w}_{1} \quad ; \quad \mathbf{w}_{2}=A \mathbf{w}_{1}=A^{2} \mathbf{w}_{2}$$由于A是对称的,这意味着 $w_1$ 和 $w_2$ 都是A的特征向量,且具有相同的特征值:1或-1。将此代入方程2,我们得到 $w_1 = w_2$ 或 $w_1 = -w_2$。
定理1的实际意义。请注意,这个结果也适用于将SwiGLU中的swish激活替换为其他GLU变体(【索引 Shazeer, 2020a, Glu variants improve transformer】),因为我们没有使用Swish的任何特定属性。因此,实际上,在正则化和充分训练的情况下,如果 $\sigma'(x_n^T(\theta)w_2)$ 通常很小,权重向量 $w_1$ 和 $w_2$ 必须收敛到相同或相反的方向,即 $w_1 \approx w_2$ 或 $w_1 \approx -w_2$。由于 $\sigma'(z)$ 随着 $|z|$ 的增加呈指数级快速衰减,这仅仅意味着神经元的输入通常不会太小。这种情况可能发生在 $\|w_2\|$ 足够大,并且通常 $|w_2^T x_n(\theta)| > 0$ 的情况下。当我们将神经网络拟合到一个大小为N的大数据集,其中 $N \gg k$(即我们处于欠参数化状态)且无法达到零损失时,后一个条件应该是普遍情况——因为那时神经网络没有多余的能力将特定的神经元输入设置为零(除了拟合数据)。事实上,我们观察到(附录中的图9),在训练后,对于异常神经元,大约99%的词元的 $|w_2^T x_n| > 1$。有趣的是,这种对齐现象是由于ℓ2正则化引起的,即使它非常弱。实际上,弱正则化会导致更大的权重范数,从而增强这种效应。
4.3 观察权重相关性增长与训练不稳定性
权重相关性增长与训练不稳定的观察。在我们的实验中,我们观察到权重矩阵 $w_1$ 和 $w_2$ 之间日益增长的相关性与FP8训练损失最终发散之间存在明确的关系。
权重对齐过程的经验验证。在图2中,我们看到权重对齐过程及其对FP8训练的影响完全符合我们理论的预测。随着训练的进行,异常值通道中的 $\|w_2\|$ 增长,最终超过了满足我们定理假设的临界阈值。因此,这些通道中的权重向量 $w_1$ 和 $w_2$ 开始迅速对齐——即 $w_1$ 和 $w_2$ 之间的相关性最初很低,但在1250亿到2100亿词元之间急剧增加。有趣的是,这种对齐似乎与范数的进一步增长同时发生。高相关性和增加的权重范数相结合,为产生极端激活值或“尖峰”创造了理想条件。
对齐现象的进一步说明。这些激活尖峰反过来导致了FP8训练的发散,如图2a所示。重要的是,虽然我们在这个例子中主要观察到强正相关,但我们的理论也预测了强负相关存在的可能性。我们也观察到了这种情况,如附录中的图7所示。
SwiGLU带来的独特挑战。这一现象突显了SwiGLU在大型语言模型FP8训练中带来的独特挑战。权重的逐渐对齐,加上范数的增长,创造了一种情景,即随着训练的进行,异常值变得越来越可能且越来越严重。因此,在较短的训练运行中,不稳定性可能不明显,但在大规模、长时间的训练场景中则会成为关键问题。这解释了为什么在之前规模较小的FP8训练研究中没有观察到此类问题。
4.4 SMOOTH-SWIGLU
问题的根源与初步解决方案。如前所示,SwiGLU激活函数可能导致MLP组件最后一个线性层输入中出现异常值。当使用依赖于层间统计特性保持一致这一假设的延迟缩放FP8精度时,这些异常值构成了一个重大挑战。SwiGLU引起的数值突然飙升破坏了这种连续性,导致训练过程不稳定。在图3中,我们展示了禁用MLP组件中最后一个线性层(SwiGLU的输出)的量化,可以使Llama2 FP8在大型数据集上成功收敛,解决了之前观察到的发散问题。这表明Llama架构中的其他组件,如RMS Norm或MHA,并非FP8训练不稳定的原因。
Smooth-SwiGLU的提出。虽然禁用SwiGLU输出的量化能有效防止发散,但它降低了FP8可能带来的加速效益。为了在解决异常值问题的同时保持完整的FP8加速,我们提出了一种名为Smooth-SwiGLU的新颖修改。图4阐释了Smooth-SwiGLU背后的核心思想:对SwiGLU函数的线性分支应用一个缩放因子,并在最后一个线性层之后将其重新缩放回去。这种方法在保留SwiGLU激活函数整体功能的同时,防止了最后一个线性层输入的量化过程中出现异常值,使我们能够充分利用整个网络的FP8精度。
Smooth-SwiGLU的数学表达式。在数学上,我们将每个通道i的量化Smooth-SwiGLU函数表示为:
$$\text{Smooth-SwiGLU}_{\hat{\mathbf{w}}_{1,i}, \hat{\mathbf{w}}_{2,i}}(\mathbf{x}) = s_i^{-1} \cdot Q(s_i \cdot (\hat{\mathbf{w}}_{1,i}^\top Q(\mathbf{x})) \text{Swish}(\hat{\mathbf{w}}_{2,i}^\top Q(\mathbf{x}))))$$其中 $s_i$ 是逐通道的缩放因子,$\hat{w} = Q(w)$ 表示权重向量 $w$ 的量化版本,Q是量化函数(为简化符号,我们省略了Q对张量的依赖性)。
缩放因子的计算方法。为了最小化计算开销,我们采用一种高效的并行方法来计算缩放因子 $s_i$:
1. 将张量分割成块,每个块对应一个通道。
2. 对每个块(通道),并行计算其最大值。
3. 使用这些逐通道的最大值来确定每个通道的独立缩放因子 $s_i$。
这种方法可以实现高效的逐通道缩放,因为每个通道的缩放因子都是独立且并行计算的。与线性层中的矩阵乘法相比,这种方法的计算成本是适度的,特别是考虑到并行化,即使我们的实现未经优化。
推理阶段的零开销。在推理期间,缩放因子可以合并到包含SwiGLU层和随后的线性层的MLP层的第一个和第三个线性层的权重中(见图4),即:
$$\sum_{i} \hat{\mathbf{w}}_{3,i} \text{Smooth-SwiGLU}_{\hat{\mathbf{w}}_{1,i}, \hat{\mathbf{w}}_{2,i}}(\mathbf{x}) = \sum_{i} s_{i}^{-1} \cdot \hat{\mathbf{w}}_{3,i} Q(s_{i} \cdot (\hat{\mathbf{w}}_{1,i}^{\top} Q(\mathbf{x})) \text{Swish}(\hat{\mathbf{w}}_{2,i}^{\top} Q(\mathbf{x}))))$$因此,我们可以通过重新定义 $\tilde{w}_{1,i} \triangleq Q(s_i \cdot w_{1,i})$ 和 $\tilde{w}_{3,i} \triangleq Q(s_i^{-1} \cdot w_{3,i})$ 来吸收这个标量。因此,这个过程在推理时没有额外的成本。
5 FP8优化器
FP8优化器的目标。Adam优化器及其变体因其在处理各种训练挑战方面的有效性而在深度学习中被广泛使用。Adam优化器的一个关键特性是它存储两个动量,传统上使用高精度(FP32)。这显著增加了内存使用,特别是对于大规模模型。虽然之前的研究(【索引 Peng et al., 2023, Fp8-lm: Training fp8 large language models】)已经表明将第一动量降至FP8精度的可行性,但他们将第二动量保留在FP16。我们的工作通过成功地将两个动量都量化到FP8,进一步推动了边界,显著提高了大型语言模型的优化器效率。
5.1 挑战
Adam优化器的矩和挑战。Adam优化器使用两个动量来为每个参数调整学习率:
1. 第一动量是梯度的均值估计。
2. 第二动量是梯度的未中心化方差估计。
第二矩的精度要求。Adam优化器的一个关键方面是在参数更新步骤中使用第二动量的平方根倒数。这个操作对精度要求有重要影响。由于这个平方根倒数操作,第二动量中的最小值在决定参数更新时变得最为重要。这一特性在考虑降低第二动量的精度时带来了独特的挑战。
5.2 方法论
第一和第二矩的FP8格式选择。我们进行了广泛的实验来确定两个动量的最佳FP8格式。我们的调查揭示了每个动量存在不同的精度要求:
1. 第一动量:E4M3格式(4位指数,3位尾数)提供了足够的精度。该格式在表示梯度均值时,在范围和准确性之间取得了良好的平衡。
2. 第二动量:必须使用E5M2格式(5位指数,2位尾数)。该格式提供了更高的动态范围,对于保留第二动量中最小值的信息至关重要。额外的指数位确保了我们能够准确表示非常小和中等大的值,这对于应用于该动量的平方根倒数操作至关重要。
实验验证。在图5的实验中,我们表明,虽然第一动量能够使用E4M3收敛,但用于估计梯度平方的第二动量需要更宽的动态范围,并且只能使用E5M2格式收敛。
与先前工作的比较。在表1中,我们将提出的优化器动量量化方案与Peng等人(【索引 Peng et al., 2023, Fp8-lm: Training fp8 large language models】)提出的方案进行了比较。我们的方案首次展示了使用标准FP8格式量化两个动量的能力。
A4 实验
实验环境
- 模型架构:实验使用Llama2模型(【索引 Touvron et al., 2023, Llama 2: Open foundation and fine-tuned chat models】)作为基线。该模型是一个仅解码器的Transformer(【索引 Brown et al., 2020, Language models are few-shot learners】),采用了预归一化(pre-normalization)的RMSNorm(【索引 Zhang & Sennrich, 2019, Root mean square layer normalization】)、SwiGLU激活函数(【索引 Shazeer, 2020b, Glu variants improve transformer】)和旋转位置嵌入(rotary positional embeddings)(【索引 Su et al., 2024, Roformer: Enhanced transformer with rotary position embedding】)。
- 数据集:模型在开源的Red Pajama数据集(【索引 Computer, 2023, Redpajama: an open dataset for training large language models】)上进行了2万亿词元的训练。
- 超参数:训练过程中的超参数与Llama2的原始论文(【索引 Touvron et al., 2023, Llama 2: Open foundation and fine-tuned chat models】)保持一致。
- 硬件配置:所有训练均在256个Intel Gaudi2设备上进行。
实验结果
不同优化器动量数据类型的比较。
表1:两种Adam优化器动量的不同数据类型比较。
训练稳定性。图6展示了使用所提出方案(包括Smooth SwiGLU和FP8 Adam优化器)的Llama2训练损失。该方案成功克服了标准FP8训练的发散点。FP8模型使用了标准格式(【索引 Micikevicius et al., 2022, Fp8 formats for deep learning】),包括保存一个高精度权重矩阵,前向传播时量化为E4M3,后向传播时量化为E5M2,并采用延迟缩放,这与Nvidia的Transformer Engine类似。该模型在256个Intel Gaudi2上训练了15天。
零样本性能。表2比较了BF16基线和我们的FP8模型在下游任务上的零样本性能(准确率和困惑度)。结果表明,我们的FP8方法在所有测试指标上都达到了与BF16基线相当的性能。其中FP8(1)指的是FP8 + SwiGLU输出为BF16,FP8(2)指的是FP8 + Smooth SwiGLU + FP8优化器。
表2:BF16基线和提出的FP8方案在零样本准确率和困惑度上的比较。注意两种模型在所有测试中都取得了相当的结果。FP8(1)指FP8 + SwiGLU输出为BF16,FP8(2)指FP8 + Smooth SwiGLU + FP8优化器。
性能增益。表3展示了不同配置在Intel Gaudi2硬件上的性能。虽然完全的FP8量化实现了最高的加速(约37%),但它会导致训练发散(如图2a所示)。为w3层禁用量化可以实现收敛(图3),并带来约27%的加速。我们提出的Smooth-SwiGLU方案不仅能收敛且结果与BF16基线相当(图6),还提供了高达约34%的加速。
表3:在Llama2 7B模型中,我们未优化的实现下不同配置的性能加速。测量在8个Intel Gaudi2设备上完成。
内存减少。在表4中,我们展示了将优化器动量从标准的FP32更改为FP8所实现的内存减少。此外,我们还将主权重减少到FP16,如Peng等人(【索引 Peng et al., 2023, Fp8-lm: Training fp8 large language models】)所示。可以看出,我们可以将内存消耗减少约30%。
表4:应用所提出的FP8优化器(第5节)时的内存减少。测量在8个Intel Gaudi2设备上完成,使用Deepspeed Zero-1。
A5 结论
本文成功地在高达2万亿词元的数据集上展示了FP8训练,显著超过了先前1000亿词元的限制(【索引 Peng et al., 2023, Fp8-lm: Training fp8 large language models】),并取得了与BF16基线相当的结果。重要的是,我们发现早期的FP8训练尝试时间不够长,未能揭示由异常值引起的关键不稳定性。通过分析方法和模拟,我们表明这些异常值会随着时间的推移而出现,特别是在长时间的训练中。我们的调查揭示,SwiGLU激活函数会放大这些异常值,从而在大规模场景中破坏FP8训练的稳定性。
为解决此问题,我们对SwiGLU激活函数应用了逐通道量化,我们称之为Smooth-SwiGLU。尽管在功能上与SwiGLU相同,但该方法有效减少了异常值的放大,确保了FP8训练的稳定性,对训练期间的模型性能影响适中,且对推理没有任何影响。此外,我们引入了首个对Adam优化器两个动量进行FP8量化的实现,进一步优化了内存使用。
我们提出的方法,结合了Smooth-SwiGLU和FP8优化器动量,在下游任务上取得了与BF16基线相当的性能,同时提供了显著的吞吐量提升。该方法成功克服了在大型数据集上进行标准FP8训练时通常遇到的发散挑战。
可复现性:论文摘要提供了一个指向匿名GitHub仓库的链接,其中包含复现实验所需的所有代码和必要细节。
伦理考量:LLM的训练需要巨大的计算资源,这显著增加了碳排放。这种环境成本已成为人工智能领域日益关注的问题。像FP8这样的低精度格式提供了一个有前途的解决方案,因为它在不牺牲模型准确性的前提下显著减少了计算开销。通过采用FP8进行训练,我们不仅可以提高训练效率,还可以减轻与大规模LLM训练相关的碳足迹,为更可持续的人工智能发展铺平道路。
💬 评论讨论
欢迎在这里分享您的想法和见解!