Best Practice of MLA Kernel Optimization on Blackwell

王泽宇, NVIDIA GPU加速计算专家团队 高级工程师 | November 7, 2025

目录 (Agenda)

  • DSA 核优化 (DSA Kernel Optimization)

    • V3.2 和 DSA 介绍 (V3.2 and DSA Introduction)
    • DSA 核的挑战 (Challenges of DSA Kernel)
    • Blackwell 平台概述 (Blackwell Platform Overview)
    • 初尝:DSA 稀疏预填充 (First Taste: DSA Sparse Prefill)
    • 第二项任务:DSA 稀疏解码 (Second Task: DSA Sparse Decoding)
  • MLA 反向核优化 (MLA Backward Kernel Optimization)

    • 背景 (Background)
    • Blackwell 中的反向注意力机制 (Backward Attention in Blackwell)
    • 针对 MLA 的优化 (Optimization for MLA)
    • 针对 Blackwell 中不平衡数据的优化 (Optimization for Unbalanced Data in Backward)
  • 未来工作 (Future Works)

DSA 核优化 (DSA Kernel Optimization)

DeepSeek V3.2 介绍

DeepSeek V3.2 是一个新的“实验性”模型,旨在提升长上下文效率。其基准测试性能和 API 定价如下:

Page 5, DeepSeek V3.2 在各项基准测试中的表现,以及其API的定价信息。
Page 5, DeepSeek V3.2 在各项基准测试中的表现,以及其API的定价信息。

该模型的一个关键特性是稀疏注意力(Sparse Attention)。

DSA: DeepSeek 稀疏注意力机制 (DeepSeek Sparse Attention)

DSA 的核心思想是在推理过程中仅选择 TopK 个 KV Token 进行注意力计算,以降低延迟。

标准的注意力机制(如 Multi-Head Attention, MHA)需要新的查询(Q)Token 与所有的键值(KV)Token 进行计算。

Page 6, 标准注意力机制示意图,新的Q Token与所有KV Tokens进行计算。
Page 6, 标准注意力机制示意图,新的Q Token与所有KV Tokens进行计算。

DSA 引入了一个 TopK 索引器(TopK Indexer),它会从大量的 KV Token 中筛选出最相关的 TopK 个。

Page 8, DSA示意图,TopK索引器从KV Tokens中选择一部分。
Page 8, DSA示意图,TopK索引器从KV Tokens中选择一部分。

通过这种方式,注意力核(Attention Kernel)只需在选定的 TopK KV Token 子集上进行计算,从而显著减少计算量和延迟。与传统的 MLA(Multi-Layer Attention)相比,DSA 的延迟大幅降低。本次演讲的重点在于对此过程中的注意力核进行优化。

Page 10, DSA机制完整流程图,展示了TopK索引器选择KV Tokens子集,以及与MLA相比的延迟优势,并强调了对注意力核的优化。
Page 10, DSA机制完整流程图,展示了TopK索引器选择KV Tokens子集,以及与MLA相比的延迟优势,并强调了对注意力核的优化。

DSA 核的挑战 (Challenges of DSA Kernel)

DSA 本质是带稀疏性的 MLA

DSA 可以被理解为带有稀疏性的 MLA (Multi-Query Attention)。下图对比了 MHA (Multi-Head Attention) 和 MQA (Multi-Query Attention) 的张量结构。

  • MHA (Prefill 阶段)

    • Head_num(QKV) = 128
    • QK_dim = 192
    • VO_dim = 128
  • MQA (Decoding 阶段)

    • Head_num(Q) = 128
    • QK_dim = 576
    • VO_dim = 512
Page 12, MHA 和 MQA 的张量结构对比图。
Page 12, MHA 和 MQA 的张量结构对比图。

回顾:MLA 注意力核 (基于 Hopper 的 FMHA_V2 优化)

在讨论 DSA 的挑战之前,先回顾一下之前在 Hopper 架构上使用 FMHA_V2 对 MLA 注意力核的优化经验。

  • 挑战:

    • 单个 warpgroup 无法存储所有的累加器(Accumulators)。计算公式为 64 x 512 / 128 = 256 regs/thread
  • 有利条件:

    • K 和 V 张量是相同的。
  • 解决方案:

    • 使用 1 个工作组(WG)处理 TMA (Tensor Memory Accelerator)。
    • 仅加载 K 的部分数据是可行的,因为 V = K[:, :512]
    • 使用 2 个工作组(WG)处理 TensorCore 计算。

下图展示了该优化方案的流水线操作:
Page 13, MLA 注意力核在 Hopper 上的优化流水线示意图。

该优化带来了显著的性能提升,如下方的 TFLOPS 性能图所示,FMHA Opt 的性能远超其他方法。
Page 14, 在 H100 80G HBM3 上的 TFLOPS 性能对比图,批大小为128。

FP8 KV 缓存布局

Page 16
Page 16

上图展示了DeepSeek稀疏注意力(DSA)中FP8 KV(键值)缓存的内存布局。每个令牌(token)占用656字节。

  • 主要部分512 e4m3(512字节)用于存储没有位置编码(NoPE)的KV缓存,其构成是 512xFP8 + 4xFP32。这里的 4xFP32 是缩放因子(Scaling factor)。
  • 附加部分:64字节用于存储旋转位置编码(RoPE),格式为64xbf16
  • 计算精度:尽管KV缓存使用了FP8和BF16格式,但所有的注意力(Attention)核函数计算仍然在BF16精度下进行。

DSA 带来的额外挑战

在 MLA 优化的基础上,DSA 引入了新的挑战:

  • 稀疏性处理 (Sparsity):

    • 张量内存访问(TMA)难以高效处理稀疏数据,这会导致带宽的浪费。
  • 预填充 (Prefill) 阶段的挑战 (在长上下文场景中):

    • 在长上下文场景中,标准的多头注意力(MHA)表现不佳。
    • MHA方法:采用自定义掩码(Custom Mask)。
    • 不同的查询(Q)令牌会关注不同的top-K键值(KV)令牌。
    • 示例:当序列长度(Seq len)为64K,top-K值为2K时,对于每个Q,在每128个KV TILE(块)中仅选择4个令牌。
    • 因此,在预填充阶段,需要一个稀疏的矩阵乘法累加(MLA)核函数。
  • 解码 (Decoding) 阶段的挑战:

    • FP8 的反量化 (Dequantization):将FP8数据转换回更高精度的格式会带来计算开销。
    • 高 CUDA 核心压力:反量化等操作会给CUDA核心带来很大的计算压力。

Page 15, DSA 额外挑战的总结,并附有一个FP8数据块的示意图。
Page 17

Blackwell 平台概述

Page 18
Page 18

本节介绍NVIDIA Blackwell平台的概览。

Blackwell平台的优势与机遇

  • 张量内存(Tensor Memory)

    • 该内存用于替代Hopper架构中的寄存器(regs),以存储matrix_Amatrix_O
    • 容量为256KB。
    • 数据移动
      • 与寄存器(REG)交互:tcgen05.ld/st
      • 与共享内存(SMEM)交互:tcgen05.cp

    Page 19
    上图展示了张量内存的布局和寻址方式,它由128个Lane组成,每个Lane为2KB。

  • 2-CTA Tensorcore GEMM

    • 这是一种使用两个协作线程块(CTA)执行通用矩阵乘法(GEMM)的技术。
    • 优势:可以减少共享内存(SMEM)的带宽需求。
    • 由一个集群(cluster)中的一个线程束(warp)启动。

    Page 20
    上图展示了2-CTA GEMM的概念,其中A和B矩阵被分配到两个不同的CTA(CTA0和CTA1)中进行计算。

  • TMA.Gather4

    • 这是一种稀疏张量内存访问(TMA)功能。
    • 引入了额外的int4 idx输入,用于指定需要收集的数据索引。

    Page 21
    上图(图10)展示了tiled::scatter4/tiled::gather4模式下边界框(bounding box)的一个示例,说明了如何根据不同的起始坐标从全局内存中收集(gather)数据到共享内存。

初尝:DSA 稀疏预填充(Sparse Prefill)

Page 22
Page 22

本节将初步探讨DSA在预填充阶段的实现。

DSA稀疏预填充概述

Page 23
Page 23
  • QKV 配置:

    • Q: [seq_len, 128, 576]
    • K: [seq_len, 576]
    • V: [seq_len, 512] == K[:, 0: 512]
  • 核函数设置 (Kernel Setup):

    • CTA维度 (CTA dim): [2*q_len, 1, 1]
    • 2-CTA集群 (2-CTA Cluster): 采用两个协作线程块(CTA)组成的集群。

      • 每个Q令牌的128个头(heads)被分配到集群中,每个CTA处理64个头。
      • 两个CTA共享相同的topK KV令牌。
    • 图中提出了一个关键问题:“如何在2-SM GEMM中处理V?”

  • 架构图:

    • Q矩阵(128个头)被分割成两个Q Tile(Q CTA0, Q CTA1)。
    • K/V矩阵(1个头)通过TMA.Gather4被加载到两个KV Tile(KV CTA0, KV CTA1)。
    • 整个计算由一个2-CTA集群完成。

DSA稀疏预填充内存布局

Page 24
Page 24
  • 共享内存 (Shared Memory):

    • 同时存储K和V(受益于L2缓存)。
    • K 形状: [64, 576],占用72KB。
    • V 形状: [2*64, 256],占用64KB。
    • Q 形状: [64, 576],占用72KB。
    • S 形状: [64, 128],占用16KB。
    • 总计: 224KB (总限制为228KB)。
    • 最终输出: [64, 512] BF16 (重用内存)。
  • 张量内存 (Tensor Memory):

    • P = QK^T: (64 * 128 FP32 = 32KB)。
    • O = SV: (64 * 512 FP32 = 128KB)。
    • 总计: 160KB。
  • Q Offloading to TMEM: Q矩阵被卸载到张量内存中。

DSA稀疏预填充流水线与基准测试

Page 25
Page 25
  • 流水线 (Pipeline):
    上图展示了TMA、Tensorcore和CUDA Core之间的并行执行流水线。TMA负责加载数据,Tensorcore执行矩阵乘法(QK, SV),CUDA Core执行数学函数(MUFU)和缩放(Scale)。

  • 线程配置:

    • 总共512个线程(4个Warpgroup)。
    • 128个线程用于TMA-K,128个线程用于TMA-V。
    • 128个线程用于softmax和重缩放(re-scale)。
    • 32个线程用于MMA(矩阵乘法累加),32个线程用于KV topK索引。
  • 基准测试 (Benchmark):

    • 性能达到 1400 Tflops
    • 这相当于理论峰值性能(2000 Tflops)的 70%

第二项任务:DSA 稀疏解码(Sparse Decoding)

Page 26
Page 26

本节介绍DSA在解码阶段的实现。

DSA稀疏解码的挑战:FP8布局与反量化

Page 27
Page 27
  • 配置: Q_LEN = 1,如果启用了多令牌并行(MTP),则可能大于1。使用splitKV来平衡SM的工作负载。
  • FP8 KV布局:
  • 反量化 (Dequantization) of FP8:

    • FP8 -> FP16: 64 ops/clk/SM
    • FP16 -> FP32: 64 ops/clk/SM
    • FP32 -> BF16: 16 ops/clk/SM
    • 在Blackwell上,BF16 Scale操作为128 ops/clk/SM。
    • 将64512的FP8数据转换为BF16需要 3328 CLK/CTA*。
  • GEMM:

    • BMM1(QK): 26464*576 = 4.6M ops/CTA
    • BMM2(SV): 26464*512 = 4Mops/CTA
    • Tensorcor性能: 8192 ops/CLK/SM
    • 总计GEMM时钟周期: 1088 CLK/CTA
  • 解决方案: 使用带有分布式共享内存(DSMEM)的2-CTA对 (2-CTA pair with DSMEM)

稀疏解码中的2-CTA反量化

Page 28
Page 28

该方案通过2-CTA集群和分布式共享内存(DSMEM)来解决反量化带来的高CUDA核心压力。
* 工作原理: 每个CTA负责一半的反量化工作,并通过DSMEM进行多播(Multi-cast)。
* 流程: 512字节的e4m3数据被分成两部分,分别加载到两个CTA的本地共享内存(SHM0的FP8-CTA0和SHM1的FP8-CTA1)中。经过反量化后,结果通过DSMEM共享到2-CTA集群的共享内存中,形成BF16-CTA0和BF16-CTA1。
* 优势:
* 两个CTA共享相同的TopK KV令牌。
* 将CUDA核心压力减半,降至 1664 CLK/CTA

DSA稀疏解码内存布局

Page 29
Page 29
  • 共享内存 (Shared Memory) (双缓冲):

    • Q_Part0: (64*256 BF16 = 32KB)
    • K_FP8[2]: (642562 FP8 = 32KB)
    • K_BF16[2]: (645762 BF16 = 144KB)
    • S[2]: (64642 BF16 = 16KB)
    • 总计: 224KB
  • 张量内存 (Tensor Memory):

    • P[2] = QK^T: (64642 FP32 = 32KB)
    • O = SV: (64*512 FP32 = 128KB)
    • Q_Part1: (64*320 BF16 = 40KB)
    • 总计: 200KB

DSA稀疏解码 - 进行中:流水线与当前进展

Page 30
Page 30
  • 理想流水线 (Ideal Pipeline):
    上图展示了稀疏解码阶段的理想化流水线,显示了TMA、Tensorcore和CUDA Core(执行反量化、MUFU、Scale)之间的并行工作流。

  • CTA维度: [2, q_len, sm_parts]

  • 线程配置:

    • 384个线程 (3个Warpgroup)。
    • 128个线程用于TMA。
    • 128个线程用于CUDA Core (反量化 & softmax & re-scale)。
    • 32个线程用于MMA,32个线程用于KV topK索引。
  • 当前进展 (Current Progress):

    • 性能达到 400 Tflops
    • 相比FlashMLA有 10% 的增益。
    • 仍有进一步优化的空间。

假设:FP8->BF16 是一条单一指令

下图展示了一种假设情况的流水线,即如果从 FP8 到 BF16 的转换可以由一条单一指令完成。在这种优化下,原本在 CUDA Core 上执行的多个反量化(Dequantization)、乘法融合(MUFU)和缩放(Scale)操作可以被整合,从而简化执行流程,提高效率。TMA(Tensor Memory Accelerator)负责加载数据,Tensor Core 执行核心的矩阵运算,而 CUDA Core 的负担减轻。

Page 31 展示了将 FP8 到 BF16 转换作为单一指令的潜在流水线优化
Page 31 展示了将 FP8 到 BF16 转换作为单一指令的潜在流水线优化

MLA 反向核优化 (MLA Backward Kernel Optimization)

Page 32 标题页:MLA Backward Kernel Implementation
Page 32 标题页:MLA Backward Kernel Implementation

背景:注意力的前向与反向传播

前向传播

标准注意力机制的前向传播过程如下图所示。它主要由两个通用矩阵乘法(GEMM)操作和一个 Softmax 操作组成:

  1. 查询(Q)与键的转置(Kᵀ)进行矩阵相乘,得到分数矩阵 P。
  2. 对 P 应用 Softmax 函数,得到注意力权重矩阵 S。
  3. S 与值(V)进行矩阵相乘,得到最终输出 O。
Page 33 注意力机制前向传播流程图
Page 33 注意力机制前向传播流程图

反向传播

反向传播过程计算输出 O 对输入 Q、K、V 的梯度(分别为 dO、dQ、dK、dV)。梯度流与前向传播的计算图方向相反。

Page 34 注意力机制反向传播的梯度流图
Page 34 注意力机制反向传播的梯度流图

反向传播的具体计算公式如下:
* $P = Q * K^T$
* $S = Softmax(P) = exp(P - lse)$
* $dV = S^T * dO$
* $dS = dO * V^T$
* $dP = S \circ (dS - sum(O \circ dO))$ (其中 $\circ$ 表示逐元素相乘)
* $dQ = dP * K$
* $dK = dP^T * Q$

Page 35 注意力机制反向传播的计算公式
Page 35 注意力机制反向传播的计算公式

从计算角度看,反向传播过程主要包含 5 个 GEMM 操作和 2 个由 CUDA Core 执行的操作。

Page 36 注意力机制反向传播的计算开销分析
Page 36 注意力机制反向传播的计算开销分析

注意力反向核函数流程

注意力反向传播的计算流程涉及以下几个关键步骤。

Page 37 注意力反向传播计算公式列表
Page 37 注意力反向传播计算公式列表

为了优化计算,一些中间值可以预先计算或在前向传播时计算并保存下来:
* lse (log-sum-exp) 在前向传播时计算。
* sum(O ◦ dO) 可以在反向传播主循环开始前预先计算。

Page 38 反向传播中的预计算与前向计算值
Page 38 反向传播中的预计算与前向计算值

在实现核函数时,循环的顺序是一个关键的设计选择。两种常见的策略是:
1. 外层 KV,内层 QO: 外层循环遍历 KV 的分块(tile),内层循环遍历 Q 的分块。在内层核函数中累加 dQ
2. 外层 QO,内层 KV: 外层循环遍历 Q 的分块,内层循环遍历 KV 的分块。在内层核函数中累加 dKdV

Page 39 注意力反向传播的两种循环策略
Page 39 注意力反向传播的两种循环策略

下图展示了注意力反向传播的数据流。首先计算 sumOdO,然后将其与 dO, Q, K, V 一同输入到主计算模块 Backward Attn 中,得到 dKdV,并累加生成最终的 dQ

Page 40 注意力反向传播的数据流图
Page 40 注意力反向传播的数据流图

Blackwell 架构上的反向注意力

本节将讨论在 Blackwell 架构上实现反向注意力的具体细节。

Page 41 标题页:Backward Attention in Blackwell
Page 41 标题页:Backward Attention in Blackwell

传统注意力反向核函数

传统的注意力反向核函数实现中,内存布局和流水线设计如下:

  • 配置参数:

    • Q TILE Step: 128
    • KV TILE Step: 128
    • Head_dim: 128
  • 资源分配:

    • MMA: 1 Warp
    • TMA Load: 1 Warp
    • CUDA Core: 8 Warps
    • Store dQ: 4 Warps

下图展示了在共享内存(Shared Memory)和张量内存(Tensor Memory)中的数据布局。

Page 42 传统注意力反向核函数的内存布局
Page 42 传统注意力反向核函数的内存布局

其计算流水线大致如下,TMA 负责加载数据,Tensor Core 和 CUDA Core 交替执行计算。注意,此图可能不完全代表真实的流水线。

Page 43 传统注意力反向核函数的计算流水线示意图
Page 43 传统注意力反向核函数的计算流水线示意图

MLA 注意力反向核函数

对于 MLA,实现上存在一些差异和挑战。

  • 差异: QK_head_dim 从 128 增加到 192。
  • 挑战: 当 Q_STEPKV_STEP 仍为 128 时,由于 head_dim 增大,会导致内存溢出。
Page 44 MLA 反向核函数面临的挑战及内存布局
Page 44 MLA 反向核函数面临的挑战及内存布局
  • 解决方案: 为了解决内存溢出问题,将 Q_Tile_STEP 从 128 减小到 64。

这一调整改变了共享内存和张量内存中与 Q 相关的张量(如 Q, dO, S, dQ)的分块大小,从而适应了更大的 head_dim,避免了内存问题。下图展示了调整后的内存布局。

Page 45 针对 MLA 调整 Q_Tile_STEP 后的解决方案及内存布局
Page 45 针对 MLA 调整 Q_Tile_STEP 后的解决方案及内存布局

针对 MLA 的优化

Attention 反向核 (MLA) 基准测试

  • 性能: 达到默认反向核(d_qk=d_vo=128)性能的 85%。

  • 基准测试结果

Page 46
Page 46
  • 实现细节:
    • 这是首个在 Blackwell 平台上针对 MLA 反向传播的开源实现(2025年7月25日)。
    • 已合并到 CUTLASS 示例(PR链接)和 FlashMLA(PR链接)。

针对反向传播中非均衡数据的优化

反向 Attention 中的非均衡数据

当数据分布不均衡时,会出现性能下降问题。例如,一个批次(Batch 0)包含大量数据,而其他批次(Batch 1-7)数据量很小。这种不均衡会导致性能下降至 300 TFLOPS。

ComputeSumOdO 计算中,对于非均衡数据 Data = [10000] + 99*[1],延迟高达 10.3ms,而对于均衡数据 Data = [100]*100,延迟仅为 0.382ms。

Page 48
Page 48

ComputeSumOdO 核布局

原始布局
  • GridDim: [ceil(q_max_len / Q_BLOCK), head_num, batch_size]
  • BlockDim: [8, 16, 1]
  • 逻辑: 原始设计假设一个 CTA(Cooperative Thread Array)内的所有 token 属于同一个批次(基于 blockIdx.x)。所有具有相同 threadIdx.ythreadIdx.x 计算一个 Sum(OdO)
目标布局(优化后)
  • GridDim: [ceil(total_q_len / Q_BLOCK), head_num]
  • 逻辑: 根据 q_idx 在总长度中的位置来确定其 batchID,使用二分搜索(Binary Search)在 cu_q_len[] 中查找。
Page 49
Page 49

优化后的 ComputeSumOdO 核布局

优化后的内核在循环的第一次迭代中使用二分搜索确定 bs_id,在后续迭代中如果 q_idx 超出了当前批次的长度,则查找新的批次。

Page 50
Page 50

性能

优化后,ComputeSumOdO 在处理非均衡数据 Data: [10000] + 99*[1] 时的延迟从基线的 10.3ms 显著降低到 0.26ms。相关的 PR 将于本月提交至 FlashMLA。

Page 51
Page 51

未来工作 (Future Works)

  • 继续在 Blackwell 平台上优化 DSA 稀疏解码核(DSA Sparse Decoding Kernel)。
  • MLA 反向核(Backward Kernel)的优化。
  • 尝试设计一个框架,以简化 Attention 核的开发,涵盖:
    • Attention
    • MLA Attention
    • Backward Attention
    • Sparse Attention 等。

谢谢!