文章标题: FlashMLA-ETAP: 用于加速NVIDIA H20 GPU上MLA推理的高效转置注意力流水线
作者: Pengcuo Dege, Qiuming Luo, Rui Mao, Chang Kong
机构: 腾讯 (Tencent), 深圳大学计算机与软件工程学院 (Shenzhen University), 深圳职业技术大学人工智能学院 (Shenzhen Polytechnic University)

A1 主要贡献

核心问题: Transformer架构中的注意力机制(如多头注意力MHA和多头潜在注意力MLA)具有与序列长度平方成正比的计算复杂度,这在处理长上下文任务时构成了严重的性能瓶颈。这一挑战在中端GPU(如NVIDIA H20)上尤为突出,其FP16计算能力仅为148 TFLOPS,远低于H100/H800等高端GPU。在单台8卡H20服务器上部署像DeepSeek-R1 671B这样的大模型时,由于内存限制,必须将128个注意力头拆分,导致每块GPU仅处理16个头。这使得WarpGroup矩阵乘累加(WGMMA)指令的M维度(16)低于其效率最优的最小要求(64),从而引发了大量的冗余填充(padding),导致在解码阶段计算利用率常低于25%。

研究目标与创新点: 为解决上述问题,本文提出了FlashMLA-ETAP,一个专为提升中端GPU(如NVIDIA H20)上MLA推理效率而设计的先进框架。其核心创新是高效转置注意力流水线(Efficient Transpose Attention Pipeline, ETAP)。ETAP通过引入策略性转置操作,重构了传统的注意力计算流程。在解码场景中,查询(Query)长度通常很短,而键值(Key-Value)上下文长度很长。ETAP利用这一不对称性,将长的KV上下文长度与WGMMA操作中的M维度对齐,从而消除了因查询维度过小而产生的冗余填充,极大地优化了H20有限计算资源的利用率,并减少了内存访问瓶颈。

主要贡献如下:
1. 提出高效转置注意力流水线 (ETAP):引入了一种新颖的计算模式,利用转置操作来减少冗余计算,该模式专为NVIDIA H20等中端GPU的架构约束而设计。
2. 开发 FlashMLA-ETAP 框架: 将ETAP集成到现有的FlashMLA框架中,显著提升了中端GPU上的MLA推理效率。ETAP的设计具有通用性,可以方便地集成到FlashInfer和FlashAttention-3等其他注意力框架中。
3. 实验验证: 实验表明,在NVIDIA H20 GPU上,FlashMLA-ETAP框架相较于基线FlashMLA框架,在64K序列长度、批大小为16时,实现了高达2.78倍的加速。

A3 背景知识与设计动机

现有框架的局限性与设计动机。先前章节中强调,尽管像FlashAttention-3和FlashMLA这样的现有注意力优化框架在高端GPU上取得了显著效率,但由于NVIDIA H20的FP16计算能力仅为148 TFLOPS,这些框架在其上的表现不佳。我们的分析表明,FlashMLA虽然为MLA推理采用了低秩压缩技术,但由于WGMMA指令中的冗余填充和低效的内存访问模式,其计算利用率常常低于25%。这些低效问题在推理过程中尤其严重,因为长的KV上下文长度和短的查询长度加剧了工作负载与H20架构之间的不匹配。受此差距的启发,我们提出了一种新颖的解决方案,该方案通过重新配置注意力计算过程以适应H20的能力,为显著的性能提升奠定了基础。本章介绍了为减少冗余计算而设计的新计算模式——高效转置注意力流水线(ETAP),并详细说明了其如何集成到FlashMLA中以创建FlashMLA-ETAP,从而实现更高的推理效率。

A2 方法细节

3.1 高效转置注意力流水线 (ETAP)

原始MLA推理计算模式。在推理阶段,MLA的每个头都遵循源自注意力机制的标准计算流程。给定查询(Q)、键(K)、值(V)矩阵 $Q, K, V \in R^{N \times d}$,其中N是序列长度,d是头维度,注意力输出O的计算公式如下:

在实践中,MLA对K和V应用低秩联合压缩以减小KV缓存大小,这对于像DeepSeek-V3这样的大型模型至关重要。然而,在推理过程中,特别是在解码阶段,查询长度通常很短(例如1或2个token),而KV缓存的上下文长度可能非常长(例如1K到32K个token),这加剧了计算上的不平衡。

原始模式的挑战。原始的计算模式在NVIDIA H20 GPU上带来了重大挑战,这主要是由于工作负载与GPU架构约束之间的不匹配。首先,为Hopper架构优化的WGMMA指令要求M维度至少为64才能高效执行。在解码场景中,由于内存不足,DeepSeek-R1 671B模型的128个头需要被拆分到8个GPU上,导致每个GPU的头数量减少到16个,这低于WGMMA的最低要求,从而引起冗余填充,计算利用率仅为25%或更低。其次,长的KV上下文长度导致K和V矩阵变大,增加了内存访问延迟和对HBM带宽的压力,而H20缺乏像H100/H800等高端GPU那样丰富的资源。这些低效问题表现为推理时间延长,在像DeepSeek-V3这样的模型中,MLA约占解码前向传播的30%(例如,BS=16, ContextLength=16K),严重影响了资源受限环境下的实时性能。

ETAP:一种新颖的计算模式。为了应对这些挑战,我们提出了高效转置注意力流水线(ETAP),这是一种利用转置来减少冗余计算的新颖计算模式。ETAP通过转置操作重新配置了标准的注意力计算,利用了推理过程中查询和KV长度之间的不对称性。具体来说,ETAP不再计算:

而是计算:

其中 $N_q$ 是查询长度,$N$ 是KV上下文长度。通过将KV上下文长度视为WGMMA操作中的M维度,并将查询头的数量视为N维度,ETAP使工作负载与H20的架构约束对齐,从而消除了在查询维度上进行冗余填充的需要。最终的转置 ($O = O^T$) 只执行一次,而内部的点积计算则在整个长上下文长度上反复受益于效率的提升。

ETAP优势的理论分析。ETAP的主要优势在于其能够减少填充开销并提高H20的148 TFLOPS FP16计算能力的利用率。在原始模式中,当M维度(查询长度)很小时,需要进行填充以满足WGMMA的要求,这引入了一个与查询长度成反比的填充因子。ETAP通过将M维度重新定义为KV上下文长度来缓解此问题,因为KV上下文长度本身就很大,不需要填充,而N维度(与查询相关)则可以在没有额外约束的情况下被高效处理。这种重新定位减少了冗余计算的比例,从理论上提高了计算效率。确切的改进取决于查询长度和KV上下文长度的相对大小,当上下文长度相对于查询长度增加时,会观察到更大的收益。这种方法为填充问题提供了一个可扩展的解决方案,为计算资源受限的中端GPU优化了注意力机制。

3.2 FlashMLA-ETAP 算法

集成到MLA框架。ETAP的简洁性和通用性使其成为一种多功能的计算模式,可以无缝集成到各种MLA框架中,包括FlashAttention-3、FlashInfer和FlashMLA。通过转置来重新定向注意力计算,ETAP对现有流水线仅需做最小的改动,主要是在矩阵乘法和softmax阶段,同时保持了数值稳定性和与MLA中的低秩压缩等技术的兼容性。这种兼容性源于这些框架共享的注意力计算模式——$S = Q \cdot K^T$, $P = \text{softmax}(S)$, $O = P \cdot V$——ETAP将其重构为一种转置形式,以符合H20的WGMMA约束。本节末尾提供了ETAP集成到FlashAttention-3和FlashInfer的详细理论分析,重点介绍了它们的架构相似之处以及所需的最小修改。

FlashMLA-ETAP:算法与实现。FlashMLA-ETAP通过集成ETAP计算模式来扩展FlashMLA,为H20的硬件约束量身定制了注意力流水线。核心修改在于注意力计算的转置,我们在FlashMLA的协作线程阵列(CTA)框架内实现这一点,利用生产者-消费者同步和分块处理。下面,我们展示了FlashMLA-ETAP前向传播的伪代码,并附有详细解释。

Algorithm 1 FlashMLA-ETAP 前向传播,带消费者内部重叠 – CTA视角

Require: 矩阵 Qi ∈ R^{Br×d} 和 K, V ∈ R^{N×d} 位于HBM, 键块大小为Bc,Tc = N/Bm。V = [V0, V1], V0, V1 ∈ R^{N/2×d}
1: 初始化流水线对象,管理具有s阶段循环SMEM缓冲区的屏障同步。
2: 在片上,初始化 Oi = 0 ∈ R^{d×Br} 和 ℓi, mi = 0, (−∞) ∈ R^{Br} , Oi = [Oi0, Oi1], Oi0, Oi1 ∈ R^{d/2×Br}
3: if 在消费者 warpgroup0 中 then
4:   根据消费者 warpgroup0 的数量重新分配预定数量的寄存器。
5:   等待 Qi 加载到共享内存中。
6:   for 0 ≤ j < Tc do
7:     Syncthreads。等待 Kj 加载到共享内存中。
8:     计算 S(j)i = Kj * Q⊤i (SS-GEMM)。提交并等待。
9:     存储 moldi = mi 并计算 mi = max(moldi , rowmax(S(j)i ))。
10:    计算 Pe(j)i = exp(S(j)i − mi) 和 ℓi = exp(moldi − mi) * ℓi + colsum(Pe(j)i )。
11:    等待 Vj 加载到共享内存中。
12:    计算 Ri = diag(exp(moldi − mi))⁻¹
13:    将 Ri 存储到共享内存。
14:    计算 Oi0 = Ri * Oi0 + V⊤j0 * Pe(j)i (SS-GEMM)。提交并等待。
15:    释放缓冲区的第 (j mod s) 阶段供生产者使用。
16:  end for
17: else in 生产者 warpgroup1
18:   根据生产者 warpgroup1 的数量重新分配预定数量的寄存器。
19:   发出从 HBM 加载 Qi 到共享内存的指令。
20:   发出从 HBM 加载 K0 到共享内存的指令。
21:   for 1 ≤ j < Tc do
22:     CopyAsyncWait and Syncthreads
23:     发出将 Kj, Vj 从 HBM 加载到缓冲区第 (j mod s) 阶段的共享内存的指令。
24:     NamedBarrier.Sync warpgroup0 and warpgroup1
25:     从共享内存复制 Ri。
26:     计算 Oi1 = Ri * Oi1 + V⊤j1 * Pe(j)i (SS-GEMM)。提交并等待。
27:   end for
28: end if
29: 计算 Oi = diag(ℓi)⁻¹ * Oi 和 Li = mi + log(ℓi)。
30: 转置 Oi = O⊤i。
31: 将 Oi 和 Li作为 O 和 L 的第 i 个块写入 HBM。

FlashMLA-ETAP伪代码解释。FlashMLA-ETAP算法在CTA内运行,利用生产者-消费者同步来管理数据移动和计算。流水线首先初始化一个循环共享内存(SMEM)缓冲区来处理分块操作(第1行)。输出矩阵 $O_i$ 和softmax统计数据 ($ℓ_i, m_i$) 在片上初始化,其中 $O_i$ 被分成两个部分 ($O_{i0}, O_{i1}$) 以实现消费者内部的重叠计算(第2行)。计算被分配给两个warp组:消费者warp组0和生产者warp组1。
* 消费者 Warpgroup0 (第4-14行): 该warp组使用转置形式 $S^{(j)}_i = K_j Q^T_i$ 计算注意力分数,这反映了ETAP的方法(第7行)。它计算softmax统计数据(第8-9行)并更新输出的第一部分 $O_{i0}$(第10-12行),利用共享内存存储中间结果 ($R_i$) 以进行同步。
* 生产者 Warpgroup1 (第15-23行): 该warp组负责将数据从HBM加载到SMEM(第16-19行),通过命名屏障与消费者同步(第20行),并使用共享的 $R_i$ 计算第二部分 $O_{i1}$(第21-22行)。ETAP固有的转置操作确保了维度与WGMMA指令的高效对齐。
* 收尾阶段 (第24-26行): 最后步骤对输出进行重新缩放,应用ETAP所需的转置 ($O_i = O^T_i$),并将结果写回HBM。

这种实现通过将KV上下文长度与M维度对齐,最小化了填充开销,从而在NVIDIA H20上实现了显著的性能提升,同时保持了与FlashMLA的低秩压缩和推理优化的兼容性。

ETAP集成到FlashAttention-3和FlashInfer的理论分析。为了证实ETAP可以无缝集成到如FlashAttention-3【10, Shah, J., et al., Flashattention3: Fast and accurate attention with asynchrony and low-precision, 2024】和FlashInfer【13, Ye, Z., et al., Flashinfer: Efficient and customizable attention engine for llm inference serving, 2025】等其他注意力框架中的论断,我们对其与这些框架计算流水线的兼容性进行了理论分析,重点关注影响集成可行性的结构相似点和不同点。FlashAttention-3和FlashInfer与FlashMLA一样,都旨在优化现代GPU上的注意力机制,但它们在实现策略和目标硬件优化上有所不同,因此需要仔细研究ETAP基于转置的方法如何与它们的架构相协调。

FlashAttention-3 是FlashAttention系列的演进版本,它采用I/O感知的切片(tiling)、在线softmax计算和异步执行来最小化高带宽内存(HBM)访问,并利用了张量内存加速器(TMA)等高级GPU特性。其核心流水线计算注意力分数 $S = Q \cdot K^T$,应用softmax得到概率 $P = \text{softmax}(S)$,并计算输出 $O = P \cdot V$,这与FlashMLA中的标准MLA计算相似。然而,FlashAttention-3的切片策略将计算划分为块以优化内存访问,当查询长度(例如1或2个token)与KV上下文长度(例如16K个token)相比很小时,这个过程会在像H20这样的中端GPU上引入填充开销。ETAP的转置方法将计算重构为 $S^T = K \cdot Q^T$,$P^T = \text{softmax}(S^T)$,以及 $O = (V^T \cdot P^T)^T$,通过将KV上下文长度与WGMMA操作中的M维度对齐,直接解决了这种低效问题。将ETAP集成到FlashAttention-3中,需要修改其矩阵乘法和softmax阶段以采用这种转置计算流。由于FlashAttention-3已经采用了分块处理,转置可以在块级别应用,从而保留其I/O感知的切片策略。主要的调整在于重新定义块维度,优先将KV上下文长度作为M维度,这只需对现有流水线进行微小改动即可实现,因为softmax和输出计算阶段在转置后功能上是等价的。此外,FlashAttention-3使用TMA进行异步内存传输,不受ETAP修改的影响,确保了其与高端GPU优化的兼容性,同时将其效率扩展到了中端GPU。

FlashInfer 是另一个优化的注意力框架,专注于长上下文场景下的高效推理,通常采用核函数融合和自定义CUDA实现来减少开销。与FlashMLA包含用于KV缓存的低秩联合压缩不同,FlashInfer通常处理完整的注意力矩阵,但共享相同的基础计算模式:$S = Q \cdot K^T$,$P = \text{softmax}(S)$,以及 $O = P \cdot V$。FlashInfer的设计通过融合核函数来强调延迟降低,但这在查询长度较小时会加剧H20上的填充效率低下问题,与FlashAttention-3类似。ETAP的转置策略可以通过修改FlashInfer的融合核函数以计算 $S^T = K \cdot Q^T$ 及后续的转置形式步骤来集成。此调整需要重构核函数以处理转置后的维度,但核心计算——矩阵乘法后接softmax——在结构上保持一致,确保了FlashInfer的核函数融合优势得以保留。此外,由于FlashInfer不依赖低秩压缩,ETAP与此类技术的兼容性(如在FlashMLA中所示)不成问题,从而简化了集成过程。转置可以作为核函数内的预处理步骤来实现,使工作负载与H20的WGMMA约束对齐,而不改变FlashInfer的整体优化策略。

从理论上讲,ETAP集成到这两个框架中的优势在于,它专注于解决注意力机制中的一个普遍挑战:在短查询和长KV上下文的解码场景中的填充开销。这个问题并非FlashMLA独有,而是影响任何在具有类似架构约束的中端GPU上运行的框架。通过重新定向计算以最小化填充,ETAP提高了各种框架的计算利用率,主要的集成成本在于调整矩阵乘法和softmax阶段以适应转置计算。ETAP的数值稳定性,如我们在FlashMLA-ETAP实验中所验证的,进一步支持了其适用性,因为转置不会引入显著的浮点误差,这对于支持如FP8等低精度格式的FlashAttention-3等框架来说是一个关键考虑因素。此外,ETAP对WGMMA指令的依赖,虽然是为H20优化的,但与其它基于Hopper架构的GPU兼容,这表明其优势可以扩展到除H20之外具有类似指令集的中端平台。这一分析凸显了ETAP作为一种通用解决方案的设计,能够提升各种注意力框架在中端GPU上的效率,从而支持其在硬件感知推理优化中更广泛应用。

A4 实验环境与结果

实验环境

  • 硬件配置:
    • GPU: 单块 NVIDIA H20 GPU,基于NVIDIA Hopper架构。
    • GPU内存: 96GB HBM3,内存带宽4.0TB/s。
    • GPU计算能力: 148 TFLOPS (BF16/FP16)。
  • 模型与软件配置:
    • 模型: DeepSeek-R1。
    • 模型参数: 16个注意力头,头维度为576。
    • 软件框架: 对比了FlashMLA-ETAP、FlashAttention-3、FlashInfer和FlashMLA。
    • 精度: FP16。
  • 实验任务与数据集:
    • 任务: 模拟自回归解码,每个前向传播过程生成一个token。
    • 工作负载: 序列长度覆盖512、1K、2K、4K、8K、16K、32K到64K,批大小(batch size)分别为16和32。
    • 评估指标: 推理性能,以TFLOPS/s为单位,结果为5次运行的平均值,以减少运行时波动。实验排除了其他无关优化,专注于核心注意力机制的比较。

实验结果

推理性能对比:
实验结果(如图1所示)表明,FlashMLA-ETAP在所有测试的序列长度和批处理大小上均表现出卓越的性能。
* 批大小为16时 (图1a): 在64K序列长度下,FlashMLA-ETAP达到89 TFLOPS/s的峰值性能。
* 相较于FlashMLA(32 TFLOPS/s),实现了2.78倍的加速。且加速比随序列长度增长而增大(从512长度的1.44倍增至64K的2.78倍)。
* 相较于FlashAttention-3(17 TFLOPS/s)和FlashInfer(18 TFLOPS/s),分别实现了5.24倍4.94倍的性能提升。
* 批大小为32时 (图1b): 在32K和64K序列长度下,FlashMLA-ETAP的峰值性能达到87 TFLOPS/s。
* 相较于FlashMLA(32 TFLOPS/s),实现了2.72倍的增益。
* 相较于FlashAttention-3(21 TFLOPS/s)和FlashInfer(23 TFLOPS/s),在64K序列长度下分别有4.14倍3.78倍的优势。

结论分析:
* FlashMLA-ETAP的性能优势在长序列(16K至64K)下最为显著,这验证了ETAP通过将KV上下文长度与WGMMA的M维度对齐,有效减少了填充开销。
* FlashAttention-3和FlashInfer的性能曲线较为平坦,表明它们针对高端GPU的优化策略在H20这样的中端硬件上因计算能力受限而表现不佳。
* 性能在序列长度超过32K后趋于平稳,这可能表明达到了计算饱和点,而非内存带宽瓶颈。

图1. 在NVIDIA H20 GPU上使用DeepSeek-R1模型(头数16,头维度576)的推理性能比较(TFLOPS/s)。(a) 批大小16。(b) 批大小32。
图1. 在NVIDIA H20 GPU上使用DeepSeek-R1模型(头数16,头维度576)的推理性能比较(TFLOPS/s)。(a) 批大小16。(b) 批大小32。

数值误差验证:
为了确保FlashMLA-ETAP的数值稳定性,我们将其FP16输出与双精度(FP64)参考实现的均方根误差(RMSE)进行了比较。
* FlashMLA-ETAP的RMSE: $1.25 \times 10^{-5}$
* FlashAttention-3的RMSE: $1.9 \times 10^{-4}$

如表1所示,FlashMLA-ETAP的RMSE比FlashAttention-3低约15.2倍。这一结果表明,尽管ETAP引入了转置等优化,并结合了FlashMLA的低秩压缩,但它依然保持了更高的数值精度,证明了其在要求精确度的推理任务中的可靠性。

表1. FlashAttention-3与FlashMLA-ETAP在FP16精度下的RMSE比较。

讨论:
实验结果有力地证实了FlashMLA-ETAP的有效性,其在64K序列长度下2.78倍的加速验证了ETAP是为H20的148 TFLOPS FP16计算能力量身定制的优化方案。随着序列长度增加,性能差距拉大,凸显了ETAP缓解中端硬件上填充开销瓶颈的能力。与FlashAttention-3和FlashInfer的显著性能差距表明,ETAP比那些为高端GPU优化的框架更有效地解决了H20的特定约束。此外,优异的数值精度(RMSE低15.2倍)增强了其在对精度敏感应用中的可靠性。

A5 结论

本文介绍了FlashMLA-ETAP,一个旨在显著提升在NVIDIA H20 GPU(其FP16计算能力受限于148 TFLOPS)上多头潜在注意力(MLA)推理效率的新型框架。通过提出高效转置注意力流水线(ETAP),该框架通过转置操作重新配置注意力计算,以减少填充开销。实验证明,在64K序列长度(批大小16)下,FlashMLA-ETAP相比FlashMLA实现了2.78倍的加速,并分别比FlashAttention-3和FlashInfer快5.24倍和4.94倍。此外,FlashMLA-ETAP的均方根误差(RMSE)比FlashAttention-3低15.2倍(为$1.25 \times 10^{-5}$),确保了在长上下文场景下的数值稳定性。ETAP的设计在理论上也与FlashAttention-3和FlashInfer兼容,显示了其在硬件感知推理优化中广泛应用的潜力。

局限性: 本研究的评估仅限于NVIDIA H20 GPU、特定的自回归解码范式(每步前向传播生成一个token)以及特定的模型配置(16个头,576头维度,最高64K序列长度)。这可能限制了其在不同GPU、推理模式或更大模型上的普适性。

未来工作: 未来的研究应致力于通过实验验证ETAP在FlashAttention-3和FlashInfer中的集成效果,为其他基于Hopper架构的中端GPU进行优化,并探索更广泛的序列长度和配置,以增强其可扩展性和适用性。