FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
文章标题:FlashAttention-3: 利用异步和低精度实现快速准确的注意力机制
作者/机构:Jay Shah∗1, Ganesh Bikshandi∗1, Ying Zhang2, Vijay Thakkar3,4, Pradeep Ramani3, and Tri Dao5,6 (1: Colfax Research, 2: Meta, 3: NVIDIA, 4: Georgia Tech, 5: Princeton University, 6: Together AI)
A1 主要贡献
本文的核心问题是,尽管FlashAttention-2通过最小化内存读写加速了GPU上的注意力计算,但在最新的H100 GPU上其利用率仅达到35%,未能充分利用新硬件的异步计算和低精度计算能力。为了解决这一问题,研究目标是重新设计FlashAttention算法,以充分挖掘Hopper架构GPU的潜力。
本文提出的FlashAttention-3包含并综合了三项主要的技术创新,以提升在新型GPU架构上的性能:
1. 生产者-消费者异步机制:通过定义一种经warp专业化的软件流水线方案,将数据的生产者和消费者拆分到不同的warp中。这利用了数据移动(通过TMA)和张量核心(Tensor Cores)计算的异步执行特性,从而增强了算法隐藏内存和指令分发延迟的能力。
2. 在异步块状GEMM下隐藏softmax计算:本文将吞吐量较低的非GEMM操作(如softmax中的浮点乘加和指数运算)与用于GEMM的异步WGMMA指令进行重叠。为此,作者重构了FlashAttention-2算法,以规避softmax和GEMM之间的某些顺序依赖。例如,在算法的2阶段版本中,当softmax在一个分数矩阵块上执行时,WGMMA可以在异步代理中计算下一个块。
3. 硬件加速的低精度GEMM:本文调整了前向传播算法,使其能够利用FP8张量核心进行GEMM计算,从而将实测的TFLOPs/s(每秒万亿次浮点运算数)几乎翻倍。这需要解决WGMMA在FP32累加器和FP8操作数矩阵内存布局方面的不同要求。作者采用了块量化(block quantization)和非相干处理(incoherent processing)技术,以减轻转向FP8精度所带来的精度损失。
A3 背景知识
2.1 多头注意力(Multi-Head Attention)
注意力机制的计算过程。给定查询Q、键K和值V矩阵(均为$R^{N_{seq} \times d_{head}}$),注意力输出O的计算公式如下:
其中,softmax是按行应用的,缩放因子$c$通常设为$1/\sqrt{d_{head}}$。为防止数值不稳定,实践中会从S中减去每行的最大值。对于多头注意力(MHA),每个头都有自己的Q, K, V投影,并且计算在多个头和批次之间并行进行。
反向传播的梯度计算。给定输出梯度dO,dQ、dK和dV的计算遵循链式法则,公式如下:
这里的dsoftmax(dP)表示将公式$d\mathbf{p} = (\text{diag}(\mathbf{p}) - \mathbf{p}\mathbf{p}^\top) d\mathbf{s}$(其中$\mathbf{p} = \text{softmax}(\mathbf{s})$)按行应用于dP。反向传播的计算同样在头和批次维度上并行。
2.2 GPU硬件特性与执行模型
内存层次结构。GPU内存被组织成一个层次结构,容量与带宽成反比。全局内存(GMEM/HBM)是所有流式多处理器(SM)均可访问的片外DRAM。数据从GMEM加载时会透明地缓存到片上的L2缓存中。每个SM包含一个由程序员管理的小型、高带宽的片上缓存,称为共享内存(SMEM)。最内层是每个SM内的寄存器文件(RMEM)。
线程层次结构。GPU的编程模型围绕线程的逻辑分组构建。从细到粗的层次结构包括:线程、线程束(warp,32个线程)、线程束组(warpgroup,4个连续的warp)、线程块(CTA)、线程块集群(Hopper架构中引入),以及网格(grid)。这两个层次结构紧密相连:同一CTA中的线程被共同调度到同一个SM上,SMEM可被该CTA内的所有线程直接寻址,而寄存器对每个线程是私有的。
异步性与warp专业化。GPU是吞吐量处理器,依赖并发和异步来隐藏内存和执行延迟。Hopper架构配备了张量内存加速器(TMA)作为专用硬件单元,用于GMEM和SMEM之间的异步内存拷贝【38, §7.29】。与Ampere等先前架构不同,Hopper的张量核心(通过warpgroup范围的WGMMA指令【40, §9.7.14】暴露)也是异步的,可以直接从共享内存获取输入。硬件对异步的支持催生了warp专业化的核函数,其中CTA的warps被划分为生产者(只发出数据移动指令)或消费者(只发出计算指令)角色。这种划分能帮助编译器生成最优的指令调度【4, CudaDMA: Optimizing GPU Memory Bandwidth via Warp Specialization, 2011, SC】。此外,Hopper支持通过setmaxnreg
指令【40, §9.7.17.1】在warpgroups之间动态重新分配寄存器,使执行矩阵乘法(MMA)的warps能获得比仅发出TMA指令的warps更多的寄存器。
低精度数值格式。现代GPU拥有专门的硬件单元来加速低精度计算。例如,WGMMA指令可以调用Hopper上的FP8张量核心,与FP16或BF16相比,每个SM的吞吐量可提升2倍。然而,正确调用FP8 WGMMA需要理解其操作数的布局约束。对于一个$M \times K$矩阵A和$N \times K$矩阵B的GEMM调用(计算$A \times B^\top$),如果操作数在外部M或N维度上是连续的,则称为mn-major;如果在内部K维度上是连续的,则称为k-major。FP16 WGMMA接受SMEM中mn-major和k-major两种格式的输入,但FP8 WGMMA只支持k-major格式。这在需要融合背靠背GEMM的注意力计算中,会因FP32累加器和FP8操作数布局的冲突而构成障碍。这些布局限制要求对FP8算法的设计进行特定修改。
2.3 标准注意力与FlashAttention
两种注意力实现方式的对比。遵循Dao等人【17, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022, NeurIPS】的定义,标准注意力指在GPU上实现时,会将中间矩阵S和P物化到HBM中的方法。而FlashAttention的核心思想是利用softmax归约的局部性,避免这些昂贵的中间读写,并将所有注意力操作融合成一个单一的GPU核函数。局部softmax的计算对应于算法1中消费者主循环的第18-19行,以及对O块的重新缩放。
A2 方法细节
本节详细描述FlashAttention-3算法,重点关注前向传播,反向传播算法在附录B.1中描述。
3.1 通过warp专业化和乒乓调度的生产者-消费者异步机制
Warp专业化。与FlashAttention-2类似,FlashAttention-3的前向传播在批次大小、头数和查询序列长度上是易于并行的。因此,我们仅需从一个CTA(线程块)的视角来描述算法,该CTA处理查询矩阵的一个瓦片(tile)$Q_i$以计算对应的输出瓦片$O_i$。为了简化描述,我们首先给出一个带有循环SMEM缓冲区的warp专业化方案,该方案尚未加入GEMM-softmax重叠。算法流程(Algorithm 1)如下:
乒乓调度(Pingpong scheduling)。WGMMA和TMA的异步特性,结合warp专业化,为重叠softmax计算和GEMM计算提供了机会。其动机在于,非矩阵乘法操作(如softmax中的指数运算)在现代硬件上的吞吐量远低于矩阵乘法。例如,H100 SXM5 GPU的FP16矩阵乘法吞吐量为989 TFLOPS,而特殊函数(如指数)仅为3.9 TFLOPS。对于头维度为128的FP16注意力前向传播,虽然矩阵乘法的FLOPS是指数运算的512倍,但指数运算的吞吐量低256倍,导致指数运算可能占用高达50%的计算周期。为了解决这个问题,我们利用同步屏障(bar.sync
指令)来强制调度warpgroup 1的GEMM操作(一次迭代的PV矩阵乘和下一次迭代的$QK^\top$矩阵乘)在warpgroup 2的GEMM操作之前。这样一来,warpgroup 1的softmax计算就会在warpgroup 2执行GEMM时被调度。随后角色互换,形成“乒乓”调度。这种方法能显著提升性能(例如,对于FP16前向传播,头维度128,序列长度8192,性能从570 TFLOPS提升到620-640 TFLOPS)。
注意力变体。对于多查询注意力(MQA)【51, Fast transformer decoding: One write-head is all you need, 2019, arXiv】和分组查询注意力(GQA)【3, Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023, arXiv】,我们遵循FlashAttention-2的方法,通过调整张量索引来避免在HBM中复制K和V。
3.2 Warpgroup内部重叠GEMM和softmax
两阶段GEMM-softmax流水线。即使在单个warpgroup内部,也可以重叠softmax和GEMM的部分指令。注意力算法内循环中的操作存在顺序依赖,例如softmax依赖于第一个GEMM($QK^\top$)的输出,而第二个GEMM($PV$)又依赖于softmax的结果。为了打破这种依赖,我们通过在寄存器中增加额外的缓冲区,在迭代之间构建流水线。我们提出了一个两阶段的GEMM-softmax流水线算法(Algorithm 2),其核心思想是在主循环(第8-16行)中,将第j
次迭代的第二个WGMMA操作(第11行,计算$O_i = O_i + \tilde{P}_{\text{cur}} V_{j-1}$)与第j+1
次迭代的softmax操作(第13行,基于$S_{\text{next}}$计算新的统计量)进行重叠。该算法取代了Algorithm 1中的消费者路径,构成了FP16精度的完整FlashAttention-3算法。
实践中的考虑因素:
- 编译器重排:伪代码表示了理想的执行顺序,但编译器(NVCC)可能会为了优化而重排指令。对SASS代码的分析表明,编译器确实按预期生成了重叠的代码(详见附录B.2)。
- 寄存器压力:为了维持高性能,应尽量减少寄存器溢出。两阶段流水线需要额外的寄存器来存储中间结果(如一个额外的$S_{\text{next}}$),这增加了寄存器需求,可能与使用更大块大小(同样需要大量寄存器)的优化策略相冲突。实践中需要根据性能分析结果进行权衡。
- 三阶段流水线:我们还提出了一个三阶段流水线的变体,旨在进一步重叠第二个WGMMA和softmax。虽然理论上能获得更高的张量核心利用率,但它需要更多寄存器,使得在瓦片大小和流水线深度之间的权衡更加困难。附录B.3提供了详细描述和评估结果。
3.3 使用FP8实现低精度计算
效率:布局转换。在FP8精度下计算FlashAttention-3的前向传播,在布局一致性方面带来了FP16中未遇到的挑战。
- 第一个挑战:V矩阵的布局。输入张量Q、K、V通常在头维度上是连续的,但为了满足第二个GEMM的FP8 WGMMA的k-major约束,加载到SMEM中的V瓦片需要在序列长度维度上是连续的。我们有两种选择:(1)在GMEM中对V进行预转置,或(2)在核函数内部,将V的瓦片加载到SMEM后再进行转置。选项(1a)融合到前置操作中难以集成到标准库,而(1b)独立的转置核在内存受限的推理场景中开销过大。因此,我们选择在核函数内进行转置(选项2)。我们利用LDSM/STSM指令,这些指令能让一个warp集体地在SMEM和RMEM之间高效地加载/存储数据,并且在内存拷贝时能够进行布局转置。此转置操作可以安排在涉及前一个V瓦片和当前K瓦片的两个WGMMA的计算阴影下执行。
- 第二个挑战:累加器与操作数布局不匹配。与FP16不同,FP8 WGMMA的FP32累加器在寄存器中的内存布局与其操作数A的布局不同。图3和图4展示了这两种布局的片段。我们通过使用字节置换指令,将第一个WGMMA的累加器(即P矩阵的瓦片)转换为适合第二个WGMMA的格式。这个操作在逻辑上置换了P瓦片的列。为了计算出正确的结果,我们可以相应地安排核函数内的转置操作,输出一个行序与之匹配的V瓦片。
精度:块量化与非相干处理。FP8(e4m3)格式仅用3位存储尾数,4位存储指数,导致比FP16/BF16更高的数值误差,尤其是在处理大模型中常见的异常值时。为降低FP8注意力的数值误差,我们采用了两种技术:
1. 块量化(Block quantization):我们为每个数据块保留一个缩放标量,即将Q、K、V张量划分为$B_r \times d$或$B_c \times d$大小的块并分别量化。这种量化可以无开销地融合到注意力之前的一个内存带宽受限的操作中(如旋转位置编码)。FlashAttention-3算法自然地在块上操作,因此可以无计算成本地对每个S块进行缩放以适应块量化。
2. 非相干处理(Incoherent processing):为了平均异常值,我们在量化到FP8之前,将Q和K乘以一个随机正交矩阵M。由于M是正交的($MM^\top = I$),$(QM)(KM)^\top = QK^\top$,因此这个操作不改变注意力输出。这起到了“摊平”异常值的作用,因为QM或KM的每个元素都是Q或K元素的随机和,从而减少了量化误差。实践中,我们遵循Chee等人【9, Quip: 2-bit quantization of large language models with guarantees, 2024, NeurIPS】和Tseng等人【58, Quip#: Even better llm quantization with hadamard incoherence and lattice codebooks, 2024, arXiv】的方法,选择M为随机对角矩阵(元素为±1)和哈达玛矩阵的乘积,这使得矩阵乘法复杂度从$O(d^2)$降至$O(d \log d)$,并且同样可以无额外计算开销地融合到旋转位置编码中。
A4 实验环境
- 硬件配置:实验在一台配备H100 80GB SXM5 GPU(700W)的服务器上进行。为了减少可变性,GPU时钟速度固定为1830MHz。
- 软件配置:
- CUDA 12.3
- cuDNN 9.1.1.17
- CUTLASS 3.5
- FlashAttention 2.5.8
- Triton nightly 3.0.0.post20240424212437
- PyTorch 2.3.0
- 模型/数据集配置:
- 实验为微基准测试,不涉及特定数据集。
- 测试参数:序列长度从512到16k变化,批大小动态设置以使总token数为16k。隐藏维度设为2048,头维度为64、128或256。
- 性能指标FLOPs的计算方式:前向传播为
4 * seqlen^2 * head_dim * num_heads
,带因果掩码时减半。反向传播FLOPs为前向的2.5倍。
A4 实验结果
基准测试
-
FP16/BF16性能(图5和图6):
- 前向传播:FlashAttention-3比FlashAttention-2快1.5-2.0倍,比Triton实现的FlashAttention-2快1.5倍。最高性能达到740 TFLOPs/s,是H100 GPU理论峰值的75%。对于中长序列(1k及以上),FlashAttention-3甚至超过了NVIDIA为H100优化的cuDNN库。
- 反向传播:FlashAttention-3比FlashAttention-2快1.5-1.75倍。相比标准注意力实现,速度提升可达3-16倍。
图5:H100 GPU上的注意力前向传播速度(FP16/BF16)
图6:H100 GPU上的注意力反向传播速度(FP16/BF16)
-
FP8性能(图7):
- 前向传播:FlashAttention-3在FP8精度下性能接近1.2 PFLOPs/s。
图7:H100 GPU上的注意力前向传播速度(FP8)
- 前向传播:FlashAttention-3在FP8精度下性能接近1.2 PFLOPs/s。
消融研究
- 2阶段流水线的效果(表2):实验证实了算法改进的有效性。在固定参数下,基线版本(仅使用Hopper指令)性能为570 TFLOPS。引入warp专业化后提升至621 TFLOPS。再加入2阶段WGMMA-softmax流水线后,性能进一步提升至661 TFLOPS,证明了异步和重叠设计带来了显著的加速。
表2:流水线消融实验测量结果
数值误差验证
- 实验设置:为了模拟大语言模型中的异常值特征,Q, K, V的生成分布中加入了0.1%的较大扰动项。所有方法的输出与FP64参考实现进行比较,计算均方根误差(RMSE)。
- 结果分析(表3):
- FP16精度:FlashAttention-2和FlashAttention-3的RMSE比标准实现低1.7倍,因为它们的中间结果(如softmax)保持在FP32精度。
- FP8精度:得益于块量化和非相干处理,FP8 FlashAttention-3的精度比使用每张量缩放的基线FP8注意力实现高2.6倍。
表3:FP16和FP8(e4m3)的数值误差比较
A5 结论
通过FlashAttention-3,本文展示了新的编程技术、硬件特性(如异步和低精度)能够显著提升注意力机制的效率和准确性。与FlashAttention-2相比,FlashAttention-3将注意力计算速度提升了1.5-2.0倍,并将FP8数值误差降低了2.6倍。
局限性与未来工作:
- 优化LLM推理场景。
- 为FP8核函数集成持久化核(persistent kernel)设计。
- 深入理解低精度注意力在大规模训练中的影响。
尽管本文聚焦于Hopper GPU,但所开发的技术有望应用于其他硬件加速器。一个更快、更准确的注意力原语将为长上下文任务解锁新的应用。
A6 附录
B.1 反向传播中的异步warp专业化
算法描述。与前向传播类似,反向传播也使用warp专业化来处理异步性。除了生产者和消费者角色外,还增加了一个dQ写入者(dQ writer)角色。这是因为需要将每个线程块计算出的dQ值累加到全局的dQ中,这会引入内存争用。设置一个专门的warp来处理dQ的原子加操作,可以避免阻塞其他warp执行后续的矩阵乘法计算。Algorithm 3展示了带有warp专业化的反向传播过程。其主要流程为:预处理计算D = rowsum(dO ◦ O)
;生产者warp负责异步加载Q、dO的块;消费者warp负责核心计算,包括多个GEMM来计算dP, dS, dV, dK以及局部的dQ,并将dQ写入smem;dQ写入者warp则负责将smem中的局部dQ原子地累加到全局内存的dQ中。
B.2 2阶段流水线SASS分析
代码分析。
- Softmax相关的指令被重排到循环体的最开始,甚至在第一个WGMMA之前。
- 第一个WGMMA的指令与softmax和FP32->FP16数据类型转换的指令交错执行,表明WGMMA和非WGMMA操作确实在并行执行。
- 指数计算、行求和、O的重缩放以及FP32->FP16转换指令本身也是交错的。
- 第二个WGMMA没有与其他指令重叠,这符合算法设计预期。
总体而言,SASS代码证实了2阶段流水线的设想按预期工作。
B.3 3阶段流水线算法
算法设计与问题。我们尝试了一个3阶段流水线算法,旨在并行化来自第j+2
次迭代的第一个WGMMA、来自第j+1
次迭代的softmax以及来自第j
次迭代的第二个WGMMA。Algorithm 4描述了这个算法。然而,实验表明其性能劣于2阶段流水线,原因如下:
- 重叠效果不佳:SASS代码显示,编译器只将第一个WGMMA与softmax重叠,而第二个WGMMA没有被重叠。编译器选择这样重排指令的原因尚不明确。
- 寄存器压力:该算法需要比2阶段算法更多的寄存器(需要额外存储一个$\tilde{P}_i$和$m_{rescale\_i}$),这导致必须选择更小的块大小,从而限制了性能。
C.2 FP8注意力完整结果
基准测试设置。FP8基准测试使用的序列长度为:512, 1024, 2048, 4224, 8448, 16896。当序列长度≥4k时,我们使其也能被132(H100 SXM5的SM数量)整除,以避免波前量化(wave quantization)效应。
💬 评论讨论
欢迎在这里分享您的想法和见解!