FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

Tri Dao¹ and Jay Shah²
¹ Together AI / Princeton University, tri@tridao.me
² Colfax Research, jayhshah@colfaxhttp://-intl.com

目录

挑战:为现代硬件(H100 GPU)优化 FlashAttention

FlashAttention-2 在 A100 GPU 上经过高度优化,达到了 70% 的利用率。然而,在 H100 GPU 上,FA-2 的利用率仅为 35-40%。下图展示了在 H100 80GB SXM5 上,随着序列长度的增加,PyTorch、FlashAttention 和 FlashAttention-2 在 Attention 前向传播速度上的对比。FlashAttention-2 的性能远未达到 H100 的理论峰值。

图表展示了在不同序列长度下,Pytorch、FlashAttention和FlashAttention-2的速度对比。
图表展示了在不同序列长度下,Pytorch、FlashAttention和FlashAttention-2的速度对比。

FlashAttention-3: 为 Hopper 架构优化 FlashAttention

FlashAttention-3 旨在解决上述挑战,主要通过以下三个方面进行优化,最终实现了 1.6-3倍 的加速。

  1. H100 上的新指令:

    • WGMMA (Warpgroup Matrix-Multiply-Accumulate): 一种更高吞吐量的 MMA 原语,由一个 warpgroup(= 4个连续的 warp)异步地、集体地执行。
    • TMA (Tensor Memory Accelerator): 加速从全局内存(gmem)到共享内存(smem)的数据加载,是异步的,并且能节省寄存器。
  2. 异步性 (Asynchrony):

    • 基于异步的 WGMMA 和 TMA 构建。
    • Warpgroup 间重叠: 采用 warp 专用化和乒乓调度(pingpong)。
    • Warpgroup 内重叠: 实现 softmax 和异步矩阵乘法的重叠。
  3. 低精度 (Low-precision):

    • FP8: 解决了布局整合(layout conformance)问题,并在核函数(in-kernel)内进行 V 矩阵的转置。

新指令:WGMMA & TMA

H100 架构引入的 WGMMA 和 TMA 是性能提升的关键。

  • WGMMA: 对于实现峰值吞吐量是必需的。传统的 mma.sync 指令只能达到峰值吞吞吐量的 2/3。
  • TMA: 能够加速全局内存到共享内存的数据复制,并节省寄存器使用。

一个关键特性是,WGMMA 和 TMA 都是异步指令:线程在发出这些指令后,可以继续执行其他工作,而指令则在后台执行。

图片对比了A100 FP16和H100 FP16的计算单元,并图示了TMA的工作原理。
图片对比了A100 FP16和H100 FP16的计算单元,并图示了TMA的工作原理。

Warp 专用化核函数设计:CTA 视角

WGMMA 和 TMA 被整合到一个生产者-消费者(producer-consumer)模型的 warp 专用化流水线设计中。

  • 生产者 (Producer): 负责从全局内存加载数据。
  • 消费者 (Consumer): 负责执行计算。

关键设计:
* 使用 warpgroup 范围的寄存器重分配,让消费者 warp 获得更多的寄存器份额。
* 使用共享内存(SMEM)双缓冲技术,将 K 和 V 的加载操作流水线化。
* 使用基于共享内存的屏障(SMEM-resident barriers)和线程本地的流水线状态对象来同步生产者和消费者。
* GEMM0 (Q * K^T) 为 SS 布局(Shared input, Shared output),GEMM1 (P * V) 为 RS 布局(Row input, Shared output)。

Warp专用化核函数设计的伪代码
Warp专用化核函数设计的伪代码

异步性:重叠 GEMM 和 Softmax

为何需要重叠?
* 专用函数单元(Special Function Units, SFU)的吞吐量远低于张量核心(Tensor Cores)。SFU 用于计算 softmax 中的指数函数(exp)。

示例:
* 假设 headdim 为 128,块大小为 128 x 192。
* FP16 WGMMA: 2 x 2 x 128 x 192 x 128 = 12.6 MFLOPS。在 4096 FLOPS/cycle 的速率下,需要 3072 个周期。
* MUFU.EX2 (用于 exp): 128 x 192 = 24.6k OPS。在 16 OPS/cycle 的速率下,需要 1536 个周期。

结论是,MUFU.EX2 的执行时间占到了 WGMMA50%。在 FP8 模式下情况更糟,两者都需要 1536 个周期。因此,我们希望在张量核心忙于 WGMMA 计算的同时,执行 EX2(softmax)操作。

Warpgroup 间 GEMM 与 Softmax 的重叠

一个简单的解决方案是依赖 warp 调度器自行处理,这在一定程度上有效,但我们可以做得更好。

通过使用同步屏障(bar.sync)实现乒乓调度(Pingpong scheduling),可以更主动地管理和重叠不同 warpgroup 的计算任务,从而提升性能。

  • 性能提升: 从 580 TFLOPS 提升到 640 TFLOPS。
Warpgroup间重叠示意图
Warpgroup间重叠示意图

Warpgroup 内 GEMM 与 Softmax 的重叠

在每个 warpgroup 内部,可以进一步利用 WGMMA 的异步特性。
* 将第 k 次迭代的 GEMM1 计算与第 k+1 次迭代的 softmax 计算进行重叠。
* 这种方法会使用更多寄存器,因为下一次 GEMM0 的累加器和当前 GEMM1 的操作数需要同时存在。

通过两阶段的 warpgroup 内重叠,性能可以进一步提升。
* 性能提升: 从 640 TFLOPS 提升到 670 TFLOPS。

Warpgroup内重叠示意图
Warpgroup内重叠示意图

结合重叠策略的消费者主循环

我们将乒乓调度warpgroup内重叠相结合,定义了 FA-3 的计算路径。

  • 使用 bar.syncbar.arrive 以及两个命名屏障来实现乒乓调度。
  • WGMMAX 涉及一组 wgmma.mma_async 指令,这些指令被作为一个组提交。
  • 为 GEMM1 的操作数 A 创建一个持久的、基于寄存器的张量,并在每次迭代中将在线 softmax 的降精度结果复制到其中。
消费者主循环的伪代码
消费者主循环的伪代码

低精度:FP8

FP8 张量核心(Tensor Cores)可以将 WGMMA 的吞吐量翻倍,但代价是精度的损失。

  • 相较于 A100 的 FP16,H100 的 FP8 提供了 6倍 的理论吞吐量。
A100 FP16 与 H100 FP8 计算单元对比
A100 FP16 与 H100 FP8 计算单元对比

FP8 Attention 与非相干处理 (Incoherent Processing)

为了解决 FP8 带来的精度损失问题,特别是异常值(outliers)导致的量化误差,可以采用非相干处理技术。
* 方法: 将 Q 和 K 乘以一个随机正交矩阵,以“分散”异常值。
* 原理: 对于正交矩阵 J (即 J * J^T = I),S = Q * K^T = (QJ)(KJ)^T
* 效果: 在包含 0.1% 大幅值条目(模拟异常值)的正态分布 QKV 数据上,该方法可将量化误差降低 2.6倍

下表对比了不同方法下的均方根误差(RMSE):
不同方法下的RMSE对比表格

FP8 的布局整合挑战 (1)

  • 挑战: FP8 WGMMA 要求共享内存中的操作数张量在内维度(k-major)上是内存连续的。而标准的 QKV 格式(BSHD)在头维度(head dimension)上是连续的。
  • 问题: 对于 GEMM0 (Q * K^T),K tile 的主次序是正确的。但对于 GEMM1 (P * V),V tile 需要进行转置。
  • 融合解决方案: 在生产者 warpgroup 中,于 TMA 加载之后,在核函数内部对 V 进行转置。这使用了带有自定义布局的 LDSM/STSM 指令和字节置换(byte permute)。
FP8布局挑战示意图
FP8布局挑战示意图

FP8 的布局整合挑战 (2)

  • 挑战: 我们还需要为 GEMM1 重塑得分累加器(scores accumulator)的布局。
  • 原因: FP32 累加器的布局与 FP8 操作数 A 的布局不同。
  • 下图展示了 FP32 累加器和 FP8 操作数 A 在 WGMMA 中的寄存器布局差异。
  • 注意: 可以在核函数内部使用“转置”将 V 的行置换写出,这样可以避免使用 shuffle 指令(但不能避免字节置换)来进行重塑。
FP32累加器与FP8操作数A的寄存器布局对比
FP32累加器与FP8操作数A的寄存器布局对比

FlashAttention-3 中的持久化核函数 (Persistent Kernels)

  • 动机: 对于较短的 KV 序列长度,核函数的收尾(epilogue)延迟成为总核函数时间的不可忽略部分。我们希望将收尾阶段的延迟隐藏在后续工作的开销(prologue)中。
  • 思路: 将物理 CTA(Cooperative Thread Arrays)与逻辑工作块(work tiles)解耦。启动与 SM(Streaming Multiprocessors)数量相等的固定数量的 CTA。这些 CTA 在处理多个工作块期间是持久的。

    • 利用 TMA 存储的异步性来重叠收尾和开销阶段。
    • 模仿了优化 GEMM 核函数中的相同设计。
  • 示例:

    • Seqlen = 4096, Heads = 8, Batch = 4Fix BlockM = 128so mblocks = 32
    • 总共有 32*8*4 = 1024 个工作块需要处理。
    • 在 H100 SXM5 GPU 上,FA3 核函数启动 132 个 CTA(等于 SM 的数量)。
    • 由于 [1024/132] = 8,每个 CTA 在其生命周期内将运行 7 或 8 个工作块。

CUTLASS 抽象与 Tile 调度器工作循环

CUTLASS 核函数是使用三个主要类组合构建的:

  1. CollectiveMainloop: 用于加载和 mma 计算。
  2. CollectiveEpilogue: 用于存储。
  3. TileScheduler: 用于管理工作循环。

由于这种组合式设计,将核函数更改为持久化核函数变得非常简单:只需编写一个不同的 TileScheduler 即可。同时,在加载和 mma 方法中也需要额外的屏障逻辑。

CUTLASS工作循环伪代码
CUTLASS工作循环伪代码

持久化内核(Persistent Kernels):隐藏前序/后序开销并平衡负载

在使用因果掩码(causal masking)时,工作块(work tiles)的主循环(mainloops)迭代次数不同。通过使用最长处理时间优先(Longest-processing-time-first, LPT)算法(Graham, 1969)来进行负载均衡。

这种优化将因果注意力(causal attention)的速度从 670 TFLOPS 提升到 710 TFLOPS。

Page 16:持久化内核的执行流程图
Page 16:持久化内核的执行流程图

解码推理优化

新旧方法对比

在解码(decoding)过程中,查询(query)长度很短(通常只有几个 token),而上下文(context)长度很长(例如,128k)。

Page 17:解码推理中查询、键和值的结构示意图
Page 17:解码推理中查询、键和值的结构示意图

从 FlashAttention-2 (FA-2) 开始,引入了 Flash Decoding:沿着 KV 序列长度进行拆分,以便为 GPU 提供足够的工作量来充分利用其计算能力。

GQA 打包(GQA Packing):为每个 KV 头计算多个查询头

WGMMA(Warp-Group Matrix Multiply Accumulate)块在 M 维度上的宽度是 64。对于短查询长度来说,这会造成浪费。然而,我们可以通过打包多个查询头来填满 WGMMA 块,这适用于 MQA/GQA(多查询注意力/分组查询注意力)。

Page 18:GQA 打包示意图,展示了未打包和打包的查询块
Page 18:GQA 打包示意图,展示了未打包和打包的查询块
  • FA-2 已经针对单个查询 token 的情况实现了这一点,这只是一个简单的重塑(reshape)操作。
  • FA-3 将此方法扩展到更复杂的任意查询长度的情况。
  • 这种方法也对某些受 tile 量化效应影响的计算密集型(compute-bound)场景有所裨益。

多头潜在注意力(Multi-head Latent Attention, MLA)

针对大头维度的 Warp 专职化

DeepSeek 的 MLA 具有很大的头维度(head dim),为 576 / 512。标准的拆分方法没有足够的寄存器。

Page 19:MLA 的 Warp 专职化示意图
Page 19:MLA 的 Warp 专职化示意图

为此,采用了 Warp 专职化(Warp specialization)策略:
- WG1 (工作组1):同时执行 QK 矩阵乘法和 PV 矩阵乘法。每个线程需要 160 个累加寄存器。
- WG2 (工作组2):仅执行 PV 矩阵乘法。每个线程需要 128 个累加寄存器。

MLA 性能

下图展示了 MLA 在批量大小为 128、查询头为 128 的解码速度测试(在 H100 80GB SXM5 上进行)。

Page 20:MLA解码速度对比图,比较了DeepSeek FlashMLA和FlashAttention-3
Page 20:MLA解码速度对比图,比较了DeepSeek FlashMLA和FlashAttention-3
  • 即使序列长度 seqlen_q = 1(即解码单个 token),也已经达到了计算密集型(compute-bound)的状态。

性能基准测试

BF16 性能基准测试:1.6-2.0 倍加速

下图展示了在 H100 80GB SXM5 上,头维度为 128 时的前向注意力计算速度。

Page 21:BF16 性能基准测试图,头维度为 128,分别在有无因果掩码的情况下进行对比
Page 21:BF16 性能基准测试图,头维度为 128,分别在有无因果掩码的情况下进行对比
  • 对比了 FlashAttention-3 与标准注意力(Standard attention)、FlashAttention-2、Triton 和 cuDNN。
  • 测试环境:CUDA tool kit 12.8, Triton 3.1, cuDNN 9.7。

BF16 性能基准测试:最高达到 850 TFLOPS

下图展示了在 H100 80GB SXM5 上,头维度为 256 时的前向注意力计算速度。

Page 22:BF16 性能基准测试图,头维度为 256,分别在有无因果掩码的情况下进行对比
Page 22:BF16 性能基准测试图,头维度为 256,分别在有无因果掩码的情况下进行对比
  • 在头维度增加到 256 时,FlashAttention-3 的性能最高可达到 850 TFLOPS,显著优于其他实现。
  • 测试环境:CUDA tool kit 12.8, Triton 3.1, cuDNN 9.7。

FP8 性能基准测试:最高达到 1.3 PFLOPS

下图展示了在 H100 80GB SXM5 上,头维度为 256 时的 FP8 前向注意力计算速度。

Page 23:FP8 性能基准测试图,头维度为 256,分别在有无因果掩码的情况下进行对比
Page 23:FP8 性能基准测试图,头维度为 256,分别在有无因果掩码的情况下进行对比
  • 使用 FP8 精度,FlashAttention-3 的性能最高可达到 1.3 PFLOPS。
  • 测试环境:CUDA tool kit 12.8, Triton 3.1, cuDNN 9.7。

MQA 的 BF16 解码基准测试:时间越低越好!

下图展示了在 H100 80GB PCIe 上,使用 BF16 精度、头维度为 128、MQA 16、查询序列长度为 4 的解码性能。

Page 24:MQA 的 BF16 解码性能对比图,比较了 FlashAttention-2 和 FlashAttention-3
Page 24:MQA 的 BF16 解码性能对比图,比较了 FlashAttention-2 和 FlashAttention-3
  • FlashAttention-3 在不同上下文长度和批量大小组合下的解码时间均显著低于 FlashAttention-2。

总结 - FlashAttention-3

  • 快速且精确的注意力机制,为现代硬件优化。
  • 关键算法思想:异步(asynchrony)低精度(low-precision)
  • 针对推理(inference)的优化:拆分 KV (Flash-Decoding)GQA 打包 (GQA packing)
  • 最终效果:更快的训练速度,以及能够支持更长序列的更好模型

代码:https://github.com/Dao-AILab/flash-attention