FP8 Training Recipes, Performance and Convergence
FP8 Training Recipes, Performance and Convergence
Xin Yao, DevTech | AI Open Day/May 30, 2025
目录
FP8 混合精度训练概述
为什么使用 FP8?
-
加速计算密集型操作
- FP8 Tensor Core 的速度是 16 位 Tensor Core 的 2 倍。
-
减少内存占用
- FP8 显著减少了权重和激活值的内存消耗。
-
加速通信
- 在 FP8 中进行非归约通信是无开销的,例如 AllGather, AllToAll。
-
缩小训练与推理之间的差距
- 训练好的模型已经是量化过的。
-
许多工作可以被复用
- 例如,DeepGEMM, DeepEP。
-
来自业界的成功案例
- DeepSeek-V3/R1
- Nemotron-H
- Llama4 family
FP8 格式
-
开放计算项目 8 位浮点规范 (Open Compute Project 8-bit Floating Point Specification)*
- 由 AMD, Arm, Google, Intel, Meta, 和 NVIDIA 联合撰写。
- 1 位符号位:最高有效位。
- e 位有偏指数。
- m 位尾数。
-
E4M3
- 4 位指数,3 位尾数。
- 更精确。
torch.float8_e4m3fn
-
E5M2
- 5 位指数,2 位尾数。
- 动态范围更广。
torch.float8_e5m2
BF16 混合精度训练
下图展示了标准的 BF16 混合精度训练流程,其中权重更新在 FP32 中进行以保持精度。
- 前向传播 (Forward): 输入 (Input)、权重 (Weight) 和输出 (Output) 均为 BF16 格式。前向传播计算 (Fprop) 使用 BF16 GEMM。
- 反向传播 (Backward): 输入梯度 (Input Gradient) 和输出梯度 (Output Gradient) 为 BF16 格式。权重梯度 (Weight Gradient) 为 FP32 格式。梯度计算 (Dgrad, Wgrad) 使用 BF16 GEMM。
- 优化器 (Optimizer): 优化器状态 (Optimizer States) 和主权重 (Master Weights) 保持为 FP32 格式,更新后的权重再转换回 BF16。
BF16 混合精度 + FP8
在 BF16 混合精度训练的基础上引入 FP8,以进一步加速计算密集型的 GEMM 操作。
- 前向传播 (Forward): 输入从 BF16 转换到 FP8,权重从 BF16 转换到 FP8。Fprop 计算使用 FP8 GEMM。输出仍为 BF16。
- 反向传播 (Backward): 输入梯度从 BF16 转换到 FP8,输出梯度从 BF16 转换到 FP8。Dgrad 和 Wgrad 计算使用 FP8 GEMM。
- 优化器 (Optimizer): 流程与 BF16 混合精度训练类似,主权重保持 FP32。
FP8 训练方案 (Training Recipes)
缩放 (Scaling)
-
缩放粒度 (Scaling granularity)
- 按张量 (Per-tensor)
- 按分块 (Per-tile),例如块(block)、组(group)、子通道(sub-channel)
-
缩放方法 (Scaling method)
- 将张量/分块的绝对值最大值缩放到 FP8 可表示的最大值。
- 通过
s * x的操作将张量的值范围“移动”到与 E4M3 可表示范围有更好重叠的区域。 - 量化函数 (Quantize Function):
scale = FP8_MAX / amax(window)
Xq = static_cast<fp8_type>(x * scale) - 反量化函数 (Dequantize Function):
X' = Xq / scale
FP8 方案概览 (FP8 Recipes Overview)
什么是 FP8 方案?
-
使用哪种 FP8 格式?
- E4M3
- 混合 (Hybrid):输入和权重使用 E4M3,梯度使用 E5M2。
-
使用何种缩放粒度?
- 按张量 (Per-tensor)
- 按块/组 (Per-block/group)
- 块大小 (Block size)
- 1D 或 2D
-
模型的哪一部分使用 FP8?
- 线性层 (Linear layers),除了嵌入层 (embedding) 和语言模型头 (LM head)。
- AllGather/AllToAll 通信。
- 第一层和最后一层使用 BF16。
- FP8 注意力机制 (FP8 attention)。
按张量缩放 (Per-Tensor Scaling)
延迟缩放 (Delayed) 与 当前缩放 (Current)
- FP8 格式: 混合 (Hybrid) - 输入和权重为 E4M3,梯度为 E5M2。
- 缩放粒度: 按张量 (Per-tensor)。
-
模型的 FP8 部分:
- 线性层,除了嵌入和语言模型头。
- AllGather/AllToAll 通信。
- 第一层和最后一层在 BF16 中。
-
当前/动态/在线 (Current/Dynamic/Online/Live) 缩放:
- 从当前张量计算
amax值。
- 从当前张量计算
-
延迟 (Delayed) 缩放:
- 使用来自历史缓冲区 (history buffer) 的
amax值(近似值)。 - 不推荐用于大规模训练。
- 使用来自历史缓冲区 (history buffer) 的
-
TE (v2.2 及更高版本)
with fp8_autocast(fp8_recipe=Float8CurrentScaling()):
model()
- MCore (v0.12 及更高版本)
--fp8-format hybrid
--fp8-recipe tensorwise
按张量缩放:Hopper 架构
- Hopper 仅支持 TN 布局的 FP8 GEMM。
-
公式:
Y[m, n] = W[n, k] @ X[m, k]dX[m, k] = W^T[k, n] @ dY[m, n]dW[n, k] = X^T[k, m] @ dY^T[n, m]
-
流程:
- 在前向和反向传播中,输入、权重和梯度在转换为 FP8 格式之前需要进行转换和转置 (Cast and Transpose) 操作以满足 TN 布局要求。
按张量缩放:Blackwell 架构
- Blackwell 支持所有布局的 FP8 GEMM。
- 权重张量每个参数节省 1 字节。
-
公式:
Y[m, n] = W[n, k] @ X[m, k]dX[m, k] = W^T[k, n] @ dY[m, n]dW[n, k] = X^T[k, m] @ dY^T[n, m]
-
流程:
- 流程更简化,因为不再需要特定的转置操作来适应硬件。输入、权重和梯度可以直接转换为 FP8 格式。
分块缩放 (Blockwise Scaling)
Groupwise/Sub-channel
-
缩放粒度: 按分块/组/块 (Per-tile/group/block)。
- 激活/梯度使用 1x128 的 1D 分块。
- 权重使用 128x128 的 2D 分块。
-
FP8 格式: E4M3。
-
模型的 FP8 部分:
- 线性层,除了嵌入和语言模型头。
- AllGather/AllToAll 通信。
-
TE (v2.3 及更高版本)
with fp8_autocast(fp8_recipe=Float8BlockScaling()):
model()
- MCore (v0.13 及更高版本)
--fp8-format e4m3 --fp8-recipe blockwise
分块缩放:Hopper 架构
- Hopper 仅支持 TN 布局的 FP8 GEMM。
- 为 1x128 和 128x128 提供了量化核函数 (Quantization kernels)。
- 支持 2D1D 和 1D1D 的 FP8 GEMM (CUDA 12.9)。
-
公式:
Y[m, n][BF16] = W[n, k][128x128] @ X[m, k][1x128]dX[m, k][BF16] = W^T[k, n][128x128] @ dY[m, n][1x128]dW[n, k][FP32] = X^T[k, m][1x128] @ dY^T[n, m][1x128]
-
流程:
- 同样需要转换和转置操作。
- 反向传播中的梯度计算 (Dgrad, Wgrad) 分别使用 2D1D 和 1D1D 的 FP8 GEMM。
MXFP8 (仅限 Blackwell 架构)
MXFP8 是 OCP (Open Compute Project) 的微缩放格式 (Microscaling Formats) 之一,专为 Blackwell 架构设计。
- FP8 格式: E4M3
- 缩放粒度 (Scaling granularity): Per-tile (1x32),即每个 1x32 的数据块共享一个缩放因子。
-
模型中采用 FP8 的部分:
- 线性层,不包括 embedding 和 LM head 层。
- AllGather/AllToAll 通信(目前暂不支持)。
-
Transformer Engine (TE) 集成 (v2.0 及更高版本):
通过fp8_autocast(fp8_recipe=MXFP8BlockScaling())与模型进行集成。 - MCore 集成 (v0.12 及更高版本):
--fp8-format e4m3--fp8-recipe mxfp8
下图展示了 MXFP8 的数据结构,其中 k 个标量元素共享一个缩放因子 X。每个元素 P_i 有 d 个比特。表格总结了 MXFP8 的格式细节。
MXFP8 训练流程
在仅支持 Blackwell 架构的 MXFP8 训练中,激活值和权重在前向和后向传播中以不同方式进行量化。这要求数据同时支持行优先(rowwise)和列优先(colwise)的格式。
- Y[m, n][BF16] = W[n, k][1x32] @ X[m, k][1x32]
- dX[m, k][BF16] = W^T[k, n][1x32] @ dY[m, n][1x32]
- dW[n, k][FP32] = X^T[k, m][1x32] @ dY^T[n, m][1x32]
下图详细描述了 MXFP8 在前向传播、后向传播和优化器更新过程中的数据流:
FP8 存储:FP8 激活与 FP8 主权重
为了优化内存使用,可以调整 FP8 的存储策略。
- FP8 激活: 在后向传播中,仅保存列优先(colwise)的 FP8 输入。图中虚线框表示的数据不会被保存。
- 权重转换流程: 主权重(Master weight)以 FP32 格式存储,在计算前转换为 BF16,再进一步转换为 FP8 权重。
一个核心问题是:我们能否只保留 FP8 权重,以进一步节省内存?
下图展示了此存储策略下的数据流,其中虚线框部分不被存储。
FP8 存储:FP8 主权重所面临的挑战
仅存储 FP8 主权重会带来一些挑战,尤其是在与 MCore 的分布式优化器(如 ZeRO-1)和分布式检查点(checkpoint)协同工作时。
以下是跨 DP Rank(数据并行等级)同步权重的流程:
- 步骤 1: 优化器(optimizer.step())更新各 DP Rank 上的 FP32 主权重(master weights)。
- 步骤 2: 从更新后的主权重中计算局部 amax 值(absolute maximum,绝对值最大值)。
- 步骤 3: 通过 AllReduce max 操作,在所有 DP Rank 间同步,得到全局 amax 值。
下图详细描绘了在 DP=2,Rank 0 和 Rank 1 之间的这个同步过程。
FP8 存储:FP8 主权重更新流程
在获取全局 amax 值之后,流程继续进行:
- 步骤 4: 将全局 amax 值和 FP32 主权重一起转换为 FP8 格式的模型权重(model weights)。
- 步骤 5: 在所有 DP Rank 间使用 AllGather 操作分发 FP8 格式的模型权重。
此流程需要为每种 FP8 配方(recipe)进行单独的实现。在 MCore (v0.12 及更高版本) 中,可以通过 --fp8-param-gather 参数来启用此功能。
下图展示了从全局 amax 到最终 FP8 模型权重的转换和分发过程。
FP8 存储:内存占用概览
下表总结了不同训练精度和存储策略下的内存占用情况,此分析不包括激活值所占用的内存,这部分内存在 MoE 模型中尤为重要。
- * 我们不考虑 BF16 优化器状态,因为它与 BF16 或 FP8 训练是正交的。
- ** 从 TE 2.1 开始,
store_param_remainders默认开启,FP32 主权重可以与模型权重共享前 16 个比特。因此,主权重只需要额外的 16 比特。
从表中可以看出,使用 FP8 主权重(FP8 Primary Weights)可以显著减少内存占用,尤其是在 ZeRO-1 优化下(对比 BF16 的 6+10/d,FP8 Primary Weights 在 Blackwell 上可以达到 5+12/d)。
FP8 通信:非归约(Non-reduction)通信
-
DP (数据并行):
- 前述的 FP8 主权重策略使得为数据并行(DP)启用 FP8 AllGather 成为可能。
-
TP (张量并行):
- CPU 等待: TE 的 Userbuffer TP-GEMM 重叠,PyTorch 的 Async-TP。
- GPU 等待: TE 的原子 GEMM,ByteDance 的 FLUX/COMET/Triton-Distributed。
- AG+GEMM 和 GEMM+RS 均支持 FP8。
下图对比了 split_overlap_rs 和 atomic_gemm_overlap_rs (atomic) 两种方式。atomic_gemm 通过将多个小 GEMM 操作合并为一个原子操作,并使用计数器进行同步,减少了计算流(compute stream)和通信流(comm stream)之间的等待事件,从而改善了计算和通信的重叠效率。
- EP (专家并行):
- 在 FP8 中分发 AlltoAll 通信。
- 利用 DeepEP。
下图展示了 FP8 A2A(All-to-All)的数据流:
Quantize into FP8 → FP8 A2A → FP8 Permute → [GroupedGEMM, BF16 output]
- 在 MCore 中实现:
Quantize into FP8 → FP8 A2A → FP8 Permute - 在 TE 中实现:
GroupedGEMM, BF16 output(FP8 Permute Kernels,GroupedLinear 的 FP8 输入)
FP8 训练性能
FP8 训练加速:Llama 3 系列/Nemotron/Mixtral
下图展示了在不同硬件平台(GB200, B200, H100)上,使用 FP8-CS (常规缩放) 和 FP8-MX (微缩放) 相对于 BF16 的训练加速比。
- 在 Llama3、Nemotron 和 Mixtral 等模型上,FP8 均展现出显著的加速效果。
- MXFP8 的性能优化工作仍在进行中。
FP8 训练加速:DeepSeek-V3
- 在 H100 平台上使用 DeepEP 进行 FP8 训练,实现了非常高的 TFLOPS。
- DeepSeek-V3 的设计旨在充分利用硬件性能。
- 在如下所示的并行映射配置下,使用 BF16 会导致内存溢出(OOM),因此没有直接的 BF16 性能对比。这表明 FP8 能够支持更优的并行策略。
下表展示了 DeepSeek-V3 训练的详细配置和性能数据:
FP8 训练加速:展望
未来的性能优化方向包括:
- 内存优化: 将张量并行度(TP)从 2 降至 1,以节省内存。
- 减少 CPU 开销:
- Kernel 融合: 将多个小算子融合成一个大算子,减少启动开销。
- 减少 PyTorch 算子: 使用更少的底层 PyTorch 操作。
下图的性能剖析图显示,在路由(router)层有“过多的 kernels”(Too many kernels),导致了显著的 TP2 开销和 TE 开销,凸显了 Kernel 融合的必要性。
FP8 训练收敛性
验收标准
- 损失曲线/损失差异:
- 一些客户遵循 DeepSeek 的标准,采用低于 0.25% 的相对误差作为接受标准。
下图比较了 BF16 和 FP8 训练的损失曲线。结果使用系数为 0.9 的指数移动平均(EMA)进行了平滑处理。从图中可以看出,FP8 的损失曲线与 BF16 高度吻合,表明其具有良好的收敛性。
图表来源:DeepSeek-V3 技术报告。
- 下游任务评估
- Nemotron-H(一个使用 per-tensor current scaling 训练的 56B 混合 mamba-transformer 模型)
下图展示了在 MMLU 和平均常识理解任务上,FP8 与 BF16 的训练收敛情况对比。结果显示,两种精度格式下的模型准确率曲线非常接近。
-
我们的经验
- 使用 1T tokens 的 8B 密集模型。
- 损失和下游任务都应匹配。
-
没有统计标准可以表明收敛。
下图展示了在 MMLU 和平均常识理解任务上,BF16 与 FP8 的进一步比较,图中还区分了 "Per-tensor current scaling" 和 "Blockwise scaling" 两种缩放方式。实验结果表明,FP8 和 BF16 的性能表现依然非常接近。