Best Practice of MLA Kernel Optimization on Blackwell
Best Practice of MLA Kernel Optimization on Blackwell
王泽宇, NVIDIA GPU加速计算专家团队 高级工程师 | November 7, 2025
目录 (Agenda)
-
DSA 核优化 (DSA Kernel Optimization)
- V3.2 和 DSA 介绍 (V3.2 and DSA Introduction)
- DSA 核的挑战 (Challenges of DSA Kernel)
- Blackwell 平台概述 (Blackwell Platform Overview)
- 初尝:DSA 稀疏预填充 (First Taste: DSA Sparse Prefill)
- 第二项任务:DSA 稀疏解码 (Second Task: DSA Sparse Decoding)
-
MLA 反向核优化 (MLA Backward Kernel Optimization)
- 背景 (Background)
- Blackwell 中的反向注意力机制 (Backward Attention in Blackwell)
- 针对 MLA 的优化 (Optimization for MLA)
- 针对 Blackwell 中不平衡数据的优化 (Optimization for Unbalanced Data in Backward)
-
未来工作 (Future Works)
DSA 核优化 (DSA Kernel Optimization)
DeepSeek V3.2 介绍
DeepSeek V3.2 是一个新的“实验性”模型,旨在提升长上下文效率。其基准测试性能和 API 定价如下:
该模型的一个关键特性是稀疏注意力(Sparse Attention)。
DSA: DeepSeek 稀疏注意力机制 (DeepSeek Sparse Attention)
DSA 的核心思想是在推理过程中仅选择 TopK 个 KV Token 进行注意力计算,以降低延迟。
标准的注意力机制(如 Multi-Head Attention, MHA)需要新的查询(Q)Token 与所有的键值(KV)Token 进行计算。
DSA 引入了一个 TopK 索引器(TopK Indexer),它会从大量的 KV Token 中筛选出最相关的 TopK 个。
通过这种方式,注意力核(Attention Kernel)只需在选定的 TopK KV Token 子集上进行计算,从而显著减少计算量和延迟。与传统的 MLA(Multi-Layer Attention)相比,DSA 的延迟大幅降低。本次演讲的重点在于对此过程中的注意力核进行优化。
DSA 核的挑战 (Challenges of DSA Kernel)
DSA 本质是带稀疏性的 MLA
DSA 可以被理解为带有稀疏性的 MLA (Multi-Query Attention)。下图对比了 MHA (Multi-Head Attention) 和 MQA (Multi-Query Attention) 的张量结构。
-
MHA (Prefill 阶段)
- Head_num(QKV) = 128
- QK_dim = 192
- VO_dim = 128
-
MQA (Decoding 阶段)
- Head_num(Q) = 128
- QK_dim = 576
- VO_dim = 512
回顾:MLA 注意力核 (基于 Hopper 的 FMHA_V2 优化)
在讨论 DSA 的挑战之前,先回顾一下之前在 Hopper 架构上使用 FMHA_V2 对 MLA 注意力核的优化经验。
-
挑战:
- 单个 warpgroup 无法存储所有的累加器(Accumulators)。计算公式为
64 x 512 / 128 = 256 regs/thread。
- 单个 warpgroup 无法存储所有的累加器(Accumulators)。计算公式为
-
有利条件:
- K 和 V 张量是相同的。
-
解决方案:
- 使用 1 个工作组(WG)处理 TMA (Tensor Memory Accelerator)。
- 仅加载 K 的部分数据是可行的,因为
V = K[:, :512]。 - 使用 2 个工作组(WG)处理 TensorCore 计算。
下图展示了该优化方案的流水线操作:
该优化带来了显著的性能提升,如下方的 TFLOPS 性能图所示,FMHA Opt 的性能远超其他方法。
FP8 KV 缓存布局
上图展示了DeepSeek稀疏注意力(DSA)中FP8 KV(键值)缓存的内存布局。每个令牌(token)占用656字节。
- 主要部分:
512 e4m3(512字节)用于存储没有位置编码(NoPE)的KV缓存,其构成是512xFP8 + 4xFP32。这里的4xFP32是缩放因子(Scaling factor)。 - 附加部分:64字节用于存储旋转位置编码(RoPE),格式为
64xbf16。 - 计算精度:尽管KV缓存使用了FP8和BF16格式,但所有的注意力(Attention)核函数计算仍然在BF16精度下进行。
DSA 带来的额外挑战
在 MLA 优化的基础上,DSA 引入了新的挑战:
-
稀疏性处理 (Sparsity):
- 张量内存访问(TMA)难以高效处理稀疏数据,这会导致带宽的浪费。
-
预填充 (Prefill) 阶段的挑战 (在长上下文场景中):
- 在长上下文场景中,标准的多头注意力(MHA)表现不佳。
- MHA方法:采用自定义掩码(Custom Mask)。
- 不同的查询(Q)令牌会关注不同的top-K键值(KV)令牌。
- 示例:当序列长度(Seq len)为64K,top-K值为2K时,对于每个Q,在每128个KV TILE(块)中仅选择4个令牌。
- 因此,在预填充阶段,需要一个稀疏的矩阵乘法累加(MLA)核函数。
-
解码 (Decoding) 阶段的挑战:
- FP8 的反量化 (Dequantization):将FP8数据转换回更高精度的格式会带来计算开销。
- 高 CUDA 核心压力:反量化等操作会给CUDA核心带来很大的计算压力。
Blackwell 平台概述
本节介绍NVIDIA Blackwell平台的概览。
Blackwell平台的优势与机遇
-
张量内存(Tensor Memory)
- 该内存用于替代Hopper架构中的寄存器(regs),以存储
matrix_A和matrix_O。 - 容量为256KB。
- 数据移动:
- 与寄存器(REG)交互:
tcgen05.ld/st - 与共享内存(SMEM)交互:
tcgen05.cp
- 与寄存器(REG)交互:
上图展示了张量内存的布局和寻址方式,它由128个Lane组成,每个Lane为2KB。 - 该内存用于替代Hopper架构中的寄存器(regs),以存储
-
2-CTA Tensorcore GEMM
- 这是一种使用两个协作线程块(CTA)执行通用矩阵乘法(GEMM)的技术。
- 优势:可以减少共享内存(SMEM)的带宽需求。
- 由一个集群(cluster)中的一个线程束(warp)启动。
上图展示了2-CTA GEMM的概念,其中A和B矩阵被分配到两个不同的CTA(CTA0和CTA1)中进行计算。 -
TMA.Gather4
- 这是一种稀疏张量内存访问(TMA)功能。
- 引入了额外的
int4 idx输入,用于指定需要收集的数据索引。
上图(图10)展示了tiled::scatter4/tiled::gather4模式下边界框(bounding box)的一个示例,说明了如何根据不同的起始坐标从全局内存中收集(gather)数据到共享内存。
初尝:DSA 稀疏预填充(Sparse Prefill)
本节将初步探讨DSA在预填充阶段的实现。
DSA稀疏预填充概述
-
QKV 配置:
- Q: [seq_len, 128, 576]
- K: [seq_len, 576]
- V: [seq_len, 512] == K[:, 0: 512]
-
核函数设置 (Kernel Setup):
- CTA维度 (CTA dim): [2*q_len, 1, 1]
-
2-CTA集群 (2-CTA Cluster): 采用两个协作线程块(CTA)组成的集群。
- 每个Q令牌的128个头(heads)被分配到集群中,每个CTA处理64个头。
- 两个CTA共享相同的topK KV令牌。
-
图中提出了一个关键问题:“如何在2-SM GEMM中处理V?”
-
架构图:
- Q矩阵(128个头)被分割成两个Q Tile(Q CTA0, Q CTA1)。
- K/V矩阵(1个头)通过TMA.Gather4被加载到两个KV Tile(KV CTA0, KV CTA1)。
- 整个计算由一个2-CTA集群完成。
DSA稀疏预填充内存布局
-
共享内存 (Shared Memory):
- 同时存储K和V(受益于L2缓存)。
- K 形状: [64, 576],占用72KB。
- V 形状: [2*64, 256],占用64KB。
- Q 形状: [64, 576],占用72KB。
- S 形状: [64, 128],占用16KB。
- 总计: 224KB (总限制为228KB)。
- 最终输出: [64, 512] BF16 (重用内存)。
-
张量内存 (Tensor Memory):
- P = QK^T: (64 * 128 FP32 = 32KB)。
- O = SV: (64 * 512 FP32 = 128KB)。
- 总计: 160KB。
-
Q Offloading to TMEM: Q矩阵被卸载到张量内存中。
DSA稀疏预填充流水线与基准测试
-
流水线 (Pipeline):
上图展示了TMA、Tensorcore和CUDA Core之间的并行执行流水线。TMA负责加载数据,Tensorcore执行矩阵乘法(QK, SV),CUDA Core执行数学函数(MUFU)和缩放(Scale)。 -
线程配置:
- 总共512个线程(4个Warpgroup)。
- 128个线程用于TMA-K,128个线程用于TMA-V。
- 128个线程用于softmax和重缩放(re-scale)。
- 32个线程用于MMA(矩阵乘法累加),32个线程用于KV topK索引。
-
基准测试 (Benchmark):
- 性能达到 1400 Tflops。
- 这相当于理论峰值性能(2000 Tflops)的 70%。
第二项任务:DSA 稀疏解码(Sparse Decoding)
本节介绍DSA在解码阶段的实现。
DSA稀疏解码的挑战:FP8布局与反量化
- 配置:
Q_LEN = 1,如果启用了多令牌并行(MTP),则可能大于1。使用splitKV来平衡SM的工作负载。 - FP8 KV布局:
-
反量化 (Dequantization) of FP8:
- FP8 -> FP16: 64 ops/clk/SM
- FP16 -> FP32: 64 ops/clk/SM
- FP32 -> BF16: 16 ops/clk/SM
- 在Blackwell上,BF16 Scale操作为128 ops/clk/SM。
- 将64512的FP8数据转换为BF16需要 3328 CLK/CTA*。
-
GEMM:
- BMM1(QK): 26464*576 = 4.6M ops/CTA
- BMM2(SV): 26464*512 = 4Mops/CTA
- Tensorcor性能: 8192 ops/CLK/SM
- 总计GEMM时钟周期: 1088 CLK/CTA。
-
解决方案: 使用带有分布式共享内存(DSMEM)的2-CTA对 (2-CTA pair with DSMEM)。
稀疏解码中的2-CTA反量化
该方案通过2-CTA集群和分布式共享内存(DSMEM)来解决反量化带来的高CUDA核心压力。
* 工作原理: 每个CTA负责一半的反量化工作,并通过DSMEM进行多播(Multi-cast)。
* 流程: 512字节的e4m3数据被分成两部分,分别加载到两个CTA的本地共享内存(SHM0的FP8-CTA0和SHM1的FP8-CTA1)中。经过反量化后,结果通过DSMEM共享到2-CTA集群的共享内存中,形成BF16-CTA0和BF16-CTA1。
* 优势:
* 两个CTA共享相同的TopK KV令牌。
* 将CUDA核心压力减半,降至 1664 CLK/CTA。
DSA稀疏解码内存布局
-
共享内存 (Shared Memory) (双缓冲):
- Q_Part0: (64*256 BF16 = 32KB)
- K_FP8[2]: (642562 FP8 = 32KB)
- K_BF16[2]: (645762 BF16 = 144KB)
- S[2]: (64642 BF16 = 16KB)
- 总计: 224KB
-
张量内存 (Tensor Memory):
- P[2] = QK^T: (64642 FP32 = 32KB)
- O = SV: (64*512 FP32 = 128KB)
- Q_Part1: (64*320 BF16 = 40KB)
- 总计: 200KB
DSA稀疏解码 - 进行中:流水线与当前进展
-
理想流水线 (Ideal Pipeline):
上图展示了稀疏解码阶段的理想化流水线,显示了TMA、Tensorcore和CUDA Core(执行反量化、MUFU、Scale)之间的并行工作流。 -
CTA维度: [2, q_len, sm_parts]
-
线程配置:
- 384个线程 (3个Warpgroup)。
- 128个线程用于TMA。
- 128个线程用于CUDA Core (反量化 & softmax & re-scale)。
- 32个线程用于MMA,32个线程用于KV topK索引。
-
当前进展 (Current Progress):
- 性能达到 400 Tflops。
- 相比FlashMLA有 10% 的增益。
- 仍有进一步优化的空间。
假设:FP8->BF16 是一条单一指令
下图展示了一种假设情况的流水线,即如果从 FP8 到 BF16 的转换可以由一条单一指令完成。在这种优化下,原本在 CUDA Core 上执行的多个反量化(Dequantization)、乘法融合(MUFU)和缩放(Scale)操作可以被整合,从而简化执行流程,提高效率。TMA(Tensor Memory Accelerator)负责加载数据,Tensor Core 执行核心的矩阵运算,而 CUDA Core 的负担减轻。
MLA 反向核优化 (MLA Backward Kernel Optimization)
背景:注意力的前向与反向传播
前向传播
标准注意力机制的前向传播过程如下图所示。它主要由两个通用矩阵乘法(GEMM)操作和一个 Softmax 操作组成:
- 查询(Q)与键的转置(Kᵀ)进行矩阵相乘,得到分数矩阵 P。
- 对 P 应用 Softmax 函数,得到注意力权重矩阵 S。
- S 与值(V)进行矩阵相乘,得到最终输出 O。
反向传播
反向传播过程计算输出 O 对输入 Q、K、V 的梯度(分别为 dO、dQ、dK、dV)。梯度流与前向传播的计算图方向相反。
反向传播的具体计算公式如下:
* $P = Q * K^T$
* $S = Softmax(P) = exp(P - lse)$
* $dV = S^T * dO$
* $dS = dO * V^T$
* $dP = S \circ (dS - sum(O \circ dO))$ (其中 $\circ$ 表示逐元素相乘)
* $dQ = dP * K$
* $dK = dP^T * Q$
从计算角度看,反向传播过程主要包含 5 个 GEMM 操作和 2 个由 CUDA Core 执行的操作。
注意力反向核函数流程
注意力反向传播的计算流程涉及以下几个关键步骤。
为了优化计算,一些中间值可以预先计算或在前向传播时计算并保存下来:
* lse (log-sum-exp) 在前向传播时计算。
* sum(O ◦ dO) 可以在反向传播主循环开始前预先计算。
在实现核函数时,循环的顺序是一个关键的设计选择。两种常见的策略是:
1. 外层 KV,内层 QO: 外层循环遍历 KV 的分块(tile),内层循环遍历 Q 的分块。在内层核函数中累加 dQ。
2. 外层 QO,内层 KV: 外层循环遍历 Q 的分块,内层循环遍历 KV 的分块。在内层核函数中累加 dK 和 dV。
下图展示了注意力反向传播的数据流。首先计算 sumOdO,然后将其与 dO, Q, K, V 一同输入到主计算模块 Backward Attn 中,得到 dK 和 dV,并累加生成最终的 dQ。
Blackwell 架构上的反向注意力
本节将讨论在 Blackwell 架构上实现反向注意力的具体细节。
传统注意力反向核函数
传统的注意力反向核函数实现中,内存布局和流水线设计如下:
-
配置参数:
- Q TILE Step: 128
- KV TILE Step: 128
- Head_dim: 128
-
资源分配:
- MMA: 1 Warp
- TMA Load: 1 Warp
- CUDA Core: 8 Warps
- Store dQ: 4 Warps
下图展示了在共享内存(Shared Memory)和张量内存(Tensor Memory)中的数据布局。
其计算流水线大致如下,TMA 负责加载数据,Tensor Core 和 CUDA Core 交替执行计算。注意,此图可能不完全代表真实的流水线。
MLA 注意力反向核函数
对于 MLA,实现上存在一些差异和挑战。
- 差异:
QK_head_dim从 128 增加到 192。 - 挑战: 当
Q_STEP和KV_STEP仍为 128 时,由于head_dim增大,会导致内存溢出。
- 解决方案: 为了解决内存溢出问题,将
Q_Tile_STEP从 128 减小到 64。
这一调整改变了共享内存和张量内存中与 Q 相关的张量(如 Q, dO, S, dQ)的分块大小,从而适应了更大的 head_dim,避免了内存问题。下图展示了调整后的内存布局。
针对 MLA 的优化
Attention 反向核 (MLA) 基准测试
-
性能: 达到默认反向核(d_qk=d_vo=128)性能的 85%。
-
基准测试结果
针对反向传播中非均衡数据的优化
反向 Attention 中的非均衡数据
当数据分布不均衡时,会出现性能下降问题。例如,一个批次(Batch 0)包含大量数据,而其他批次(Batch 1-7)数据量很小。这种不均衡会导致性能下降至 300 TFLOPS。
在 ComputeSumOdO 计算中,对于非均衡数据 Data = [10000] + 99*[1],延迟高达 10.3ms,而对于均衡数据 Data = [100]*100,延迟仅为 0.382ms。
ComputeSumOdO 核布局
原始布局
GridDim:[ceil(q_max_len / Q_BLOCK), head_num, batch_size]BlockDim:[8, 16, 1]- 逻辑: 原始设计假设一个 CTA(Cooperative Thread Array)内的所有 token 属于同一个批次(基于 blockIdx.x)。所有具有相同
threadIdx.y的threadIdx.x计算一个Sum(OdO)。
目标布局(优化后)
GridDim:[ceil(total_q_len / Q_BLOCK), head_num]- 逻辑: 根据
q_idx在总长度中的位置来确定其batchID,使用二分搜索(Binary Search)在cu_q_len[]中查找。
优化后的 ComputeSumOdO 核布局
优化后的内核在循环的第一次迭代中使用二分搜索确定 bs_id,在后续迭代中如果 q_idx 超出了当前批次的长度,则查找新的批次。
性能
优化后,ComputeSumOdO 在处理非均衡数据 Data: [10000] + 99*[1] 时的延迟从基线的 10.3ms 显著降低到 0.26ms。相关的 PR 将于本月提交至 FlashMLA。
未来工作 (Future Works)
- 继续在 Blackwell 平台上优化 DSA 稀疏解码核(DSA Sparse Decoding Kernel)。
- MLA 反向核(Backward Kernel)的优化。
- 尝试设计一个框架,以简化 Attention 核的开发,涵盖:
- Attention
- MLA Attention
- Backward Attention
- Sparse Attention 等。