Flash-Decoding for long-context inference

文章标题: Flash-Decoding for long-context inference
作者/机构: Tri Dao, Daniel Haziza, Francisco Massa, Grigory Sizov

A1 主要贡献

核心问题:
大型语言模型(LLM)的推理(或称“解码”)成本高昂。解码过程是迭代的,每生成一个新词元(token)都需要模型进行一次前向传播。虽然可以通过缓存(KV-cache)来优化,但注意力(attention)操作的计算开销仍会随着上下文长度的增加而显著增长,成为性能瓶瓶颈。这一问题在长上下文场景(如处理长文档、长对话、代码库)和高并发(大批量大小)场景中尤为突出。现有的优化方法存在局限性:
1. 标准矩阵乘法实现: 能够充分利用GPU,但会启动大量内核(kernel),导致中间结果频繁读写全局内存,效率不高。
2. FlashAttention: 为训练场景优化,主要沿批次大小(batch size)和查询长度(query length)维度并行。在推理场景中,查询长度通常为1,若批次大小小于GPU的流式多处理器(SM)数量(如A100为108),会导致GPU利用率极低(例如,批次大小为1时,GPU利用率不足1%),尤其是在长上下文场景下,为了节省内存通常需要设置更小的批次。

研究目标:
本文旨在解决LLM在长上下文推理场景下注意力计算的性能瓶颈,提出一种新的注意力计算技术,以显著提升生成速度,降低推理成本,从而更好地支持需要长上下文的LLM应用。

创新点 (Flash-Decoding):
本文提出了 Flash-Decoding 技术,这是一种专为解码(推理)场景设计的、更快的注意力计算方法。其核心思想是在FlashAttention的基础上,增加了一个新的并行维度——键/值(keys/values)的序列长度维度。
具体而言,Flash-Decoding通过以下步骤实现加速:
1. 分块: 将键(K)和值(V)张量沿序列长度维度分割成多个小块(chunk)。
2. 并行计算: 并行地、独立地计算查询(Q)与每个K/V块的注意力。这一步利用了FlashAttention的高效计算。除了注意力输出外,每个块还会额外计算并写入一个标量:注意力值的对数指数和(log-sum-exp)。
3. 合并与重缩放: 最后,通过一个独立的归约(reduction)步骤,将所有块的输出合并。利用上一步保存的log-sum-exp值对每个块的贡献进行正确缩放,从而得到最终的、精确的注意力输出。

该方法结合了现有方法的优点:像FlashAttention一样,它最大限度地减少了对全局内存的额外读写;同时,它能像标准矩阵乘法一样充分利用GPU,即使在批次大小很小的情况下(只要上下文足够长)。实验表明,该技术可为极长序列的解码带来高达8倍的速度提升。

A3 背景知识/关键Observation/设计原则

多头注意力解码

解码过程与注意力瓶颈。在解码过程中,每生成一个新词元都需要对所有先前的词元进行注意力计算,其核心操作为 softmax(queries @ keys.transpose) @ values。尽管FlashAttention(v1和最近的v2)已在训练场景下对此操作进行了优化,解决了中间结果(如 Q @ K^T)读写带来的内存带宽瓶颈,但这些优化无法直接适用于推理场景,因为两者的瓶颈不同。

FlashAttention在推理中的局限性。对于训练,FlashAttention主要沿批次大小和查询长度维度进行并行化。然而,在推理过程中,查询长度通常为1。这意味着如果批次大小小于GPU上的流式多处理器(SM)数量(例如A100上有108个),该操作将只能利用GPU的一小部分算力。这种情况在使用长上下文时尤为常见,因为长上下文需要较小的批次大小才能装入GPU内存。当批次大小为1时,FlashAttention将只使用不到1%的GPU!

FlashAttention仅在查询块和批次大小上进行并行化,在解码过程中无法占满整个GPU
FlashAttention仅在查询块和批次大小上进行并行化,在解码过程中无法占满整个GPU

矩阵乘法原语的局限性。注意力也可以通过矩阵乘法原语来完成——不使用FlashAttention。在这种情况下,操作能够完全占据GPU,但会启动许多内核来写入和读取中间结果,这并非最优选择。

A2 方法细节

一种更快的解码注意力机制:Flash-Decoding

Flash-Decoding的核心思想。我们的新方法Flash-Decoding基于FlashAttention,并增加了一个新的并行维度:键/值(keys/values)的序列长度。它结合了上述两种方法的优点。与FlashAttention一样,它只向全局内存存储极少的额外数据;然而,只要上下文长度足够大,即使批次大小很小,它也能充分利用GPU。

Flash-Decoding也在键和值上进行并行化,代价是一个小的最终归约步骤
Flash-Decoding也在键和值上进行并行化,代价是一个小的最终归约步骤

Flash-Decoding的工作步骤。Flash-Decoding分三步工作:
1. 首先,我们将键/值分割成更小的块。
2. 我们使用FlashAttention并行计算查询与每个分割块的注意力。我们还为每行和每个分割块额外写入一个标量:注意力值的对数指数和(log-sum-exp)。
3. 最后,我们通过对所有分割块进行归约来计算实际输出,使用log-sum-exp来缩放每个分割块的贡献。

实现原理。这一切之所以成为可能,是因为注意力/softmax可以迭代计算。在Flash-Decoding中,这种迭代计算被用在两个层面:在分割块内部(如同FlashAttention),以及跨分割块以执行最终的归约。

具体实现细节。在实践中,步骤(1)不涉及任何GPU操作,因为键/值块是完整键/值张量的视图(view)。然后,我们有两个独立的内核分别执行步骤(2)和(3)。

使用 Flash-Decoding

可用性。Flash-Decoding已在以下库中提供:
- 在FlashAttention包中,从2.2版本开始。
- 通过xFormers,从0.0.22版本开始,通过 xformers.ops.memory_efficient_attention 调用。分发器(dispatcher)会根据问题规模自动选择使用Flash-Decoding或FlashAttention方法。当这些方法不受支持时,它可以分发到一个实现了Flash-Decoding算法的高效Triton内核。

示例代码。使用LLaMa v2 / CodeLLaMa进行解码的完整示例可在FlashAttention仓库(https://github.com/Dao-AILab/flash-attention/blob/main/examples/inference/decode_llama.py">链接)和xFormers仓库(https://github.com/facebookresearch/xformers/blob/main/examples/Llama/decode_llama.py">链接)中找到。我们还提供了一个用于LLaMa v1/v2模型的高效解码代码的最小示例,旨在实现快速、易读、具有教育意义且易于修改。

A4 实验环境

  • 模型架构:
    • 主要模型:CodeLLaMa-34B,其架构与Llama 2相同。
    • 微基准测试配置:模拟在4个GPU上运行的CodeLLaMa-34B,注意力头配置为16个查询头(维度128)和2个键/值头(分组查询注意力,Grouped-Query Attention)。
  • 硬件配置:
    • GPU:NVIDIA A100。
  • 软件配置与基线:
    • Pytorch: 使用纯PyTorch原语实现的注意力。
    • FlashAttention v2: 业界领先的训练优化注意力库。
    • FasterTransformer: 使用FasterTransformer的注意力内核。
    • Flash-Decoding: 本文提出的方法。
    • Upper bound: 一个理论上限,计算为从内存中读取整个模型及KV缓存所需的时间。
  • 实验任务与数据:
    • 任务:评估不同注意力实现在不同序列长度(从512到64k)下的解码吞吐量(单位:tok/s)。
    • 数据类型:微基准测试使用f16(半精度浮点数)作为输入。
    • 批次大小(Batch Size):实验设置批次大小为1,以模拟长上下文推理的典型场景。

A4 实验结果

CodeLlama 34B 上的基准测试

  • 实验内容: 对比了Pytorch、FlashAttention v2、FasterTransformer和Flash-Decoding在CodeLlama-34B模型上,从512到64k不同序列长度下的解码速度(tok/s)。
  • 实验结果: 对于短序列,所有方法的性能相似。但随着序列长度增加,除Flash-Decoding外的所有方法性能都急剧下降。Flash-Decoding在极长序列上实现了高达8倍的解码速度提升,并且其生成速度受序列长度增长的影响很小。
  • 分析结论: Flash-Decoding在长上下文推理场景下展现出卓越的扩展性,远超其他替代方案。
    除了Flash-Decoding,所有方法在序列长度从512增加到64k时扩展性都很差。在batch size为1的这种情况下,使用Flash-Decoding时,扩展序列长度对生成速度影响很小。
    除了Flash-Decoding,所有方法在序列长度从512增加到64k时扩展性都很差。在batch size为1的这种情况下,使用Flash-Decoding时,扩展序列长度对生成速度影响很小。

组件级微基准测试

  • 实验内容: 在A100上对多头注意力组件本身进行了微基准测试,测量了在不同序列长度下(批次大小为1,输入为f16)的运行时间(单位:微秒)。
  • 实验结果: Flash-Decoding在注意力组件层面比FlashAttention快了高达50倍。在序列长度达到32k之前,Flash-Decoding的运行时间几乎保持不变。
  • 分析结论: 端到端高达8倍的加速得益于注意力组件本身实现的巨大性能提升。运行时间在长序列下保持恒定,证明了Flash-Decoding能够成功地充分利用GPU资源,克服了传统方法因序列增长而导致的性能瓶颈。

A5 结论

本文提出了一种名为Flash-Decoding的新技术,有效解决了大型语言模型在长上下文推理场景下面临的注意力计算瓶颈。传统方法或因GPU利用率不足(如FlashAttention),或因内存读写开销过大(如标准矩阵乘法)而效率低下。Flash-Decoding通过沿序列长度维度进行并行化,将K/V缓存分块处理并高效合并,成功地在保持内存效率的同时充分利用了GPU算力。在CodeLlama-34B模型上的实验表明,对于超长序列,Flash-Decoding可带来高达8倍的端到端解码速度提升,且性能随序列长度增长的衰减远小于其他方法。该技术已集成到主流库FlashAttention和xFormers中,为开发更高效、更强大的长上下文LLM应用提供了关键支持。

A6 附录

致谢

感谢Erich Elsen、Ashish Vaswani和Michaël Benesty提出了分割KV缓存加载的想法。我们还要感谢Jeremy Reizenstein、Patrick Labatut和Andrew Tulloch进行的有价值的讨论,以及Quentin Carbonneaux为xFormers贡献的高效解码示例。我们同时感谢Geeta Chauhan和Gregory Chanan在撰写和推动本文在PyTorch博客上发表方面提供的帮助。