DeepGEMM 2.0 Technical Overview

Ray Wang (王辉) NVIDIA GPU加速计算专家团队 高级工程师

目录

  • DeepGEMM 2.0 技术概述
    • 背景 (Background)
    • 输入规范和用法 (Input Specification and Usage)
    • Blackwell 实现 (Blackwell Implementation)
    • 性能基准 (Performance Benchmarks)
    • 未来路线图 (Future Roadmap)

DeepGEMM 2.0 技术概述

背景 (Background)

缩放 FP8 GEMM (Scaled FP8 GEMM)

在大型语言模型 (LLM) 中,一个标准的 FP8 GEMM 运算如下图所示,输入矩阵 (FP8) 与权重矩阵 (FP8) 相乘,得到输出矩阵 (BF16)。

Page 5 - 标准FP8 GEMM示意图
Page 5 - 标准FP8 GEMM示意图

FP8 矩阵是从 BF16 矩阵转换而来的。由于 BF16 矩阵中的值可能超出 FP8 的可表示范围,因此需要一个缩放因子 (scaling factor) 将 BF16 的值压缩到 FP8 的范围内。这也有助于充分利用 FP8 的全部数值容量。该过程可参考 Transformer Engine。

DeepSeek V3 之前的 FP8 实现方案

在 DeepSeek V3 之前使用的 FP8 实现方案如下:
Page 6 - 旧版FP8 GEMM方案
将经过缩放的 FP8 输入矩阵与经过缩放的 FP8 权重矩阵相乘后,再乘以两个缩放因子 Scale A 和 Scale B,以恢复原始值范围,最终得到 BF16 输出。

  • 优点: 速度快。在 A×B 之后进行单次缩放步骤,使得运算非常迅速。
  • 缺点: 每个张量只有一个缩放因子,这意味着异常值 (outliers) 会影响准确性。许多客户尝试在 LLM 训练中使用此方法,但失败了。

DeepSeek 的 FP8 实现方案:1D-2D 细粒度缩放

DeepSeek 采用了 1D-2D 细粒度缩放的 FP8 实现方案,如下图所示:
Page 7 - DeepSeek的FP8 GEMM方案

该方案采用块缩放 (block scaling)
* 输入矩阵:一个 token 中的 128 个元素共享一个缩放因子(1D scale)。
* 权重矩阵:一个 128x128 的元素块 (tile) 共享一个缩放因子(2D scale)。

  • 优点: 精度高,使得 DeepSeek V3 的训练成为可能。
  • 缺点: 速度慢。沿 K 维度的多个缩放因子会减慢累加过程。
  • 展望: 在 Blackwell 架构上,UTCMMA 支持块缩放的矩阵乘累加(MMA),因此不会再有速度慢的问题。

什么是 MoE GEMM

MoE (Mixture of Experts) GEMM 内核执行多个矩阵 A 与多个形状相同的矩阵 B 的矩阵乘法,并将它们融合成一个单一的 GEMM 内核。
* 每次单独的乘法被称为一个 "problem"。
* MoE GEMM 是分组 GEMM (grouped GEMM) 的一种特殊情况,因为每个矩阵 B (Expert Weight) 都具有相同的形状。

Page 8 - MoE GEMM 示意图
Page 8 - MoE GEMM 示意图

什么是 DeepGEMM

DeepGEMM 是一个开源、高性能的 GEMM 库,专为 DeepSeek 模型的训练和推理而定制。

  • 主要功能:

    • FP8 块缩放 GEMM (FP8 block scale GEMM)
    • FP8 块缩放 MoE GEMM (FP8 block scale MoE GEMM)
  • 亮点:

    • 提供简单且开箱即用的 PyTorch 接口。
    • 单个 CUDA 内核同时支持标准 GEMM 和 MoE GEMM。
    • 性能与 cuBLAS 和 CUTLASS 相当。
  • 峰值性能 (在 DeepSeek V3 中最大的 GEMM,M×N×K = 4096x7168x16384):

    • Hopper GPU: 1,556 TFLOPS (对比 cuBLAS 12.9 的 1,550 TFLOPS)
    • Blackwell GPU: 3,200 TFLOPS (对比 cuBLAS 12.9 的 3,070 TFLOPS)
    • cuBLAS 和 CUTLASS 的基准测试使用的是非块缩放的 FP8 GEMM 实现,但使用了相同的脚本。

注:仅供技术讨论和参考,性能可能因不同的产品组合而异。

输入规范和用法 (Input Specification and Usage)

DeepGEMM 的输入和用法是为满足 DeepSeek 的实际需求而量身定制的。主要有三个 API:

  1. fp8_gemm_nt

    • 用例: DeepSeek 模型中 MLA (一种注意力变体) 的上/下投影 (Up-/down-projection)。
    • 功能: 标准的 1×128 和 128×128 FP8 缩放 GEMM。
  2. m_grouped_fp8_gemm_nt_masked

    • 用例: 解码 (Decoding),低延迟推理。
    • 功能: MoE GEMM,其中 M 因任务而异,N 和 K 固定。
    • 要求: 将每个 problem 的 M 填充 (Pad) 到最大可能的 M (每个 expert 的最大 token 数)。
  3. m_grouped_fp8_gemm_nt_contiguous

    • 用例: 预填充 (Prefilling) 和训练。
    • 功能: MoE GEMM,其中 M 因任务而异,N 和 K 固定。
    • 要求: 将每个 problem 的 M 填充 (Pad) 到 128 的倍数。

API 1: fp8_gemm_nt

此 API 执行缩放 GEMM 运算 D = C + A x BT

  • 输入: A [M×K] 及 scales_A, B [N×K] 及 scale_B, 可选的 C [M×N], 输出 D [M×N]。
  • 缩放因子规范:

    • Hopper: 1D2D FP32 缩放因子。
      • A 的每个 1×128 tile 共享一个 FP32 scale, scales_A.shape == [M, K / 128]
      • B 的每个 128×128 tile 共享一个 FP32 scale, scales_B.shape == [N / 128, K / 128]
    • Blackwell: 1D1D E8M0 缩放因子。
      • A 和 B 的每个 1×128 tile 共享一个 E8M0 scale, scales_A.shape == [M, K / 128], scales_B.shape == [N, K / 128]
  • 注意 1: 当使用 1D scale (1×128 scale) 时,必须转置 scale 张量,使其 scale.stride(0) == 1 (即列主序/M主序布局)。

  • 注意 2: 对于 Blackwell 的 E8M0 scale,与 cuBLAS 和 CUTLASS 不同,DeepGEMM 不需要特殊的输入格式,这使其更加用户友好。DeepGEMM 在内核内部转换 scale 的布局,对性能几乎没有影响。
注 1: 缩放因子转置 (Scales Transpose)

为了提高内存访问效率,需要对缩放因子张量进行转置。原始布局沿 K 维度是内存连续的,但在每个计算阶段加载时是非连续的。转置后,布局沿 M 维度连续,使得在一个阶段内读取的所有缩放因子在内存中都是连续的。

Page 13 - 缩放因子转置示意图
Page 13 - 缩放因子转置示意图
注 2: E8M0 (Blackwell) 缩放因子输入

cuBLAS/CUTLASS 要求用户将缩放因子重排为一种特定的平铺布局 (tiled layout),这种布局不直观且元素在行主序下不连续。

Page 14 - E8M0缩放因子内存布局对比
Page 14 - E8M0缩放因子内存布局对比

DeepGEMM 的处理方式:
* 将每四个 E8M0 缩放因子打包成一个 uint32_t 整数。
* 然后将得到的 uint32_t 数组以列主序 (column-major order) 存储。

这简化了用户的使用,无需手动进行复杂的内存重排。更多细节请参阅 https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout

API 2: m_grouped_fp8_gemm_nt_masked

  • 缩放因子规范: 与 fp8_gemm_nt 相同。不同 problem 的缩放因子沿 num_problems 轴堆叠。
  • 输入张量 A 的布局:
    • 所有 problem 的缓冲区大小相同,均被填充到最大可能的 M(即最坏情况)。
    • 输入张量 B 的形状为 [num_problems, N, K]
      Page 15 - 输入张量A的布局示意图
  • 优点: 先前的内核不再需要为 GEMM 提供每个 problem 的偏移量,从而减少了辅助内核造成的延迟。
  • 缺点: 存在显著的空间浪费。

API 3: m_grouped_fp8_gemm_nt_contiguous

  • 尺度规范 (Scales specification): 与 fp8_gemm_nt_nt 相同。尺度因子 (Scale factors) 沿着 num_problems 轴堆叠。

  • 输入A张量布局:

  • Token 缓冲区被填充到128的倍数。
  • 输入张量B的形状为 [num_problems, N, K]
  • 经过填充后,来自不同问题 (problem) 的 token 不会在同一个 GEMM tile 中。
输入张量布局图示 (Page 16)
输入张量布局图示 (Page 16)
API 3: m_grouped_fp8_gemm_nt_contiguous (续)
  • 我们只使用一个 TMA (Tensor Memory Accelerator) 张量描述符来处理所有问题。
  • TMA 存储的 token 数量是固定的——如果我们将其固定为128,那么它每条指令必须存储128个 token。
  • 如果缓冲区没有填充到128的倍数,问题之间就会发生数据竞争(data races)。如下图所示,Problem 0 的输出有64个 token,Problem 1的输出紧随其后。当 TMA 尝试为 Problem 0 写入128个 token 时,会覆盖 Problem 1 的部分数据区域,造成写-写冲突(Write-Write Conflict)。
写-写冲突图示 (Page 17)
写-写冲突图示 (Page 17)

Blackwell 实现 (Blackwell Implementation)

Blackwell 实现章节封面 (Page 18)
Blackwell 实现章节封面 (Page 18)

下表比较了 Ada、Hopper 和 Blackwell 架构在数据加载、FP8 MMA (Matrix Multiply-Accumulate) 和缩放 (Scaling) 方面的实现方式。

  • 数据加载 (Data Load): Blackwell 和 Hopper 都使用 cp.async.bulk (TMA),而 Ada 使用 cp.async
  • FP8 MMA: Blackwell 使用 tcgen05.mma (m128nNk32, 2CTA),Hopper 使用 wgmma.mma_async (m64nNk32),而 Ada 使用 mma.m16n8k32
  • 缩放 (Scaling): Blackwell 将缩放操作与 mma 融合 (Fused with mma),而 Hopper 和 Ada 则使用 CUDA Core 并与 MMA 重叠执行。
Ada, Hopper, Blackwell 实现对比 (Page 19)
Ada, Hopper, Blackwell 实现对比 (Page 19)

概述 (Overview)

对于用户提供的形状 (M, N, K),DeepGEMM 执行以下步骤:
1. 自动调整内核配置 (Auto tune kernel config)
2. JIT 内核生成 (JIT kernel generation)
3. 调度 CTA tiles (Schedule CTA tiles)
4. TMA 引擎加载 (TMA engine loads)
5. 张量核心计算 (Tensor cores compute)
6. TMA 引擎存储 (TMA engine stores)

自动调整内核参数 (Auto-tune Kernel Parameter)

该过程位于 csrc/jit_kernels/heuristics/common.hppget_best_config 函数中。
在当前实现中,DeepGEMM 只调整两个参数:
1. BLOCK_N (决定了 UTCMA 的 N 维度大小)
2. 流水线阶段数 (Number of stages)

1. Tile 大小选择 (CTaTile)
  • block_m ∈ {64,128,256} (对于Blackwell,block_m 目前固定为128)
  • block_n ∈ {16,32,48,...,256}
  • block_k = 128 / sizeof(AB element) // 使用FP8时,block_k = 128
2. Wave 数量与利用率
  • total_ctas = ceil(M/block_m) * ceil(N/block_n) * num_problems
  • waves = ceil(total_ctas / num_sms)
  • 选择 block_n 的策略:
  • 最小化 wave 的总数。
  • 在 wave 数量相同的情况下,最大化最后一个 wave 的 SM 利用率。
  • 如果 wave 数量和最后一个 wave 的 SM 利用率都相同,则选择较大的 block_n
3. 流水线阶段 (Pipeline Stages)
  1. 给定 block_n,计算每个 SMEM 的 tile (A tile, B tile, SF, barrier, pointers, CD tile)。
  2. 选择能装入 SMEM 的最大阶段数。

JIT 实现 (JIT Implementation)

典型的 JIT 工作流程如下:
配置 (Config) → 生成 (Generate) → 哈希 (Hash) → 缓存未命中/命中 (Cache Miss/Hit) → 加载与启动 (Load & Launch)
↳ (若未命中) → 编译与保存 (Compile & Save) & 加载 (Load)

关键基类:
  • LaunchRuntime<Derived> (kernel_runtime.hpp) (生成, 启动)
  • Derived::generate_impl(args...) 产生 CUDA C++ 内核源码。
  • Derived::launch_impl(kernel, config, args...) 在当前 CUDA 流上调用内核。

  • KernelRuntime (kernel_runtime.hpp) (加载)

  • 加载 "kernel.cubin",提取 cuobjdump 的单个内核符号。
  • 在 Driver API (< CUDA 12.8) 和 Runtime API (≥ CUDA 12.8) 之间切换以加载模块和内核句柄。

  • Compiler (compiler.hpp) (哈希, 缓存, 编译)

  • build(name, code) 从内核名称、DeepGEMM 库、编译器版本、标志和代码生成哈希签名。
  • 如果缓存未命中,则调用 NVCC 或 NVRTC 后端来编译一个新的 CUBIN 并将其插入缓存中。

Tile 调度器 (Tile Scheduler)

CtaTile 调度器负责向 TMA 引擎提供 block_ablock_b 的 ID。
它必须满足两个要求:
- 支持 MoE GEMM
- 调度器首先遍历 block_a 的 ID。
- 然后通过全局内存中的连续布局索引,查找 block_a 的问题 ID。
- 在选定的问题中,遍历 block_b 的 ID。
- 将 block_ablock_b 的 ID 提供给 TMA 引擎。
- 通过线程块乱序 (thread-block swizzling) 提高 L2 缓存命中率

CtaTile 调度器图示 (Page 23)
CtaTile 调度器图示 (Page 23)
线程块乱序 (Thread Block Swizzle) 以提高 L2 缓存命中率

为简单起见,假设:
- block_ablock_b 大小相同。
- L2 缓存可以容纳8个块 (block_ablock_b)。

存在两种针对矩阵 C 的遍历策略:
1. 行主序 (Row-major order) (从左到右,从上到下):
- L2 缓存中的块 (灰色蒙版) 为 block_a_0block_b_0...block_b_6
- 对于第八次乘法 (block_a_0 × block_b_7),其中一个块必须从 L2 中被驱逐 (block_b_0)。

  1. 将矩阵 C 划分为 4x4 的块组:
    • L2 缓存中的块 (灰色蒙版) 为 block_a_0...block_a_3block_b_0...block_b_3 (共8个块)。
    • 这使得可以进行 4x4 = 16 次连续的 block_a × block_b 乘法而无需驱逐。

第二种方法被称为 线程块乱序 (Thread Block Swizzle)。在 DeepGEMM 中,大型 GEMM 使用 16x16 的块组,而小型 GEMM 使用 8x8 的块组。

线程块乱序策略图示1 (Page 24)
线程块乱序策略图示2 (Page 25)
线程块乱序策略图示3 (Page 26)

流水线安排 (前言、主循环、结语) (Pipeline Arrangement (prologue, mainloop, epilogue))

下图展示了 Blackwell 实现中的流水线结构,包含不同功能的 Warp 组:DMA Warp、SF (Scaling Factor) Transpose Warp、MMA Warp 和 4个 Epilogue Warps。

流水线结构图 (Page 27)
流水线结构图 (Page 27)
流水线各阶段任务:
  • TMA: 使用 TMA 从全局内存加载 tiles 和 scales 到共享内存。
  • SF Transpose: 将 scales 转换为 UTCCP 指令所需的布局。
  • MMA: 使用 UTCIMMA 执行矩阵乘法和累加,将结果写入张量内存。
  • Epilogue: 处理张量内存中的结果,并将它们存储回全局内存。
流水线各阶段任务解释 (Page 28)
流水线各阶段任务解释 (Page 28)
流水线同步机制:

每个任务都由一对屏障 (barriers) 进行同步:
- 一个 起始屏障 (start-barrier) 解锁任务以开始执行。
- 一个 结束屏障 (end-barrier) 发出任务已完成执行的信号。

下表展示了不同任务在执行前后所需的屏障状态。

任务 起始屏障 结束屏障
TMA Empty Bar Full Bar
SF Transpose Full Bar Scaling Factor Full Bar
MMA Scaling Factor Full Bar Empty Bar
Mainloop Tensor Memory Empty Bar Tensor Memory FULL Bar
Epilogue Tensor Memory FULL Bar Tensor Memory Empty Bar
流水线同步屏障机制 (Page 29)
流水线同步屏障机制 (Page 29)
DeepGEMM 中的两种流水线:
  1. 主循环 (Mainloop) 的多阶段流水线:
    流水线的阶段数由 Shape_K / TILE_K 决定。每个阶段包含 TMA, FB, SF Trans, SFB, MMA 步骤。

  2. 2阶段的 Epilogue 流水线:
    Epilogue 过程自身也构成一个流水线,与主循环交错执行。

DeepGEMM 中的两种流水线图示 (Page 30)
DeepGEMM 中的两种流水线图示 (Page 30)

实用性能优化:一个 TN GEMM 案例研究 (Practical Performance Optimization: A TN GEMM Case Study)

形状(Shape): m=4096, n=7168, k=2048。内核(Kernel): 用于 Wgrad 的 TN GEMM。

  • 步骤 1:假设瓶颈:在 TN GEMM 中,K 维度在内存中是不连续的,这会降低效率。
  • 步骤 2:验证假设:移除所有代码,只保留内存读取并测量内存吞吐量;观察到与硬件峰值之间存在巨大差距。
  • 步骤 3:提出优化方案:增加连续内存读取的粒度,即选择 M 和 N 维度的 tile 大小,最好是 128 的倍数。
    • BLOCK_MBLOCK_N 均为 256:共享内存不足,且阶段数(number of stages)非常少。❌
    • BLOCK_M = 256 且 BLOCK_N 较小:BLOCK_N 要么太小 (128,这样配对的 CTA 只能读取 64 字节,降低效率),要么仍然超出共享内存。❌
    • BLOCK_M = 128 且 BLOCK_N = 256:在这种情况下,A 和 B 都是从内存中以 128 字节的粒度读取。✅
  • 步骤 4:验证优化:隔离测试内存读取;使用 BLOCK_M=128BLOCK_N=256,内存读取速度得到提升。
  • 步骤 5:实施优化:为 DeepGEMM 添加 BLOCK_N=256 支持。
    • 需要超过 512 个 Tensor Memory 列(512 列用于存储结果,外加用于缩放因子的附加列),这在 Blackwell 上是不可行的。
    • 从内存中读取并使用 BLOCK_N=256 馈送 MMA,但当 MMA 写入 Tensor Memory 时,仅使用前 240 列。
    • TMEM 布局:| Out 1 | Out 2 | SF A | SF B | == | 240 | 240 | 4 | 8 |。因为 UMMA_N = 256,Out2 可能会在回写过程中被 TMEM 中的 MMA 破坏,所以我们等待 Out2 写回 16 列后才释放 Out1。

优化结果:1450 TFLOPS -> 1710 TFLOPS

Page 31 - TN GEMM 优化案例研究步骤
Page 31 - TN GEMM 优化案例研究步骤

性能基准 (Performance Benchmarks)

Page 32 - 性能基准测试标题页
Page 32 - 性能基准测试标题页

基准测试结果

下表展示了在不同矩阵维度 (m, n, k)下,DeepGEMM、cuBLAS 和 CUTLASS 的性能对比。所有单位均为 TFLOPS。

Page 33 - DeepGEMM、cuBLAS 和 CUTLASS 性能对比表格
Page 33 - DeepGEMM、cuBLAS 和 CUTLASS 性能对比表格
  • DeepGEMM 使用 1x128x128 的块规模 GEMM。
  • cuBLAS 和 CUTLASS 使用非块规模(non-block-scale)的 FP8 GEMM 实现,并使用相同的脚本。
  • cuBLAS 通过 torch._scaled_mm 调用,CUTLASS 通过 vllm.model_executor.layers.quantization.utils.fp8_utils.cutlass_scaled_mm 调用。

*注:仅供技术讨论和参考,性能可能因不同的产品组合而异。

未来路线图 (Future Roadmap)

Page 34 - 未来路线图标题页
Page 34 - 未来路线图标题页

我们计划按以下优先顺序为 DeepGEMM 添加优化和功能:
1. 支持 BLOCK_M = 64,以提升当 M 较小(例如 M = 64 或 128)时的性能。
2. 支持额外的架构(例如 Ada)。
3. 支持 NVFP4。
4. 支持 TMA gather-4 和 scatter-4,以实现各种 MoE 算子的融合。

如果您有兴趣贡献,请随时提交 PR!

Page 35 - DeepGEMM 未来路线图详情
Page 35 - DeepGEMM 未来路线图详情