文章标题: FlashAttention-2:通过更好的并行性和工作分区实现更快的注意力机制
作者: Tri Dao
机构: 1. 普林斯顿大学计算机科学系 2. 斯坦福大学计算机科学系

A1 主要贡献

本文的核心问题是解决 Transformer 模型在处理长序列时的性能瓶颈。注意力层的运行时间和内存占用随序列长度呈二次方增长,这限制了模型处理长文档、高分辨率图像或长视频的能力。尽管 FlashAttention 通过优化 GPU 内存使用,将内存占用从二次方降低到线性,并实现了2-4倍的加速,但其计算效率(FLOPs/s)仍远低于优化后的矩阵乘法(GEMM)操作,仅达到理论峰值的25-40%。

本文通过分析发现,FlashAttention 效率不高的主要原因是 GPU 上不同线程块和 warp 之间的工作分区不理想,导致了低占用率或不必要的共享内存读写。

为解决上述问题,本文提出了 FlashAttention-2,其研究目标是通过改进并行策略和工作分区来进一步提升注意力计算的效率。主要创新点和贡献如下:

  1. 算法调整以减少非矩阵乘法浮点运算:对 FlashAttention 算法进行了微调,减少了非矩阵乘法(non-matmul)的浮点运算次数。由于 GPU 上的专用计算单元(如 Tensor Cores)使得矩阵乘法的吞吐量远高于非矩阵乘法(可达16倍),这一优化能让计算时间更多地用于高效的矩阵乘法上。
  2. 增强并行性以提升GPU占用率:除了在批次大小和头数量维度上进行并行化,FlashAttention-2 还增加了沿序列长度维度的并行化。这在处理长序列(此时批次大小通常较小)的场景下,能显著提高 GPU 资源的利用率(即占用率)。
  3. 优化线程块内部工作分区:在每个线程块(thread block)内部,重新设计了不同 warp 之间的工作分配方式,以减少它们之间通过共享内存进行的通信和数据读写。

这些改进使得 FlashAttention-2 相比于 FlashAttention 实现了约2倍的速度提升,在 A100 GPU 上的前向传播计算效率达到了理论峰值的50-73%。在端到端的 GPT-style 模型训练中,每块 A100 GPU 的训练速度高达 225 TFLOPs/s,模型 FLOPs 利用率达到72%。

A3 背景知识

2.1 硬件特性

  • GPU性能特点。GPU 由计算单元(如浮点算术单元)和内存层次结构组成。大多数现代 GPU 包含专门用于加速低精度矩阵乘法的单元(例如,Nvidia GPU 上的 Tensor Cores 用于 FP16/BF16 矩阵乘法)。内存层次结构包括高带宽内存(HBM)和片上 SRAM(也称为共享内存)。以 A100 GPU 为例,它拥有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0TB/s,以及 108 个流式多处理器(SM),每个 SM 有 192KB 的片上 SRAM,其带宽估计约为 19TB/s 【索引6, Dissecting the Ampere GPU architecture via microbenchmarking,2021,GPU Technology Conference】【索引7, Dissecting the nvidia Volta GPU architecture via microbenchmarking,2018,arXiv】。由于 L2 缓存不能由程序员直接控制,本文的讨论主要集中在 HBM 和 SRAM 上。
  • 执行模型。GPU 拥有海量的线程来执行一个操作(称为核函数)。线程被组织成线程块,这些线程块被调度到流式多处理器(SMs)上运行。在每个线程块内部,线程被分组为 warps(一组32个线程)。一个 warp 内的线程可以通过快速的 shuffle 指令进行通信,或协作执行矩阵乘法。一个线程块内的 warps 可以通过读写共享内存进行通信。每个核函数从 HBM 加载输入到寄存器和 SRAM,进行计算,然后将输出写回 HBM。

2.2 标准注意力实现

  • 标准Attention实现。给定输入序列 $Q, K, V \in \mathbb{R}^{N \times d}$,其中 $N$ 是序列长度,$d$ 是头的维度,我们希望计算注意力输出 $O \in \mathbb{R}^{N \times d}$:
    公式1
    公式1

    其中 softmax 是按行应用的。对于多头注意力(MHA),这个相同的计算会在多个头上并行执行,并且在批次维度(一个批次中的输入序列数量)上也是并行的。
  • 反向传播。注意力的反向传播过程如下。设 $dO \in \mathbb{R}^{N \times d}$ 是 $O$ 相对于某个损失函数的梯度。那么根据链式法则(即反向传播):
    公式2
    公式2

    其中 $dsoftmax$ 是按行应用的 softmax 的梯度(反向传播)。可以推导出,如果对于某个向量 $z$ 和 $p$,有 $p = softmax(z)$,那么对于输出梯度 $dp$,输入梯度 $dz = (\text{diag}(p) - pp^\top)dp$。
  • 性能瓶颈。标准的注意力实现会将矩阵 S 和 P 物化到 HBM,这需要 $O(N^2)$ 的内存。通常 $N \gg d$($N$ 通常在 1k–8k 的量级,而 $d$ 约为 64–128)。标准的注意力实现 (1) 调用矩阵乘法(GEMM)子程序来计算 $S = QK^\top$,将结果写入 HBM,然后 (2) 从 HBM 加载 S 来计算 softmax 并将结果 P 写入 HBM,最后 (3) 调用 GEMM 得到 $O = PV$。由于大部分操作受限于内存带宽,大量的内存访问导致了较慢的执行时间。此外,由于必须物化 S 和 P,所需的内存是 $O(N^2)$。并且,必须保存 $P \in \mathbb{R}^{N \times N}$ 用于反向传播计算梯度。

2.3 FlashAttention

  • FlashAttention核心思想。为了在 GPU 等硬件加速器上加速注意力计算,【索引5, FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022,NeurIPS】提出了一种算法,通过减少内存读写来保持相同的输出(无近似)。
  • 2.3.1 前向传播。FlashAttention 应用了经典的 tiling(分块)技术来减少内存 IO,通过 (1) 将输入的块从 HBM 加载到 SRAM,(2) 计算与该块相关的注意力,然后 (3) 更新输出,而无需将大的中间矩阵 S 和 P 写入 HBM。由于 softmax 将整行或整块行耦合在一起,在线 softmax【索引11, Online normalizer calculation for softmax, 2018, arXiv】【索引13, Self-attention does not need O(n^2) memory, 2021, arXiv】可以将注意力计算分成块,并重新缩放每个块的输出,最终得到正确的结果(无近似)。通过显著减少内存读写量,FlashAttention 比优化的基线注意力实现快了2-4倍。
  • 在线Softmax技术。我们描述在线 softmax 技术【索引11, Online normalizer calculation for softmax, 2018, arXiv】以及它在注意力中的应用【索引13, Self-attention does not need O(n^2) memory, 2021, arXiv】。为简单起见,只考虑注意力矩阵 S 的一个行块,形式为 $[S^{(1)} S^{(2)}]$,其中 $S^{(1)}, S^{(2)} \in \mathbb{R}^{B_r \times B_c}$,$B_r$ 和 $B_c$ 分别是行和列的块大小。我们希望计算这个行块的 softmax 并与 V 相乘,V 的形式为 $[V^{(1)\top} V^{(2)\top}]^\top$,其中 $V^{(1)}, V^{(2)} \in \mathbb{R}^{B_c \times d}$。标准 softmax 的计算方式是:
    公式3
    公式3

    而在线 softmax 则是计算每个块的“局部” softmax,并通过重新缩放最终得到正确输出:
    公式4
    公式4

    FlashAttention 利用在线 softmax 实现分块(图1),以减少内存读写。
    图1:FlashAttention前向传播示意图,当键K和值V被划分为两个块时。通过计算每个块的注意力并重新缩放输出,我们最终得到正确答案,同时避免了对中间矩阵S和P的昂贵内存读写。图中简化了softmax中每个元素减去行最大值的步骤。
    图1:FlashAttention前向传播示意图,当键K和值V被划分为两个块时。通过计算每个块的注意力并重新缩放输出,我们最终得到正确答案,同时避免了对中间矩阵S和P的昂贵内存读写。图中简化了softmax中每个元素减去行最大值的步骤。
  • 2.3.2 反向传播。在反向传播过程中,当输入的 Q、K、V 块被加载到 SRAM 后,FlashAttention 通过重新计算注意力矩阵 S 和 P 的值,避免了存储这些大的中间值。由于不需要保存大小为 $N \times N$ 的大矩阵 S 和 P,FlashAttention 根据序列长度实现了 10-20 倍的内存节省(所需内存与序列长度 $N$ 成线性关系而非二次方关系)。由于减少了内存读写,反向传播也实现了 2-4 倍的执行速度提升。反向传播将分块技术应用于 2.2 节中的方程。尽管反向传播在概念上比前向传播简单(没有 softmax 的重新缩放),但实现上要复杂得多。这是因为在反向传播中需要在 SRAM 中保留更多的值来执行 5 次矩阵乘法,而前向传播中只有 2 次。

A2 方法细节

我们描述了 FlashAttention-2 算法,它包含了对 FlashAttention 的几处调整以减少非矩阵乘法 FLOPs。然后,我们描述了如何在不同的线程块上并行化计算以充分利用 GPU 资源。最后,我们描述了在一个线程块内如何在不同的 warp 之间划分工作以减少共享内存的访问量。这些改进带来了2-3倍的加速,这在第4节中得到了验证。

3.1 算法

  • 减少非矩阵乘法FLOPs。我们对 FlashAttention 算法进行了调整,以减少非矩阵乘法(non-matmul)的 FLOPs。这是因为现代 GPU 拥有专门的计算单元(例如 Nvidia GPU 上的 Tensor Cores),使得矩阵乘法运算速度快得多。例如,A100 GPU 的 FP16/BF16 矩阵乘法理论峰值吞吐量为 312 TFLOPs/s,但非矩阵乘法的 FP32 吞吐量仅为 19.5 TFLOPs/s。换个角度看,每个非矩阵乘法的 FLOPs 比矩阵乘法的 FLOPs 昂贵 16 倍。为了保持高吞吐量(例如,超过理论峰值 TFLOPs/s 的 50%),我们希望尽可能多地执行矩阵乘法 FLOPs。

3.1.1 前向传播

  • 算法微调。我们回顾了 2.3 节中展示的在线 softmax 技巧,并进行了两处微小的调整以减少非矩阵乘法 FLOPs:
    1. 我们不必用 $diag(l^{(2)})^{-1}$ 对输出更新的两个项都进行重新缩放:
      公式5
      公式5

      我们可以改为维护一个“未缩放”版本的 $O^{(2)}$ 并保留统计量 $l^{(2)}$:
      公式6
      公式6

      只有在循环的最后,我们才将最终的 $\tilde{O}^{(\text{last})}$ 乘以 $diag(l^{(\text{last})})^{-1}$ 来得到正确的输出。
    2. 我们不必为反向传播同时保存最大值 $m^{(j)}$ 和指数和 $l^{(j)}$。我们只需要存储 logsumexp $L^{(j)} = m^{(j)} + \log(l^{(j)})$。
  • 两块情况下的在线Softmax。在 2.3 节的简单两块情况下,在线 softmax 技巧现在变为:
    公式7
    公式7
  • FlashAttention-2 前向传播算法。我们在算法1中描述了完整的 FlashAttention-2 前向传播过程。
    • 算法1 FlashAttention-2 前向传播
    • 输入: 矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$ 在HBM中,块大小 $B_r, B_c$。
    • 将 $Q$ 分成 $T_r = \lceil N/B_r \rceil$ 个块 $Q_1, \dots, Q_{T_r}$,每个块大小为 $B_r \times d$。将 $K, V$ 分成 $T_c = \lceil N/B_c \rceil$ 个块 $K_1, \dots, K_{T_c}$ 和 $V_1, \dots, V_{T_c}$,每个块大小为 $B_c \times d$。
    • 将输出 $O \in \mathbb{R}^{N \times d}$ 分成 $T_r$ 个块 $O_1, \dots, O_{T_r}$,每个块大小为 $B_r \times d$。将 logsumexp $L$ 分成 $T_r$ 个块 $L_1, \dots, L_{T_r}$,每个块大小为 $B_r$。
    • 对于 $1 \le i \le T_r$ 执行:
    • 从 HBM 加载 $Q_i$ 到片上 SRAM。
    • 在片上初始化 $O_i^{(0)} = \mathbf{0}_{B_r \times d} \in \mathbb{R}^{B_r \times d}$, $l_i^{(0)} = \mathbf{0}_{B_r} \in \mathbb{R}^{B_r}$, $m_i^{(0)} = (-\infty)_{B_r} \in \mathbb{R}^{B_r}$。
    • 对于 $1 \le j \le T_c$ 执行:
    • 从 HBM 加载 $K_j, V_j$ 到片上 SRAM。
    • 在片上计算 $S_{ij}^{(j)} = Q_i K_j^\top \in \mathbb{R}^{B_r \times B_c}$。
    • 在片上计算 $m_i^{(j)} = \max(m_i^{(j-1)}, \text{rowmax}(S_{ij}^{(j)})) \in \mathbb{R}^{B_r}$, $\tilde{P}_{ij}^{(j)} = \exp(S_{ij}^{(j)} - m_i^{(j)}) \in \mathbb{R}^{B_r \times B_c}$ (逐元素), $l_i^{(j)} = e^{m_i^{(j-1)} - m_i^{(j)}} l_i^{(j-1)} + \text{rowsum}(\tilde{P}_{ij}^{(j)}) \in \mathbb{R}^{B_r}$。
    • 在片上计算 $O_i^{(j)} = \text{diag}(e^{m_i^{(j-1)} - m_i^{(j)}})^{-1} O_i^{(j-1)} + \tilde{P}_{ij}^{(j)} V_j$。
    • 结束 for循环。
    • 在片上计算 $O_i = \text{diag}(l_i^{(T_c)})^{-1} O_i^{(T_c)}$。
    • 在片上计算 $L_i = m_i^{(T_c)} + \log(l_i^{(T_c)})$。
    • 将 $O_i$ 作为 $O$ 的第 $i$ 个块写回 HBM。
    • 将 $L_i$ 作为 $L$ 的第 $i$ 个块写回 HBM。
    • 结束 for循环。
    • 返回 输出 $O$ 和 logsumexp $L$。
  • Causal masking。注意力机制的一个常见用例是自回归语言建模,此时需要对注意力矩阵 S 应用因果掩码(即,任何满足 $j > i$ 的项 $S_{ij}$ 都被设为 $-\infty$)。
    1. 由于 FlashAttention 和 FlashAttention-2 已经是按块操作,对于那些所有列索引都大于行索引的块(在长序列中约占一半的块),我们可以跳过该块的计算。与不使用因果掩码的注意力相比,这带来了约 1.7-1.8 倍的加速。
    2. 对于行索引保证严格小于列索引的块,我们不需要应用因果掩码。这意味着对于每一行,我们只需要对一个块应用因果掩码(假设块是方形的)。
  • 正确性、运行时间和内存需求。与 FlashAttention 一样,算法1 返回正确的输出 $O = \text{softmax}(QK^\top)V$(无近似),使用 $O(N^2 d)$ 的 FLOPs,并且除了输入和输出外,需要 $O(N)$ 的额外内存(用于存储 logsumexp $L$)。其证明与【索引5, FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022,NeurIPS】的证明几乎相同,因此在此省略。

3.1.2 反向传播

  • 算法微调。FlashAttention-2 的反向传播过程与 FlashAttention 的几乎相同。我们做了一个微小的调整,只使用行向的 logsumexp $L$,而不是同时使用行向最大值和行向指数和。为完整起见,我们在算法2中包含了反向传播的描述。
  • FlashAttention-2 反向传播算法
    • 算法2 FlashAttention-2 反向传播
    • 输入: 矩阵 $Q, K, V, O, dO \in \mathbb{R}^{N \times d}$ 在HBM中,向量 $L \in \mathbb{R}^N$ 在HBM中,块大小 $B_r, B_c$。
    • 将 $Q$ 分成 $T_c = \lceil N/B_c \rceil$ 个块 $Q_1, \dots, Q_{T_c}$,大小为 $B_c \times d$。将 $K, V$ 分成 $T_r = \lceil N/B_r \rceil$ 个块 $K_1, \dots, K_{T_r}$ 和 $V_1, \dots, V_{T_r}$,大小为 $B_r \times d$。
    • 将 $O$ 分成 $T_c$ 个块 $O_1, \dots, O_{T_c}$,将 $dO$ 分成 $T_c$ 个块 $dO_1, \dots, dO_{T_c}$,大小为 $B_c \times d$。将 $L$ 分成 $T_c$ 个块 $L_1, \dots, L_{T_c}$,大小为 $B_c$。
    • 初始化 $dQ = \mathbf{0}_{N \times d}$ 在HBM中,并将其分成 $T_c$ 个块 $dQ_1, \dots, dQ_{T_c}$。将 $dK, dV \in \mathbb{R}^{N \times d}$ 分成 $T_r$ 个块 $dK_1, \dots, dK_{T_r}$ 和 $dV_1, \dots, dV_{T_r}$。
    • 计算 $D = \text{rowsum}(dO \circ O) \in \mathbb{R}^N$ (逐元素相乘),将 $D$ 写回 HBM 并分成 $T_c$ 个块 $D_1, \dots, D_{T_c}$。
    • 对于 $1 \le j \le T_r$ 执行:
    • 从 HBM 加载 $K_j, V_j$ 到片上 SRAM。
    • 在 SRAM 中初始化 $dK_j = \mathbf{0}_{B_r \times d}, dV_j = \mathbf{0}_{B_r \times d}$。
    • 对于 $1 \le i \le T_c$ 执行:
    • 从 HBM 加载 $Q_i, O_i, dO_i, dQ_i, L_i, D_i$ 到片上 SRAM。
    • 在片上计算 $S_{ij}^{(j)} = Q_i K_j^\top \in \mathbb{R}^{B_c \times B_r}$。
    • 在片上计算 $P_{ij}^{(j)} = \exp(S_{ij} - L_i) \in \mathbb{R}^{B_c \times B_r}$。
    • 在片上计算 $dV_j \leftarrow dV_j + (P_{ij}^{(j)})^\top dO_i \in \mathbb{R}^{B_r \times d}$。
    • 在片上计算 $dP_{ij}^{(j)} = dO_i V_j^\top \in \mathbb{R}^{B_c \times B_r}$。
    • 在片上计算 $dS_{ij}^{(j)} = P_{ij}^{(j)} \circ (dP_{ij}^{(j)} - D_i) \in \mathbb{R}^{B_c \times B_r}$。
    • 从 HBM 加载 $dQ_i$ 到 SRAM,然后在片上更新 $dQ_i \leftarrow dQ_i + dS_{ij}^{(j)} K_j \in \mathbb{R}^{B_c \times d}$,并写回 HBM。
    • 在片上计算 $dK_j \leftarrow dK_j + (dS_{ij}^{(j)})^\top Q_i \in \mathbb{R}^{B_r \times d}$。
    • 结束 for循环。
    • 将 $dK_j, dV_j$ 写回 HBM。
    • 结束 for循环。
    • 返回 $dQ, dK, dV$。
  • 多查询注意力和分组查询注意力。多查询注意力(MQA)【索引15, Fast transformer decoding: One write-head is all you need, 2019, arXiv】和分组查询注意力(GQA)【索引1, Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023, arXiv】是注意力的变体,其中多个查询头关注同一个键和值头,以减少推理过程中的 KV 缓存大小。我们不是显式地复制键和值头进行计算,而是隐式地操纵头的索引来执行相同的计算。在反向传播中,我们需要对被隐式复制的不同头的梯度 dK 和 dV 进行求和。

3.2 并行性

  • FlashAttention的并行策略。第一版 FlashAttention 在批次大小和头数量上进行并行化。我们使用1个线程块处理一个注意力头,总共有 batch_size × num_heads 个线程块。每个线程块被调度到一个流式多处理器(SM)上运行,例如 A100 GPU 上有108个 SM。当线程块总数很大(比如 ≥ 80)时,这种调度是高效的,因为我们可以有效地利用 GPU 上几乎所有的计算资源。
  • FlashAttention-2的并行策略。在处理长序列(这通常意味着小批次大小或少头数)的情况下,为了更好地利用 GPU 上的多处理器,我们现在额外在序列长度维度上进行并行化。这在该场景下带来了显著的加速。
  • 前向传播并行化。我们看到外层循环(在序列长度上)是易于并行的,我们将其调度到不同的线程块上,这些线程块之间不需要通信。我们也在批次维度和头数量维度上进行并行化,就像 FlashAttention 中所做的那样。在序列长度上增加并行性有助于在批次大小和头数量较少时提高占用率(GPU 资源被使用的比例),从而在这种情况下实现加速。这些思想,即将循环顺序颠倒(外层循环遍历行块,内层循环遍历列块,而不是原始 FlashAttention 论文中的方式),以及在序列长度维度上并行化,最初是由 Phil Tillet 在 Triton【索引17, Triton: an intermediate language and compiler for tiled neural network computations, 2019, ACM SIGPLAN International Workshop on Machine Learning and Programming Languages】的实现中提出和实现的。
  • 反向传播并行化。注意到在算法2中,不同列块之间唯一的共享计算是更新 dQ,我们需要从 HBM 加载 $dQ_i$ 到 SRAM,然后在片上更新 $dQ_i \leftarrow dQ_i + dS_{ij}^{(j)} K_j$,并写回 HBM。因此,我们也在序列长度维度上进行并行化,并为反向传播的每个列块调度1个线程块。我们使用原子加法在不同线程块之间进行通信以更新 dQ。
  • 并行化方案示意图。我们在图2中描述了并行化方案。
    图2:在前向传播(左)中,我们将工作者(线程块)并行化,每个工作者负责注意力矩阵的一个行块。在反向传播(右)中,每个工作者负责注意力矩阵的一个列块。
    图2:在前向传播(左)中,我们将工作者(线程块)并行化,每个工作者负责注意力矩阵的一个行块。在反向传播(右)中,每个工作者负责注意力矩阵的一个列块。

3.3 Warp之间的工作分区

  • 线程块内部分工。第3.2节描述了我们如何调度线程块,即使在每个线程块内部,我们也必须决定如何在不同的 warp 之间划分工作。我们通常每个线程块使用4或8个 warp,分区方式如图3所示。
  • 前向传播分工。对于每个块,FlashAttention 将 K 和 V 划分给4个 warp,而 Q 对所有 warp 可见。每个 warp 进行乘法得到 QK^T 的一部分,然后它们需要与 V 的一部分相乘并通过通信相加结果。这被称为“split-K”方案。然而,这种方式效率低下,因为所有 warp 都需要将其-中间结果写出到共享内存,进行同步,然后相加中间结果。这些共享内存的读写减慢了 FlashAttention 的前向传播速度。
  • FlashAttention-2的分工方案。在 FlashAttention-2 中,我们改为将 Q 划分给4个 warp,而 K 和 V 对所有 warp 可见。在每个 warp 执行矩阵乘法得到 QK^T 的一部分后,它们只需与共享的 V 部分相乘即可得到对应的输出部分。warp 之间无需通信。减少共享内存的读写带来了加速(第4节)。
    图3:前向传播中不同warp之间的工作分区
    图3:前向传播中不同warp之间的工作分区
  • 反向传播分工。类似地,对于反向传播,我们选择划分 warp 以避免“split-K”方案。然而,由于所有不同的输入和梯度 Q, K, V, O, dO, dQ, dK, dV 之间更复杂的依赖关系,它仍然需要一些同步。尽管如此,避免“split-K”减少了共享内存的读写,并再次带来了加速(第4节)。
  • 调整块大小。增加块大小通常会减少共享内存的加载/存储,但会增加所需的寄存器数量和共享内存总量。超过某个块大小时,寄存器溢出会导致显著的减速,或者所需的共享内存量大于 GPU 可用量,导致核函数根本无法运行。我们通常选择大小为 {64, 128} × {64, 128} 的块,具体取决于头的维度 $d$ 和设备的共享内存大小。我们为每个头的维度手动调整,因为基本上只有4种块大小的选择,但这可以从自动调优中受益以避免这种手动劳动。我们将其留给未来的工作。

A4 实验环境

  • 硬件配置
    • 主要基准测试在单张 NVIDIA A100 80GB SXM4 GPU 上进行。
    • 端到端训练性能测试在一个包含 8块 A100 80GB SXM GPU 的节点上进行。
    • 还提供了一组在 H100 GPU 上的基准测试结果。
  • 软件配置
    • FlashAttention-2 的实现与多个基线进行了比较,包括:
      • PyTorch 中的标准注意力实现。
      • 原始的 FlashAttention。
      • 在 Triton 中实现的 FlashAttention。
      • 在 xformers 库中的 FlashAttention 实现("cutlass" 版本)。
    • 端到端训练的 FLOPs 计算遵循 Megatron-LM【索引16, Megatron-LM: Training multi-billion parameter language models using model parallelism, 2019, arXiv】库的公式。
  • 模型架构
    • 在端到端训练评估中,使用了参数量分别为 1.3B 和 2.7B 的 GPT-style 模型。
    • 在注意力基准测试中,模型隐藏层维度设置为2048,头维度设置为64或128(对应32个头或16个头)。
  • 数据集和工作负载
    • 注意力基准测试:序列长度从512变化到16k。批次大小被动态调整以使每个批次的总 tokens 数保持在16k。测试涵盖了有和没有因果掩码两种情况。
    • 端到端训练:使用了2k和8k两种序列长度进行训练。

A4 实验结果

注意力基准测试

  • 实验内容:在 A100 80GB SXM4 GPU 上,针对不同设置(有/无因果掩码,头维度为64或128),测量了不同注意力实现的运行时间。
  • 实验结果
    • 与FlashAttention对比:FlashAttention-2 比 FlashAttention 和 xformers 中的 FlashAttention("cutlass"实现)快约 2倍
    • 与Triton实现对比:在前向传播中,FlashAttention-2 比 Triton 实现的 FlashAttention 快约 1.3-1.5倍;在反向传播中快约 2倍
    • 与PyTorch标准实现对比:FlashAttention-2 的速度可达 PyTorch 标准实现的 10倍
    • 性能峰值:在 A100上,FlashAttention-2 的前向传播速度最高达到 230 TFLOPs/s,达到理论峰值的 73%(见图5),反向传播最高达到理论峰值的 63%(见图6)。
    • H100初步结果:在 H100 GPU 上,未使用任何新特性(如 TMA 和第四代 Tensor Cores)的相同实现,速度已高达 335 TFLOPs/s(见图7)。
  • 图表引用
    图4:A100 GPU 上的注意力前向+反向传播速度
    图4:A100 GPU 上的注意力前向+反向传播速度

    图5:A100 GPU 上的注意力前向传播速度
    图5:A100 GPU 上的注意力前向传播速度

    图6:A100 GPU 上的注意力反向传播速度
    图6:A100 GPU 上的注意力反向传播速度

    图7:H100 GPU 上的注意力前向+反向传播速度
    图7:H100 GPU 上的注意力前向+反向传播速度

端到端性能

  • 实验内容:在 8xA100 80GB SXM 平台上,训练 1.3B 和 2.7B 参数的 GPT-style 模型,测量训练吞吐量。
  • 实验结果
    • 与未使用 FlashAttention 的基线相比,FlashAttention-2 带来了高达 2.8倍 的速度提升。
    • 与使用原始 FlashAttention 相比,FlashAttention-2 带来了高达 1.3倍 的速度提升。
    • 在训练中,每块 A100 GPU 的速度最高达到 225 TFLOPs/s,实现了 72% 的模型 FLOPs 利用率
  • 结论:内核级别的优化成功转化为显著的端到端训练加速。
  • 图表引用
    表1:在8xA100 GPU上训练GPT-style模型的速度(TFLOPs/s/GPU)。FlashAttention-2达到了高达225 TFLOPs/s(72%的模型FLOPs利用率)。我们与一个没有运行FlashAttention的基线进行了比较。
    表1
    表1

A5 结论

FlashAttention-2 比 FlashAttention 快2倍,这意味着现在训练一个16k上下文长度的模型的成本与之前训练一个8k上下文长度的模型相当。这一进步有望推动模型在理解长篇书籍报告、高分辨率图像、音频和视频等领域的应用。同时,FlashAttention-2 也将加速现有模型的训练、微调和推理过程。

未来工作展望
1. 扩展硬件和数据类型支持:计划与研究人员和工程师合作,将 FlashAttention 推广到不同类型的设备(如 H100 GPU、AMD GPU)和新的数据类型(如 FP8)。
2. 针对 H100 的深度优化:下一步计划是优化 FlashAttention-2 以利用 H100 GPU 的新硬件特性(如 TMA、第四代 Tensor Cores、FP8)。
3. 结合高级算法:将 FlashAttention-2 的底层优化与高级算法(如局部注意力、扩张注意力、块稀疏注意力)相结合,可能使我们能够训练上下文更长的 AI 模型。
4. 提升可编程性:与编译器研究者合作,使这些优化技术更易于编程实现。

方法细节中的引用汇总

  • [1] Joshua Ainslie, et al. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
    • 引用位置: A2 方法细节 -> 3.1.2 反向传播 -> 多查询注意力和分组查询注意力。
    • 引用描述: 该文献提出了分组查询注意力(GQA),作为一种注意力变体,FlashAttention-2 对其提供了支持。
  • [5] Tri Dao, et al. FlashAttention: Fast and memory-efficient exact attention with IO-awareness. In Advances in Neural Information Processing Systems, 2022.
    • 引用位置: A3 背景知识 -> 2.3 FlashAttention;A2 方法细节 -> 3.1.1 前向传播 -> 正确性、运行时间和内存需求。
    • 引用描述: 该文献是 FlashAttention-2 的基础,提出了原始的 FlashAttention 算法。本文多次引用它作为比较基线和理论基础。
  • [6] Zhe Jia and Peter Van Sandt. Dissecting the Ampere GPU architecture via microbenchmarking. GPU Technology Conference, 2021.
    • 引用位置: A3 背景知识 -> 2.1 硬件特性。
    • 引用描述: 引用该文献来说明 A100 GPU 的硬件参数,特别是 HBM 和 SRAM 的带宽。
  • [7] Zhe Jia, et al. Dissecting the nvidia Volta GPU architecture via microbenchmarking. arXiv preprint arXiv:1804.06826, 2018.
    • 引用位置: A3 背景知识 -> 2.1 硬件特性。
    • 引用描述: 与[6]一同被引用,用于提供 GPU 硬件微基准测试信息。
  • [11] Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. arXiv preprint arXiv:1805.02867, 2018.
    • 引用位置: A3 背景知识 -> 2.3.1 前向传播。
    • 引用描述: 该文献提出了在线 softmax 技术,这是 FlashAttention 实现分块计算的核心技术之一。
  • [13] Markus N Rabe and Charles Staats. Self-attention does not need O(n^2) memory. arXiv preprint arXiv:2112.05682, 2021.
    • 引用位置: A3 背景知识 -> 2.3.1 前向传播。
    • 引用描述: 该文献展示了在线 softmax 技术在注意力机制中的应用,为 FlashAttention 提供了思路。
  • [15] Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
    • 引用位置: A2 方法细节 -> 3.1.2 反向传播 -> 多查询注意力和分组查询注意力。
    • 引用描述: 该文献提出了多查询注意力(MQA),FlashAttention-2 对其提供了支持。
  • [16] Mohammad Shoeybi, et al. Megatron-LM: Training multi-billion parameter language models using model parallelism. arXiv preprint arXiv:1909.08053, 2019.
    • 引用位置: A4 实验环境。
    • 引用描述: 引用该文献的 FLOPs 计算公式,以确保与现有大型语言模型训练工作的评估标准保持一致。
  • [17] Philippe Tillet, et al. Triton: an intermediate language and compiler for tiled neural network computations. In Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages, pages 10–19, 2019.
    • 引用位置: A2 方法细节 -> 3.2 并行性 -> 前向传播并行化。
    • 引用描述: 该文献提出的 Triton 语言及其实现启发了 FlashAttention-2 在序列长度维度上进行并行化的思想。