Flashmask: Efficient and Rich Mask Extension of Flashattention
Flashmask: Efficient and Rich Mask Extension of Flashattention
作者/机构: Guoxia Wang, Jinle Zeng, Xiyuan Xiao, Siming Wu, Jiabin Yang, Lujing Zheng, Zeyu Chen, Jiang Bian, Dianhai Yu, Haifeng Wang (Baidu Inc.)
A1 主要贡献
核心问题: 传统(vanilla)注意力机制的计算和内存需求随序列长度 N 呈二次方($O(N^2)$)增长,给处理长序列带来了巨大挑战。FlashAttention 通过消除 $O(N^2)$ 的内存依赖和IO感知的内存优化,缓解了这些问题,但其对特定注意力掩码类型的原生支持有限,无法内在支持更复杂的掩码需求。以往的方法采用内存复杂度为 $O(N^2)$ 的密集掩码,导致效率低下。
研究目标: 本文旨在扩展 FlashAttention 的能力,提出一种名为 FLASHMASK 的新方法,以高效地表示和处理更广泛、更复杂的注意力掩码类型,同时保持线性内存复杂度和高计算效率,特别适用于长上下文序列建模。
创新点:
1. 提出了一种新颖的列式稀疏掩码表示法:该方法可以高效地支持更广泛的掩码类型,从而实现更灵活的注意力机制。它将二维的密集掩码压缩为一维表示,实现了 $O(N)$ 的线性内存复杂度。
2. 扩展了 FlashAttention 的掩码能力:通过集成优化的核函数实现,确保了高计算效率,同时不牺牲计算精度。这种方法利用注意力掩码中的稀疏性来跳过不必要的计算。
3. 在多种场景下验证了有效性:通过在 SFT、LoRA、DPO 和 RM 等大语言模型(LLM)微调和对齐训练中的评估,证明了 FLASHMASK 的多功能性和鲁棒性。实验表明,与现有的 FlashAttention 密集掩码方法相比,端到端速度提升了1.65倍至3.22倍,显著减少了计算和内存开销,非常适合长上下文建模。
A3 背景知识
2.1 注意力掩码类型
Transformer模型中的多种注意力掩码类型。Transformer 模型在各种任务中表现出卓越的通用性,每种任务都受益于不同的注意力掩码类型,如图1所示。因果掩码 (Causal Mask) 主要用于自回归模型中预测序列中的下一个词元,确保每个词元只关注前面的词元,避免未来信息泄露【Vaswani et al. (2017), Attention is all you need, 2017, NIPS】。滑动窗口掩码 (Sliding Window Mask) 通过允许词元关注一个固定大小的邻近词元窗口来捕捉局部上下文,平衡了计算效率与捕捉局部依赖的能力【Beltagy et al. (2020), Longformer: The long-document transformer, 2020, arXiv】。因果文档掩码 (Causal Document Mask) 应用于高效序列打包和批内/词元内技术等方法中,通过确保词元只关注同一文档内的先前词元,在不降低性能的情况下加速大语言模型【Krell et al. (2021), Efficient sequence packing without cross-contamination: Accelerating large language models without impacting performance, 2021, arXiv; Iyer et al. (2022), Opt-iml: Scaling language model instruction meta learning through the lens of generalization, 2022, arXiv; Dubey et al. (2024), The llama 3 herd of models, 2024, arXiv】。文档掩码 (Document Mask),或称双向注意力,允许词元关注同一文档内的所有其他词元,有助于从两个方向学习上下文,广泛应用于BERT和像NaViT这样的视觉Transformer模型中【Devlin et al. (2018), Bert: Pre-training of deep bidirectional transformers for language understanding, 2018, arXiv; Dehghani et al. (2024), Patch n’ pack: Navit, a vision transformer for any aspect ratio and resolution, 2024, NeurIPS】。
图 1: FLASHMASK 概览。(a) FLASHMASK 支持的掩码类型,(b) FLASHMASK 中的列式稀疏表示法,(c) FLASHMASK 的高效实现。
特定任务的复杂掩码类型。共享问题掩码 (Shared Question Mask) 用于奖励模型(RM)和直接偏好优化(DPO)模型中,允许多个答案共享同一个问题,从而消除冗余计算并加速训练【Ouyang et al. (2022), Training language models to follow instructions with human feedback, 2022, NeurIPS】。全局+滑动窗口掩码 (Global + Sliding Window Mask) 结合了全局注意力和滑动窗口注意力,其中全局词元关注所有词元,而其他词元使用滑动窗口掩码,有效处理需要全局上下文和局部细节的任务【Zaheer et al. (2020), Big bird: Transformers for longer sequences, 2020, NeurIPS】。
用于长上下文和稀疏优化的掩码类型。因果块状掩码 (Causal BlockWise Mask) 主要用于上下文学习,将序列划分为块,其中示例只关注小块内的附近示例,而测试示例可以关注所有示例,从而研究模型在长上下文任务中的性能提升【Bertsch et al. (2024), In-context learning with long-context models: An in-depth exploration, 2024, arXiv】。前缀LM因果掩码 (Prefix LM Causal Mask) 专为语言建模任务定制,允许前缀关注所有词元,以根据前缀生成连贯的文本【Raffel et al. (2020), Exploring the limits of transfer learning with a unified text-to-text transformer, 2020, JMLR】。前缀文档掩码 (Prefix Document Mask) 将此概念扩展到多个文档,其中每个文档中的前缀关注该文档内的所有词元,但不跨文档。
用于性能优化的稀疏和随机掩码类型。QK-稀疏掩码 (QK-Sparse Mask) 通过稀疏化查询-键对来优化自注意力,减少计算负载同时保持性能,这对于大规模模型尤其有益【Kitaev et al. (2020), Reformer: The efficient transformer, 2020, arXiv】。哈希稀疏掩码 (Hash-Sparse Mask) 采用局部敏感哈希将序列划分为更小的块,为长序列实现高效的稀疏注意力【Kitaev et al. (2020), Reformer: The efficient transformer, 2020, arXiv】。最后,随机驱逐掩码 (Random Eviction Mask) 通过在训练期间随机屏蔽词元来引入随机性,有助于泛化并模拟键值(KV)缓存驱逐过程,以在不发生内存溢出的情况下处理长序列【Chen et al. (2024), NACL: A general and effective KV cache eviction framework for LLM at inference time, 2024, ACL】。
2.2 支持的注意力掩码
不同注意力实现对掩码的支持。注意力机制是Transformer模型的基础,各种掩码类型支持不同的注意力模式。如公式2所示,传统的注意力机制通过一个密集的掩码矩阵支持任意掩码类型:
$$S = \frac{QK^\top}{\sqrt{d_k}} \in \mathbb{R}^{N \times N}, \quad P = \text{Softmax}(S + M) \in \mathbb{R}^{N \times N}, \quad O = PV \in \mathbb{R}^{N \times d},$$其中 $Q, K, V \in \mathbb{R}^{N \times d}$ 是输入序列,$M \in \mathbb{R}^{N \times N}$ 是注意力掩码,$N$ 是序列长度,$d$ 是头维度。掩码 $M$ 通过与 $S$ 进行逐元素相加来调节词元的可见性。虽然这种方法支持任意掩码类型,但它会产生 $O(N^2)$ 的内存复杂度,限制了其对长序列的可扩展性。
现有优化方法的掩码支持。FlashAttention【Dao et al. (2022), Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, NeurIPS; Dao (2023), Flashattention-2: Faster attention with better parallelism and work partitioning, 2023, arXiv】通过IO感知的读/写操作和平铺技术解决了这一限制,无需中间矩阵 $S \in \mathbb{R}^{N \times N}$ 和显式掩码 $M$。然而,FlashAttention仅在其内核内支持预定的掩码模式,如因果掩码、滑动窗口掩码、因果文档掩码和文档掩码,如图1所示。
对更灵活掩码支持的探索。xFormers【Lefaudeux et al. (2022), xformers: A modular and hackable transformer modelling library, 2022, GitHub】扩展了FlashAttention的功能,支持带有对角线偏移的掩码。它使用累积序列长度来表示文档掩码,实现了 $O(B)$ 的内存复杂度,其中 $B$ 是文档数量。
最新的灵活掩码实现。FlexAttention【He et al. (2024), FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention, 2024, PyTorch Blog】引入了一种基于深度学习编译器技术的更灵活的掩码描述方法。通过将块掩码与基于表达式的描述相结合,它可以支持任意掩码类型。虽然这种方法通过基于块的处理显著减少了内存开销,但其内存复杂度仍然是 $O(N_{blocks})$。
FLASHMASK的提出。我们提出的方法FLASHMASK扩展了FlashAttention的掩码支持能力。它引入了一种灵活的、列式的稀疏掩码表示法,涵盖了主流Transformer建模的大部分需求。如图1(b)所示,FLASHMASK以每列为基础表示需要屏蔽的区间,实现了 $O(N)$ 的内存复杂度。这种方法弥合了掩码灵活性和计算效率之间的差距,为大规模Transformer模型中的注意力机制提供了一个更通用的解决方案。
2.3 注意力优化技术
注意力的计算与内存挑战。如公式2所述,注意力机制在计算 $QK^T$ 时面临显著的计算和内存挑战。随着序列长度 $N$ 的增加,得到的注意力分数矩阵呈二次方增长,导致复杂度为 $O(N^2)$。为了解决这个可扩展性问题,研究人员提出了各种优化技术,重点关注内存效率和计算速度。
内存效率的提升。内存高效注意力(MEA)【Rabe & Staats (2021), Self-attention does not need o (n2) memory, 2021, arXiv】是模型训练优化的一个显著进步。通过利用在线Softmax【Milakov & Gimelshein (2018), Online normalizer calculation for softmax, 2018, arXiv】和分块技术,MEA将内存需求从 $O(N^2)$ 降低到 $O(\sqrt{N})$,从而能够在现有硬件限制内使用更大的模型或更长的序列长度。在此基础上,FlashAttention【Dao et al. (2022), Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, NeurIPS; Dao (2023), Flashattention-2: Faster attention with better parallelism and work partitioning, 2023, arXiv】通过IO感知的内存读/写优化专注于减少注意力延迟。在计算过程中利用平铺技术,FlashAttention实现了 $O(N)$ 的内存开销,在没有自定义掩码需求的任务中特别有效。此外,FlashAttention扩展到块稀疏FlashAttention,引入了一个二维块掩码矩阵表示来指示被掩码的平铺块。这一创新允许跳过对被掩码块的计算,从而加速了处理过程。
针对特定掩码的解决方案。对于需要特定注意力掩码的场景,出现了一些量身定制的解决方案。稀疏因果Flash Attention (SCFA)【Pagliardini et al. (2023), Faster causal attention over large sequences through sparse flash attention, 2023, NeurIPS】扩展了FlashAttention,以优化因果注意力结构中的QK-稀疏和哈希稀疏场景。SCFA使用原始未压缩张量中查询和键的索引来描述掩码,从而能够省略对被掩码块的计算并提高计算效率。FlexAttention【He et al. (2024), FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention, 2024, PyTorch Blog】利用编译器技术简化掩码注意力的实现,利用注意力掩码中的稀疏性跳过某些被掩码的块,以实现更快的速度。然而,特别是在复杂的掩码模式下,仍有优化的空间。
FLASHMASK的贡献。我们提出的方法FLASHMASK在这些进步的基础上,支持自定义的复杂注意力掩码。FLASHMASK将内存复杂度从 $O(N^2)$ 降低到 $O(N)$,同时利用注意力掩码中的稀疏性跳过被掩码的块。通过严格的工程优化,FLASHMASK实现了比FlexAttention更优的计算速度,特别是在具有复杂掩码需求的任务中。通过综合现有方法的优点与新颖的优化技术,FLASHMASK代表了注意力机制效率的重大进步,解决了大规模Transformer中的内存限制和计算速度问题。
3 观察
注意力掩码的连续性特征。在当前训练Transformer模型的范式中,注意力机制可以根据其因果性进行广泛分类,如第2.1节介绍和图1(a)所示。这些表示涵盖了训练场景中遇到的大多数掩码类型。考虑注意力分数矩阵 $S$,其中每个元素 $S_{ij}$ 代表第 $i$ 个查询词元对第 $j$ 个键词元的关注度。从键词元的角度,我们观察到一个关键模式,即查询词元如何关注每个键词元。
核心观察与一维表示法。我们的关键观察是,查询词元无法关注某些键词元的现象表现出一种连续性。这种连续性使我们能够将二维的密集掩码 $M$ 转换为一种更紧凑的一维表示,使用行索引区间,如图1(b)所示。形式上,我们可以将这种转换表示为:
$$M_j = [start_j, end_j), \quad \forall j \in \{1, \dots, N\}$$其中 $M_j$ 表示第 $j$ 个键词元被掩码的行索引区间,$N$ 是序列长度。
表示法的优势。虽然这种列式的一维区间表示法可能无法捕捉任意的掩码模式,但它有效地涵盖了实践中遇到的主要掩码类型。此外,这种表示法提供了一个显著的优势:它有助于轻松地转换为基于平铺计算中的掩码块。这种转换能够消除不必要的计算,从而提高注意力核的计算效率。
表示法的通用性。这种区间表示的概念可以推广到其他形式的连续区间。例如,通过转置注意力矩阵,我们可以得到使用列索引区间的行式表示。这种方法的灵活性允许高效处理各种注意力模式,同时保持一种紧凑的表示形式,有利于优化计算。
A2 方法细节
4.1 列式掩码表示法
FLASHMASK的列式稀疏表示法。为了在因果和双向注意力场景中高效处理复杂的掩码模式,我们为FLASHMASK提出了一种新颖的列式稀疏表示法。注意力分数矩阵被划分为相对于对角线的左下和右上三角部分。FLASHMASK使用四个一维向量来表示掩码:
- $LTS$: 左下三角起始 (Lower Triangular Start) - 左下三角中掩码的起始行。
- $LTE$: 左下三角结束 (Lower Triangular End) - 左下三角中掩码的结束行。
- $UTS$: 右上三角起始 (Upper Triangular Start) - 右上三角中掩码的起始行。
- $UTE$: 右上三角结束 (Upper Triangular End) - 右上三角中掩码的结束行。
掩码区间的定义与示例。在左下三角部分需要被掩码的行索引由 $[LTS_j, LTE_j)$ 给出,在右上三角部分由 $[UTS_j, UTE_j)$ 给出。具体来说,每一列由两个掩码区间描述。对于第 $j$ 个词元,区间 $[LTS_j, LTE_j) \cup [UTS_j, UTE_j)$ 内的词元不能关注它。例如,如图1(b)(6)所示,对于第五列,$[LTS_5, LTE_5) \cup [UTS_5, UTE_5) = [7, 10) \cup [2, 4)$ 表示第2到4行和第7到9行被掩码。
该表示法的优势。这种表示法具有几个优点:
1. 紧凑性:它将一个密集的二维掩码缩减为一个更高效的一维表示。
2. 灵活性:它可以捕捉广泛的实际掩码模式,包括因果、双向和更复杂的注意力机制。
3. 计算效率:它便于在基于平铺的计算中轻松转换为掩码块,从而能够消除不必要的计算。
4.2 扩展FLASHATTENTION以支持复杂掩码
集成与核函数实现。我们将FLASHMASK的列式掩码表示法集成到FlashAttention-2算法中,扩展了其掩码支持能力。FLASHMASK的高性能核函数实现包括两个关键步骤:
预处理步骤。给定输入的列式稀疏掩码向量,我们首先在高带宽内存(HBM)中沿列维度将它们划分为 $T_c$ 个块。对于每个掩码向量,我们计算每个块内的最大值和最小值。这会产生8个中间向量:$LTS_j^{min}, LTS_j^{max}, LTE_j^{min}, LTE_j^{max}, UTS_j^{min}, UTS_j^{max}, UTE_j^{min}, UTE_j^{max}$,每个向量的大小为 $T_c$。
实时块跳过计算。使用这些最小-最大值向量,我们可以在核函数计算期间将注意力分数矩阵的每个平铺块分为三类。块掩码类型 BlockMaskType 的确定方式如下:
- 完全掩码,如果 $i \times B_r \ge LTStart_{j}^{min}$ 且 $i \times B_r < LTEnd_{j}^{max}$ 同时 $i \times B_r \ge UTStart_{j}^{min}$ 且 $i \times B_r < UTEnd_{j}^{max}$
- 部分掩码,如果 $i \times B_r < LTEnd_{j}^{min}$ 或 $i \times B_r > LTStart_{j}^{max}$ 同时 $i \times B_r < UTEnd_{j}^{min}$ 或 $i \times B_r > UTStart_{j}^{max}$
- 无掩码,其他情况
这种分类使我们能够跳过完全掩码的块,减少对无掩码块的计算,并仅对部分掩码的块应用逐元素掩码。图1(c)展示了在左下三角区域具有 $LTS$ 和 $LTE$ 的因果场景的整个核函数计算过程。算法1详细描述了从FlashAttention2扩展而来的FLASHMASK前向计算过程,蓝色阴影部分表示FLASHMASK的计算。
反向传播的优势。对于反向传播,FLASHMASK的列稀疏表示法特别有利。$dK$ 和 $dV$ 的计算是列并行的,允许将最大值和最小值高效地加载到寄存器中,以便在块计算期间进行广泛的数据重用,如附录中的算法2所示。
4.3 复杂度分析
块稀疏度定义。我们将注意力掩码中的块稀疏度定义为 $\rho = \frac{M_{masked}}{N/B_r \times N/B_c}$,其中 $B_r, B_c$ 是块大小,$M_{masked}$ 是完全被掩码的块的数量。
空间复杂度。密集掩码需要 $O(N^2)$ 的空间,而FLASHMASK对 $LTS, LTE, UTS, UTE \in \mathbb{R}^N$ 使用 $O(N)$ 的空间,以及8个预计算的最小-最大值向量 $\in \mathbb{R}^{T_c}$。这种内存使用上的显著减少使得在更长的序列上进行训练成为可能。
内存访问复杂度。密集掩码需要在HBM上进行 $O(N^2)$ 的内存访问。FLASHMASK从HBM读取 $LTS, LTE, UTS, UTE \in \mathbb{R}^N$ 向量,如算法1的第16和19行所示,每个 $Q_{block}$ 读取整个 $LTS, LTE, UTS, UTE$,总共 $4 \times \frac{N}{B_r} \times N$ 次内存访问。这大大减少了内存访问,显著提升了性能。此外,FLASHMASK的紧凑表示允许将掩码向量预加载到SRAM中,进一步提高了内存访问效率。
计算复杂度。虽然标准注意力计算的复杂度为 $O(N^2)$,但FLASHMASK利用注意力掩码中的稀疏性,通过跳过完全被掩码的块,将其降低到 $O((1 - \rho)N^2)$。
性能提升总结。这些在空间、内存访问和计算复杂度方面的改进,共同促成了FLASHMASK在处理复杂注意力模式时的卓越性能和效率。
4.4 正确性分析
算法的数值等价性。如公式2所示,注意力矩阵 $O = \text{Softmax}(S + M)V$ 的计算涉及用掩码 $M$ 来增强注意力分数 $S$,其中被掩码的元素被设置为 $-\infty$。此操作确保了在这些被掩码位置的softmax输出为零,从而有效地将它们从注意力计算中排除。因此,如果整个块被完全掩码,该块的最终输出将全为零。FLASHMASK利用注意力掩码中的稀疏性,通过跳过涉及这些完全被掩码块的计算,从而减少计算开销而不改变结果。
精度保证。重要的是,FLASHMASK与在FlashAttention中使用密集掩码执行的计算保持比特级别的数值等价性,确保没有精度损失。这种精确性在我们的实验评估中得到了证实,我们验证了端到端训练的损失收敛曲线在比特级别上完全一致(见第5.2节)。
A4 实验
实验环境
-
硬件配置:
- GPU: NVIDIA A100-SXM 80G
- CPU: Intel(R) Xeon(R) Platinum 8350C
-
软件配置:
- CUDA: 12.0
- 驱动版本: 525.125.06
- 代码实现: 基于 PaddlePaddle【Ma et al. (2019), Paddlepaddle: An open-source deep learning platform from industrial practice, 2019, Frontiers of Data and Domputing】,并已集成到 PaddleNLP 中。
-
模型与数据集:
- 模型: Llama-2 (7B, 13B, 70B) 和 Llama 3.1 (8B)
- 数据集:
- 端到端吞吐量测试使用合成数据来模拟长序列训练。
- 收敛性验证使用了
allenai/tulu-v2-sft-mixture【Ivison et al. (2023), Camels in a changing climate: Enhancing lm adaptation with tulu 2, 2023】和HuggingFaceH4/ultrafeedback_binarized【Tunstall et al. (2023), Zephyr: Direct distillation of lm alignment, 2023】。
-
超参数: 详细信息见附录A。
实验结果
5.1 端到端训练吞吐量
- 实验内容: 在Llama-2模型的三个规模(7B, 13B, 70B)上,针对四种下游任务(SFT, LoRA, DPO, RM)和不同序列长度,评估了端到端训练吞吐量,并与两种密集掩码方法进行了比较。
-
实验结果:
- 更高吞吐量: 与具有二次内存复杂度的密集掩码方法相比,FLASHMASK实现了更高的吞吐量。具体来说,在FlashAttention密集掩码支持的最大序列长度上,FLASHMASK实现了1.65倍至3.22倍的提升。
- 线性内存开销: FLASHMASK的线性内存开销使其能够支持更长的序列长度。在Llama-2 7B LoRA训练中,FLASHMASK支持高达544K的序列长度,而其他方法仅限于64K。
-
分析结论: FLASHMASK能显著提高大语言模型的训练吞吐量,降低训练成本,并通过线性内存开销支持更长的上下文。
- 图表引用: 实验结果见图2。内存开销曲线见图4(b)。
图 2: 在SFT、LoRA、DPO和RM四个下游训练任务中,针对不同Llama2模型规模,评估了不同序列长度下的端到端训练吞吐量。
5.2 端到端训练收敛性验证
- 实验内容: 在Llama 3.1 8B模型上,针对四种微调和对齐任务,进行了端到端训练,以验证FLASHMASK在不牺牲收敛精度的情况下提升速度。实验在确定性和非确定性控制条件下进行。
-
实验结果:
- 在启用确定性控制时,FLASHMASK和FlashAttention密集掩码的损失曲线完全重合,表现出相同的数值行为。
- 在禁用确定性控制时,两种方法都表现出相同的损失收敛趋势。
-
分析结论: 实验结果确凿地证明FLASHMASK是一种精确算法,能够保持收敛精度。
- 图表引用: 损失曲线如图3所示。
图 3: 端到端训练损失。
5.3 稀疏性相关实验
- 实验内容: 为了验证FLASHMASK的计算复杂度与块稀疏度 $\rho$ 呈 $O((1-\rho)N^2)$ 的关系,在三种不同的掩码情况下(序列长度32K,头维度128,32个头),测试了不同稀疏度水平下的核函数执行延迟。
- 实验结果: 如图4(a)所示,核函数执行延迟与稀疏度之间存在线性关系。
- 分析结论: 实验结果证实了FLASHMASK的计算复杂度与注意力掩码的块稀疏度成正比,从而验证了其通过利用稀疏性来跳过计算的有效性。
- 图表引用: 实验结果见图4(a)。
图 4: (a) 不同稀疏度下的核函数执行延迟,(b) 内存使用情况,Y轴使用以10为底的对数刻度。
5.4 核函数性能比较
- 实验内容: 为了全面评估FLASHMASK在常见注意力掩码模式下的表达能力和计算效率,与FlexAttention进行了核函数级别的比较。实验涵盖了12种不同的掩码情况,序列长度为8K、32K和128K,头维度为64和128。
- 实验结果: 如图5所示,在所有情况下,FLASHMASK在前后向传播的总TFLOPs/s方面均优于FlexAttention,提升幅度从12.1%到60.7%不等。FLASHMASK在A100 GPU上达到了理论最大FLOPs/s的37.8%到62.3%。
- 分析结论: FLASHMASK在核函数层面具有卓越的计算效率,显著优于最新的同类方法FlexAttention。
- 图表引用: 性能比较见图5。
图 5: 在A100-SXM 80G GPU上的核函数前向和后向速度(头维度128,BF16)。FlexAttention使用PyTorch 2.6.0.dev20240920+cu124。
A7 补充细节
6 局限性与未来方向
局限性。尽管FLASHMASK在处理长序列的注意力机制效率方面取得了显著进展,但它也存在局限性。其列式掩码表示法将内存复杂度从 $O(N^2)$ 降低到 $O(N)$,为长序列训练节省了大量内存,并有效地捕捉了最常见的掩码模式。然而,它无法表示任意掩码,特别是那些在单列内具有不规则掩码区域的掩码。极端情况,如完全随机的掩码,对表示和高效计算都构成了挑战。
未来方向。未来的研究应专注于开发更复杂的稀疏表示法,以同时最大化表达能力和计算效率,特别是那些适合采用平铺技术以实现高性能核函数的表示法。将FLASHMASK扩展以利用更新架构(如NVIDIA的Hopper)的特性,可以进一步提升性能。此外,虽然我们当前的实现是基于PaddlePaddle【Ma et al. (2019), Paddlepaddle: An open-source deep learning platform from industrial practice, 2019, Frontiers of Data and Domputing】框架,但将FLASHMASK集成到其他流行的深度学习框架中可以扩大其影响力和可及性。这些努力旨在解决当前的局限性,同时将FLASHMASK的应用范围扩展到更广泛的任务中,为高效Transformer模型处理长序列的持续发展做出贡献。
A5 结论
本文介绍了FLASHMASK,这是FlashAttention算法的一项创新性扩展。它引入了一种列式稀疏掩码表示法,以高效处理Transformer模型中广泛的注意力掩码模式。我们的方法将内存复杂度从 $O(N^2)$ 降低到 $O(N)$,从而能够处理更长的序列,这对于现代大语言模型至关重要。通过将此表示法集成到FlashAttention算法中并实现优化的核函数,FLASHMASK利用注意力掩码中的稀疏性来跳过对完全掩码块的计算,而不牺牲计算精度。这种策略性方法实现了显著的计算加速,与现有的FlashAttention密集方法相比,在大语言模型的微调和对齐训练中,端到端性能提升了1.65倍至3.22倍。此外,FLASHMASK显著降低了与注意力掩码存储相关的内存开销,从而扩展了对更长序列建模的支持。另外,FLASHMASK在核函数TFLOPs/s方面比最新的同类产品FlexAttention快12.1%至60.7%,在A100 GPU上达到了理论最大FLOPs/s的37.8%至62.3%。我们的方法已在大语言模型的下游任务上得到验证,我们期待它在工业界得到广泛应用。
A6 附录
A.1 反向传播算法细节
FLASHMASK反向传播的实现细节。FLASHMASK反向传播的详细实现在算法2中给出。与前向传播类似,我们预先计算LTS、LTE、UTS和UTE的最大值和最小值。这些预计算的值 $LTS_j^{min}, LTS_j^{max}, LTE_j^{min}, LTE_j^{max}, UTS_j^{min}, UTS_j^{max}, UTE_j^{min}, UTE_j^{max}$ 可以直接加载到寄存器中并保持常驻,因为反向计算以列并行模式运行。此外,$LTS_j, LTE_j, UTS_j$ 和 $UTE_j$ 可以在内层循环(第10-11行)之外加载到SRAM中,从而将对HBM的访问次数减少到 $4 \times T_c$。在内层循环中,FLASHMASK的计算逻辑与前向传播的计算逻辑保持一致。
A.2 端到端训练吞吐量
长序列训练的必要性与数据构建。诸如Llama 3.1、GPT-4、Claude系列和谷歌的Gemini等近期模型支持超过128K词元的序列建模。FLASHMASK凭借其减少的内存开销,有助于使用更长的上下文进行训练。然而,现有的公开DPO和RM数据集缺乏超过128K词元的训练数据。为了全面评估FLASHMASK,我们构建了合成数据来模拟长序列训练,并验证端到端吞吐量的提升。我们在涉及大语言模型微调和对齐训练的四个下游任务中验证了我们的方法:SFT、LoRA、DPO和RM。
A.2.1 数据构建方法
合成数据的构建过程。对于端到端训练,为了真实地模拟真实数据集的分布,我们需要在文档的序列长度内区分源词元和目标词元。此外,数据构建方法与核函数实验中使用的方法不同。给定最大训练序列长度 $N$ 和文档数量范围 $D \in [1, 10]$,我们首先随机抽样文档数量,然后抽样每个文档的序列长度,使总序列长度等于 $N$。最后一个文档被视为填充。对于RM训练,应用了特殊约束:对于 $N \in (0, 4096]$,$D \in [1, 3]$;对于 $N \in (4096, 8192]$,$D \in [1, 4]$。在抽样期间,我们为SFT、LoRA和DPO设置的最小文档长度为128,为RM设置为512。SFT、LoRA、DPO的填充长度不超过128,RM不超过512。
源/目标词元的划分。假设一个文档的序列长度为 $L$,它被进一步划分为源词元和目标词元。SFT和LoRA表示为(问题,答案)对,DPO表示为(问题,答案1,答案2),RM有2到6个答案,我们将其标准化为6个答案:(问题,答案1,...,答案6)。因此,根据训练任务,$L$ 被划分为一个问题和 $k$ 个答案,其中SFT和LoRA的 $k$ 为1,DPO为2,RM为6。每个答案的长度是从范围 $[\frac{0.1 \times L}{k} + 0.1, \frac{0.2 \times L}{k} + 0.2]$ 中随机确定的,使得每个答案的长度约为问题长度的10%到20%。因此,问题的长度计算为 $L$ 减去所有答案的总长度。对于每个序列长度 $N$,我们收集了240个有效样本,并根据稀疏度 $\rho$ 将它们分为10个区间,如图6所示。
图 6: 用于端到端训练吞吐量测试的合成数据集的稀疏度分布。
表 1: Llama2不同规模模型的超参数和分布式配置。
A.2.2 实验配置和分布式策略
实验设置。我们评估了Llama2不同模型规模(7B、13B、70B)在4K到544K序列长度范围内的性能,并与两种密集方法进行了比较:Vanilla Attention和FlashAttention DenseMask。所有端到端吞吐量实验都在四台服务器上进行,每台服务器配备八个NVIDIA A800-SXM 80G GPU,共计32个GPU。实验目标不是为每个配置优化峰值性能,而是评估随着序列变长的可扩展性;因此,我们统一启用了完全重计算。模型参数和计算使用BFloat16数据类型,而梯度累积和通信使用Float32。不同规模的超参数和分布式策略详见表1。
A.2.3 端到端训练内存消耗
内存消耗分析。主论文中的图2报告了端到端训练吞吐量。我们还记录了峰值内存消耗,如图7所示。值得注意的是,FLASHMASK的内存使用增长速度明显慢于密集方法。然而,该图也表明,随着序列长度的增加,FLASHMASK的内存消耗仍然迅速攀升,这主要是由于更长序列导致的激活内存增加,如表2所示。“Param & Opt State”列表示参数、梯度和优化器状态的内存消耗,应用了分片阶段1。“Activations”指32个解码器层输入的内存消耗。“Peak Mem One Layer”表示启用完全重计算时的峰值内存使用量。“Total”表示不带注意力掩码的FlashAttention的总内存消耗。
图 7: 在SFT、LoRA、DPO和RM四个下游训练任务中,不同Llama2模型规模下,随序列长度变化的端到端训练峰值内存消耗。
表 2: 在Llama-2 7B模型上,不带注意力掩码的FlashAttention与FLASHMASK的内存消耗比较。观察到的总内存占用差异归因于内存碎片效应。
A.3 端到端训练收敛性验证
实验配置。我们选择了Llama 3.1 8B模型来验证在四个涉及大语言模型微调和对齐训练的下游任务中的收敛性:SFT、LoRA、DPO和RM。SFT和LoRA使用相同的数据集,使用allenai/tulu-v2-sft-mixture【Ivison et al. (2023), Camels in a changing climate: Enhancing lm adaptation with tulu 2, 2023】进行验证。对于都使用(问题,答案)数据格式的DPO和RM,我们使用HuggingFaceH4/ultrafeedback_binarized【Tunstall et al. (2023), Zephyr: Direct distillation of lm alignment, 2023】数据集进行验证。我们统一应用了线性学习率衰减策略,预热步数设为总训练步数的3%。AdamW优化器的参数为$\beta_1 = 0.9$和$\beta_2 = 0.999$。模型参数和计算使用BFloat16数据类型,梯度累积和通信使用Float32。最大训练序列长度设为8K。分布式并行结合了分片和张量并行。其他超参数列于表3。
A.4 稀疏性相关实验
测试方法。每次运行前,我们进行10次预热迭代,然后运行100次核函数计算,使用CUDA Events记录平均执行时间(毫秒)。我们的目标是验证FLASHMASK的计算复杂度与注意力掩码中的块稀疏度成线性关系。我们报告了核函数前向和后向传播的总延迟。此验证在序列长度为32K的情况下,针对三种常见的掩码类型进行:因果文档掩码、共享问题掩码和文档掩码,分别对应于大语言模型中的SFT、DPO/RM等下游训练任务,以及像NaViT这样的视觉模型的预训练。
表 3: 端到端训练损失收敛性验证的配置。
A.4.1 数据构建方法
不同掩码类型的数据生成。因果文档掩码和共享问题掩码是因果注意力类型,其块稀疏度值在[0.5, 1.0]范围内,而文档掩码是双向注意力类型,其块稀疏度值在[0.0, 1.0]范围内。我们将稀疏度值划分为多个桶:因果类型为10个桶,双向类型为20个桶,每个桶的间隔为0.05,确保每个桶的样本数量在10到20之间。
因果文档掩码与文档掩码数据生成。对于因果文档掩码,给定最大序列长度,我们将文档数量限制在[2, 20]。我们随机抽样文档数量,然后抽样每个文档的长度,使得总序列长度等于最大序列长度。每次抽样后,我们计算块稀疏度并将其分配到相应的桶中,直到每个桶都满足所需的样本数量。文档掩码的数据抽样过程类似;为确保覆盖所有稀疏度级别,文档数量限制在[2, 10]。
共享问题掩码数据生成。共享问题掩码的数据抽样略有不同。文档数量限制在[1, 5]。我们首先抽样每个文档的长度,使其总和等于给定的最大序列长度。然后将每个文档划分为一个问题和2到6个答案。同样,在每次抽样后,我们计算块稀疏度并将其分配到适当的桶中,直到所有桶都得到充分填充。
样本总量。总共,我们为因果文档掩码抽样了182个样本,为共享问题掩码抽样了175个样本,为文档掩码抽样了374个样本。
A.5 核函数性能比较
A.5.1 测试方法
性能评估指标。FLASHMASK和FlexAttention都利用注意力掩码中的稀疏性来跳过完全被掩码的块,从而减少冗余计算。为了提供直观的比较,我们采用TFLOPs/s指标进行评估。对于每个测试案例,我们评估前向和后向计算。每次运行前,我们进行10次预热迭代,然后运行100次核函数计算,使用CUDA Events记录平均执行时间(毫秒)。根据注意力掩码中的块稀疏度,我们计算单次运行的FLOPs,随后计算TFLOPs/s。
A.5.2 数据构建方法
测试配置。我们在不同的批处理大小和序列长度(8K、32K、128K)、不同的头维度(64、128)和头数量下进行了详细比较。我们将总词元数固定在128K;通过改变序列长度,我们计算出相应的批处理大小。隐藏层大小固定为4096,通过改变头维度可以确定头数量。
数据生成。为了在给定序列长度下涵盖更广泛的注意力掩码块稀疏度情况,我们使用构建的数据进行测试。给定一个测试序列长度,我们将文档计数范围定义为$D \in [Doc_{min}, Doc_{max}]$。我们首先抽样文档数量,然后抽样每个文档的长度,使得总长度等于测试序列长度。对于共享问题掩码类型,我们进一步将每个文档划分为一个问题和2到6个答案。8K序列长度的文档计数范围是[3, 7],32K是[10, 14],128K是[11, 15]。对于每个序列长度,我们生成了五个测试数据样本。
A.5.3 实验结果
性能对比。图8展示了当头维度为64时,FLASHMASK和FlexAttention在前向和后向传播总TFLOPs/s的比较。主论文中的图5展示了头维度为128时的类似结果。在所有情况下,FLASHMASK在前向和后向传播的总TFLOPs/s方面均优于FlexAttention,提升幅度从4.2%到53.6%不等。FLASHMASK在A100 GPU上达到了理论最大FLOPs/s的33.6%到55.1%。表4至9详细列出了每个测试掩码案例的稀疏度,以及前向和后向计算的延迟、TFLOPs和TFLOPs/s。
图 8: 在A100-SXM 80G GPU上的核函数前向和后向速度(头维度64,BF16)。FlexAttention使用PyTorch 2.6.0.dev20240920+cu124。
表 4: 核函数速度详情 (8K, 头维度 128, BF16) on A100-SXM 80G GPU.
表 5: 核函数速度详情 (32K, 头维度 128, BF16) on A100-SXM 80G GPU.
表 6: 核函数速度详情 (128K, 头维度 128, BF16) on A100-SXM 80G GPU.
表 7: 核函数速度详情 (8K, 头维度 64, BF16) on A100-SXM 80G GPU.
表 8: 核函数速度详情 (32K, 头维度 64, BF16) on A100-SXM 80G GPU.
表 9: 核函数速度详情 (128K, 头维度 64, BF16) on A100-SXM 80G GPU.
B FLASHMASK在推理中的应用
推理阶段的有效性。在我们论文的主体部分,我们专注于FLASHMASK在大型模型训练阶段的应用。然而,需要强调的是,FLASHMASK在推理阶段同样有效。在本附录中,我们提供了详细的实验结果,展示了FLASHMASK在推理中的功效,并与包括FlashInfer【Ye et al. (2025), Flashinfer: Efficient and customizable attention engine for llm inference serving, 2025, arXiv】在内的最先进的注意力实现进行了比较。
B.1 实验设置
实验环境。我们的实验在一台NVIDIA A100-SXM 80G GPU上进行,使用FlashInfer 0.1.6版本、CUDA 12.1、PyTorch 2.4和BF16数据类型。我们将批处理大小设为1,查询/输出头为32个,键/值头为8个,每个头的维度为128。评估包括典型的注意力掩码,如因果文档掩码、文档掩码和共享问题掩码。
数据适配。为了与FlashInfer的稀疏掩码表示(其中掩码块大小 $R=64$)兼容,我们调整了A.5.2节的数据集,使每个子文档序列长度都能被64整除。我们根据FlashInfer的块稀疏行(BSR)API参数 $R$ 和 $C$ 定义了掩码块大小,并将平铺块大小与核函数的操作维度相匹配。
B.2 与FlashInfer的比较
性能对比。我们在不同序列长度(8K、32K和128K词元)下,将FLASHMASK与FlashInfer的密集掩码API(single_prefill_with_kv_cache)和稀疏掩码API(BlockSparseAttentionWrapper)进行了比较。结果总结在表10至14中。
效率分析。FLASHMASK在TFLOPs/s方面始终优于FlashInfer的密集和稀疏实现,特别解决了FlashInfer密集掩码API观察到的低效问题。虽然使用稀疏掩码的FlashInfer随着掩码块大小的增加($R, C \ge 16$)显示出性能提升,但由于现实世界应用中注意力模式的性质,这样大的块大小很少实用。
性能瓶颈分析。在FlashInfer的single_prefill_with_kv_cache实现中(见prefill.cuh第1234–1241行),逐词元的密集掩码通过在完全被掩码的块上执行不必要的计算导致了显著的低效。此外,在FlashInfer的BlockSparseAttentionWrapper中,较小的掩码块大小会增加填充后的批处理大小(nblks(padded_batch_size, 1, num_kv_heads)),由于核函数超参数调整不佳而对性能产生负面影响。相比之下,FLASHMASK只高效地计算所需的平铺块,避免了冗余计算,因此实现了更高的TFLOPs/s。
表 10: 在8K、32K和128K序列长度下因果文档掩码的性能比较
表 11: 在8K、32K和128K序列长度下共享问题掩码的性能比较
B.3 掩码块大小的影响
掩码块大小对性能的影响。对于文档掩码,我们研究了改变掩码块大小($R/C$)对性能的影响。尽管FlashInfer DenseMask和FLASHMASK不使用特定的 $R/C$ 值,但为完整起见,我们将它们包含在比较中。序列长度为8K、32K和128K词元的结果分别呈现在表12、13和14中。
表 12: 在8K序列长度下,不同掩码块大小的文档掩码性能
表 13: 在32K序列长度下,不同掩码块大小的文档掩码性能
表 14: 在128K序列长度下,不同掩码块大小的文档掩码性能
💬 评论讨论
欢迎在这里分享您的想法和见解!