FLEX ATTENTION: A PROGRAMMING MODEL FOR GENERATING OPTIMIZED ATTENTION KERNELS
文章作者: Juechu Dong, Boyuan Feng, Driss Guessous, Yanbo Liang, Horace He
A1 主要贡献
本文旨在解决当前深度学习中注意力机制优化所面临的一个核心问题:高性能的注意力核(如FlashAttention)虽然效率极高,但其单一固化的设计(monolithic nature)导致其灵活性不足,难以支持研究人员提出的各种新型注意力变体。这造成了一种“软件彩票”(software lottery)现象,即如果一个新的注意力变体不被现有核所支持,其进一步的研究和应用将受到缓慢运行时间和内存限制的严重阻碍。此外,传统的编译器方法难以自动生成高效的融合注意力核,因为这涉及到复杂的代数重写和对块稀疏性的处理。
为应对这一挑战,本文提出了FlexAttention,一个新颖的、由编译器驱动的编程模型,其核心贡献和创新点如下:
-
灵活的编程模型:FlexAttention提供了一个通用的注意力编程模型,而不是一个为特定变体设计的受限接口。通过观察,研究者发现许多注意力变体可以被定义为在进行softmax之前对中间分数矩阵进行的分数修改。基于此,FlexAttention引入了两个核心构件:一个分数修改可调用函数(
score_mod
)和一个注意力掩码可调用函数(mask_mod
)。用户只需通过编写几行符合PyTorch习惯的代码,定义如何根据位置信息更新分数或计算布尔掩码,即可实现新的注意力变体。该模型不仅支持了如Alibi、文档掩码(Document Masking)、PagedAttention等多种现有变体,还通过嵌套score_mod
和对mask_mod
进行布尔运算,轻松实现了注意力变体的组合,解决了变体数量组合爆炸的问题。 -
基于手写模板的代码生成:FlexAttention利用
torch.compile
框架,在编译期间将用户定义的score_mod
和mask_mod
函数降级(lower)并生成代码,然后将这些代码块嵌入到一个手写的高度优化的注意力核模板中。这种方法结合了手写核的极致性能和前端编程的灵活性。torch.compile
的降级框架与PyTorch原生兼容,能够灵活处理用户定义的各种mod
函数,并支持通过torch.autograd
自动生成反向传播过程。 -
块稀疏性优化:为了利用掩码(masking)引入的稀疏性,FlexAttention引入了一种名为
BlockMask
的预计算数据结构。BlockMask
是一个小矩阵,用于追踪分块(tiled)后的分数矩阵中哪些块被完全掩码。借助torch.vmap
,系统可以从用户定义的mask_mod
自动生成BlockMask
。这使得计算可以跳过被完全掩码的分数矩阵块,从而节省计算量,而无需加载一个庞大的逐元素注意力掩码。对于部分掩码的块,系统仍然可以利用mask_mod
进行逐元素的掩码,以保持语义的正确性。此外,BlockMask
被实现为一个索引向量,这种设计还可以直接服务于PagedAttention中的地址映射需求。
A2 背景知识
2.1 注意力变体
注意力机制的核心作用。注意力机制是Transformer模型的核心,扮演着至关重要的角色。每个注意力层接收三个输入张量:查询Q ∈ RB×H×Q_LEN×D,以及键K和值V ∈ RB×H×KV_LEN×D,其中B是批量大小,H是注意力头的数量,D是特征维度,Q_LEN和KV_LEN分别代表查询和键值的长度。注意力机制首先计算一个分数矩阵S ∈ RB×H×Q_LEN×KV_LEN,该矩阵通过让每个查询词元(token)关注到每个键词元来编码上下文信息。
通过分数矩阵修改增强注意力。接着,注意力机制计算输出SV ∈ RB×H×Q_LEN×D,通过分数矩阵对值特征进行加权。基于这个基础算法,机器学习研究人员正在探索通过修改分数矩阵S来增强其功能,以实现更有效和高效的上下文提取。例如,邻域注意力(Neighborhood Attention)【Hassani & Shi, 2022, Dilated neighborhood attention transformer】和滑动窗口注意力(Sliding Window Attention)【Beltagy et al., 2020b, Longformer: The long-document transformer】使每个查询词元只关注其邻近的键词元,以降低处理大图像和长序列时的计算和内存复杂性。Softcapping【Team et al., 2024, Gemma 2: Improving open language models at a practical size】增加了一个tanh层来防止logits过度增长。Alibi【Press et al., 2022, Train short, test long: Attention with linear biases enables input length extrapolation】在分数矩阵中嵌入了一个逐元素的偏置,该偏置惩罚距离较远的词元,从而使在短输入上训练的模型能够在长提示上表现良好。这个巨大的设计空间激发了为注意力变体的训练和推理设计高效编译器的需求。
2.2 最先进的注意力实现
FlashAttention的成功与局限。为了高效支持注意力机制,许多核函数被手动实现。其中,具备IO感知能力的FlashAttention【Dao, 2024, Flashattention-2: Faster attention with better parallelism and work partitioning】已成为最广泛采用的解决方案之一。FlashAttention通过避免物化庞大的分数矩阵S并即时计算它,显著减少了内存访问,实现了大幅度的加速。最近发布的FlashAttention v3利用先进的硬件特性和手动性能调优进一步加速了注意力计算。尽管FlashAttention提供了最先进的性能,但它专门为有限的几种注意力变体量身定制(如表1所示)。最近的一项工作FlashMask【Wang et al., 2024, Flashmask: Efficient and rich mask extension of flashattention】通过列式稀疏表示扩展了FlashAttention,以支持更多的掩码设计。然而,它在分数修改方面仍然缺乏灵活性,并且对于复杂的掩码会增加较大的开销。
手动适配的困境。鉴于注意力变体表现出多样化的计算特性,如稀疏性和局部性,需要大量的手动工作来调整现有的注意力实现以支持众多的变体。这种灵活性和高效核函数的缺乏,已经成为机器学习研究人员探索新型注意力变体的一大障碍。
2.3 机器学习编译器
现有编译器的挑战。许多编译器被构建用于加速机器学习工作负载。torch.compile
【Ansel et al., 2024, Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation】使用TorchDynamo从任意代码中捕获计算图,并用TorchInductor优化该图。TVM【Chen et al., 2018, Tvm: an automated end-to-end optimizing compiler for deep learning】允许用户描述核的计算并生成优化代码。Mirage【Wu et al., 2024, A multi-level superoptimizer for tensor programs】利用µgraph探索算子级别的优化机会。然而,这些机器学习编译器在注意力变体上未能提供令人满意的性能,原因在于注意力中特殊的计算模式。首先,注意力需要融合两个矩阵-矩阵乘法(即QKT和SV),而现有的机器学习编译器通常专注于优化单个矩阵-矩阵乘法。其次,如FlashAttention【Dao, 2024, Flashattention-2: Faster attention with better parallelism and work partitioning】所示,在线softmax(online softmax)显著减少了内存访问并提高了性能。由于在线softmax是为注意力设计量身定制的,现有通用的机器学习编译器对其支持不佳。据我们所知,我们是第一个在为任意注意力变体提供编译支持的同时,还能达到最先进性能的工作。
A2 方法细节
3 前端设计与实现
本节中,我们提出了一个能够统一各种注意力变体的抽象。这种抽象使程序员能够轻松地表达注意力语义,而无需担心实现细节和核函数性能。
3.1 统一抽象
两种通用模式的归纳。尽管机器学习研究人员设计了众多的注意力变体,我们将这些变体统一为两种普遍共享的模式。第一种模式是通过掩蔽(mask out)来自特定词元的特征来管理上下文关系。例如,因果掩码(causal mask)强制模型根据先前的词元进行预测,而忽略未来的词元。滑动窗口掩码(sliding window mask)【Beltagy et al., 2020b, Longformer: The long-document transformer】则强制模型专注于附近词元的局部上下文。
精细化的分数调整。第二种模式以更精细的方式进一步调整词元间的注意力分数。例如,Alibi【Press et al., 2022, Train short, test long: Attention with linear biases enables input length extrapolation】根据两个词元的相对位置调整注意力分数,这有助于模型更好地理解上下文并处理长序列。它还根据注意力头的维度为分数添加偏置,使不同的头专注于不同类型的信息。
score_mod
与mask_mod
的提出。基于这些洞察,我们设计了一个统一的抽象,包括一个分数修改可调用函数(score_mod
)和一个注意力掩码可调用函数(mask_mod
)。它使用户能够通过定义如何根据位置信息更新分数或计算布尔掩码,在符合PyTorch习惯的代码中构建新的注意力变体。
def mask_mod(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool
def score_mod(score: T, batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> T
具体来说,考虑一个分数矩阵S ∈ RB×H×M×N。mask_mod
接收输入位置,并指定相应的分数标量是否被设置为-inf。score_mod
额外接收任意类型T(例如,bfloat16或float32)的分数标量,并根据其位置对分数进行更新。
映射到具体示例。我们的统一抽象捕捉了注意力变体中的两个重要模式,并实现了自动化的编译器优化。图1展示了4个注意力变体的示例,每个都已被广泛研究和优化。然而,目前缺乏一个适用于所有这些变体的统一优化方法。相反,我们的抽象在指定掩码和分数矩阵修改方面提供了完全的灵活性,使机器学习研究人员能够轻松探索新颖的注意力变体。
图1. 注意力变体及mask_mod
示例,包括因果掩码、滑动窗口掩码、文档掩码和PrefixLM掩码。
示例详解。正如我们前面讨论的,因果掩码强制一个词元的查询特征只关注来自先前词元的键和值特征。这可以简单地表示为q_idx ≥ kv_idx
。由于因果掩码与批次和头的维度无关,mask_mod
不涉及b
和h
输入。滑动窗口掩码要求一个词元只关注某个滑动窗口大小内的邻近词元。我们可以轻松地将此要求描述为q_idx − kv_idx ≤ window
。文档掩码用于同时训练多个文档,并要求一个词元只关注同一文档中的其他词元。我们可以将“文档ID”记录为一个从词元索引到文档索引的映射,并且仅当document_id[q_idx] == document_id[kv_idx]
时保留分数。Alibi偏置通过添加一个偏置标量来修改每个分数标量,该偏置标量随查询索引和键值索引之间的距离而缩放。我们可以轻松地将其描述为score + alibi_bias[h] * (q_idx − kv_idx)
。请注意,我们可以轻松支持其他分数修改,如缩放分数标量甚至非线性变换。
为何需要区分mask_mod
。mask_mod
是score_mod
的一个特例,任何mask_mod
在语义上都可以轻松转换为score_mod
。然而,我们观察到区分这两个API有两个根本原因。首先,score_mod
对每个分数标量都应用了昂贵的修改。当将mask_mod
转换为score_mod
时,我们需要将每个分数与1或0相乘,这会带来额外的开销。其次,与score_mod
相比,mask_mod
提供了额外的语义信息,即某些分数计算可以被跳过。我们利用这一信息来进一步加速注意力计算,具体细节在4.2小节中阐述。
3.2 通过逻辑融合实现可组合性
组合现有变体的趋势。随着许多注意力变体的设计问世,最近的一个趋势是组合现有的注意力变体以进一步提取语义信息。我们支持通过and_mask
和or_mask
进行逻辑融合,以实现掩码设计的可组合性。通过接收两个mask_mod
函数,我们可以通过应用逐元素的逻辑运算来自动组合它们。生成的mask_mod
可以进一步与其他mask_mod
结合,这显著减轻了设计注意力变体时的编程负担。
PrefixLM示例。图1展示了一个逻辑融合的例子。PrefixLM【Raffel et al., 2023, Exploring the limits of transfer learning with a unified text-totext transformer】对前缀输入执行完全的双向注意力,对其余部分执行因果注意力。我们无需使用复杂的条件分支,而是可以构建一个简单的前缀掩码,并通过or_mask
将其与因果掩码组合起来。
4 后端设计与实现
在本节中,我们提出了一个后端设计,该设计能将注意力变体编译成高效的核函数。
4.1 基于模板的降级(Lowering)
结合灵活性与性能。我们构建了一个基于模板的降级系统,用于生成优化的GPU核函数,该系统将灵活性与最先进的注意力优化技术相结合。我们方法的关键创新在于认识到注意力变体主要区别在于其逐点的分数修改,这使我们能够将通用的计算模式模板化。为此,我们构建了基于模板的降级来捕获通用模式,并在分数矩阵上逐元素地应用score_mod
。我们将在4.2小节中介绍利用mask_mod
稀疏性的BlockSparsity设计。
Triton模板与优化技术。基于模板的降级首先利用TorchDynamo【Ansel et al., 2024, Pytorch 2: Faster machine learning through dynamic python bytecode transformation and graph compilation】来捕获score_mod
和mask_mod
的计算图。如图1所示,这两个函数在计算和内存访问方面通常是轻量级的,并且可以与注意力中的其他计算融合。然后,我们构建了一个高度优化的Triton模板以实现高性能。该Triton核函数设计采用了已有的融合注意力核优化技术,包括在线softmax(online softmax)、精细的GPU占用管理、通过分区和广播实现的高效内存处理,以及对分组查询注意力(GQA)【Ainslie et al., 2023, Gqa: Training generalized multi-query transformer models from multi-head checkpoints】的专门支持。
图2. 从用户定义的注意力模式到优化的Triton核函数的编译过程。上图:用户定义的PyTorch函数指定了自定义的掩码(因果掩码)和分数修改(相对位置)操作。下图:torch.compile
捕获这些操作并将其降级为Triton原语。右图:生成的Triton代码块被集成到我们的注意力模板中,该模板处理核心的注意力计算,同时保留用户定义的修改。
动态代码注入。我们在三个手写的注意力核模板(前向、后向和解码)中捕获了这些先进技术。这些模板被设计为可以接受自定义的分数修改代码块,这些代码块是使用torch.compile
从捕获的score_mod
和mask_mod
操作生成的。TorchInductor将这些子图翻译成Triton代码,在运行时将前向和后向的分数修改操作动态注入到预定义的模板中。这种方法产生的注意力核函数在性能上与手动优化的、针对特定变体的实现非常接近。图2展示了我们的降级流水线,说明了用户定义的PyTorch操作如何被翻译成优化的Triton代码并集成到我们的注意力模板中。
4.2 块稀疏性
利用稀疏性加速的挑战。注意力变体通常表现出高度的稀疏性,例如因果掩码带来的50%稀疏性,以及滑动窗口掩码带来的显著更高的稀疏性。一个自然的问题是,我们的编译器能否利用这种稀疏性进行加速?
朴素方法的缺陷。一种朴素的方法是在运行时检查一个位置是否被掩蔽并跳过计算。然而,这会带来巨大的运行时开销,因为它需要遍历所有标量,即使它们被掩蔽了。另一种方法是预先计算一个形状为B × H × Q_LEN × KV_LEN的掩码张量,并根据该张量计算分数标量。然而,这个掩码张量会带来显著的内存开销,这与FlashAttention避免实现分数矩阵的原则相矛盾。
图3. 滑动窗口注意力的BlockMask。左图:分数矩阵被划分为完全可见块(绿色)、部分掩码块(黄色)和忽略块(白色)。右图:BlockMask被编码为一个索引矩阵加上每行的块计数。上图:score_mod
应用于完整和部分块。下图:mask_mod
应用于部分块。
BlockMask
数据结构。FlexAttention实现了一个BlockMask
数据结构来利用稀疏性,同时只增加可忽略的内存开销。我们在编译时预先计算一个块掩码,以消除运行时开销。BlockMask
首先将分数矩阵沿着Q_LEN和KV_LEN维度分割成块。然后,如果一个块中的所有分数标量都被掩码为-inf,BlockMask
就将其指定为非计算块。通过在块级别记录稀疏性,我们无需实现一个大的稀疏矩阵,从而显著减少了内存开销。
BlockMask
的具体实现。具体来说,BlockMask
包含两个张量,一个形状为B × H × Num_Row的kv_num_block
和一个形状为B × H × Num_Row × Num_Col的kv_indices
。这里,Num_Row和Num_Col分别是沿查询和键值维度的行和列的块数。kv_num_block
存储每行的非零块数量,而kv_indices
存储这些非零块的索引。通过访问kv_num_block
,FlexAttention可以跳过被掩码的块,并实现成比例的性能加速。
全块优化。我们还将BlockMask
分为完整块(full blocks)和部分块(partial blocks),以进一步提高性能。我们的关键思想是最小化因应用mask_mod
而产生的运行时开销,并在可能的情况下跳过它。我们根据稀疏性识别出两种类型的块:
1. 部分块:其中一些分数标量被掩码为-inf,需要在运行时逐元素应用mask_mod
。
2. 完整块:其中没有分数标量被掩码,这使我们能够跳过mask_mod
,只应用score_mod
。
对于像因果掩码这样的常见模式,这种优化带来了大约15%的性能提升。图3展示了滑动窗口注意力的BlockMask
。滑动窗口掩码仅应用于部分块,而相对位置分数修改则逐元素地应用于完整块和部分块。
BlockMask
引导的间接内存访问。在注意力中利用稀疏性可以带来显著的性能加速。现有的解决方案,如FlashAttention,会遍历KV_LEN维度,并根据注意力变体手动指定迭代的开始和结束索引。这带来了巨大的编程负担。此外,如果查询词元索引小于键词元索引,还需要额外的努力来确保最后一个块中的注意力分数被正确掩码。
自动化稀疏性利用。FlexAttention利用BlockMask
信息来自动利用注意力变体中的稀疏性。它首先根据kv_num_block
(指定未被掩码的分数矩阵块的数量)来调整每个GPU块中的工作负载。然后,它使用KV_Indices
来映射到下一个要处理的块。在这里,我们利用了一种间接内存访问策略,因为索引不必指向序列中的连续词元。这为编译各种注意力模式(如滑动窗口注意力、局部-全局注意力或自定义稀疏模式)提供了灵活性,而无需修改核函数。
图4. 将完整块和部分块调度到SM。
数据预取流水线。FlexAttention沿着Q_LEN维度将分数矩阵划分为瓦片(tiles),以便在多个流处理器(SM)上并行计算。每个SM沿着KV_LEN维度迭代并处理一行块(图4)。需要特别注意对这些块的数据获取和计算进行流水线化。当第一个分数块正在由第一个kv瓦片计算时,第二个kv瓦片(图中虚线边框所示)正在从HBM预取到SRAM。我们的BlockMask
访问策略消除了检查分数标量是否被掩码的条件分支,从而实现了高效的流水线处理并隐藏了迭代之间的数据访问延迟。
开销分析。BlockMask
表示通过在粒度和开销之间取得谨慎平衡,实现了内存效率。尽管我们存储了辅助张量来编码稀疏模式,但它们的内存占用与O(⌈Q_LEN/BS⌉ × ⌈KV_LEN/BS⌉)成比例,其中BS是块大小(默认为128)。这远小于完整分数矩阵所需的O(M × N)。
5 案例研究
5.1 分页注意力(Paged Attention)变体
PagedAttention的背景与挑战。PagedAttention【Kwon et al., 2023a, Efficient memory management for large language model serving with pagedattention】已被广泛部署于处理一批句子请求的推理场景,以减少GPU内存消耗。然而,当前的PagedAttention设计需要手动重写注意力核函数,以支持其不规则内存访问的页表查找。此外,它与特定的注意力掩码设计紧密耦合,限制了与其他注意力变体的可组合性。在本小节中,我们将解释FlexAttention如何为任意注意力变体启用分页支持,而无需手动重写核函数,同时仍能提供高性能和低内存使用。
页表机制。PagedAttention引入了一个页表来有效管理KV缓存并减少其内存使用。如图5(a)所示,一个逻辑KV缓存存储每个序列的键和值特征,其形状通常为B × Max_len × D,其中B是批量大小,Max_len是最大上下文长度,D是特征维度。由于句子之间的序列长度差异很大,并且可能远短于Max_len,逻辑KV缓存的很大一部分仍未被使用。为了解决这个碎片化问题,PagedAttention转而分配一个形状为1 × Max_token × D的物理KV缓存,其中Max_token是所有句子组合起来的最大词元数。将句子存储在物理KV缓存中减少了碎片,并实现了诸如句子间KV块共享之类的内存优化。
页表映射。PagedAttention维护一个作为二维矩阵的页表,该页表将批次索引和逻辑KV索引映射到相应的物理KV索引。在注意力核函数内部,使用一次分散(scatter)操作通过该矩阵有效地转换索引,同时最小化内存访问开销。在运行时,一旦一个句子到达其结束词元,我们就会清除页表中相应的行,并用新的请求更新它。
图5. FlexAttention与分页注意力的图示。a) 使用页表更新物理KV缓存;b) 转换一个块掩码,特别是KV索引;c) 用于转换mask_mod
和score_mod
的物理到逻辑映射。
融合的间接内存访问。让注意力核函数在物理KV缓存而不是逻辑KV缓存上计算,带来了几个挑战。首先,页表在内存访问中增加了一个额外的间接层,现有设计通常通过手动重写的CUDA核函数来处理。其次,对于已经很复杂的支持多种注意力变体的实现来说,页表施加了巨大的编程负担。
通过BlockMask
转换解决。FlexAttention通过BlockMask
转换解决了这个问题,无需重写核函数。我们的关键思想是,由于BlockMask
已经包含了一层间接内存访问以避免不必要的计算(4.2小节),我们可以将其与来自页表的间接内存访问合并,以进一步跳过不必要的内存访问。具体来说,如图5(b)所示,给定一个注意力变体的BlockMask
,我们取其kv_index
并根据页表将逻辑块索引映射到相应的物理块索引。在运行时,FlexAttention依赖于转换后的kv_index
来从物理KV缓存中访问特定句子的词元。请注意,我们保持kv_num_blocks
不变,因为分页注意力不改变未掩码块的数量。
mask_mod
和score_mod
的转换。FlexAttention支持带有mask_mod
和score_mod
的注意力变体,这些函数利用位置信息(例如,逻辑KV索引)来指定掩码和分数修改。当使用分页注意力并在物理KV缓存上操作时,此位置信息会发生变化。一种朴素的方法是手动重写mask_mod
和score_mod
以适应位置信息的变化。然而,考虑到众多的注意力变体甚至这些变体的逻辑融合(3.2小节),这会增加巨大的编程负担。
自动化编译。我们自动编译mask_mod
和score_mod
以支持分页注意力。
图6. mask_mod
转换
def converted_score_mod(score: T, batch_idx : int, head_idx: int, q_idx: int, physical_kv_idx: int) -> T
具体来说,我们在更新页表时维护一个将物理块索引映射到逻辑块索引的向量,其开销为O(1)。给定物理KV词元索引,我们首先可以计算出物理KV块索引和由于固定块大小产生的偏移量。然后,我们可以查找相应的逻辑KV块索引并重新生成逻辑KV词元索引。最后,我们用生成的逻辑KV词元索引调用用户提供的mask_mod
和score_mod
。
5.2 推理的修改转换
推理时的上下文需求。在注意力变体中,掩码和分数修改可能会随着查询索引的变化而变化,这是通过在mask_mod
和score_mod
中接收q_idx
来实现的。然而,LLM推理通常一次迭代地处理一个查询词元,并需要额外的上下文信息,即已经处理的查询词元数量,我们称之为偏移量(offset)。图6(a)显示了训练和推理期间因果掩码的差异。
自动化转换。为了解决这个问题,我们自动将mask_mod
和score_mod
转换为它们的推理对应版本,如图6(b)所示。具体来说,我们提供一个装饰器(decorator),它接收用户定义的mask_mod
和一个offset
,并生成一个转换后的mask_mod
,该函数会消费偏移量信息。
A3 实验环境
我们评估了FlexAttention在7种流行的注意力修改上的性能(见表1),包括经典的noop(无操作)和causal(因果),流行的位置嵌入alibi【Press et al., 2022, Train short, test long: Attention with linear biases enables input length extrapolation】,局部注意力sliding window(滑动窗口)【Beltagy et al., 2020a, Longformer: The long-document transformer】,完全可见前缀的prefixLM,添加tanh层的soft cap,以及用于批处理不同长度输入的document mask(文档掩码)。这些注意力变体在多头注意力(MHA)和分组查询注意力(GQA)下进行了评估。
这些注意力变体在我们评估的五个基线中得到了不同程度的支持:
* FlashAttention-v2 (FAv2)【Dao, 2024, Flashattention-2: Faster attention with better parallelism and work partitioning】:一个最先进的高性能注意力核,通过核融合、分块和重计算来减少HBM读写。
* FlashAttention-v3 (FAv3)【Shah et al., 2024, Flashattention-3: Fast and accurate attention with asynchrony and low-precision】:一个实验性的核,利用Hopper GPU的新硬件特性进一步提升性能。
* FlashDecoding (FAKV)【Dao et al., 2023, Flashdecoding for long-context inference】:FlashAttention提供的一个为支持kv-cache的推理场景优化的核。
* Scale dot-product attention (SDPA):PyTorch原生的函数式注意力API,支持数学实现、内存高效注意力(mem efficient)【Rabe & Staats, 2022, Self-attention does not need o(n2) memory】、FAv2和cuDNN后端。SDPA能平滑地集成到PyTorch框架中,支持NestedTensor、torch.compile
和CUDA图等优化。
我们使用gpt-fast
【gpt-fast maintainers & contributors, 2023, Accelerating generative ai with pytorch ii: Gpt, fast】和torchtune
【torchtune maintainers & contributors, 2024, torchtune: Pytorch’s finetuning library】在LLaMa3和LLaMa3.1模型【Dubey et al., 2024, The llama 3 herd of models】上评估了端到端的训练和推理性能。gpt-fast
和torchtune
是原生的PyTorch库,为流行模型提供了易于使用的微调和推理方案。它们默认依赖torch.compile
来启用优化(包括CUDA图、核融合和内联、matmul模板、参数冻结等),并使用SDPA进行注意力计算。
我们的实验在Nvidia H100 GPU(功率上限650W,内存带宽限制2.4TB/s)、Nvidia A100 GPU(功率上限330W)和Nvidia A6000 GPU上进行。
A4 实验结果
6.2 注意力核性能
我们通过在不同的序列长度、注意力变体和头数下进行基准测试来表征FlexAttention核的性能,其中kv大小固定为256 MiB,头维度为64,数据类型为bfloat16。
训练性能。如图7所示,在因果掩码的假设下,与FAv2相比,FlexAttention在前向传播中实现了1.00倍至1.22倍的稳定加速,在后向传播中实现了0.86倍至1.05倍的加速。对于我们评估的7种注意力变体,当FAv2支持该变体时,FlexAttention相对于FAv2实现了0.68倍至1.43倍的加速。对于FAv2缺乏原生支持的变体,FlexAttention通过在运行时从mask_mod
计算掩码,避免了实现和加载逐项掩码的需要,与使用逐项注意力掩码的SDPA核相比,实现了5.49倍至8.00倍的加速。
推理性能。如图8所示,对于查询长度为1的情况,FlexAttention的解码性能与FlashDecoding(FAKV)相当,实现了0.93倍至1.45倍的加速。一个例外是:当使用带alibi的GQA时,FlexAttention比FAKV快5.37倍。这是“软件彩票”的一个例子,FAKV对带alibi的GQA缺乏手动优化,其回退方案只提供了最优性能的1/5。相比之下,FlexAttention无需手动调优即可为这种组合保持一致的性能。由于页面限制,我们将邻域注意力的结果放在附录A.1中。
数值精度。与我们的基线相比,FlexAttention没有引入额外的数值误差(图9)。
Flex Attention
表1. 测试的注意力变体及其在FlashAttention、SDPA和FlexAttention中的支持水平
图7. 注意力核速度:前向和后向。上图:在1k到64k的QKV长度上使用/不使用GQA的因果掩码。下图:在16k token长度的QKV上使用/不使用GQA的不同注意力变体。
图8. 注意力核速度:对1个查询词元进行解码。左图:在1k到132k的KV长度上使用/不使用GQA的经典注意力。右图:在16k token长度的KV缓存上使用/不使用GQA的不同注意力变体。
图9. bfloat16和fp16注意力输出与fp64黄金结果的均方根误差(RMSE)。
6.3 端到端性能
FlexAttention将端到端训练性能提升超过2.4倍,推理性能最高提升2.04倍,并且这种加速随着序列长度的增加而表现出良好的扩展性。我们在gpt-fast
和torchtune
库中用FlexAttention替换了SDPA,并在LLaMa3和LLaMa3.1模型上评估了性能。我们表明,FlexAttention能与gpt-fast
和torchtune
中的PyTorch框架很好地集成,实现了与SDPA相同的优化,如CUDA图、参数冻结和核融合。
torchtune的训练性能。我们设置torchtune
在Alpaca【Taori et al., 2023, Stanford alpaca: An instruction-following llama model】数据集上微调LLaMa3-8B。为了高效处理不同长度的输入序列,torchtune
将它们连接成固定长度的锯齿状长序列。这种方法需要一个文档掩码,允许每个输入序列关注自身,同时忽略其邻居。SDPA利用一个预先计算的、大小为B × N × N的布尔掩码,其中B是批大小,N是序列长度。如图10所示,访问这个布尔掩码的成本呈二次方增长,导致当序列长度从2k增加到8k时,训练吞吐量下降了25%。相比之下,FlexAttention使用一个BlockMask
和一个大小为B × N的文档ID张量,并能有效地随序列长度扩展。
图10. 在llama3-8B上的torchtune训练吞吐量。
gpt-fast的推理性能。在图11中,我们展示了FlexAttention将LLaMa3.1-8B的服务性能提升了1.22倍至2.04倍,将LLaMa3.1-70B的性能提升了0.99倍至1.66倍,均与SDPA相比。随着上下文长度的增长,注意力核在每次迭代中的计算占比越来越大,加速效果也随之增加。
图11. 在LLaMa3.1-8B和70B上的gpt-fast推理速度。
6.4 案例研究:分页注意力(Paged Attention)
运行时开销评估。虽然分页注意力将不同长度的句子请求存储在紧凑的物理KV缓存中,一个主要问题是它是否会引入高昂的运行时开销。图12(a)显示了使用和不使用分页注意力的FlexAttention以及FlashAttn-v2的运行时延迟。我们在改变序列长度的情况下展示了运行时延迟,同时保持其他维度不变,批量大小为32,头维度为64,头数为16。总体而言,我们观察到在使用带分页注意力的FlexAttention时,平均运行时开销不到1%,这远低于vLLM【Kwon et al., 2023b, Efficient memory management for large language model serving with pagedattention】报告的20-26%的更高注意力核开销。原因在于我们没有引入任何核函数更改,而是依赖于融合的间接内存访问来支持带分页注意力的FlexAttention。令人惊讶的是,我们甚至观察到在长序列上,带分页注意力的FlexAttention比不带分页注意力的FlashAttn-v2更快,这证明了我们设计的可扩展性。
图12. 使用和不使用分页注意力的运行时(越低越好)。上图:不同序列长度下的延迟。下图:不同页面大小下的延迟。
页面大小的影响。图12(b)展示了页面大小从16到256变化带来的影响。总的来说,我们没有观察到改变页面大小会对性能产生显著影响。请注意,我们在GPU全局内存中管理物理KV缓存,并且不将内存交换到主机磁盘,这减轻了磁盘访问开销。
A5 结论
在本文中,我们提出了FlexAttention,一个用于生成优化注意力核的编程模型。尽管研究人员不断设计新的注意力变体,但他们常常受到缺乏手动调优核的限制,这导致了巨大的编程和性能负担。我们希望FlexAttention能让研究人员探索新的注意力变体,而不再受限于手写核所支持的功能。
A6 附录
A.1 邻域注意力(Neighborhood Attention, NA)
邻域注意力的实现挑战。邻域注意力(NA)【Hassani & Shi, 2022, Dilated neighborhood attention transformer】是一种用于2D图像的局部注意力模式,其中每个像素关注其最近的邻近像素。由于需要将2D邻域嵌入展开为1D,NA掩码非常复杂(图13左上),这使得NA的计算具有挑战性。虽然已经提出了各种先进的展开策略【Hassani et al., 2023, Neighborhood attention transformer; Hassani et al., 2024, Faster neighborhood attention: Reducing the o(n2) cost of self attention at the threadblock level】以提高其块稀疏性,但在高性能注意力核中手动实现这些优化是困难的。
FlexAttention的简化实现。相反,我们展示了通过FlexAttention,平铺NA(Tiled NA)和莫顿曲线NA(Morton curve NA)可以用少于10行的PyTorch代码实现,从而利用NA的稀疏性(图13)并享受其性能优势(图14)。
图13. NA逐项掩码和对应的块掩码。上:朴素的NATTEN掩码。中:2D平铺NA掩码。下:莫顿曲线NA掩码。
图14. 不同映射方式下邻域注意力的性能,随画布大小和核大小变化。上:掩码稀疏度。下:速度。
💬 评论讨论
欢迎在这里分享您的想法和见解!