FP8 Training Recipes, Performance and Convergence

Xin Yao, DevTech | AI Open Day/May 30, 2025

目录

  1. FP8 混合精度训练概述
  2. FP8 训练方案 (Training Recipes
  3. FP8 训练性能
  4. FP8 训练收敛性
  5. 致谢

FP8 混合精度训练概述

为什么使用 FP8?

  • 加速计算密集型操作

    • FP8 Tensor Core 的速度是 16 位 Tensor Core 的 2 倍。
  • 减少内存占用

    • FP8 显著减少了权重和激活值的内存消耗。
  • 加速通信

    • 在 FP8 中进行非归约通信是无开销的,例如 AllGather, AllToAll。
  • 缩小训练与推理之间的差距

    • 训练好的模型已经是量化过的。
  • 许多工作可以被复用

    • 例如,DeepGEMM, DeepEP。
  • 来自业界的成功案例

    • DeepSeek-V3/R1
    • Nemotron-H
    • Llama4 family

FP8 格式

Page 5
Page 5
  • 开放计算项目 8 位浮点规范 (Open Compute Project 8-bit Floating Point Specification)*

    • 由 AMD, Arm, Google, Intel, Meta, 和 NVIDIA 联合撰写。
    • 1 位符号位:最高有效位。
    • e 位有偏指数。
    • m 位尾数。
  • E4M3

    • 4 位指数,3 位尾数。
    • 更精确。
    • torch.float8_e4m3fn
  • E5M2

    • 5 位指数,2 位尾数。
    • 动态范围更广。
    • torch.float8_e5m2

参考链接:https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1

BF16 混合精度训练

下图展示了标准的 BF16 混合精度训练流程,其中权重更新在 FP32 中进行以保持精度。

Page 6
Page 6
  • 前向传播 (Forward): 输入 (Input)、权重 (Weight) 和输出 (Output) 均为 BF16 格式。前向传播计算 (Fprop) 使用 BF16 GEMM。
  • 反向传播 (Backward): 输入梯度 (Input Gradient) 和输出梯度 (Output Gradient) 为 BF16 格式。权重梯度 (Weight Gradient) 为 FP32 格式。梯度计算 (Dgrad, Wgrad) 使用 BF16 GEMM。
  • 优化器 (Optimizer): 优化器状态 (Optimizer States) 和主权重 (Master Weights) 保持为 FP32 格式,更新后的权重再转换回 BF16。

BF16 混合精度 + FP8

在 BF16 混合精度训练的基础上引入 FP8,以进一步加速计算密集型的 GEMM 操作。

Page 7
Page 7
  • 前向传播 (Forward): 输入从 BF16 转换到 FP8,权重从 BF16 转换到 FP8。Fprop 计算使用 FP8 GEMM。输出仍为 BF16。
  • 反向传播 (Backward): 输入梯度从 BF16 转换到 FP8,输出梯度从 BF16 转换到 FP8。Dgrad 和 Wgrad 计算使用 FP8 GEMM。
  • 优化器 (Optimizer): 流程与 BF16 混合精度训练类似,主权重保持 FP32。

FP8 训练方案 (Training Recipes)

缩放 (Scaling)

Page 9
Page 9
  • 缩放粒度 (Scaling granularity)

    • 按张量 (Per-tensor)
    • 按分块 (Per-tile),例如块(block)、组(group)、子通道(sub-channel)
  • 缩放方法 (Scaling method)

    • 将张量/分块的绝对值最大值缩放到 FP8 可表示的最大值。
    • 通过 s * x 的操作将张量的值范围“移动”到与 E4M3 可表示范围有更好重叠的区域。
    • 量化函数 (Quantize Function):
      scale = FP8_MAX / amax(window)
      Xq = static_cast<fp8_type>(x * scale)
    • 反量化函数 (Dequantize Function):
      X' = Xq / scale

FP8 方案概览 (FP8 Recipes Overview)

什么是 FP8 方案?

Figure from Insights into DeepSeek-V3: Scaling Challenges and Reflections on Hardware for AI Architectures. https://arxiv.org/pdf/2505.09343v1
Figure from Insights into DeepSeek-V3: Scaling Challenges and Reflections on Hardware for AI Architectures. https://arxiv.org/pdf/2505.09343v1
  • 使用哪种 FP8 格式?

    • E4M3
    • 混合 (Hybrid):输入和权重使用 E4M3,梯度使用 E5M2。
  • 使用何种缩放粒度?

    • 按张量 (Per-tensor)
    • 按块/组 (Per-block/group)
      • 块大小 (Block size)
      • 1D 或 2D
  • 模型的哪一部分使用 FP8?

    • 线性层 (Linear layers),除了嵌入层 (embedding) 和语言模型头 (LM head)。
    • AllGather/AllToAll 通信。
    • 第一层和最后一层使用 BF16
    • FP8 注意力机制 (FP8 attention)

按张量缩放 (Per-Tensor Scaling)

延迟缩放 (Delayed) 与 当前缩放 (Current)

Page 11
Page 11
  • FP8 格式: 混合 (Hybrid) - 输入和权重为 E4M3,梯度为 E5M2。
  • 缩放粒度: 按张量 (Per-tensor)。
  • 模型的 FP8 部分:

    • 线性层,除了嵌入和语言模型头。
    • AllGather/AllToAll 通信。
    • 第一层和最后一层在 BF16 中。
  • 当前/动态/在线 (Current/Dynamic/Online/Live) 缩放:

    • 从当前张量计算 amax 值。
  • 延迟 (Delayed) 缩放:

    • 使用来自历史缓冲区 (history buffer) 的 amax 值(近似值)。
    • 不推荐用于大规模训练。
  • TE (v2.2 及更高版本)

with fp8_autocast(fp8_recipe=Float8CurrentScaling()):
    model()
  • MCore (v0.12 及更高版本)
--fp8-format hybrid
--fp8-recipe tensorwise

按张量缩放:Hopper 架构

  • Hopper 仅支持 TN 布局的 FP8 GEMM。
Page 12
Page 12
  • 公式:

    • Y[m, n] = W[n, k] @ X[m, k]
    • dX[m, k] = W^T[k, n] @ dY[m, n]
    • dW[n, k] = X^T[k, m] @ dY^T[n, m]
  • 流程:

    • 在前向和反向传播中,输入、权重和梯度在转换为 FP8 格式之前需要进行转换和转置 (Cast and Transpose) 操作以满足 TN 布局要求。

按张量缩放:Blackwell 架构

  • Blackwell 支持所有布局的 FP8 GEMM。
  • 权重张量每个参数节省 1 字节。
Page 13
Page 13
  • 公式:

    • Y[m, n] = W[n, k] @ X[m, k]
    • dX[m, k] = W^T[k, n] @ dY[m, n]
    • dW[n, k] = X^T[k, m] @ dY^T[n, m]
  • 流程:

    • 流程更简化,因为不再需要特定的转置操作来适应硬件。输入、权重和梯度可以直接转换为 FP8 格式。

分块缩放 (Blockwise Scaling)

Groupwise/Sub-channel

Page 14
Page 14
  • 缩放粒度: 按分块/组/块 (Per-tile/group/block)。

    • 激活/梯度使用 1x128 的 1D 分块。
    • 权重使用 128x128 的 2D 分块。
  • FP8 格式: E4M3。

  • 模型的 FP8 部分:

    • 线性层,除了嵌入和语言模型头。
    • AllGather/AllToAll 通信。
  • TE (v2.3 及更高版本)

with fp8_autocast(fp8_recipe=Float8BlockScaling()):
    model()
  • MCore (v0.13 及更高版本)

--fp8-format e4m3 --fp8-recipe blockwise

分块缩放:Hopper 架构

  • Hopper 仅支持 TN 布局的 FP8 GEMM。
  • 为 1x128 和 128x128 提供了量化核函数 (Quantization kernels)。
  • 支持 2D1D 和 1D1D 的 FP8 GEMM (CUDA 12.9)。
Page 15
Page 15
  • 公式:

    • Y[m, n][BF16] = W[n, k][128x128] @ X[m, k][1x128]
    • dX[m, k][BF16] = W^T[k, n][128x128] @ dY[m, n][1x128]
    • dW[n, k][FP32] = X^T[k, m][1x128] @ dY^T[n, m][1x128]
  • 流程:

    • 同样需要转换和转置操作。
    • 反向传播中的梯度计算 (Dgrad, Wgrad) 分别使用 2D1D 和 1D1D 的 FP8 GEMM。

MXFP8 (仅限 Blackwell 架构)

MXFP8 是 OCP (Open Compute Project) 的微缩放格式 (Microscaling Formats) 之一,专为 Blackwell 架构设计。

  • FP8 格式: E4M3
  • 缩放粒度 (Scaling granularity): Per-tile (1x32),即每个 1x32 的数据块共享一个缩放因子。
  • 模型中采用 FP8 的部分:

    • 线性层,不包括 embedding 和 LM head 层。
    • AllGather/AllToAll 通信(目前暂不支持)。
  • Transformer Engine (TE) 集成 (v2.0 及更高版本):
    通过 fp8_autocast(fp8_recipe=MXFP8BlockScaling()) 与模型进行集成。

  • MCore 集成 (v0.12 及更高版本):
    • --fp8-format e4m3
    • --fp8-recipe mxfp8

下图展示了 MXFP8 的数据结构,其中 k 个标量元素共享一个缩放因子 X。每个元素 P_id 个比特。表格总结了 MXFP8 的格式细节。

Page 16
Page 16

MXFP8 训练流程

在仅支持 Blackwell 架构的 MXFP8 训练中,激活值和权重在前向和后向传播中以不同方式进行量化。这要求数据同时支持行优先(rowwise)和列优先(colwise)的格式。

  • Y[m, n][BF16] = W[n, k][1x32] @ X[m, k][1x32]
  • dX[m, k][BF16] = W^T[k, n][1x32] @ dY[m, n][1x32]
  • dW[n, k][FP32] = X^T[k, m][1x32] @ dY^T[n, m][1x32]

下图详细描述了 MXFP8 在前向传播、后向传播和优化器更新过程中的数据流:

Page 17
Page 17

FP8 存储:FP8 激活与 FP8 主权重

为了优化内存使用,可以调整 FP8 的存储策略。

  • FP8 激活: 在后向传播中,仅保存列优先(colwise)的 FP8 输入。图中虚线框表示的数据不会被保存。
  • 权重转换流程: 主权重(Master weight)以 FP32 格式存储,在计算前转换为 BF16,再进一步转换为 FP8 权重。

一个核心问题是:我们能否只保留 FP8 权重,以进一步节省内存?

下图展示了此存储策略下的数据流,其中虚线框部分不被存储。

Page 18
Page 18

FP8 存储:FP8 主权重所面临的挑战

仅存储 FP8 主权重会带来一些挑战,尤其是在与 MCore 的分布式优化器(如 ZeRO-1)和分布式检查点(checkpoint)协同工作时。

以下是跨 DP Rank(数据并行等级)同步权重的流程:

  1. 步骤 1: 优化器(optimizer.step())更新各 DP Rank 上的 FP32 主权重(master weights)。
  2. 步骤 2: 从更新后的主权重中计算局部 amax 值(absolute maximum,绝对值最大值)。
  3. 步骤 3: 通过 AllReduce max 操作,在所有 DP Rank 间同步,得到全局 amax 值。

下图详细描绘了在 DP=2,Rank 0 和 Rank 1 之间的这个同步过程。

Page 19
Page 19

FP8 存储:FP8 主权重更新流程

在获取全局 amax 值之后,流程继续进行:

  1. 步骤 4: 将全局 amax 值和 FP32 主权重一起转换为 FP8 格式的模型权重(model weights)。
  2. 步骤 5: 在所有 DP Rank 间使用 AllGather 操作分发 FP8 格式的模型权重。

此流程需要为每种 FP8 配方(recipe)进行单独的实现。在 MCore (v0.12 及更高版本) 中,可以通过 --fp8-param-gather 参数来启用此功能。

下图展示了从全局 amax 到最终 FP8 模型权重的转换和分发过程。

Page 20
Page 20

FP8 存储:内存占用概览

下表总结了不同训练精度和存储策略下的内存占用情况,此分析不包括激活值所占用的内存,这部分内存在 MoE 模型中尤为重要。

Page 21
Page 21
  • * 我们不考虑 BF16 优化器状态,因为它与 BF16 或 FP8 训练是正交的。
  • ** 从 TE 2.1 开始,store_param_remainders 默认开启,FP32 主权重可以与模型权重共享前 16 个比特。因此,主权重只需要额外的 16 比特。

从表中可以看出,使用 FP8 主权重(FP8 Primary Weights)可以显著减少内存占用,尤其是在 ZeRO-1 优化下(对比 BF16 的 6+10/d,FP8 Primary Weights 在 Blackwell 上可以达到 5+12/d)。

FP8 通信:非归约(Non-reduction)通信

  • DP (数据并行):

    • 前述的 FP8 主权重策略使得为数据并行(DP)启用 FP8 AllGather 成为可能。
  • TP (张量并行):

    • CPU 等待: TE 的 Userbuffer TP-GEMM 重叠,PyTorch 的 Async-TP。
    • GPU 等待: TE 的原子 GEMM,ByteDance 的 FLUX/COMET/Triton-Distributed。
    • AG+GEMM 和 GEMM+RS 均支持 FP8。

下图对比了 split_overlap_rsatomic_gemm_overlap_rs (atomic) 两种方式。atomic_gemm 通过将多个小 GEMM 操作合并为一个原子操作,并使用计数器进行同步,减少了计算流(compute stream)和通信流(comm stream)之间的等待事件,从而改善了计算和通信的重叠效率。

Page 23
Page 23
  • EP (专家并行):
    • 在 FP8 中分发 AlltoAll 通信。
    • 利用 DeepEP。

下图展示了 FP8 A2A(All-to-All)的数据流:
Quantize into FP8 → FP8 A2A → FP8 Permute → [GroupedGEMM, BF16 output]

  • 在 MCore 中实现:Quantize into FP8 → FP8 A2A → FP8 Permute
  • 在 TE 中实现:GroupedGEMM, BF16 output(FP8 Permute Kernels,GroupedLinear 的 FP8 输入)
Page 24
Page 24

FP8 训练性能

FP8 训练加速:Llama 3 系列/Nemotron/Mixtral

下图展示了在不同硬件平台(GB200, B200, H100)上,使用 FP8-CS (常规缩放) 和 FP8-MX (微缩放) 相对于 BF16 的训练加速比。

Page 26
Page 26
  • 在 Llama3、Nemotron 和 Mixtral 等模型上,FP8 均展现出显著的加速效果。
  • MXFP8 的性能优化工作仍在进行中。

FP8 训练加速:DeepSeek-V3

  • 在 H100 平台上使用 DeepEP 进行 FP8 训练,实现了非常高的 TFLOPS。
  • DeepSeek-V3 的设计旨在充分利用硬件性能。
  • 在如下所示的并行映射配置下,使用 BF16 会导致内存溢出(OOM),因此没有直接的 BF16 性能对比。这表明 FP8 能够支持更优的并行策略。

下表展示了 DeepSeek-V3 训练的详细配置和性能数据:

Page 27
Page 27

FP8 训练加速:展望

未来的性能优化方向包括:

  • 内存优化: 将张量并行度(TP)从 2 降至 1,以节省内存。
  • 减少 CPU 开销:
    • Kernel 融合: 将多个小算子融合成一个大算子,减少启动开销。
    • 减少 PyTorch 算子: 使用更少的底层 PyTorch 操作。

下图的性能剖析图显示,在路由(router)层有“过多的 kernels”(Too many kernels),导致了显著的 TP2 开销和 TE 开销,凸显了 Kernel 融合的必要性。

Page 28
Page 28

FP8 训练收敛性

验收标准

  • 损失曲线/损失差异:
    • 一些客户遵循 DeepSeek 的标准,采用低于 0.25% 的相对误差作为接受标准。

下图比较了 BF16 和 FP8 训练的损失曲线。结果使用系数为 0.9 的指数移动平均(EMA)进行了平滑处理。从图中可以看出,FP8 的损失曲线与 BF16 高度吻合,表明其具有良好的收敛性。

Figure 10 | Loss curves comparison between BF16 and FP8 training. Results are smoothed by Exponential Moving Average (EMA) with a coefficient of 0.9. (Page 30)
Figure 10 | Loss curves comparison between BF16 and FP8 training. Results are smoothed by Exponential Moving Average (EMA) with a coefficient of 0.9. (Page 30)

图表来源:DeepSeek-V3 技术报告。

  • 下游任务评估
    • Nemotron-H(一个使用 per-tensor current scaling 训练的 56B 混合 mamba-transformer 模型)

下图展示了在 MMLU 和平均常识理解任务上,FP8 与 BF16 的训练收敛情况对比。结果显示,两种精度格式下的模型准确率曲线非常接近。

Page 31 - FP8 与 BF16 在 MMLU 和平均常识理解任务上的收敛曲线对比
Page 31 - FP8 与 BF16 在 MMLU 和平均常识理解任务上的收敛曲线对比
  • 我们的经验

    • 使用 1T tokens 的 8B 密集模型。
    • 损失和下游任务都应匹配。
  • 没有统计标准可以表明收敛。

下图展示了在 MMLU 和平均常识理解任务上,BF16 与 FP8 的进一步比较,图中还区分了 "Per-tensor current scaling" 和 "Blockwise scaling" 两种缩放方式。实验结果表明,FP8 和 BF16 的性能表现依然非常接近。

Page 32 - BF16 与 FP8 在不同缩放方式下的收敛曲线对比
Page 32 - BF16 与 FP8 在不同缩放方式下的收敛曲线对比

致谢

Page 33 - 致谢页面
Page 33 - 致谢页面