Training LLMs with MXFP4
文章标题:使用 MXFP4 训练大语言模型
作者/机构:Albert Tseng (康奈尔大学), Tao Yu (AWS AI), Youngsuk Park (AWS AI)
A1 主要贡献
本文旨在解决使用如 MXFP4 等低精度数据类型加速大语言模型(LLM)训练时遇到的模型质量显著下降的问题。尽管 MXFP4 在支持的硬件上矩阵乘法(GEMMs)速度比 FP8 快 2 倍,但直接应用会导致训练效果不佳。
核心问题:如何利用 MXFP4 的计算速度优势进行 LLM 训练,同时避免模型性能的损失,实现与 BF16 混合精度训练相近的效果?
研究目标:提出一种近乎无损的 MXFP4 训练方案,使其能够在反向传播中显著加速训练,同时保持模型的最终质量。
核心创新点:
1. 采用随机舍入(Stochastic Rounding, SR)计算无偏梯度:通过随机舍入来生成对梯度的无偏估计,从而实现更精确的模型参数更新,这是维持模型质量的关键。
2. 使用随机哈达玛变换(Random Hadamard Transform, RHT)来约束随机舍入的方差:直接对 MXFP4 应用 SR 会因数据块级别的异常值导致高方差,损害收敛。本文利用 RHT 在理论上约束 SR 的方差,通过在量化前对梯度、激活和权重进行变换,集中其数值分布,从而稳定训练过程。
主要贡献概述:
* 本文提出了一种结合了随机舍入(SR)和随机哈达玛变换(RHT)的 MXFP4 训练方法,该方法能够在反向传播过程中计算出无偏且低方差的梯度估计。
* 通过预训练高达 6.7B 参数的 GPT 模型,实验证明该方法将 MXFP4 与 BF16 训练的验证困惑度差距缩小到 0.1 以内,实现了近乎无损的训练效果。
* 分析表明,本文提出的 RHT 和 SR 结构为 MXFP4 GEMM 带来的开销极小,在反向传播过程中,相比 FP8 和 BF16 分别能带来超过 1.3 倍和 1.7 倍的理论加速。
A3 背景知识与相关工作
2.1 低精度数据类型与 IEEE 754 浮点数
低精度数据类型的定义与硬件支持。传统上,低精度(LP)数据类型指的是使用远少于 32 位来表示一个数字的数据类型,例如 FP16。尽管存在多种低精度数据类型,包括有状态的数据类型【33,Qtip: Quantization with trellises and incoherence processing,2024】,但一部分已在 IEEE754 浮点(FP)标准【12,Ieee standard for floating-point arithmetic,2019】下被标准化。这些数据类型通常在计算密集型工作负载中得到硬件加速支持。
IEEE 浮点数的表示。IEEE 浮点数(如表1所示)由1个符号位、e 个指数位和 m 个尾数位定义。一个 1+m+e 位的数据类型简写为 EeMm。一个 IEEE 浮点数所代表的“正规”值由符号位 S、尾数 M 和指数 E 共同决定,其计算公式为:
$(-1)^S \times (1.M) \times 2^{E-\text{bias}}$
其中,bias 是一个依赖于数据类型的整数指数偏差偏移量,由 IEEE (2019) 标准指定【12,Ieee standard for floating-point arithmetic,2019】。这种指数-尾数结构意味着,在不发生上溢或下溢的情况下,浮点数据类型的量化信噪比(SNR)是尺度不变的【3,Unit scaling: out-of-the-box low-precision training,2023】。
2.2 LLM 训练
训练流程与瓶颈。训练 LLM 的最常见方法包括计算损失函数、计算损失函数关于模型参数的梯度,然后利用梯度信息更新参数。例如,在预训练一个仅解码器的 LLM 时,通常使用自回归的交叉熵损失和 AdamW 优化器【31,Llama 2: Open foundation and fine-tuned chat models,2023;9,The llama 3 herd of models,2024】。尽管具体的训练设置可能不同,但训练的核心瓶颈在于计算密集型的前向传播和反向传播,它们分别用于计算损失和梯度。在这两个部分中,绝大部分浮点运算(FLOPs)发生在全连接层——对于一个 30B 参数的模型,超过90%的 FLOPs 来自于全连接层【5,Transformer flops,2023】。
全连接层的计算。对于一个输入维度为 n、输出维度为 m 的全连接层,其前向传播计算为 $y = xW^T + b$,其中 $W \in R^{m \times n}$ 是参数矩阵, $b \in R^m$ 是可选的偏置项。为了通过全连接层进行反向传播,需要计算损失 L 关于 x、W 和 b 的梯度。这些梯度由以下公式给出:$\frac{dL}{dx} = \frac{dL}{dy}W$,$\frac{dL}{dW} = (\frac{dL}{dy})^T x$,以及 $\frac{dL}{db} = \mathbf{1}^T \frac{dL}{dy}$,其中 $\mathbf{1}$ 是全一向量,$\frac{dL}{dy}$ 是链式法则中前一个(反向传播方向)操作的输出【14,https://cs231n.stanford.edu/handouts/linear-backprop.pdf,2017】。每个全连接层需要三次计算密集型的矩阵乘法($xW^T$、$\frac{dL}{dx}$ 和 $\frac{dL}{dW}$),其中两次发生在反向传播过程中。
2.3 混合精度训练
混合精度训练的原理与挑战。加速训练的一种方法是采用“混合精度”(MP)训练。在 MP 训练中,参数以高精度保存,而 GEMM 的操作数则被转换为低精度数据类型以进行低精度 GEMM。由于量化通常开销很小,MP 是利用低精度数据类型吞吐量优势的一种简单方法。端到端来看,BF16 MP 训练通常比 FP32 训练快 70%以上【28,Megatron-lm: Training multi-billion parameter language models using model parallelism,2020】。然而,量化会在 GEMM 操作数中引入失真,从而影响输出。由于前向和反向传播都在低精度下进行,损失和模型更新都可能偏离其“真实”值。在远低于 16 位的低比特率下,失真会降低模型质量,甚至导致发散,因此需要更高级的训练策略。例如,FP8 MP 策略通常在前向传播中使用 E4M3(更高精度),在反向传播中使用 E5M2(更大范围),以适应梯度、权重和激活值的不同特性【23,Fp8-lm: Training fp8 large language models,2023;10,Transformer engine 1.11.0】。
4位精度训练的相关工作。在 4 位精度下,量化失真变得更难处理。Xi 等人【36,Training transformers with 4-bit integers,2023】通过在前向传播中使用非随机化的哈达玛变换,并在反向传播中利用杠杆分数采样(LSS)来训练较小的非 GPT Transformer 模型,其 GEMM 精度为 INT4。由于 LSS 引入了额外开销,他们最终只实现了比 FP16 快 30% 的端到端加速,这与 FP8 混合精度训练的水平相当【23,Fp8-lm: Training fp8 large language models,2023】。我们还注意到 Wang 等人【34,Optimizing large language model training using fp4 quantization,2025】的一项同期工作,他们使用 FP4 训练 LLMs。在该工作中,作者通过使用可微分梯度估计器,并将异常值保持在高精度,从而在十亿参数级别的 GPT 模型的前向和反向传播中都使用了 FP4,最终的困惑度差距大于 0.5。由于他们的工作是在我们论文评审之后发布的,我们将完整的比较留待未来工作。
2.4 随机舍入
随机舍入的无偏性。混合精度训练的每一步都需要将高精度张量量化为低精度张量,这为量化方式提供了灵活性。标准的“就近舍入”(NR)方法是将每个高精度数舍入到低精度数据类型中最接近的可表示值【12,Ieee standard for floating-point arithmetic,2019】。然而,NR 是有偏的,我们稍后将证明这对于低精度训练是有害的。实现无偏舍入的一种方法是“随机舍入”(SR),它将一个数随机地舍入到低精度数据类型的一个可表示值,使得其期望值等于原始数值【6,Stochastic rounding: implementation, error analysis and applications,2022】。
SR 的实现与应用。SR 可以通过“抖动”(dithering)高效实现,即在输入数值上添加均匀分布的随机噪声,然后执行 NR【6,Stochastic rounding: implementation, error analysis and applications,2022】。例如,亚马逊的 Trainium 系列芯片可以在为 BF16 GEMM 增加不到 2% 开销的情况下执行带抖动的 SR。公式 1 描述了带抖动的 SR 对于均匀整数分级器的实现;非均匀情况需要修改噪声尺度,但本质上是相同的。
$SR(x, \Delta) = \text{NearestRound}(x + u)$
其中 $u \sim \text{Uniform}(-\frac{\Delta}{2}, \frac{\Delta}{2})$,$\Delta$ 是量化步长。
SR 也可以应用于任何需要量化数字的场景。例如,在训练末期,模型更新的范数远小于参数的范数,低精度更新中的信息可能会“丢失”【37,Collage: Light-weight low-precision strategy for llm training,2024】。在这种情况下,可以使用随机舍入来在期望上保留更新信息,这比保存高精度参数副本更节省内存。
2.5 微缩放(MX)浮点格式
MX 格式的原理。最近引入的微缩放(microscaling)浮点数据类型系列是在 IEEE 浮点数的基础上,为一组浮点数添加了一个共享的尺度因子【24,Ocp microscaling formats (mx) v1.0 spec-final,2023】。这个尺度因子使得 MXFP 张量能够表示更广范围的值,而不会显著增加总比特率,但其前提是组内的数值大小应大致相同,这样尺度因子才能发挥作用。在实践中,随着基础数据类型比特率的降低,MX 缩放变得更为重要。例如,FP8 E4M3 的动态范围为 $448 \times 2^{-9} \approx 2.3 \times 10^6$,而 FP4 的动态范围仅为 $6 \times 2^1 = 12$。MX 缩放使得 MXFP4 能够跨块表示范围更广的值。
MX 格式的实现。核心的硬件支持的 MXFP 格式通常遵循相似的模式。尺度因子在内存中的连续条目(通常是32个)之间共享,将一个标量张量量化为 MX 张量的过程取决于每个组中的最大元素【24,Ocp microscaling formats (mx) v1.0 spec-final,2023;20,Ptx isa 8.7】。算法1描述了将标量张量量化为 MX 张量的“参考”算法,该算法可以在现代AI加速器上高效实现【30,CUTLASS, January 2023】。算法1基于每个组的最大幅值元素进行缩放,然后执行就近舍入得到 MX 张量。
Algorithm 1 OCP MX Quantization
Input: Scalar matrix V, MX group size g
Output: MX matrix
1: for each group Vi ∈ V do
2: m ← maxi(|Vi|)
3: shared_exp ← ⌊log2(m)⌋ − emaxelem
4: Vi ← Vi/2^shared_exp
5: Vi ← NearestRoundToBaseDatatype(Vi)
6: end for
7: return MX matrix from (V, shared_exp) pairs
A2 方法细节
本文的剩余部分描述了我们提出的方法,该方法能够在使用 MXFP4 加速的 GEMM 进行训练时达到近乎无损的效果。尽管本文重点关注 MXFP4,但我们的分析同样适用于其他低精度数据类型,如 MXINT4。我们选择 MXFP4 是因为它在最新的加速器上具有相关性和硬件支持。据我们所知,MXFP4 此前仅成功用于近乎无损的推理【26,Microscaling data formats for deep learning,2023;22,Nvidia blackwell platform sets new llm inference records in mlperf inference v4.1,2024】。虽然某些工作实现了使用 MXFP4 权重的近乎无损训练,但这些工作要求激活值和梯度保持在更高精度,导致其运行速度受限于高精度操作数,比纯 FP4 方案慢。
本文方法的核心。我们的方法关键在于在反向传播中使用纯 MXFP4 GEMM 获得无偏且低方差的梯度估计,从而实现更精确的模型更新。由于反向传播占据了超过一半的训练 FLOPs,我们的方法可以在不降低模型表示能力(即不使用低精度前向传播)的情况下显著加速训练【15,Scaling laws for precision,2025】。为此,我们首先修改了 OCP MX 量化算法,通过缩放和随机舍入实现无偏量化。然后,我们证明,在量化前通过一种内存受限的随机哈达玛变换(RHT)结构对 GEMM 操作数进行变换,可以约束 GEMM 输出的方差。我们的方法在增加极小开销的同时,显著提升了训练模型的质量,使得 MXFP4 在训练中变得实用。
3.1 对 MXFP4 的无偏量化
标准 MX 量化算法的分析。算法 1 描述了将标量矩阵转换为 MX 矩阵的“参考”MX 量化算法。对于每组 32 个条目,算法 1 找到最大幅值 $m = \max_i(|V_i|)$。然后,它根据 $m$ 和基础数据格式中正规数的最大指数 $e_{\text{maxelem}}$ 计算一个共享指数。例如,对于 FP4,其最大正规值为 $6 = 2^2 \times 1.5$,因此 $e_{\text{maxelem}} = 2$。
Algorithm 2 Unbiased MXFP4 Quantization
Input: Scalar matrix V
Output: MX matrix
1: for each group Vi ∈ V do
2: m ← maxi(|Vi|)
3: shared_exp ← ⌊log2(m)⌋ − 2
4: Vi ← (Vi/2^shared_exp) * 3/4
5: Vi ← StochasticRoundToFP4(Vi)
6: end for
7: return MX matrix from (V, shared_exp) pairs
最后,组内元素被共享指数归一化,并舍入到基础数据类型。
标准 MXFP4 量化的有偏性。对于 MXFP4,算法 1 的第 3 行返回 $shared\_exp \leftarrow \lfloor\log_2(m)\rfloor - 2$。观察到,在将整个组除以 $2^{shared\_exp}$ 后,$m$ 变为 $m' = m / 2^{\lfloor\log_2(m)\rfloor - 2}$。由于 $4 \leq m' < 8$,而 FP4 中可表示的最大正规值为 6,因此被缩放到 6 到 8 之间的值将被裁剪,这使得算法 1 具有内在的偏见。虽然被裁剪的比例取决于输入矩阵,但我们通过经验验证,对于广泛分布的矩阵,大约有 3% 的条目会被裁剪。
实现无偏量化的修改。我们可以通过两个简单的修改使算法 1 变得无偏,这两个修改都可以在硬件中高效实现。首先,我们将 $V_i / X$ 乘以 3/4 以防止裁剪。然后,我们使用随机舍入将结果 $Q'$ 量化为 FP4,这给出了 $Q'$ 的一个无偏估计。算法 2 总结了这些修改。最终得到的 MX 矩阵是原始矩阵乘以 3/4 的一个无偏估计。由于 SR 在硬件中是通过独立的均匀抖动实现的,最终的 GEMM 输出是正确输出乘以 $(3/4)^2 = 9/16$ 的一个无偏估计。为了得到无偏的输出,我们只需将高精度累加器输出乘以 16/9。
无偏性的理论保证。引理 3.1 指出,假设随机舍入是通过具有独立噪声的抖动实现的,那么算法 2 产生的 MXFP4 矩阵是其输入乘以 3/4 的一个无偏估计。此外,以算法 2 为子程序的算法 3 能够产生对 $\frac{dL}{dx}$ 和 $\frac{dL}{dW}$ 的无偏估计。
3.2 使用随机哈达玛变换约束 SR 的方差
低精度训练中的异常值问题。一个全连接层($y = xW^T$)的反向传播需要计算 $\frac{dL}{dx} = \frac{dL}{dy}W$ 和 $\frac{dL}{dW} = (\frac{dL}{dy})^T x$。已知 LLM 存在激活值(x)和权重(W)的“异常值”,以及稀疏的梯度($\frac{dL}{dy}$)【36,Training transformers with 4-bit integers,2023;32,QuIP#: Even better LLM quantization with hadamard incoherence and lattice codebooks,2024】。回顾一下,MXFP4 量化依赖于组级别的统计数据,例如最大幅值元素,因此含有异常值的块会遭受高量化失真和高随机舍入方差。
高方差对模型质量的影响。虽然引理 3.1 告诉我们算法 2 产生的是对真实 GEMM 的无偏估计,但高方差的估计仍然会通过向梯度估计中添加噪声来降低模型质量。为了解决这个问题,我们在量化前使用随机哈达玛变换(RHT)来集中梯度、激活值和权重,这能渐进地减少 GEMM 输出的方差。
随机哈达玛变换的定义。随机哈达玛变换执行 $x \leftarrow HSx$,其中 $x \in R^{j \times k}$,$S \in \{\pm 1\}^k$ 是一个随机符号向量,H 是 k 维的哈达玛矩阵【11,Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions,2011】。哈达玛矩阵是递归定义的正交矩阵,满足以下性质:
$H_n = \frac{1}{2^{n/2}}\begin{pmatrix} H_{n-1} & H_{n-1} \\ H_{n-1} & -H_{n-1} \end{pmatrix}$
其中 $H_1 = [1]$。
RHT 的可逆性。由于 H 和 diag(S) 都是正交的,RHT 是完全可逆的。这意味着我们可以将 RHT 应用于 GEMM 的操作数而无需进行逆变换,即 $(HSA)^T(HSB) = A^T B$。
RHT 对方差的理论约束。定理 3.2 指出,设 A 和 B 是两个大小为 b 的向量,Q 执行算法 2。那么,$Q(A)^T Q(B)$ 的方差为 $O(b\Delta^4 \|A\|_\infty \|B\|_\infty)$,而 $Q(HSA)^T Q(HSB)$ 的方差以至少 $(1-\epsilon)^2$ 的概率为 $O(\Delta^4 \|A\| \|B\| \log(2b/\epsilon))$,其中 $\Delta$ 是 Q 的量化器中两个连续可表示点之间的最大间距。
RHT 的作用机制。定理 3.2 表明,MX 矩阵乘法关于随机舍入的方差与操作数中最大幅值元素的乘积成线性关系。对一个向量应用 RHT 能有效地将其集中,使其具有亚高斯尾分布。根据 Tseng 等人的研究【32,QuIP#: Even better LLM quantization with hadamard incoherence and lattice codebooks,2024】,我们知道 $\Pr(|(xHS)_i| \geq t) \leq 2\exp(-t^2 / (2\|x\|^2))$,这使我们能够约束定理 3.2 中 SR GEMM 的方差。具体来说,应用 RHT 将方差的依赖性从与块大小的线性关系减少到对数关系,尽管代价是依赖于输入的 $L_2$ 范数而不是 $L_\infty$ 范数。
RHT 效果的经验验证。我们可以通过测量有无 RHT 的 SR GEMM 的方差来凭经验验证这一点。图 2 显示了在包含比例为 p 的异常值情况下,$Q(A)^T Q(B)$ 与 $Q(HSA)^T Q(HSB)$ 的平均方差。正如定理 3.2 所预期的,使用 RHT 的方差随 b 的增长速度远慢于不使用 RHT 的情况。
RHT 的实际开销与挑战。然而,RHT 并非没有代价。首先,在计算 $dW \approx Q(HS\frac{dL}{dy})^T Q(HSx)$ 时,RHT 会沿批次维度“混合”数据。在数据并行设置(如 FSDP【38,Pytorch fsdp: Experiences on scaling fully sharded data parallel,2023】或 ZeRO-3【25,Zero: Memory optimizations toward training trillion parameter models,2020】)中,激活值在 GPU 间分片,完整的 RHT 需要昂贵的跨 GPU 通信,这将立即成为梯度计算的瓶颈。其次,尽管存在 $O(n \log n)$ 时间复杂度的快速算法,但 RHT 步骤是在高精度下进行的。减少此开销至关重要——如果 RHT 比 FP4 矩阵乘法慢,那么直接使用 FP8 会是更好的选择。
分块 RHT 作为解决方案。为解决这些问题,我们将 RHT 作为对少数 MX 块的密集矩阵乘法来应用,这使其在 GEMM 操作数上成为内存受限的操作(见表 5)。具体来说,设 RHT 块大小为 g,且 32 是 g 的因子。将这种分块 RHT 作为密集矩阵乘法应用,其运行时间为 $O((b+m)ng)$,I/O 成本为 $O(bn+nm+bm)$。由于现代 AI 加速器具有高计算内存比,当 $g \lessapprox 256$ 时,这种“分块” RHT 是内存受限的。算法 3 总结了我们如何在全连接层的反向传播中使用 RHT。由于 g 小于任何合理大模型的序列长度,算法 3 可以作为全连接层的直接替代品,即使在数据并行设置中也适用。此外,虽然算法中的第 3-6 行为了清晰而分开写出,但高效的实现可以将其融合到第 7 和第 8 行中,以减少昂贵的内存访问。
分块 RHT 的权衡与扩展。进行这种分块 RHT 的权衡在于,其集中程度依赖于 g——g 越大,集中效果越好。然而在实践中,我们观察到 g = 64 已足以获得紧密的分布,并且 MX 格式可以处理块间的尺度差异。最后,需要注意的是,这种构造也允许我们使用任何随机正交变换。我们选择 RHT 是因为它易于随机化(通过采样一个 g 维符号向量)且具有良好的集中特性,但其他矩阵也可以奏效。
A4 实验环境
- 数据集:
- 预训练:使用公开的 GPT2 Wikipedia 数据集【19,Gpt wikiedia dataset】。
- 微调:使用公开的 Tulu V2 数据集(6.57亿 tokens)【13,Camels in a changing climate: Enhancing lm adaptation with tulu 2,2023】进行模型微调测试。
- 模型架构:
- GPT 模型,参数规模分别为 345M、1.3B 和 6.7B。
- 超参数:使用 AdamW 优化器,学习率、权重衰减等参数依据模型规模调整(详见附录表格)。所有模型均使用 FlashAttention【8,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022】。
- 硬件配置:
- 训练平台:AWS P4 和 G6e EC2 实例。
- 性能评测:NVIDIA A100 GPU 用于 Llama 2 70B 层的吞吐量测试;NVIDIA H100 GPU 用于 Triton 内核的 RHT 开销基准测试;Amazon Trainium 1 芯片(EC2 Trn1 实例)用于测量随机舍入的硬件开销。
- 软件配置:
- 代码库:基于 Megatron-LM【28,Megatron-lm: Training multi-billion parameter language models using model parallelism,2020】 (commit a4ad305d...) 和 Microsoft microxcaling 库【18,Mx pytorch emulation library,2024】 (commit 7bc41952...) 进行开发。
- 框架与依赖:使用 PyTorch【29,Pytorch】,NVIDIA Pytorch + Ubuntu 24.04 docker 镜像,其中包含 Transformer Engine 1.5【21,Transformer engine】用于 FP8 实验。使用了
torch.compile
、CUTLASS INT4/INT8 GEMM 内核。
A4 实验结果
实验主要集中于预训练 345M、1.3B 和 6.7B 参数的 GPT 模型,训练 token 数至少为 200 亿,以评估不同低精度策略的长期训练性能。
4.1 GPT 预训练结果
核心发现:
* RHT 和 SR 的必要性:
* 在短期训练(20-40B tokens)中,单独使用 RHT 或 SR 均能实现接近无损的训练效果(见表 2)。
* 在长期训练(1.3B 模型上训练 210B tokens)中,SR 的无偏性至关重要。仅使用 RHT 的模型最终与 BF16 存在约 0.1 的困惑度差距,而结合了 SR 的方法(MXFP4+RHT+SR)则无差距(见图 6)。
* 收敛动态:
* 在所有模型规模上,MXFP4+RHT+SR 的验证困惑度曲线都与 BF16 基线紧密贴合(见图 3, 4, 5)。
* 仅使用 SR(MXFP4+SR)的方法虽然最终性能与结合 RHT 的方法相当,但其初始收敛速度较慢。这可能是因为没有 RHT 的情况下,梯度信息会因随机下溢为零而丢失。
* 纯 MXFP4 的性能:不使用 RHT 和 SR 的纯 MXFP4 混合精度训练,在所有模型规模上都表现出显著的性能下降和巨大的困惑度差距,其实用性远低于 FP8。
* RHT 块大小的影响:消融实验表明,增加 RHT 的块大小(g)可以提升模型质量,验证了 RHT 在降低 SR 方差方面的作用(见表 4)。
下游任务与微调:
* 在 6.7B 模型上进行的零样本下游任务评估显示,使用 MXFP4+RHT+SR 训练的模型与 BF16 训练的模型性能相当。
* 使用 Tulu V2 数据集对这两个模型进行微调后,它们的性能依然相近(最终训练困惑度分别为 1.98 和 1.96),表明用本文方法训练的模型质量与 BF16 模型相当(见表 3)。
与 FP8 前向传播的兼容性:
* 实验证明,本文的 MXFP4 反向传播方法与 FP8(E4M3)前向传播兼容。在 1.3B 和 6.7B 模型上,FP8 前向 + MXFP4 反向的组合基本能够匹配纯 BF16 训练的性能,这为进一步提升训练速度提供了可能(见图 7)。
4.2 开销计算
由于没有可用的 FP4 硬件,本节通过代理基准测试来估算 RHT 和 SR 的开销。
- RHT 开销:
- 在 H100 GPU 上,通过 Triton 内核为 FP8 GEMM 添加 RHT,对 7B 和 70B 规模的矩阵分别引入了 9.7% 和 1.6% 的开销。若 MXFP4 吞吐量是 FP8 的两倍,开销将翻倍至 19.4% 和 3.2%,但仍比 FP8 更快。
- 在 A100 GPU 上,以 INT4 作为 MXFP4 的代理,对 Llama 2 70B 解码器层进行测试。结果显示(见表 5),使用 INT4+RHT 的反向传播比 FP16 快约 70%,比 INT8 快约 30%。RHT 本身带来的端到端开销小于 5%。
- SR 开销:
- 在具有专用 SR 硬件的 Amazon Trainium 1 芯片上测试,将 FP32 量化到 BF16 的 SR 操作开销低于 BF16 GEMM 本身的 2%。假设从 BF16 到 FP4 的 GEMM 吞吐量增加 4 倍,SR 带来的开销将低于 10%。
A5 结论
尽管硬件对低精度数据类型的支持不断进步,但在不牺牲模型质量的情况下使用这些数据类型进行训练变得越来越困难。本文展示了首个能够实现与 FP32/BF16 混合精度训练相媲美的近乎无损模型质量的 MXFP4 训练方法。该方法的核心在于为解码器的全连接层计算低方差、无偏的梯度估计,从而实现更精确的模型更新。为此,我们提出了使用随机舍入(SR)和随机哈达玛变换(RHT)。SR 产生无偏的梯度估计,而 RHT 则降低了 SR 的方差,并减少了因下溢而丢失梯度信息的风险。我们在高达 6.7B 参数的 GPT 模型上的预训练实验表明,RHT 和 SR 对于实现近乎无损的 MXFP4 训练至关重要。最后,我们的基准测试表明,该方法可以以极小的开销实现,在反向传播中相比 FP8 和 BF16 分别带来约 30% 和 70% 的加速。
A6 附录
6.1 FP8 前向传播结果
实验背景。本节包含使用混合精度 FP8 进行前向传播、MXFP4 进行反向传播的实验。已有研究表明,混合精度的 FP8 前向和反向传播可以实现接近无损于混合精度 BF16 的训练效果【23,Fp8-lm: Training fp8 large language models,2023;21,Transformer engine】。为了测试 MXFP4 反向传播在 FP8 前向传播下的实用性,我们使用 NVIDIA 的 TransformerEngine (TE) FP8 (E4M3) 实现【21,Transformer engine】进行前向传播,并结合我们的 MXFP4 公式进行反向传播,训练了 GPT 1.3B 和 6.7B 模型。我们没有测试 GPT 345M,因为在我们的实验中,TE FP8 与 BF16 已有超过 0.1 的验证困惑度差距。对于 GPT 6.7B,由于我们无法使用支持张量并行训练所需的高速互联的 FP8 硬件,我们通过将 FP8 GEMM 操作数反量化为 BF16 再执行 BF16 GEMM 来模拟 FP8 矩阵乘法。虽然这与真正的 FP8 GEMM 并非比特级精确,但对于随机高斯输入,输出的相对误差约为 0.3%,且这与 PyTorch 模拟 FP8 GEMM 的方式基本一致【29,Pytorch】。在两种模型规模下,我们发现 FP8 前向传播和 MXFP4 反向传播足以基本匹配 BF16 的训练效果。
6.2 仅使用随机舍入的 GPT 345M 验证曲线
实验描述。此图与正文中的图类似,但增加了一个仅使用随机舍入(无 RHT)的实验。与 1.3B 模型类似,SR 的初始表现比 RHT 变体“差”,但在训练运行结束时能够匹配它们的性能。
6.3 GPT 1.3B 训练曲线
6.4 GPT 1.3B 在 2000 亿 Tokens 上的训练困惑度
实验描述。本节包含正文第 4 节中引用的完整的 2100 亿 token GPT 1.3B 运行结果。仅使用 MXFP4+RHT(困惑度 10.02)与 BF16(困惑度 9.92)之间存在约 0.1 的验证困惑度差距,而 MXFP4+RHT+SR(困惑度 9.90)则与 BF16 匹配。这表明随机舍入对于近乎无损的全尺寸 FP4 训练至关重要。
6.5 GPT 6.7B 训练曲线
7. 实验设置细节
所有实验均在 AWS P4 和 G6e EC2 实例上运行。我们的代码基于 Github 提交号为 a4ad305d4b117217141730b9b18af52dda069450
的 MegatronLM 代码库和提交号为 7bc41952de394f5cc5e782baf132e7c7542eb4e4
的 Microsoft microxcaling 代码库。我们使用了 NVIDIA Pytorch + Ubuntu 24.04 docker 镜像,其中包含 Transformer Engine 1.5 用于 FP8 实验。所有模型均使用 AdamW 优化器、FlashAttention【7,FlashAttention-2: Faster attention with better parallelism and work partitioning,2024】和以下超参数进行训练:
超参数 | GPT 345M | GPT 1.3B | GPT 6.7B |
---|---|---|---|
隐藏层大小 | 1024 | 2048 | 4096 |
头数 | 16 | 32 | 32 |
层数 | 24 | 24 | 32 |
序列长度 | 2048 | 2048 | 4096 |
全局批大小 | 1024 | 1024 | 2048 |
学习率 | 3.0e-4 | 2.0e-4 | 1.6e-4 |
LR 衰减 | cosine | cosine | cosine |
最小 LR | 3.0e-5 | 2.0e-5 | 1.6e-5 |
权重衰减 | 0.1 | 0.1 | 0.1 |
梯度裁剪 | 1.0 | 1.0 | 1.0 |
预热 tokens | 3.75e8 | 3.75e8 | 3.75e8 |
8. 引理 3.1 的证明
证明过程。首先,我们证明算法 2 产生输入向量 v 的 3/4 的无偏 MXFP4 估计。设 $v \in R^g$,其中 g 是 MX 组大小。对 FP4 进行随机舍入的输入为 $w = \frac{3}{4} v/X$,其中 $X = 2^{\lfloor\log_2(\text{argmax}(|v|))\rfloor-2}$。设 $m = \text{argmax}(|v|)$。观察到 w 的最大幅值元素为:
$|w_m| = \frac{3}{4} |v_m| / 2^{\lfloor\log_2(|v_m|)\rfloor-2} \in [3, 6)$
根据定义,只要 x 在“范围内”(即不溢出 FP4 可表示值的范围),对 FP4 进行随机舍入会产生 x 的无偏估计。由于 FP4 中的最大正规值为 6,对 w 进行随机舍入将得到 $\frac{3}{4} v/X$ 的无偏 FP4 估计。最后,根据期望的线性性质,$X \times \text{StochasticRoundToFP4}(w)$ 得到 v 的 $\frac{3}{4}$ 的无偏估计,符合要求。
接下来,我们证明算法 3 产生 $\frac{dL}{dx}$ 和 $\frac{dL}{dW}$ 的无偏估计。设 $C = \text{MXFP4\_GEMM}(A, B^T)$,其中 $A \in R^{b \times n}$,$B \in R^{m \times n}$,且 g 是 n 的因子。我们有:
由于随机舍入是通过独立噪声实现的,$A^{FP4}$ 和 $B^{FP4}$ 是独立的随机变量。因此:
对于 $\frac{dL}{dx}$,我们有 $A = \frac{dL}{dy} \text{diag}(S)H$ 和 $B = W^T \text{diag}(S)H$。此时,$E[\text{MXFP4\_GEMM}(A, B^T)] = \frac{9}{16} \frac{dL}{dy} \text{diag}(S)HH^T \text{diag}(S)W = \frac{9}{16} \frac{dL}{dy} W$。对于 $\frac{dL}{dW}$,我们有 $A = (\frac{dL}{dy})^T \text{diag}(S)H$ 和 $B = x^T \text{diag}(S)H$。此时,$E[\text{MXFP4\_GEMM}(A, B^T)] = \frac{9}{16} (\frac{dL}{dy})^T \text{diag}(S)HH^T \text{diag}(S)x = \frac{9}{16} (\frac{dL}{dy})^T x$。最后,在第 10 和 11 行将两个值乘以 16/9,得到所需的无偏梯度估计器。
9. 使用 RHT 约束 SR 方差的界
定理 3.2 证明。略。
💬 评论讨论
欢迎在这里分享您的想法和见解!