Towards Fully FP8 GEMM LLM Training at Scale

  • 作者/机构: Alejandro Hernández-Cano (EPFL), Dhia Garbaya (EPFL), Imanol Schlag (ETHZ), Martin Jaggi (EPFL)

A1 主要贡献

本文旨在解决大型语言模型(LLM)预训练中采用FP8数据格式所面临的稳定性挑战,从而充分发挥其潜力。现有方法因依赖于次优的细粒度FP8内核,或在注意力投影等敏感组件中回退到更高精度的矩阵乘法(GEMMs),从而限制了吞吐量的提升。

核心问题与研究目标
- 核心问题: 在大规模LLM训练中,FP8格式因其动态范围有限,难以处理激活值中出现的大幅离群特征,导致训练不稳定甚至发散。现有FP8训练方法为了维持稳定,要么采用开销较大的细粒度缩放策略,要么在关键计算(如注意力)中保留BF16精度,这都削弱了FP8带来的性能优势。
- 研究目标: 设计一种新的LLM架构,使其能够在Transformer块内的所有GEMMs(包括线性投影和注意力机制)中,在整个前向和后向传播过程中都使用FP8计算,从而在不牺牲模型性能的前提下,实现前所未有的训练吞吐量提升。

创新点与主要贡献
本文的主要贡献如下:
- 提出FOG架构集: 引入了FOG(Fast and Outlier-Guarded)系列LLM架构,该架构专门设计用于在训练过程中最小化离群特征。该方法能够在Transformer块内的所有GEMMs中稳定地使用FP8计算,与标准的BF16方法相比,吞吐量提升高达43%。
- 性能与稳定性验证: FOG架构在达到与BF16基线相当的下游任务质量的同时,显著提升了训练速度。本文通过在多种模型尺寸(0.4B、1.5B、8B)和高达Chinchilla最佳数据预算15倍的数据量下进行实验,经验性地证明了其性能和稳定性【12, Jordan Hoffmann, et al. Training Compute-Optimal Large Language Models. NeurIPS 2022】。
- 设计的灵活性: 证明了FOG设计的灵活性,它可以适应多种架构,包括使用不同激活函数的变体,甚至可以应用于专家混合(MoE)设置。
- 提出基于峰度的稳定性诊断方法: 使用峰度(kurtosis)作为关键指标,提出了一套方法,可以通过短期运行的诊断数据来判断一个架构在长数据量下进行FP8训练的鲁棒性。该方法不仅解释了先前观察到的规模化训练中的发散行为,还为社区未来开发的Transformer变体提供了评估FP8训练适用性的工具,无需进行昂贵的全面实验。

A3 背景知识

FP8格式与缩放策略: 由于动态范围有限,FP8张量在表示极端值时特别容易发生上溢和下溢。FP8格式有两种标准形式【20, Paulius Micikevicius, et al. FP8 Formats for Deep Learning. arXiv 2022】:E4M3和E5M2,各有优劣。E4M3(4位指数,3位尾数)提供更高的精度,而E5M2(5位指数,2位尾数)则以牺牲精度为代价提供更宽的动态范围。现有的大规模分布式训练框架,如DeepSpeed或Megatron【29, Mohammad Shoeybi, et al. Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv 2020】,利用这种差异,在前向传播中使用E4M3以保持精度,在后向传播中使用E5M2以有效处理梯度更宽的动态范围。然而,这两种格式的表示能力都远低于半精度或单精度格式。因此,在将张量转换为FP8时,会应用各种缩放策略以更有效地利用其有限的范围。这些策略主要是张量级别的,分为两大类:延迟缩放(delayed scaling)和即时缩放(just-in-time scaling, JIT)。延迟缩放利用先前训练迭代的信息来确定当前迭代的张量缩放因子,只需要对数据进行单次遍历,并存储过去一段时间内观察到的有用指标的简短历史。相比之下,JIT缩放会阻碍使用FP8带来的增益,因为它需要先以更高精度生成张量,计算其分布以确定缩放因子,然后再将输入转换为FP8并执行GEMM,这至少需要两次数据遍历。一种更新的方法旨在通过为单个张量使用多个缩放因子来使缩放更鲁棒,允许不同的张量块有不同的缩放因子【27, Bita Darvish Rouhani, et al. Microscaling Data Formats for Deep Learning. arXiv 2023; 3, DeepSeek-AI, et al. DeepSeek-V3 Technical Report. arXiv 2025】。这种更精细的缩放策略自然会相对于张量级的延迟缩放方法在GEMM内核上引入更大的开销。

训练稳定性的挑战与现有解决方案: 确保FP8训练的稳定性仍然具有挑战性。当某些激活值在训练过程中产生大的离群值时,问题就变得尤为突出,这使得低精度表示变得不可行,并导致快速发散。先前的工作引入了“巨量激活”(massive activations)这一术语,这一现象类似于离群特征,并显示了它们在LLM能力中的关键作用【30, Mingjie Sun, et al. Massive Activations in Large Language Models. CoLM 2024】。理解这些离群值的动态对于解释FP8发散现象和识别网络中导致这些问题的组件至关重要。一个显著的离群值放大源头已被确定为广泛采用的SwiGLU(Swish Gated-Linear-Unit)激活函数【28, Noam Shazeer. GLU Variants Improve Transformer. arXiv 2020】。通过将其替换为一个缩放变体SmoothSwiGLU,可以调节大的离群值,并被证明可以稳定先前发散的FP8训练运行并确保其收敛【6, Maxim Fishman, et al. Scaling FP8 Training to Trillion-Token LLMs. ICLR 2025】。

离群值的根源与架构改进: 进一步的研究表明,不仅SwiGLU是离群值的放大器,门控线性单元(GLUs)和预归一化层(pre-normalization layers)通常也是,这表明不当的信号传播是离群值的根本原因【9, Bobby He, et al. Understanding and Minimising Outlier Features in Transformer Training. NeurIPS 2024】。移除这些组件并为Transformer配备QK熵正则化机制,如QK RMS归一化【10, Alex Henry, et al. Query-Key Normalization for Transformers. EMNLP 2020 findings】,产生了离群值保护(Outlier Protected, OP)架构【9】,该架构已被证明可以数量级地减少后期出现的离群值,同时提供相当的预测质量。虽然OP架构被证明对训练后量化有益,但其在FP8预训练中的应用仍未被探索。最后,预归一化层的一个替代方案是后归一化层(post-normalization layers)【18, Ze Liu, et al. Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022】。长数据范围的训练已经证实了它们在使用标准BF16混合精度训练时在训练稳定性方面的优越性【21, Team OLMo, et al. 2 OLMo 2 Furious. arXiv 2025】。

A2 方法细节

FOG架构的核心设计: 我们的基础架构,如Figure 1所示,对广泛使用的Transformer网络【32, Ashish Vaswani, et al. Attention Is All You Need. NeurIPS 2017】做出了关键性改变。我们移除了注意力机制和FFN之前的预归一化模块。此外,在注意力机制中增加了一个归一化机制,以防止熵坍塌(transformer训练中的一个关键不稳定性问题【34, Shuangfei Zhai, et al. Stabilizing Transformer Training by Preventing Attention Entropy Collapse. ICLR 2023】)的发生。这个机制可以采用QK RMSNorm模块的形式【10, Alex Henry, et al. Query-Key Normalization for Transformers. EMNLP 2020 findings】:

$$\mathcal{N}_{\gamma}(\boldsymbol{x}) := \frac{1}{\mathrm{rms}(\mathbf{x})} \boldsymbol{\gamma} \odot \mathbf{x}, \quad \mathrm{rms}(\mathbf{x}) := \frac{\|\mathbf{x}\|_2}{\sqrt{D}},$$

其中$x \in R^D$, $\gamma \in R^N$是可学习的增益向量,$\odot$是哈达玛积,$\|\cdot\|_2$是ℓ2范数。或者,也可以对query和key张量应用逐元素的激活函数$tanh_{\alpha}(x) := tanh(\alpha x)$,其中$\alpha \in R$是可训练的。这个激活函数已被证明具有类似于RMS归一化模块的正则化效果【35, Jiachen Zhu, et al. Transformers without Normalization. CVPR 2025】,同时计算效率更高。

Figure 1: FOG transformer。
Figure 1: FOG transformer。

性能增强与初始化: 此外,第一个transformer块的输入会乘以$\sigma^{-1}$进行缩放,以在初始化时保持激活值的单位方差,其中$\sigma$是网络随机初始化选择的标准差。最后,为了提升性能,在残差连接之前应用了一个可学习的归一化模块。这可以是一个LayerScale【31, Hugo Touvron, et al. Going Deeper with Image Transformers. ICCV 2021】模块,$LayerScale(x) := \gamma \odot x$,其中$\gamma \in R^D$是一个可学习的增益向量;或者一个RMSNorm模块,从而形成一个后归一化架构【18, Ze Liu, et al. Swin Transformer V2: Scaling Up Capacity and Resolution. CVPR 2022】。在这两种情况下,可学习的增益向量都初始化为$1/\sqrt{\text{num\_layers}}$,并保持残差分支不被归一化,从而允许适当的信号传播【9, Bobby He, et al. Understanding and Minimising Outlier Features in Transformer Training. NeurIPS 2024】。我们的架构套件在Table 1中有详细说明,更多细节见附录B。

从OP到FOG的演进: 尽管OP架构已经提供了多种防止大离群值出现的保护措施,但我们观察到它对于FP8DPA(FP8 Dot Product Attention)训练仍然是一个不切实际的选择。在5.1节中,我们展示了它像所有其他测试的架构一样,在训练早期就遭遇了致命的损失发散。我们分离出导致OP与FP8DPA训练不兼容的两个组件:可训练的QK RMSNorm增益向量$\gamma$和缺乏任何(除LayerScale外的)归一化。我们发现,使用后归一化不会像预归一化网络那样容易产生离群值。

Table 1: FOG架构套件与OP的比较。标有*的正则化表示增益不可训练。每个变体提供不同的权衡,FOG-flash具有更高的吞吐量,而FOG-max被观察到具有更好的下游质量。
Table 1: FOG架构套件与OP的比较。标有*的正则化表示增益不可训练。每个变体提供不同的权衡,FOG-flash具有更高的吞吐量,而FOG-max被观察到具有更好的下游质量。
Figure 2: 从OP到FOG-opt的逐步演变。比较了390M模型在FP8DPA训练下的表现。第一个发散的架构是OP,而冻结了QK RMSNorm增益的OP模型在训练的稳定阶段存活了下来。然而,它仍然在学习率冷却期间(约42B tokens开始)经历了显著的发散。最终收敛的架构FOG-opt是在之前方法的基础上增加了后归一化的结果。报告的梯度范数是200个滚动窗口的均值和5%-95%的分位数带。
Figure 2: 从OP到FOG-opt的逐步演变。比较了390M模型在FP8DPA训练下的表现。第一个发散的架构是OP,而冻结了QK RMSNorm增益的OP模型在训练的稳定阶段存活了下来。然而,它仍然在学习率冷却期间(约42B tokens开始)经历了显著的发散。最终收敛的架构FOG-opt是在之前方法的基础上增加了后归一化的结果。报告的梯度范数是200个滚动窗口的均值和5%-95%的分位数带。

冻结QK RMSNorm增益的重要性: Figure 2对从OP到FOG-opt的演变组件进行了消融研究。我们可以看到,冻结可训练的QK RMSNorm增益导致了显著更稳定的训练。我们将OP的早期发散归因于在低精度训练时,不受控制的QK归一化导致其增益爆炸。需要注意的是,这些增益通常不进行权重衰减。我们通过实验观察到了这种爆炸现象,并确认使用L2正则化有助于延迟发散。我们最终选择将增益冻结为一个常数,因为这种方法更简单、足够有效,不影响性能,并且还能提供微小的速度提升。我们的消融实验表明,将增益设置为一个略大于1的常数值可以改善损失。因此,为了在移除$\gamma$增益向量后保留其益处,我们增大了标准的$s = 1/\sqrt{D_{qk}}$ attention softmax缩放因子——这是一个微小的优化技巧,能提供等效的注意力得分矩阵$S$:

$$\mathbf{S}=\frac{1}{\sqrt{D_{q k}}} \mathcal{N}_{\gamma_{0}}(\mathbf{Q}) \mathcal{N}_{\gamma_{0}}(\mathbf{K})^{\top}=\frac{1}{\sqrt{D_{q k}}}\left(\frac{\gamma_{0} \mathbf{Q}}{\operatorname{rms}(\mathbf{Q})}\right)\left(\frac{\gamma_{0} \mathbf{K}}{\operatorname{rms}(\mathbf{K})}\right)^{\top}=\frac{\gamma_{0}^{2}}{\sqrt{D_{q k}}} \mathcal{N}_{1}(\mathbf{Q}) \mathcal{N}_{1}(\mathbf{K})^{\top},$$

后归一化对收敛的关键作用: 最后,我们通过经验证明,在学习率衰减阶段,添加后归一化对于确保FP8DPA的收敛至关重要。先前的工作也倾向于后归一化而非预归一化【21, Team OLMo, et al. 2 OLMo 2 Furious. arXiv 2025】,为它们在BF16训练中更好的稳定性提供了证据。我们将这一观察扩展到我们的FP8设置,并确认仅使用可学习的LayerScale模块,即使有受控的QK正则化,也无法确保在最后阶段的收敛。我们将OP的这种后期发散归因于没有归一化的LayerScale模块不足以处理FP8离群值,这可能是由于学习率衰减后模型统计数据的巨大变化,这些变化在残差连接中累积,导致最后几层的激活值出现巨大的离群值,正如Figure 3中每个transformer块输出的离群值递增模式所强调的那样。我们注意到,最初我们测试了在学习率冷却阶段同时降低之前恒定的权重衰减值的想法,旨在保持模型权重的范数【16, Atli Kosson, et al. Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks. ICML 2024】。附录G显示,这种干预对稳定性或性能都没有明显影响。为了在消融实验中保持一致性和公平比较,我们决定在所有FOG运行中保留这一设置。

Figure 3: 在FP8DPA学习率冷却期间,使用OP+frozenQK架构的QKV张量的峰度。后层表现出显著更大的激活离群值。
Figure 3: 在FP8DPA学习率冷却期间,使用OP+frozenQK架构的QKV张量的峰度。后层表现出显著更大的激活离群值。

A3 关键Observation

使用峰度作为离群值度量: 为了分析神经网络激活值中存在的离群值,我们使用峰度(kurtosis)作为衡量激活值偏差极端程度(例如由离群值引起的偏差)的指标。我们将向量$x \in R^D$的峰度$kurt(x)$定义为标量:

$$\operatorname{kurt}(\mathbf{x}) := \frac{\mu[\mathbf{x}^4]}{\sigma^2[\mathbf{x}^2]},$$

其中$\mu$和$\sigma^2$分别是样本均值和方差,指数运算是逐元素进行的。给定一个激活张量$X \in R^{N \times C \times D}$,其中N、C和D分别是批量大小、序列长度和隐藏大小,我们将其峰度定义为平均峰度$kurt(X) := \frac{1}{NC} \sum_{n=1}^{N} \sum_{c=1}^{C} kurt(x_{nc})$。

峰度定义与应用: 在这个定义下,当$x$的少数元素达到相对于整个向量方差极大的值时,即存在大的离群特征时,$kurt(x)$达到最大值。该定义已在先前工作中用于分析BF16训练中的离群值【9, Bobby He, et al. Understanding and Minimising Outlier Features in Transformer Training. NeurIPS 2024】,并且与概率论文献中峰度的标准定义【22, Karl Pearson. "Das Fehlergesetz Und Seine Verallgemeiner-Ungen Durch Fechner Und Pearson." a Rejoinder. Biometrika 1905】不同,此定义不会将$x$中心化为零均值。对于我们的用途而言,这与FP8核在缩放和转换前不对其输入进行移位的事实是一致的。我们跟踪关键激活值中峰度的动态变化,即FFN中第二个投影的输入、QKV矩阵以及每个transformer块的输出。除非特别说明,我们报告的是所有层激活峰度的平均值。

FOG架构的低峰度特性: 利用这些激活值,我们可以分析训练不同阶段大离群特征的出现情况。Figure 4展示了与基线相当的损失进展,同时在某些激活值上的峰度低了几个数量级。请注意,与之前的FP8方法不同,FOG架构使用FP8注意力计算进行训练,这引入了更多的量化误差。因此,key、query和value投影的峰度变得尤为重要。

激活函数与稳定性: Baseline Llama在FP8训练(注意力为BF16)期间表现出后期发散,这被完全归因于其门控激活函数的二次行为——当权重在训练后期变得足够对齐时出现【6, Maxim Fishman, et al. Scaling FP8 Training to Trillion-Token LLMs. ICLR 2025】。在我们使用FOG-max架构进行的扩展的450B token运行中,我们采用了固有的二次xIELU激活函数(见公式(2)),并观察到稳定的训练,其峰度水平比基线Llama低几个数量级。实际上,将FOG-max架构修改为使用SwiGLU激活函数也导致了稳定的FP8DPA训练行为,如附录D所述。这些结果强烈表明,在训练期间倾向于低峰度激活的架构使得二次激活函数的稳定使用成为可能,并挑战了先前解释的完整性。这特别有趣,因为这类激活函数已知会产生线性梯度,这对反向传播有益——这可能有助于FOG-max在性能上优于基于GeLU的变体,如5.3节所示。

离群值的长期增长趋势: 这些架构在QKV离群值的长期增长中表现出次线性到对数的趋势,如Figure 4中的峰度所示。这种行为支持了它们对FP8DPA的鲁棒性,因为它表明需要更长的训练时间才能看到峰度的显著增加。我们的扩展运行与这一假设一致,因为它没有表现出任何发散的迹象。

Figure 4: 1.5B FOG-max和Llama3模型使用BF16精度训练超过100B tokens的损失和峰度训练动态。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。
Figure 4: 1.5B FOG-max和Llama3模型使用BF16精度训练超过100B tokens的损失和峰度训练动态。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。

峰度作为早期预警信号: Figure 5展示了一个发散的FP8DPA运行示例,并与成功的FOG-max训练进行比较。这强调了跟踪张量级指标(如峰度)的重要性,以便在常见的全局指标(如损失和梯度范数)显示任何发散症状之前,潜在地预测后期的发散。在这个例子中,虽然损失在约15B token时明确发散,梯度范数不早于12B tokens时才持续飙升,但QKV峰度早在3B token时就已经偏离了在不同架构中一致观察到的预期次线性增长,这提供了一个潜在的早期发散信号。

Figure 5: 一次失败和一次成功的FP8DPA运行的训练动态。峰度在损失发散之前很早就表现出异常行为。报告的梯度范数是200个滚动窗口的均值和5%-95%的分位数带。
Figure 5: 一次失败和一次成功的FP8DPA运行的训练动态。峰度在损失发散之前很早就表现出异常行为。报告的梯度范数是200个滚动窗口的均值和5%-95%的分位数带。

A4 实验环境

  • 数据集:

    • 名称: FineWeb-Edu【23, Guilherme Penedo, et al. The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale. NeurIPS 2024】。
    • 规模与用途: 这是一个大规模文本语料库,通过robots.txt过滤掉了选择退出的网站域,形成了一个严格遵守数据规范的语料库【5, Dongyang Fan, et al. Can Performant LLMs Be Ethical? Quantifying the Impact of Web Crawling Opt-Outs. CoLM 2025】。数据使用一个131K词汇量的BPE分词器进行分词。所有主要实验的上下文长度保持在4096。
  • 模型架构:

    • 规模: 训练了390M、1.5B和8B参数的模型。
    • 基线: 遵循Llama3 8B模型的设计【7, Aaron Grattafiori, et al. The Llama 3 Herd of Models. arXiv 2024】,390M和1.5B模型也相应调整。
    • FOG/OP: 由于Llama3使用门控线性单元,FOG和OP变体的FFN尺寸被增加以保持参数数量相等。
  • 硬件配置:

    • GPU: 每个节点配备4块Nvidia Grace Hopper GPU。
  • 软件配置:

    • 框架: 分布式训练框架改编自Megatron-LM【29, Mohammad Shoeybi, et al. Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv 2020】,并使用Transformer Engine【2, NVIDIA Corporation. NVIDIA/TransformerEngine. URL https://github.com/NVIDIA/ TransformerEngine】的FP8方案。
    • 优化器: 使用AdamW【19, Ilya Loshchilov and Frank Hutter. Decoupled Weight Decay Regularization. ICLR 2019】,采用默认超参数。
    • 学习率调度: 采用WSD(Warm-up, Steady, Decay)三阶段调度方案【8, Alexander Hägele, et al. Scaling Laws and Compute-Optimal Training Beyond Fixed Training Durations. NeurIPS 2024】。
    • FP8策略: 所有实验均使用FP8延迟缩放策略,margin为0,历史长度为1024步。

A4 实验结果

5.1 FP8稳定性

  • 实验内容: 比较了FOG架构与文献中其他几种架构(OP, OLMo2, Llama3, Llama3+SmoothSwiGLU【6, Maxim Fishman, et al. Scaling FP8 Training to Trillion-Token LLMs. ICLR 2025】)在FP8DPA(全FP8注意力)训练下的稳定性。测试了390M、1.5B和8B三种模型规模。
  • 实验结果: 如Figure 6所示,所有其他被测试的现有架构在FP8DPA训练下均发生发散。即使是拥有注意力离群值缓解策略(QK RMSNorm)的OP和OLMo2架构也未能幸免。
  • 分析结论: 现有架构不适合进行FP8DPA训练。此外,实验观察到一个有趣的现象:在FP8DPA训练下,较大模型比较小模型倾向于在训练的更后期才发散(见Table 2)。这一现象之前未被报道过,可能因为在BF16注意力的FP8设置下,发散需要更长时间。
Figure 6: 不同架构在FP8DPA训练下的交叉熵损失图。在任何规模下,没有其他测试架构能够不发散地超过20B token的训练量。
Figure 6: 不同架构在FP8DPA训练下的交叉熵损失图。在任何规模下,没有其他测试架构能够不发散地超过20B token的训练量。
Table 2: 观察到损失发散时的Token数量。
Table 2: 观察到损失发散时的Token数量。

5.2 效率

  • 实验内容: 在标准上下文长度(4096)和长上下文场景下,评估FOG架构在1.5B和8B模型规模下的训练吞吐量,并与BF16基线和稳定的Llama FP8(使用SmoothSwiGLU)进行比较。
  • 实验结果:

    • 标准上下文: 如Table 3所示,随着模型规模增大,吞吐量增益更加明显。在8B规模下,所有FOG变体的性能都优于其他架构,FOG-flash相比BF16基线吞吐量提升了40.2%。
    • 长上下文: 如Table 4所示,由于点积注意力的二次复杂度成为瓶颈,FP8DPA在长上下文场景下优势更大。随着序列长度增加,FOG与先前FP8方法的速度差距也随之增大。例如在16384序列长度下,FOG-flash比Llama+SmoothSwiGLU的领先优势达到7.7%。
  • 分析结论: FOG架构通过启用FP8DPA,在各种模型规模和上下文长度下都实现了显著的训练效率提升,尤其是在计算密集型的大模型和长上下文场景中。

Table 3: FOG与其他基线的训练吞吐量测量。8B模型使用8个GH200节点和Zero-1分片【26, Samyam Rajbhandari, et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC20 2020】,1.5B模型使用单个GH200节点。值得注意的是,在8B规模下,所有FOG变体的性能都优于其他架构。
Table 3: FOG与其他基线的训练吞吐量测量。8B模型使用8个GH200节点和Zero-1分片【26, Samyam Rajbhandari, et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC20 2020】,1.5B模型使用单个GH200节点。值得注意的是,在8B规模下,所有FOG变体的性能都优于其他架构。
Table 4: 在不同序列长度下的训练吞吐量增益(相对于Llama BF16),实验在8B规模下使用8个GH200节点,全局批量大小为1024。增加到更长上下文需要启用张量并行(TP)。原始吞吐量值在Table 11中报告。
Table 4: 在不同序列长度下的训练吞吐量增益(相对于Llama BF16),实验在8B规模下使用8个GH200节点,全局批量大小为1024。增加到更长上下文需要启用张量并行(TP)。原始吞吐量值在Table 11中报告。

5.3 下游性能

  • 实验内容: 在一系列标准基准测试上,比较FOG架构(使用FP8DPA训练)与更高精度的Llama3基线(使用BF16训练)的下游性能。评估时推理使用BF16精度。1.5B模型训练了125B tokens,较小模型训练了50B tokens。
  • 实验结果: 如Table 5所示,所有FOG变体都提供了与更高精度Llama3基线相当的下游性能,其中FOG-max架构甚至在平均分上超过了基线。
  • 分析结论: FOG架构在实现训练加速的同时,没有牺牲模型的下游任务表现,证明了其方法的有效性。
Table 5: 各项任务的性能。对于每个任务和模型大小,第一个分数是BF16消融实验的结果,第二个是FP8DPA的结果。平均分*是基于附录G中展示的更大数据集上的平均值。
Table 5: 各项任务的性能。对于每个任务和模型大小,第一个分数是BF16消融实验的结果,第二个是FP8DPA的结果。平均分*是基于附录G中展示的更大数据集上的平均值。

5.4 长数据量训练

  • 实验内容: 为了进一步验证FOG进行FP8DPA长程训练的可行性,将一个1.5B参数的FOG-max模型训练了450B tokens,远超之前发现的Llama2-7B在200B tokens处的发散点【6, Maxim Fishman, et al. Scaling FP8 Training to Trillion-Token LLMs. ICLR 2025】。
  • 实验结果: 如Figure 7所示,训练过程保持稳定,语言建模损失曲线的平滑度与Llama基线相当甚至更优。在130B tokens后,切换到FP16优化器状态和BF16梯度以节省内存。
  • 分析结论: FOG架构在超长数据量训练下表现出强大的稳定性,进一步证实了其作为FP8训练鲁棒性设计的成功。
Figure 7: 长数据量训练。FOG-max 1.5B FP8DPA模型在450B tokens上进行训练。更高精度的Llama3实验作为参考。请注意,每个实验的最后25B tokens会根据WSD计划进行学习率冷却。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。
Figure 7: 长数据量训练。FOG-max 1.5B FP8DPA模型在450B tokens上进行训练。更高精度的Llama3实验作为参考。请注意,每个实验的最后25B tokens会根据WSD计划进行学习率冷却。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。

5.5 附加结果

  • SwiGLU: 实验证明,FOG架构可以与门控激活函数(如SwiGLU)结合,并保持FP8DPA训练的稳定性(详见附录D)。
  • FP8优化器动量: 在一个390M规模的FOG-flash实验中,使用8位优化器动量,模型收敛到了与高精度动量几乎相同的损失值(2.645 vs 2.649),展示了进一步降低内存占用的潜力(详见附录E)。
  • MoE: FOG架构在专家混合(MoE)设置下进行了测试,训练过程始终保持稳定,证明了FOG的鲁棒性和泛化能力(详见附录D)。

A7 补充细节

限制: 尽管FOG架构具有鲁棒性、创纪录的吞吐量提升和灵活性,但它仍然存在一个限制。最后的投影层(LM head)仍然是在BF16精度下执行的。这个操作符对离群值非常敏感,在包括本文在内的前向-后向FP8训练方法中都使用了半精度。由于计算资源的限制,作者决定将对这一限制的研究留给未来的工作。

A5 结论

本文首次展示了在Transformer模块内完全使用FP8矩阵乘法(包括注意力机制)进行稳定的LLM训练,且没有牺牲性能。我们测试了多种先前提出的架构,发现在FP8DPA训练下它们都会在早期稳定地发散,这凸显了FP8DPA训练的难度和我们成果的新颖性。此外,与其他细粒度缩放方案相比,我们使用了低开销的延迟缩放FP8策略。我们的设计提供了与更高精度基线相当的下游质量,同时在8B规模下训练速度提高了43%。我们将1.5B模型扩展到了450B tokens的训练量,是其规模下Chinchilla最优数据预算的15倍。我们的工作使社区向着完全FP8 GEMM规模化训练(即包括语言建模头)又迈进了一步。通过使用峰度观察关键激活值的离群值训练动态,我们进一步证明了我们架构的长期稳定性。使用峰度来跟踪训练中存在的离群值被证明能够为选择某些架构组件或预测未来的不稳定性提供有意义的见解,因为它是一个衡量离群值的量化指标。

A6 附录

A 超参数

超参数选择: Table 6详细列出了实验中使用的超参数。对于FOG-flash,所有模型尺寸的$tanh_{\alpha}$熵正则化的$\alpha_0$初始值均为0.5。所有模型均采用线性预热(warmup)和1-sqrt冷却(cooldown)的学习率调度。长数据量训练的1.5B FOG-max实验总共训练了430,000步,消耗约450.9B tokens,使用了与短时运行相同的超参数,包括预热步数。

Table 6: 实验中使用的超参数。请注意,FFN隐藏大小指的是门控激活函数中每个线性投影的维度;没有GLU的网络使用该值的1.5倍以匹配参数数量。指定的Softmax scale仅适用于FOG模型,所有其他模型遵循标准的s = 1/√DQK。
Table 6: 实验中使用的超参数。请注意,FFN隐藏大小指的是门控激活函数中每个线性投影的维度;没有GLU的网络使用该值的1.5倍以匹配参数数量。指定的Softmax scale仅适用于FOG模型,所有其他模型遵循标准的s = 1/√DQK。

B 架构

架构细节: 我们为本文中出现的所有架构提供详细的公式。我们的transformer架构按顺序由以下组件构成:
1. 输入token嵌入
2. 一个输入缩放因子 $u \in (0, \infty)$,可能等于1
3. 一系列L个如下所述的transformer块
4. 一个最终的归一化函数 $N^{(\text{final})}$,可能为恒等函数
5. 一个线性输出层

Transformer块定义: Transformer块定义如下:

$$\operatorname{block}(\mathbf{X}):=\hat{\mathbf{X}}+\left(N_{2}^{(\text {post })} \circ \mathrm{FFN} \circ N_{2}^{(\text {pre })}\right)(\hat{\mathbf{X}}), \quad \hat{\mathbf{X}}:=\mathbf{X}+\left(N_{1}^{(\text {post })} \circ \mathrm{GQA} \circ N_{1}^{(\text {pre })}\right)(\mathbf{X}) .$$


其中$N_i^{(*)}$是可能为恒等函数的归一化层,FFN(X)是一个带有非线性激活函数$\phi$且无偏置的双层前馈网络。GQA(分组查询注意力)遵循标准的分组查询自注意力定义,带有softmax缩放因子s和旋转位置嵌入(Rotary Position Embeddings)。每个注意力头的定义如下:

$$\operatorname{attnhead}(\mathbf{X}):=\operatorname{selfattn}\left(N_{Q}^{(\mathrm{QK})}\left(\mathbf{X} \mathbf{W}^{(Q)}\right), N_{K}^{(\mathrm{QK})}\left(\mathbf{X} \mathbf{W}^{(K)}\right), \mathbf{X} \mathbf{W}^{(V)}\right)$$
其中$N^{(QK)}$是熵正则化机制,而$selfattn = PV$。P是注意力概率矩阵:
$$\mathbf{P}:=\operatorname{Softmax}\left(s \mathbf{Q} \mathbf{K}^{\top}+\mathbf{M}\right)$$
使用这些符号,Table 7详细说明了项目中使用的架构系列。
Table 7: 所用模型的架构细节。带有(a)的模型以后归一化增益$\gamma_0 = 1/\sqrt{L}$进行初始化。带有(b)的模型在QK熵正则化$N^{(QK)}$中具有冻结的增益。id是恒等函数,$\sigma_0$是选择的初始化标准差,N是RMS归一化。输入缩放u是不可训练的。

xIELU激活函数: xIELU激活函数在【13, Allen Hao Huang and Imanol Schlag. Deriving Activation Functions Using Integration. arXiv 2025】中被引入,其逐元素定义为:

$$\begin{aligned} \text{xIELU}(x) := \begin{cases} \alpha_p x^2 + 0.5x & \text{if } x > 0, \\ \alpha_n(e^x - 1) - \alpha_n x + 0.5x & \text{if } x \le 0. \end{cases} \end{aligned}$$


其中$\alpha_p$和$\alpha_n$是每层可训练的标量。xIELU是Squared ReLU的扩展,并已在大规模模型中得到应用和验证【11, Alejandro Hernández-Cano, et al. Apertus: Democratizing Open and Compliant LLMs for Global Language Environments. arXiv 2025】。

C 离群值对量化的影响

离群值定义: 首先给出一个有用的定义。
$\tau$-离群值: 给定$x \in R^d$和$\sigma = rms(x)$,$x$的元素$x_j$是一个$\tau$-离群值,如果$|x_j| \ge \tau\sigma$。
随着$\tau$的增加,$x_j$成为一个更大的离群值($\sigma$代表$x$的自然幅度)。在实践中,$\tau \gg 1$。在FP8量化之前,每个张量都通过$s(x) := \text{MaxFP8Value}/\text{absmax}(x)$进行缩放,以更好地利用有限的FP8动态范围。

定理1: 设$x \in R^d$有一个$\tau$-离群值$x_j$,而$x' \in R^d$有一个$\tau'$-离群值$x'_j$,其中$\tau' > \tau$。那么对于任何子集$T \subseteq \{1 \dots d\} \setminus j$,向量$x'_T$的量化精度将低于$x_T$。
换句话说,更大的离群值会导致FP8量化结果的精度降低。

证明:
令$r = \text{FP8MaxValue}$,$s = r/\text{absmax}(x)$,以及$s' = s(x')$。
由于$\text{absmax}(x) \ge |x_j| \ge \tau\sigma$,那么$s \le r/(\tau\sigma)$(类似地,$s' \le r/(\tau'\sigma)$)。令$m = \text{absmax}(x_T)$且$m' = \text{absmax}(x'_T)$。
$x_T$的元素位于$[-m, m]$区间内。缩放后,在$sx_T$中范围变为$[-rm/(\tau\sigma), rm/(\tau\sigma)]$,在$s'x'_T$中范围变为$[-rm/(\tau'\sigma), rm/(\tau'\sigma)]$。由于$\tau' > \tau$,$[ -rm/(\tau'\sigma), rm/(\tau'\sigma)] \subsetneq [-rm/(\tau\sigma), rm/(\tau\sigma)]$,所以$x'_T$的范围更小。
这个变窄的范围包含的n-bit可表示数字更少,从而证明了该定理。
该证明适用于任何子集$T$,包括“典型”值集合(例如,90%分位数)。定理1保证了大的离群值会恶化张量中90%元素的量化效果。

经验验证: 我们在训练中期测量了Llama和FOG-max 1.5B模型在一个微批次数据上的激活值(精确地说是第二个FFN层的输入,量化前)。
观察: Llama呈现一个688-离群值,而FOG-max仅显示一个183-离群值。使用90%分位数,我们发现Llama 90%的激活系数被缩放到$[-0.289, 0.289]$范围内,而FOG-max则允许一个更宽的范围$[-2.084, 2.984]$。

D FOG扩展

FOG-SwiGLU: 除了我们的主要实验外,我们还训练了一个使用SwiGLU激活函数的1.5B FOG模型,我们称之为FOG-SwiGLU。该架构从FOG-max改编而来,将激活函数更改为SwiGLU,并调整FFN隐藏层大小以匹配参数数量。Figure 8显示了该模型在FP8DPA训练下的损失进展。这个实验得到了完全稳定的训练,并进一步证明了我们设计中适用激活函数的灵活性。

Figure 8: FOG-SwiGLU 1.5B FP8DPA运行。FOG-max作为参考。我们观察到两种方法都有稳定的训练动态。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。
Figure 8: FOG-SwiGLU 1.5B FP8DPA运行。FOG-max作为参考。我们观察到两种方法都有稳定的训练动态。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。

MoE扩展: 我们将FOG-flash架构调整为遵循MoE设计,保留了390M模型的主干配置(隐藏层大小、层数等),但按照【17, Dmitry Lepikhin, et al. GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding. ICLR 2021】的方法,使用8个FFN专家(2个激活)进行扩展,从而得到一个1.8B的模型,该模型在与所有390M模型相同的配置下从头开始训练。此外,我们使用了系数为0.01的z-loss【17】,没有显式的负载均衡损失,而是采用了更新率为0.01的专家偏置(expert bias)【33, Lean Wang, et al. Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts. arXiv】。我们在BF16和FP8DPA训练下训练了这个模型,最终收敛损失分别为2.477和2.483,如Figure 9所示。这个稳定的训练结果表明FOG在FP8DPA训练下对其他MoE设计具有良好的泛化能力。我们进一步调整了我们的1.5B模型,并类似地扩展到8个专家(2个激活)来衡量更大规模下的吞吐量增益。Table 8总结了我们的结果,其中FOG仍然提供了最大的吞吐量增益。

Table 8: MoE训练吞吐量。41B-8E MoE的测量使用4个GH200节点,专家并行和流水线并行大小为4,批次大小为512。与密集模型一样,FOG-flash的性能优于所有其他架构。
Table 8: MoE训练吞吐量。41B-8E MoE的测量使用4个GH200节点,专家并行和流水线并行大小为4,批次大小为512。与密集模型一样,FOG-flash的性能优于所有其他架构。
Figure 9: FOG-flash-MoE 1.8B-8E FP8DPA损失进展。图中显示了BF16和FP8DPA的训练。FP8DPA训练在整个训练期间保持稳定,两种精度下最终收敛的损失值在±0.005范围内,表明下游能力相当。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。
Figure 9: FOG-flash-MoE 1.8B-8E FP8DPA损失进展。图中显示了BF16和FP8DPA的训练。FP8DPA训练在整个训练期间保持稳定,两种精度下最终收敛的损失值在±0.005范围内,表明下游能力相当。报告的损失是200个滚动窗口的均值和5%-95%的分位数带。

E FP8训练

延迟缩放: 在我们所有的实验中,我们使用了Transformer Engine的延迟缩放实现,历史长度$\ell = 1024$,边际$m = 0$。数学上,给定一个张量X的绝对最大值历史记录,表示为$H = \{h_t\}_{t=1}^\ell \subseteq [0, \infty)$,我们将其缩放因子定义为:

$$\rho(\mathbf{X}):=\frac{\text{FP8MaxValue}}{2^m\max H}$$


其中$\text{FP8MaxValue} \in (0, \infty)$是所用FP8格式可表示的最大值。我们使用$H \leftarrow \{\max_{x \in X} |x|\} \cup \{h_t\}_{t=2}^\ell$来更新历史记录,以便在下一次迭代中用于此激活。端到端的FP8矩阵乘法为:

$$\mathrm{GEMM}(\mathbf{X}, \mathbf{Y}):=\frac{1}{\rho(\mathbf{X}) \rho(\mathbf{Y})} \mathrm{FP8GEMM}(\mathrm{FP8cast}(\rho(\mathbf{X}) \mathbf{X}), \mathrm{FP8cast}(\rho(\mathbf{Y}) \mathbf{Y})),$$
其中FP8GEMM接收FP8张量并返回BF16结果。我们在Table 9中进一步详细说明了FP8和FP8DPA实验期间每次矩阵乘法使用的精度。
Table 9: FP8方法之间的比较。FP8DPA方法允许所有GEMM计算——除了输出头——都用FP8精度完成。相比之下,FP8训练对核心注意力计算使用更高的精度。线性算子是形如LinearW(X) = XW的线性层:即FFN线性层、QKV投影和注意力输出投影。注意力概率矩阵P的定义见公式(1)。

FP8优化器动量: 为了进一步减少内存使用,我们在通常的FP8DPA设置下测试了FOG-flash,并增加了一个约束:FP8优化器动量。这扩展了我们大多数实验中使用的半精度梯度和动量的典型设置。Figure 10显示了三种不同设置下的训练损失以供比较。

细粒度缩放方案: 最近的FP8训练成就,如DeepSeek的DeemGEMM内核,涉及使用细粒度的FP8缩放方案以提供更鲁棒的训练机制。虽然这些选项在张量级缩放方案发散时可能使FP8训练成为可能,但它带来了显著的开销。我们使用TransformerEngine的块级缩放(Blockwise scaling)在8B规模上验证了这一说法。训练吞吐量报告在Table 10中。使用延迟缩放方案的FP8DPA训练在所有测试方法中提供了最高的提升。

Figure 10: FOG-flash 390M损失曲线比较不同训练精度。我们的设计在用任一精度训练时都能达到相似的损失。报告的损失是200个滚动窗口的均值。
Figure 10: FOG-flash 390M损失曲线比较不同训练精度。我们的设计在用任一精度训练时都能达到相似的损失。报告的损失是200个滚动窗口的均值。
Table 10: 训练吞吐量。测量使用8个4xGH200节点,采用Zero-1分片【26】,无模型并行,批次大小为1024。值得注意的是,FOG-flash的性能优于所有其他架构。
Table 10: 训练吞吐量。测量使用8个4xGH200节点,采用Zero-1分片【26】,无模型并行,批次大小为1024。值得注意的是,FOG-flash的性能优于所有其他架构。

F 长上下文

吞吐量分析: 请注意,当全局批次大小(GBS)增加时——微批次大小固定——计算时间会超过通信时间。因此,GBS越大,所有方法的吞吐量增益就越高:FOG-flash在GBS=1k时达到+42.6%,而在GBS=512时为+40.2%。此外,在注意力中启用FP8计算为长上下文训练带来了独特的吞吐量优势。因此,FOG-flash FP8DPA与Llama3+SmoothSwiGLU在Table 11中实现了巨大的效率差距。

Table 11: 在不同序列长度下的训练吞吐量,实验在8B规模下使用8个4xGH200节点,全局批量大小为1024。TP指张量并行。
Table 11: 在不同序列长度下的训练吞吐量,实验在8B规模下使用8个4xGH200节点,全局批量大小为1024。TP指张量并行。

G 评估

基准测试集: 我们选择了以下一组基准测试:ARC-Easy, CommonsenseQA, HellaSwag, LAMBADA-OpenAI, LAMBADA-standard, OpenBookQA, PIQA, SocialIQA, 和 WinoGrande。我们使用一个标准的开源LLM评估包进行这些评估,该包在代码仓库 https://github.com/anonymous4375934/FOG 中被引用。

详细结果: 在Table 5中,我们报告了三个关键基准测试的原始准确率分数(百分比)以及上述完整任务集的平均分。在Table 12中,我们提供了1.5B模型规模的所有分数及其估计误差,表明在许多数值上观察到的微小差异在统计上并不显著。

Table 12: 1.5B规模下更详细的结果。对于每个模型和每个任务,第一个分数是BF16训练的结果,第二个是FP8DPA训练的结果。
Table 12: 1.5B规模下更详细的结果。对于每个模型和每个任务,第一个分数是BF16训练的结果,第二个是FP8DPA训练的结果。

权重衰减冷却: 正如第3节所述,我们实验了冷却权重衰减(通常用作等于0.1的常数值,与学习率耦合)的方法,以观察它是否能解决OP+frozenQK架构在学习率衰减阶段的持续发散问题。我们也在其他架构上测试了它,并且为了优化资源使用,我们不得不在后续的最终实验中保留它。这个技巧确实帮助稳定了权重的范数,但未能解决发散问题。此外,它对最终性能和稳定性没有影响。Table 13突显了在1.5B规模下这种无效果的结论。

Table 13: 学习率(LR)衰减阶段的权重衰减(WD)。如果恒定,则等于0.1。否则,从0.1开始并与LR成正比。
Table 13: 学习率(LR)衰减阶段的权重衰减(WD)。如果恒定,则等于0.1。否则,从0.1开始并与LR成正比。

H 计算资源

硬件与资源消耗: 我们的实验在配备4个Grace Hopper (GH200) GPU的节点上进行。我们通常分别为390M、1.5B和8B参数的实验使用4、8和16个节点,不同运行之间略有变化。重要的是,所有吞吐量测量都是在相同的硬件配置下进行的。Table 14详细说明了我们主要实验结果所需的计算资源(以GPU小时,GPUh为单位)。这包括训练所有在FP8DPA训练期间发散的架构、我们三个主要架构的FP8DPA和BF16稳定训练运行,以及BF16 Llama3基线的计算成本。该总和包括节点启动时间、由于节点故障而损失的计算时间,以及计算和记录峰度指标的开销。整个研究项目所需的计算资源超出了表中指定的范围,因为我们进行了大量初步实验并探索了未出现在最终论文中的想法。

Table 14: 主要实验使用的GPU小时数。
Table 14: 主要实验使用的GPU小时数。