文章标题:FlashAttention:具有IO感知能力、快速且内存高效的精确注意力机制
作者/机构:Tri Dao†, Daniel Y. Fu†, Stefano Ermon†, Atri Rudra‡, and Christopher Ré† (†斯坦福大学计算机科学系, ‡纽约州立大学布法罗分校计算机科学与工程系)

A1 主要贡献

本文旨在解决Transformer模型在处理长序列时速度慢、内存消耗大的核心问题,该问题源于自注意力机制的时间和内存复杂度与序列长度成二次方关系。许多现有的近似注意力方法虽然降低了计算复杂度,但由于忽略了内存访问(IO)开销,并未实现实际的运行时间加速。

本文认为,一个被忽略的关键原则是使注意力算法具备IO感知能力 (IO-aware),即精确地考虑GPU不同内存层级(如高速的片上SRAM和相对较慢的HBM)之间的数据读写。基于此,本文提出了一种名为 FlashAttention 的新算法,它是一种IO感知的精确注意力算法,其核心贡献和创新点如下:

  1. 提出IO感知的精确注意力算法FlashAttention:该算法通过分块(tiling)技术,将输入Q、K、V矩阵分割成块,从HBM加载到SRAM中进行计算,从而减少了GPU高带宽内存(HBM)和片上SRAM之间的内存读写次数。这避免了在HBM中实例化和存取巨大的 N×N 注意力矩阵。

  2. 实现技术

    • 分块计算:通过分块并在输入块上进行多次传递,增量式地执行Softmax计算,而无需一次性访问整个输入。
    • 重计算:在反向传播中,不从HBM读取为反向传播存储的巨大中间注意力矩阵,而是利用前向传播中保存的Softmax归一化因子,在SRAM上快速重新计算注意力矩阵。这虽然增加了浮点运算次数(FLOPs),但由于大幅减少了HBM访问,反而加速了反向传播。
  3. 理论分析与最优性证明

    • 本文分析了FlashAttention的IO复杂度,证明其需要 $O(N^2d^2M^{-1})$ 次HBM访问(其中N为序列长度,d为头维度,M为SRAM大小),而标准注意力需要 $Ω(Nd + N^2)$ 次。对于典型硬件配置,FlashAttention的HBM访问次数显著减少。
    • 本文还提供了一个IO复杂度的下界,证明在所有SRAM大小范围内,没有精确注意力算法能渐进地超越FlashAttention的HBM访问次数。
  4. 扩展至块稀疏注意力:将FlashAttention扩展为块稀疏FlashAttention,这是一种比现有任何近似注意力方法都更快的近似注意力算法,其IO复杂度比FlashAttention更优,改进程度与稀疏度成正比。

  5. 显著的实践效果

    • 更快的模型训练:在BERT-large(序列长度512)上比MLPerf 1.1的训练速度记录快15%;在GPT-2(序列长度1K)上提速3倍;在长距离竞技场(LRA)基准测试(序列长度1K-4K)上提速2.4倍。
    • 更高质量的模型:通过支持更长的上下文,提升了模型质量。在GPT-2上困惑度降低0.7,在长文档分类任务上性能提升6.4个点。
    • 实现全新能力:使Transformer首次在Path-X(序列长度16K)和Path-256(序列长度64K)挑战上取得超越随机猜测的性能。

A3 背景知识

2.1 硬件性能

GPU内存层级。GPU的内存层级(图1左)由多种不同大小和速度的内存组成,其中较小的内存速度更快。例如,A100 GPU拥有40-80GB的高带宽内存(HBM),带宽为1.5-2.0TB/s,同时每个流式多处理器(共108个)拥有192KB的片上SRAM,其带宽估计约为19TB/s【44, Zhe Jia and Peter Van Sandt, Dissecting the Ampere GPU architecture via microbenchmarking, 2021, GPU Technology Conference】【45, Zhe Jia et al., Dissecting the nvidia Volta GPU architecture via microbenchmarking, 2018, arXiv】。片上SRAM比HBM快一个数量级,但容量小好几个数量级。由于计算速度的增长超过了内存速度【61, NVIDIA, Nvidia Tesla V100 GPU architecture, 2017】【62, NVIDIA, Nvidia A100 tensor core GPU architecture, 2020】【63, NVIDIA, Nvidia H100 tensor core GPU architecture, 2022】,越来越多的操作受限于内存(HBM)访问。因此,利用高速SRAM变得越来越重要。

执行模型。GPU拥有海量线程来执行一个操作(称为核函数,kernel)。每个核函数从HBM加载输入到寄存器和SRAM,进行计算,然后将输出写回HBM。

性能特征。根据计算和内存访问的平衡,操作可分为计算密集型或内存密集型。这通常通过算术强度(arithmetic intensity)【85, Samuel Williams et al., Roofline: an insightful visual performance model for multicore architectures, 2009, Communications of the ACM】,即每字节内存访问所对应的算术操作次数来衡量。1. 计算密集型:操作时间由算术操作的数量决定,而访问HBM的时间要小得多。典型例子是内积维度大的矩阵乘法和通道数多的卷积。2. 内存密集型:操作时间由内存访问次数决定,而计算时间要小得多。例子包括大多数其他操作:逐元素操作(如激活、dropout)和归约操作(如求和、softmax、批归一化、层归一化)。

核函数融合(Kernel fusion)。加速内存密集型操作最常用的方法是核函数融合:如果多个操作应用于同一输入,可以将输入从HBM加载一次,而不是为每个操作加载多次。编译器可以自动融合许多逐元素操作【53, Mingzhen Li et al., The deep learning compiler: A comprehensive survey, 2020, IEEE Transactions on Parallel and Distributed Systems】【65, Adam Paszke et al., Pytorch: An imperative style, highperformance deep learning library, 2019, Advances in neural information processing systems】【75, Amit Sabne, XLA: Compiling machine learning for peak performance, 2020】。然而,在模型训练的背景下,中间值仍需写入HBM以便为反向传播保存,这降低了朴素核函数融合的有效性。

2.2 标准注意力实现

注意力计算公式。给定输入序列 Q, K, V ∈ RN×d,其中 N 是序列长度,d 是头维度,我们希望计算注意力输出 O ∈ RN×d


其中softmax是逐行应用的。

标准实现的问题。标准的注意力实现会将矩阵S和P物化到HBM中,这需要 O(N2) 的内存。通常 N >> d(例如,对于GPT2,N=1024,d=64)。我们在算法0中描述了标准的注意力实现。由于其中一些或大部分操作是内存密集型的(例如softmax),大量的内存访问导致了较慢的运行时间。这个问题因应用于注意力矩阵的其他逐元素操作(如应用于S的掩码或应用于P的dropout)而加剧。因此,已有许多尝试融合多个逐元素操作,例如将掩码与softmax融合【77, Mohammad Shoeybi et al., Megatron-LM: Training multi-billion parameter language models using model parallelism, 2019, arXiv】。

算法0:标准注意力实现
- 输入: 矩阵 Q, K, V ∈ RN×d 位于HBM中。
- 步骤:
1. 从HBM分块加载Q, K,计算 S = QKT,将S写入HBM。
2. 从HBM读取S,计算 P = softmax(S),将P写入HBM。
3. 从HBM分块加载P和V,计算 O = PV,将O写入HBM。
- 返回: O。

在3.2节中,我们将展示标准注意力实现的HBM访问次数与序列长度N成二次方关系。我们还将比较标准注意力与我们的方法(FlashAttention)的FLOPs数量和HBM访问次数。

A2 方法细节

3.1 一种利用分块和重计算的高效注意力算法

核心目标。给定位于HBM的输入Q, K, V ∈ RN×d,我们的目标是计算注意力输出O ∈ RN×d并将其写回HBM,同时减少HBM的访问量(使其在N上为亚二次方)。我们应用了两种成熟的技术(分块和重计算)来克服在亚二次方HBM访问下计算精确注意力的技术挑战,详见算法1。核心思想是将输入Q, K, V分割成块,从慢速的HBM加载到快速的SRAM,然后计算这些块的注意力输出。通过在累加前用正确的归一化因子缩放每个块的输出,我们最终能得到正确的结果。


图1:左图:FlashAttention 使用分块技术来避免在(相对)慢速的 GPU HBM 上物化庞大的 N × N 注意力矩阵(虚线框)。在外层循环中(红色箭头),FlashAttention 遍历 K 和 V 矩阵的块,并将它们加载到快速的片上 SRAM。在每个块中,FlashAttention 遍历 Q 矩阵的块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。右图:在 GPT-2 上,相较于 PyTorch 的注意力实现,FlashAttention 的加速效果。FlashAttention 不会将庞大的 N × N 注意力矩阵读写到 HBM,从而在注意力计算上实现了 7.6 倍的加速。

分块(Tiling)。我们按块计算注意力。由于Softmax操作耦合了K的列,我们采用带缩放的大型Softmax分解方法【51, Nikita Kitaev et al., Reformer: The efficient transformer, 2020, The International Conference on Machine Learning (ICML)】【60, Maxim Milakov and Natalia Gimelshein, Online normalizer calculation for softmax, 2018, arXiv】【66, Markus N Rabe and Charles Staats, Self-attention does not need O(n^2) memory, 2021, arXiv】。为了数值稳定性,向量 $x \in R^B$ 的softmax计算如下:


对于向量 $x^{(1)}, x^{(2)} \in R^B$,我们可以将拼接后向量 $x = [x^{(1)} x^{(2)}] \in R^{2B}$ 的softmax分解为:

因此,如果我们跟踪一些额外的统计量($m(x), \ell(x)$),我们就可以逐块计算softmax。我们因此将输入Q, K, V分割成块(算法1第3行),计算softmax值以及额外的统计量(算法1第10行),并合并结果(算法1第12行)。

重计算(Recomputation)。我们的目标之一是不为反向传播存储 $O(N^2)$ 的中间值。反向传播通常需要矩阵S, P ∈ RN×d来计算关于Q, K, V的梯度。然而,通过存储输出O和softmax归一化统计量(m, ℓ),我们可以很方便地在反向传播过程中,从SRAM中的Q, K, V块重新计算注意力矩阵S和P。这可以看作是一种选择性梯度检查点技术【10, Tianqi Chen et al., Training deep nets with sublinear memory cost, 2016, arXiv】【34, Andreas Griewank and Andrea Walther, Evaluating derivatives: principles and techniques of algorithmic differentiation, 2008, SIAM】。虽然梯度检查点技术被提出用于减少所需的最大内存量【66, Markus N Rabe and Charles Staats, Self-attention does not need O(n^2) memory, 2021, arXiv】,但据我们所知,所有实现都不得不用速度换取内存。相比之下,我们的重计算方法即使有更多的FLOPs,由于减少了HBM访问,也能加速反向传播(图2)。完整的反向传播描述在附录B中。

实现细节:核函数融合(Kernel fusion)。分块使我们能够在一个CUDA核函数中实现我们的算法:从HBM加载输入,执行所有的计算步骤(矩阵乘法、softmax、可选的掩码和dropout、矩阵乘法),然后将结果写回HBM(掩码和dropout的细节在附录B)。这避免了从HBM重复读写输入和输出。

算法1:FlashAttention

定理1。算法1返回 $O = \text{softmax}(QK^T)V$,其浮点运算次数(FLOPs)为 $O(N^2d)$,并且除了输入和输出之外,需要 $O(N)$ 的额外内存。

3.2 分析:FlashAttention的IO复杂度

IO复杂度分析。我们分析了FlashAttention的IO复杂度,结果显示它相比标准注意力显著减少了HBM访问。我们还提供了一个下界,证明没有精确注意力算法能在所有SRAM大小上渐进地改进HBM访问次数。证明见附录C。

定理2。设N为序列长度,d为头维度,M为SRAM大小,且 $d \le M \le Nd$。标准注意力(算法0)需要 $Θ(Nd + N^2)$ 次HBM访问,而FlashAttention(算法1)需要 $Θ(N^2d^2M^{-1})$ 次HBM访问。对于典型的d(64-128)和M(约100KB),$d^2$ 远小于M,因此FlashAttention需要的HBM访问次数比标准实现少很多倍。这带来了更快的执行速度和更低的内存占用,我们将在第4.3节中验证这一点。

证明思路。证明的主要思想是,给定大小为M的SRAM,我们可以加载大小为 $Θ(M)$ 的K和V块(算法1第6行)。对于每个K和V块,我们遍历所有Q的块(算法1第8行)来计算中间值,这导致对Q进行 $Θ(NdM^{-1})$ 次遍历。每次遍历加载 $Θ(Nd)$ 个元素,总共产生 $Θ(N^2d^2M^{-1})$ 次HBM访问。我们同样证明了标准注意力的反向传播需要 $Θ(Nd + N^2)$ 次HBM访问,而FlashAttention的反向传播需要 $Θ(N^2d^2M^{-1})$ 次HBM访问(附录B)。

IO复杂度下界。我们证明了一个下界:在计算精确注意力时,对于所有SRAM大小M的值,不可能在HBM访问次数上实现渐进式改进。

命题3。设N为序列长度,d为头维度,M为SRAM大小,且 $d \le M \le Nd$。不存在一个算法能够在 $[d, Nd]$ 范围内的所有M值上,以 $o(N^2d^2M^{-1})$ 的HBM访问次数计算精确注意力。证明依赖于这样一个事实:当 $M = Θ(Nd)$ 时,任何算法都必须执行 $Ω(N^2d^2M^{-1}) = Ω(Nd)$ 次HBM访问。这种在M的子区间上的下界在流式算法文献中很常见【88, David P Woodruff, Optimal space lower bounds for all frequency moments, 2004, SODA】。

实验验证。我们验证了HBM访问次数是决定注意力运行时的主要因素。在图2(左)中,我们看到尽管FlashAttention的FLOPs计数高于标准注意力(由于反向传播中的重计算),但它的HBM访问次数少得多,从而导致更快的运行时间。在图2(中),我们改变FlashAttention的块大小 $B_c$,这导致了不同数量的HBM访问,并测量了前向传播的运行时间。随着块大小增加,HBM访问次数减少(因为我们对输入的遍历次数减少),运行时间也随之减少。当块大小足够大(超过256)时,运行时间则受到其他因素(如算术运算)的瓶颈。此外,更大的块大小将无法装入有限的SRAM中。


图2:左图:在 A100 GPU 上,GPT-2 medium(序列长度 1024,头维度 64,16 个头,批量大小 64)的标准注意力和 FlashAttention 的前向+后向运行时间。HBM 访问是影响运行时的主要因素。中图:在 A100 GPU 上,FlashAttention 的前向运行时间(序列长度 1024,头维度 64,16 个头,批量大小 64)。更少的 HBM 访问导致更快的运行时间,直到某个点为止。右图:块稀疏 FlashAttention 的运行时间(序列长度 4K)比 FlashAttention 更快,其速度提升与稀疏度成正比。

3.3 扩展:块稀疏FlashAttention

块稀疏FlashAttention。我们将FlashAttention扩展到近似注意力:我们提出了块稀疏FlashAttention,其IO复杂度比FlashAttention小,其减小因子与稀疏度成正比。给定输入Q, K, V ∈ RN×d和一个掩码矩阵 $\tilde{M} \in \{0, 1\}^{N \times N}$,我们希望计算:


其中 $(S \odot \tilde{M})_{ij} = S_{ij}$ 如果 $\tilde{M}_{ij} = 1$,否则为 $-\infty$。我们要求 $\tilde{M}$ 具有块形式:对于某个块大小 $B_r, B_c$,对所有的 $i, j$,$\tilde{M}_{i,j} = M_{\lfloor i/B_r \rfloor, \lfloor j/B_c \rfloor}$,其中 $M \in \{0, 1\}^{N/B_r \times N/B_c}$。给定一个预定义的块稀疏掩码 $M \in \{0, 1\}^{N/B_r \times N/B_c}$,我们可以轻松地调整算法1,只计算注意力矩阵的非零块。该算法与算法1相同,只是跳过了零块。我们在附录B的算法5中重现了该算法的描述。

IO复杂度分析。我们还分析了块稀疏FlashAttention的IO复杂度。

命题4。设N为序列长度,d为头维度,M为SRAM大小,且 $d \le M \le Nd$。块稀疏FlashAttention(算法5)需要 $Θ(Nd + N^2d^2M^{-1}s)$ 次HBM访问,其中s是块稀疏掩码中非零块的比例。我们看到,应用块稀疏性对IO复杂度中较大的项带来了与稀疏度成正比的直接改进。对于大的序列长度N,s通常设置为 $N^{-1/2}$【11, Rewon Child et al., Generating long sequences with sparse transformers, 2019, arXiv】或 $N^{-1}\log N$【3, Iz Beltagy et al., Longformer: The long-document transformer, 2020, arXiv】【17, Tri Dao et al., Pixelated butterfly: Simple and efficient sparse training for neural network models, 2022, ICLR】【92, Manzil Zaheer et al., Big bird: Transformers for longer sequences, 2020, Advances in Neural Information Processing Systems】,这导致了 $Θ(N\sqrt{N})$ 或 $Θ(N \log N)$ 的IO复杂度。对于下游实验,我们使用了固定的蝶形稀疏模式(fixed butterfly sparsity pattern)【17, Tri Dao et al., Pixelated butterfly: Simple and efficient sparse training for neural network models, 2022, ICLR】,该模式已被证明能够近似任意稀疏模式【16, Tri Dao et al., Kaleidoscope: An efficient, learnable representation for all structured linear maps, 2020, ICLR】。

实验验证。在图2(右)中,我们验证了随着稀疏度的增加,块稀疏FlashAttention的运行时间成比例地改善。在LRA基准测试中,块稀疏FlashAttention实现了2.8倍的加速,同时性能与标准注意力相当(第4节)。

A4 实验环境

  • 数据集
    • BERT: Wikipedia数据集,用于预训练,遵循MLPerf 1.1的设定。
    • GPT-2: OpenWebtext数据集,用于语言建模。
    • LRA (Long-range Arena): 一个包含多个任务的基准测试,序列长度从1024到4096不等,用于评估长序列处理能力。
    • 长文档分类: MIMIC-III (重症监护病房出院小结) 和 ECtHR (欧洲人权法院案例)。
    • 长程挑战: Path-X (序列长度16K) 和 Path-256 (序列长度64K),用于测试模型处理超长上下文的能力。
  • 模型架构
    • BERT-large: 遵循【22, Jacob Devlin et al., BERT: Pre-training of deep bidirectional transformers for language understanding, 2019】的配置。
    • GPT-2: small和medium两种尺寸,遵循【67, Alec Radford et al., Language models are unsupervised multitask learners, 2019, OpenAI blog】的配置。
    • RoBERTa: 作为长文档分类的预训练模型骨干。
  • 硬件配置
    • 训练: 主要在8xA100-80GB或8xA100-40GB GPU节点上进行。
    • 基准测试: 在单张A100-40GB GPU上进行,同时也在RTX 3090和T4 GPU上进行了性能分析。
  • 软件配置
    • 实现: FlashAttention使用CUDA实现。
    • 框架与库: PyTorch, HuggingFace Transformers, Megatron-LM, Apex AMP (用于混合精度训练)。
    • 基线: 与PyTorch、Megatron-LM中的标准注意力实现,以及Reformer、Longformer、BigBird等多种近似/稀疏注意力方法的公开实现进行比较。

A4 实验结果

更快的模型训练

  • BERT:在BERT-large的训练中,FlashAttention比Nvidia在MLPerf 1.1中创造的训练速度记录快15%。在8xA100 GPU上,达到72.0%的目标准确率平均耗时17.4分钟,而Nvidia的记录为20.4分钟(见表1)。
  • GPT-2:与广泛使用的HuggingFace和Megatron-LM实现相比,FlashAttention在训练GPT-2上实现了显著加速。相比HuggingFace,端到端速度提升高达3倍;相比Megatron-LM,速度提升高达1.7倍(见表2)。同时,FlashAttention达到了与基线相同的困惑度,数值稳定性得到保证。
  • 长距离竞技场 (LRA):在LRA基准测试中,使用FlashAttention的普通Transformer相比标准注意力实现平均提速2.4倍。而块稀疏FlashAttention比所有测试过的近似注意力方法都要快(见表3)。

表1:BERT-large的训练时间
表1:BERT-large的训练时间

表2:GPT-2 small和medium的训练时间
表2:GPT-2 small和medium的训练时间

表3:LRA基准测试上的性能
表3:LRA基准测试上的性能

更高质量的模型

  • 长上下文语言建模:FlashAttention使得GPT-2的上下文长度能扩展4倍(至4K),同时训练速度仍比使用1K上下文的Megatron-LM实现快30%。更长的上下文带来了0.7的困惑度提升(见表4)。
  • 长文档分类:在MIMIC-III和ECtHR数据集上,增加序列长度显著提升了性能。在MIMIC-III上,使用16K序列长度比512长度的micro F1分数高4.3点;在ECtHR上,8K长度比512长度高8.5点(见表5)。
  • Path-X和Path-256:FlashAttention使Transformer首次在Path-X(序列长度16K)任务上取得了61.4%的准确率,远超随机猜测。块稀疏FlashAttention进一步将序列长度扩展到64K,在Path-256任务上取得了63.1%的准确率,这也是首次有序列模型在该任务上取得超越随机猜测的性能(见表6)。

表4:GPT-2 small长上下文训练效果
表4:GPT-2 small长上下文训练效果

表5:长文档分类性能
表5:长文档分类性能

表6:Path-X和Path-256上的性能
表6:Path-X和Path-256上的性能

注意力基准测试

  • 运行时间:FlashAttention比标准注意力实现快达3倍(序列长度高达2K)。对于更长的序列(超过512-1024),一些近似注意力方法开始变得更快。然而,块稀疏FlashAttention在所有序列长度上都比所有已知的精确、稀疏和近似注意力实现更快(见图3左)。
  • 内存占用:FlashAttention和块稀疏FlashAttention的内存占用随序列长度线性增长。它比标准注意力实现内存效率高20倍,并且也优于大多数近似注意力基线。在所有方法中,除了Linformer,其他方法在序列长度达到64K之前都在A100 GPU上耗尽内存,而FlashAttention的内存效率仍然是Linformer的两倍(见图3右)。


图3:左图:前向传播+反向传播的运行时间。右图:注意力内存使用情况。

A5 结论

本文提出了FlashAttention,一种基于IO感知原则的快速且内存高效的精确注意力算法。通过分块和重计算技术,FlashAttention显著减少了GPU HBM和SRAM之间的内存读写次数,从而在不牺牲模型质量的前提下,实现了端到端的训练加速,并使得Transformer能够处理更长的序列。

局限性
1. CUDA编程:当前方法需要为每种新的注意力实现编写专门的CUDA核函数,这需要大量的底层编程和工程努力,并且可能不具备跨GPU架构的可移植性。未来需要一种能将高级语言(如PyTorch)编写的注意力算法编译为IO感知CUDA实现的方法,类似于图像处理领域的Halide【70, Jonathan Ragan-Kelley et al., Halide: a language and compiler for optimizing parallelism, locality, and recomputation in image processing pipelines, 2013, Acm Sigplan Notices】。

未来方向
1. IO感知的深度学习:IO感知的思想可以扩展到注意力之外。深度网络中的每一层都会与GPU HBM交互。本文的工作希望能启发社区为更多模块开发IO感知的实现。
2. 多GPU的IO感知方法:FlashAttention在单GPU上是近似最优的。然而,注意力计算可以在多个GPU上并行【72, Benjamin Recht and Christopher Ré, Parallel stochastic gradient algorithms for large-scale matrix completion, 2013, Mathematical Programming Computation】。这引入了新的IO分析层次——GPU之间的数据传输。本文希望能激发这方面的未来研究。

A6 附录

B 算法细节

B.1 内存高效的前向传播

核心挑战与方法。实现注意力机制内存高效的主要挑战在于Softmax操作,它耦合了K和V的列。我们的方法是分开计算Softmax的归一化常数来解耦这些列。这种技术【60, Maxim Milakov and Natalia Gimelshein, Online normalizer calculation for softmax, 2018, arXiv】在文献中已被用于证明注意力计算不需要二次方的额外内存【51, Nikita Kitaev et al., Reformer: The efficient transformer, 2020, The International Conference on Machine Learning (ICML)】【66, Markus N Rabe and Charles Staats, Self-attention does not need O(n^2) memory, 2021, arXiv】。为简化说明,我们在此省略了Softmax中的最大值移位步骤。

公式推导。注意力输出 O ∈ RN×d 的计算公式为:


我们定义Softmax的归一化常数 $D_i$:

输出的第i列 $O_i$ 可以表示为:

可以看出,一旦 $D_i$ 计算完毕,我们就可以通过累加 $\exp(Q_i^T K_j)V_j$ 来计算 $O_i$,而无需额外的内存。因此,前向传播可以用 O(N) 的额外内存完成:
1. 根据公式(1)计算所有i的 $D_i$,需要 O(N) 额外内存。
2. 根据公式(2)计算所有i的 $O_i$,需要 O(1) 额外内存。

B.2 内存高效的反向传播

反向传播推导。我们推导了注意力的反向传播过程,并证明它也可以在线性内存下完成。Rabe和Staats【66, Markus N Rabe and Charles Staats, Self-attention does not need O(n^2) memory, 2021, arXiv】提出通过对内存高效的前向传播应用梯度检查点来实现这一点。我们则直接推导反向传播过程。假设存在一个标量损失函数 L,输出梯度为 dO ∈ RN×d,我们需要计算输入梯度 dQ, dK, dV ∈ RN×d

dV的梯度。dV 的梯度很容易得到:$dV = P^T dO$。因此:


由于我们已经计算了 $D_i$,因此 $dV_j$ 可以通过重复求和来计算,无需额外内存。

dQ和dK的梯度。dQ和dK的梯度计算稍微复杂一些。我们首先看dP和dS的梯度。从公式(2)我们得到 $dP = dOV^T$,所以:


利用 $P_{i:} = \text{softmax}(S_{i:})$ 以及softmax的雅可比矩阵是 $\text{diag}(p) - pp^T$ 的事实,我们有:

其中 ◦ 表示逐点相乘。定义 $D'_{i} = \sum_j dP_{ij} P_{ij}$,则 $dS_{ij} = P_{ij} (dP_{ij} - D'_{i})$。因此,我们可以得到dQ和dK的梯度。由于 $S_{ij} = Q_i^T K_j$,所以:

以及

因此,反向传播也可以用 O(N) 的额外内存完成。

B.3 FlashAttention: 前向传播

完整算法描述。我们描述FlashAttention前向传播的完整细节(算法2)。该算法考虑了softmax缩放、掩码和dropout。

B.4 FlashAttention: 反向传播

标准反向传播。我们首先在算法3中描述了标准注意力的反向传播作为参考。
算法3:标准注意力反向传播
-

FlashAttention反向传播。FlashAttention的反向传播(算法4)有两大特点:
1. 无需存储dropout掩码: 我们不存储 $O(N^2)$ 大小的dropout掩码,而是保存前向传播中的伪随机数生成器状态,并在反向传播中重新生成它,仅需 O(N) 额外内存。
2. 高效计算softmax梯度: 我们不直接计算 $D'_{i} = \sum_j dP_{ij} P_{ij}$,而是重写为 $D'_{i} = (dO_{i:})^T O_{i:}$,这样就可以在SRAM中计算两个大小为d的向量的点积。

IO复杂度分析。与前向传播类似,我们分析反向传播的IO复杂度。
定理5。标准注意力的反向传播需要 $Θ(Nd + N^2)$ 次HBM访问,而FlashAttention的反向传播需要 $Θ(N^2d^2M^{-1})$ 次HBM访问。

B.5 与Rabe和Staats [66] 的比较

相似之处。FlashAttention和Rabe and Staats [66]都使用了分块技术来处理注意力矩阵,并通过在反向传播中重计算来减少内存占用。

主要区别
1. 优化目标不同:Rabe and Staats [66] 关注减少GPU所需的最大总内存,而FlashAttention关注减少内存访问次数。减少内存访问是提升运行速度的关键,并自然地减少了总内存需求。因此,FlashAttention比标准注意力快,而Rabe and Staats [66]的速度相当或稍慢。
2. 块间信息传递方式不同:Rabe and Staats [66] 为每个块保存临时输出和softmax统计量,最后合并。FlashAttention则在处理每个块后增量更新唯一的输出副本,因此总内存需求更小。
3. 反向传播计算方式不同:Rabe and Staats [66] 使用梯度检查点重计算每个块的注意力和临时输出。FlashAttention则通过解析简化反向传播过程,只重计算注意力矩阵,不重计算临时输出,这减少了反向传播的内存需求并带来了速度提升。

C 证明

  • 定理1的证明 (正确性、FLOPs和内存):通过对块索引j进行归纳,证明在每个外循环迭代后,HBM中存储的O, m, ℓ是处理了前j个K, V块后的正确结果。FLOPs主要来自矩阵乘法,总计为 $O(N^2d)$。额外内存主要用于存储统计量ℓ和m,为 $O(N)$。

  • 定理2的证明 (IO复杂度)

    • 标准注意力:每一步(计算S,计算P,计算O)都需要读/写大小为 $Θ(N^2)$ 或 $Θ(Nd)$ 的矩阵到HBM,总访问量为 $Θ(Nd + N^2)$。
    • FlashAttention:K和V的每个元素从HBM加载一次。对Q和O进行了 $T_c = N/B_c$ 次遍历。总HBM访问量为 $Θ(Nd + N \cdot T_c \cdot d)$。通过根据SRAM大小M优化块大小 $B_c$ 和 $B_r$ (设 $B_c \approx M/d, B_r \approx M/d, B_rB_c \approx M$),得到 $B_c = \Theta(M/d)$,最终HBM访问量为 $Θ(N^2d^2M^{-1})$。
  • 命题3的证明 (IO复杂度下界):采用反证法。假设存在一个算法在所有 $M \in [d, Nd]$ 范围内HBM访问次数为 $o(N^2d^2M^{-1})$。当 $M = Θ(Nd)$ 时,访问次数为 $o(Nd)$。然而,输入Q, K, V和输出O的大小均为 $Θ(Nd)$,它们最初都在HBM中,因此任何算法至少需要 $Ω(Nd)$ 次HBM访问才能读入数据和写回结果,产生矛盾。

  • 定理5的证明 (反向传播IO复杂度):证明过程与定理2类似。标准反向传播的每一步都需要 $Θ(Nd+N^2)$ 的HBM访问。FlashAttention反向传播的访问模式与前向传播相似,通过对块大小进行类似的约束分析,得出其HBM访问次数为 $Θ(N^2d^2M^{-1})$。

D 扩展细节

D.1 块稀疏FlashAttention

算法描述。完整的块稀疏FlashAttention算法在算法5中给出。该算法与算法2(FlashAttention前向传播)基本相同,唯一的区别是在内层循环中增加了一个条件判断 if M_ij != 0,只有当对应的块掩码为非零时,才执行该块的加载和计算。

IO复杂度证明 (命题4):证明与定理2类似。对于块稀疏情况,我们只需要加载与非零块对应的输入块。因此,HBM访问次数被非零块的比例s所缩放。然而,无论s多小,我们仍然需要写入完整的输出O ∈ RN×d。因此,HBM访问次数为 $Θ(Nd + s \cdot N^2d^2M^{-1})$。

D.2 潜在的扩展

  • 多GPU注意力:大型语言模型训练通常在多个GPU上进行,这引入了另一层内存层级(其他GPU的HBM)。对于超长序列,可以设计一种协同计算注意力的方案,考虑不同内存层级的不对称性。
  • 稀疏MLP层:典型的密集MLP层是计算密集型的。使用稀疏权重矩阵可以提高效率,但很多稀疏MLP实现反而变成了内存密集型,其加速效果与稀疏度不成正比。IO感知的实现可以缓解此问题。
  • 核机器学习:FlashAttention的思路可以应用于核方法。核矩阵 $K_{ij}$ 是由两个大小为d的向量 $x_i, x_j$ 计算得出的。与注意力矩阵类似,我们可以通过重复加载输入并重计算所需的核矩阵块来显著减少HBM访问。KeOps库【8, Benjamin Charlier et al., Kernel operations on the gpu, with autodiff, without memory overflows, 2021, Journal of Machine Learning Research】是这方面的一个成功例子。

E 完整的实验结果

BERT:遵循MLPerf 1.1的参考实现进行训练,使用LAMB优化器,批量大小448,在8xA100-80GB GPU上训练。

GPT-2:使用Huggingface和Megatron-LM的标准实现作为基线,遵循Megatron-LM的训练流程,在OpenWebtext数据集上训练,使用AdamW优化器和混合精度训练。图4显示FlashAttention与基线实现具有相同的验证困惑度曲线,验证了其数值稳定性。


图4:GPT-2 small/medium 使用两种实现的验证困惑度。我们确认 FlashAttention 产生了与 HuggingFace 基线实现相同的验证曲线。

LRA细节:遵循【80, Yi Tay et al., Long range arena: A benchmark for efficient transformers, 2020, ICLR】和【90, Yunyang Xiong et al., Nyströmformer: A nystöm-based algorithm for approximating self-attention, 2021, AAAI Conference on Artificial Intelligence】中的超参数设置。

与Apex FMHA的比较:Apex FMHA是当时针对短序列(<=512)最快的注意力实现之一。FlashAttention以FMHA代码为起点,应用分块和重计算来处理长序列并节省内存。在短序列上,FlashAttention与FMHA性能相当,前向稍快,后向稍慢,总体速度差异在±8%以内(见表7)。

表7:FlashAttention与FMHA的运行时间对比
表7:FlashAttention与FMHA的运行时间对比

在不同硬件上的加速效果
- A100 (d=64):通常有2-4倍的加速,尤其在有dropout和掩码时,核函数融合效果更佳(见图5)。
- A100 (d=128):头维度增加,需要更小的块来适应SRAM,导致整体加速比下降,但在因果掩码下仍有高达3倍的加速(见图6)。
- RTX 3090:由于其内存带宽低于A100,观察到稍高的加速比(2.5-4.5倍)(见图7)。
- T4:T4的SRAM更小,块大小受限,因此加速比较低,这与IO复杂度分析一致(见图8)。

图5:在A100上不同序列长度的加速比
图5:在A100上不同序列长度的加速比

图6:在A100上,头维度128时的加速比
图6:在A100上,头维度128时的加速比

图7:在RTX 3090上不同序列长度的加速比
图7:在RTX 3090上不同序列长度的加速比

图8:在T4上不同序列长度的加速比
图8:在T4上不同序列长度的加速比

完整的基准测试结果:附录E中包含了详尽的基准测试表格(表9至表21),涵盖了多种注意力实现(精确、近似、稀疏)在不同配置(有/无dropout/掩码)下的前向、后向和总运行时间,以及内存使用情况。这些表格系统地展示了FlashAttention和块稀疏FlashAttention相对于其他方法的性能优势。