Ring Attention with Blockwise Transformers for Near-Infinite Context

作者/机构: Hao Liu, Matei Zaharia, Pieter Abbeel, UC Berkeley hao.liu@cs.berkeley.edu

A1 主要贡献

核心问题: Transformer架构在处理长序列时面临严峻的内存挑战。其自注意力机制的内存成本与序列长度成二次方关系,导致难以扩展到长序列输入。此外,即使采用内存优化技术,存储每一层的输出激活值仍然是巨大的瓶셔颈,例如,处理一个包含1亿个token的序列,即使模型隐藏维度仅为1024,也需要超过1000GB的内存,这远超现代GPU和TPU不足100GB的高带宽内存(HBM)容量。

研究目标: 本文旨在消除单个设备内存容量对Transformer模型上下文长度的限制,提出一种新方法,使得上下文长度能够随着设备数量线性增长,从而处理极长的序列,同时不引入额外的通信和计算开销,也不对注意力机制进行近似处理。

创新点 (Ring Attention):
本文提出了一种名为环形注意力与分块Transformer(Ring Attention with Blockwise Transformers, 简称Ring Attention)的新方法,其核心创新在于:
1. 结合分块计算与设备间通信: 该方法利用了自注意力和前馈网络的分块计算(blockwise computation)特性。它将长序列沿序列维度切分成块,并均匀分布到多个设备上。
2. 环形通信拓扑: 所有参与计算的设备(host)在逻辑上形成一个环。在计算过程中,每个设备在处理其本地的查询块(query block)与一个键值块(key-value block)时,会同时将其正在使用的键值块发送给环中的下一个设备,并从上一个设备接收新的键值块。
3. 计算与通信的完全重叠: Ring Attention的关键优势在于能够将键值块在设备间的传输(通信)与分块注意力的计算完全重叠。只要分块计算的时间大于或等于块传输的时间,通信开销就可以被完全隐藏,从而实现零开销的上下文扩展。
4. 线性扩展上下文长度: 通过这种方式,Ring Attention将单个设备的内存瓶颈转移到了设备集群的总内存上。其上下文长度可以随设备数量线性增长,达到比现有内存效率最高的Transformer长设备数量倍的序列长度,从而能够训练和推理上下文长度超过数百万甚至上亿token的模型。

下图展示了Ring Attention在TPUv4-1024上进行端到端大规模训练时,相较于基线方法在最大上下文长度上的显著优势。

图1:在TPUv4-1024上进行端到端大规模训练的最大上下文长度。基线是vanilla transformers [37],内存高效transformers [30],以及内存高效注意力和前馈(分块并行transformers)[23]。我们提出的Ring Attention方法允许训练的序列长度比基线长达设备数量倍,并且能够训练长度超过百万的序列,而无需进行近似,也无需增加任何通信和计算开天销。
图1:在TPUv4-1024上进行端到端大规模训练的最大上下文长度。基线是vanilla transformers [37],内存高效transformers [30],以及内存高效注意力和前馈(分块并行transformers)[23]。我们提出的Ring Attention方法允许训练的序列长度比基线长达设备数量倍,并且能够训练长度超过百万的序列,而无需进行近似,也无需增加任何通信和计算开天销。

A3 背景知识/关键Observation/设计原则

标准Transformer的计算: 给定输入序列Q, K, V ∈ $R^{s \times d}$,其中s是序列长度,d是头维度,输出矩阵的计算公式为:

公式1
公式1

其中softmax是按行应用的。每个自注意力子层都伴随着一个前馈网络,它被分别且相同地应用于每个位置,包含两个线性变换,中间有一个ReLU激活函数。
公式2
公式2

分块并行Transformer(Blockwise Parallel Transformers): 先前的先进技术通过创新的方法大幅降低了内存使用,这些技术使得注意力计算无需完全实例化注意力矩阵,而是以分块方式进行【索引30,Self-attention does not need o(n2) memory,2021】【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】。这些进步将注意力的内存开销降低到每层$2bsh$字节,其中$b$是批量大小,$s$是序列长度,$h$是模型的隐藏大小。为了进一步减少内存使用,分块并行Transformer(BPT)【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】引入了一种策略,即与每个自注意力子层相关的前馈网络也以分块方式计算。这种方法有效地将前馈网络的最大激活大小从$8bsh$限制到$2bsh$。总的来说,最先进的Transformer层的激活内存成本为$2bsh$。

每层输出的巨大内存占用: 尽管BPT显著减少了Transformer的内存需求,但它在扩展上下文长度方面仍然面临重大挑战,因为它需要存储每一层的输出。这种存储是至关重要的,因为自注意力的内在性质涉及所有元素之间的交互(n到n的交互)。如果没有这些存储的输出,后续层的自注意力计算将变得不切实际,需要为每个序列元素重新计算,使得处理更长序列变得不可行。简单来说,即使批量大小为1,处理1亿个token也需要超过1000GB的内存,即使对于隐藏大小为1024的中等模型也是如此。相比之下,现代GPU和TPU通常提供不到100GB的高带宽内存(HBM),而HBM的显著扩展前景受到物理限制和高制造成本的阻碍。

A2 方法细节

目标与挑战: 我们的主要目标是通过在多个主机之间高效地分布长序列来消除单个设备造成的内存限制,且不增加额外开销。为实现此目标,我们提出对分块并行Transformer(BPT)框架【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】进行增强。当将输入序列分布到不同主机时,每个主机负责运行与其指定块相对应的分块注意力外层循环的一个元素,以及特定于该块的前馈网络。这些操作不需要与其他主机通信。然而,内层循环会出现一个挑战,它涉及键值块的交互,需要从其他主机获取块。由于每个主机仅拥有一个键值块,从其他主机获取块的朴素方法会导致两个重大问题。首先,它引入了计算延迟,因为系统需要等待接收必要的键值块。其次,键值块的累积导致内存使用增加,这违背了降低内存成本的初衷。

基于环形的分块注意力: 为了解决上述挑战,我们利用了内层循环中键值块操作的排列不变性。这一特性源于这样一个事实:一个查询块与一组键值块之间的自注意力可以按任何顺序计算,只要每个块的统计数据被正确地组合以进行重新缩放。我们利用这一特性,将所有主机概念化为一个环形结构:host-1, host-2, ..., host-N。当我们计算分块注意力和前馈网络时,每个主机通过并发地将在注意力计算中使用的键值块发送给下一个主机,同时从前一个主机接收键值块来进行高效协调,从而有效地将块的传输与分块计算重叠。具体来说,对于任何主机i,在计算其查询块与一个键值块之间的注意力时,它会同时将键值块发送给下一个主机-(i + 1),并从前一个主机-(i − 1)接收键值块。如果计算时间超过了传输键值块所需的时间,这就不会产生额外的通信成本。这种重叠机制适用于我们方法的前向和后向传播,因为可以使用相同的操作和技术。先前的工作也提出了利用环形拓扑来计算自注意力【索引21,Sequence parallelism: Long sequence training from system perspective,2023,ACL】,旨在降低通信成本。我们的工作不同之处在于利用分块并行Transformer来大幅降低内存成本。正如我们将在下一节展示的,这使得在训练和推理过程中上下文大小能够实现零开销扩展,并允许任意大的上下文大小。

图2:顶部(a):我们使用与原始Transformer相同的模型架构,但重新组织了计算。在图中,我们通过展示在一个主机环中,每个主机持有一个查询块,而键值块则以逐块的方式穿过主机环进行注意力和前馈计算来解释这一点。当我们计算注意力时,每个主机将键值块发送给下一个主机,同时从前一个主机接收键值块。通信与分块注意力和前馈的计算是重叠的。底部(b):我们逐块计算原始Transformer。每个主机负责查询外层循环的一次迭代,而键值块则在主机之间轮换。如图所示,一个设备从左边的第一个查询块开始;然后我们遍历水平放置的键值块序列。查询块与键值块结合,用于计算自注意力(黄色框),其输出传递给前馈网络(青色框)。
图2:顶部(a):我们使用与原始Transformer相同的模型架构,但重新组织了计算。在图中,我们通过展示在一个主机环中,每个主机持有一个查询块,而键值块则以逐块的方式穿过主机环进行注意力和前馈计算来解释这一点。当我们计算注意力时,每个主机将键值块发送给下一个主机,同时从前一个主机接收键值块。通信与分块注意力和前馈的计算是重叠的。底部(b):我们逐块计算原始Transformer。每个主机负责查询外层循环的一次迭代,而键值块则在主机之间轮换。如图所示,一个设备从左边的第一个查询块开始;然后我们遍历水平放置的键值块序列。查询块与键值块结合,用于计算自注意力(黄色框),其输出传递给前馈网络(青色框)。

主机间的算术强度: 为了确定重叠传输与计算所需的最小块大小,假设每个主机的浮点运算能力为$F$ FLOPS,主机间的带宽表示为$B$。值得注意的是,我们的方法仅与环形配置中紧邻的前一个和后一个主机进行交互,因此我们的分析适用于GPU的全连接(all-to-all)拓扑和TPU的环面(torus)拓扑。我们考虑以下变量:块大小表示为$c$,隐藏大小表示为$d$。在计算分块自注意力时,我们使用查询和键计算注意力分数需要$2dc^2$ FLOPs,将这些注意力分数乘以值需要另外$2dc^2$ FLOPs。总共,计算需求为$4dc^2$ FLOPs。我们排除了查询、键和值的投影以及分块前馈操作,因为它们只增加计算复杂度而没有任何主机间的通信成本。这种简化导致了更严格的条件,并且不损害我们方法的有效性。在通信方面,键和值块总共需要$2cd$字节,假设每个元素为bfloat16(2字节),所以总通信需求是$4cd$字节。为了实现通信和计算之间的重叠,必须满足以下条件:$4dc^2/F \ge 4cd/B$。这意味着块大小$c$应大于或等于$F/B$。实际上,这意味着块大小需要大于FLOPS与带宽的比率。

表1:不同Transformer架构的最大激活大小比较。其中,b是批量大小,h是隐藏维度,n是注意力头数,s是序列长度,c是块大小。块大小(c)与输入序列长度(s)无关。比较的对象包括vanilla Transformer [37]、内存高效注意力 [30]、内存高效注意力和前馈网络 [23],以及我们提出的Ring Attention方法。数值以每层的字节数表示,假设使用bfloat16精度。

表1
表1

内存需求: 一个主机需要存储多个块,包括一个块大小用于存储当前查询块,两个块大小用于当前的键和值块,以及两个块大小用于接收键和值块。此外,存储分块注意力和前馈的输出需要一个块大小,因为输出的形状与查询块相同。因此,总共需要六个块,这相当于$6bch$字节的内存。值得注意的是,分块前馈网络的最大激活大小为$2bch$【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】。因此,总的最大激活大小保持在$6bch$字节。表1详细比较了我们的方法与其他方法的内存成本。值得注意的是,我们的方法展现出内存随块大小$c$线性扩展的优势,并且与输入序列长度$s$无关。我们的分析表明,模型需要在每个设备上处理的序列长度为$s=6c$,即最小块大小的六倍。表2展示了在流行的计算服务器上的要求。每个主机所需的最小序列长度(最右列)在6K到10K之间,而每个主机的最小块大小(右二列)对于具有高带宽互连的TPU和GPU约为1K。对于通过InfiniBand连接的GPU,其带宽较低,要求更为严格。这些要求很容易通过数据并行、张量并行以及内存高效的分块注意力和前馈网络【索引30,Self-attention does not need o(n2) memory,2021】【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】来满足,我们将在实验第5节中展示这一点。

表2:每个设备上所需的最小序列长度。互连带宽是主机之间的单向带宽,即GPU之间的NVLink/InfiniBand带宽和TPU之间的ICI带宽。所需的最小块大小$c = FLOPS/Bandwidth$,最小序列长度$s = 6c$。

表2
表2

算法与实现: 算法1提供了该算法的伪代码。Ring Attention与现有的内存高效Transformer代码兼容:Ring Attention只需在每个主机上本地调用任何可用的内存高效计算,并将主机间的键值块通信与分块计算重叠。我们使用集合操作jax.lax.ppermute在相邻主机之间发送和接收键值块。附录A中提供了一个Jax实现。

算法1:Ring Attention算法伪代码
算法1:Ring Attention算法伪代码

A4 实验环境

  • 模型架构:
    • 语言模型: 基于LLaMA架构,涵盖3B, 7B, 13B, 和 30B四种模型尺寸。
    • 强化学习模型: 350M参数的Agentic Transformer (AT) 模型。
  • 数据集:
    • 语言模型微调: 使用从ShareGPT.com收集的用户共享对话数据集,经过清洗后包含125K个对话。
    • 语言模型评测: 使用line retrieval test来评估长上下文文本生成、检索和信息关联能力。
    • 强化学习: 使用ExoRL基准测试,该数据集由无监督强化学习方法收集,数据多样性高。
  • 硬件配置:
    • GPU: 单台DGX A100服务器(8x A100 80GB GPU)及分布式集群(32x A100 GPU)。
    • TPU: 使用了多个代次的TPU,包括TPUv3, TPUv4, 和 TPUv5e。
  • 软件配置:
    • 实现框架: JAX,利用其SPMD进行并行化。
    • 核心库/操作: 使用jax.lax.ppermute实现环形通信。
    • 并行策略: 实验中结合了全分片数据并行(Fully Sharded Data Parallelism, FSDP)和张量并行。所有方法均应用了完全梯度检查点(full gradient checkpointing)。
    • 精度: 在TPU上,矩阵乘法使用bfloat16,权重累积使用float32;在GPU上,所有操作均使用float32。
  • 基线模型:
    • Vanilla Transformer: 标准Transformer实现,会实例化完整的注意力矩阵【索引37,Attention is all you need,2017,NeurIPS】。
    • Memory Efficient Attention: 使用内存高效注意力机制的Transformer【索引30,Self-attention does not need o(n2) memory,2021】及其高效CUDA实现【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】。
    • Blockwise Parallel Transformer (BPT): 同时对注意力和前馈网络进行分块计算的Transformer【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】。

A4 实验结果

5.1 最大上下文尺寸评估

  • 实验内容: 在FSDP的端到端训练设置下,比较Ring Attention与各基线模型在不同硬件(A100 GPU, TPUv3/v4)和模型规模(3B-30B)下所能支持的最大上下文长度。
  • 实验结果: Ring Attention显著超越所有基线。其支持的上下文长度与设备数量成线性关系。例如,在32个A100 GPU上,7B模型的上下文长度超过100万token,是之前最佳方法的32倍。在TPUv4-512上,上下文长度可超过3000万token,是之前最佳方法的256倍。这证实了如果基线方法能在n个设备上训练上下文长度为s的模型,Ring Attention则能训练上下文长度为ns的模型。
  • 相关图表: 表3详细列出了不同配置下的最大上下文长度对比。
    表3:在使用完全分片数据并行和各种Transformer架构的端到端训练中支持的最大上下文长度。我们展示了不同的模型大小和加速器。基线是vanilla transformer [37],带有内存高效注意力的transformer [30],以及带有内存高效注意力和前馈网络的transformer [23]。上下文大小以token(1e3)为单位报告。我们的Ring Attention显著优于基线,并使训练序列的长度比现有技术长达设备数量倍。
    表3:在使用完全分片数据并行和各种Transformer架构的端到端训练中支持的最大上下文长度。我们展示了不同的模型大小和加速器。基线是vanilla transformer [37],带有内存高效注意力的transformer [30],以及带有内存高效注意力和前馈网络的transformer [23]。上下文大小以token(1e3)为单位报告。我们的Ring Attention显著优于基线,并使训练序列的长度比现有技术长达设备数量倍。

5.2 模型FLOPS利用率(MFU)评估

  • 实验内容: 在标准训练设置(FSDP+张量并行)下,评估Ring Attention在处理极大上下文长度(超过4M tokens)时对模型FLOPS利用率(MFU)的影响。
  • 实验结果: 尽管长上下文增加了计算量更大但MFU较低的自注意力部分的比例,预期MFU会略有下降,但实验表明Ring Attention的实际MFU与预期值非常接近。这证明Ring Attention在实现极大上下文扩展的同时,几乎没有引入额外的计算开销,能够维持高MFU和高吞吐量。
  • 相关图表: 表4展示了不同模型、计算资源和上下文长度配置下的MFU。
    表4:不同训练配置下的模型FLOPS利用率(MFU):模型大小、计算资源和上下文长度。Ring Attention能够在大型输入上下文(超过4M)上训练大型模型(7B-65B),且开销可忽略不计。
    表4:不同训练配置下的模型FLOPS利用率(MFU):模型大小、计算资源和上下文长度。Ring Attention能够在大型输入上下文(超过4M)上训练大型模型(7B-65B),且开销可忽略不计。

5.3 对上下文强化学习(In-Context RL)性能的影响

  • 实验内容: 将Ring Attention应用于Agentic Transformer (AT)模型,在ExoRL基准上进行评估。通过利用Ring Attention的长上下文能力,将模型在上下文中处理的轨迹数量从32条增加到128条。
  • 实验结果: AT + Ring Attention在所有六个任务上都一致地优于使用BPT的原始AT模型。总平均回报从111.13提高到113.66。这表明更长的上下文(更多的历史经验)能够提升Transformer在强化学习任务中的表现。
  • 相关图表: 表5比较了不同方法在ExoRL任务上的累积回报。
    表5:Ring Attention在提升Transformer在强化学习中应用的效果。BC和DT使用vanilla attention。AT + ME表示使用内存高效注意力,AT + BPT表示使用分块并行transformer。AT + RA表示使用Ring Attention。
    表5:Ring Attention在提升Transformer在强化学习中应用的效果。BC和DT使用vanilla attention。AT + ME表示使用内存高效注意力,AT + BPT表示使用分块并行transformer。AT + RA表示使用Ring Attention。

5.4 对大语言模型(LLM)性能的影响

  • 实验内容: 使用ShareGPT数据集将LLaMA-13B模型微调至512K的上下文长度,并在长程行检索(line retrieval)任务上进行评测。该任务要求模型从长文档中精确检索一个数字。
  • 实验结果: 微调后的Ring Attention-13B-512K模型即使在非常长的上下文中也能保持高检索准确率。相比之下,其他模型如GPT-3.5-turbo-16K和Claude-2-100K虽然在短上下文中表现良好,但无法处理如此长的上下文。
  • 相关图表: 图3展示了不同模型在不同上下文长度下的行检索任务准确率。
    图3:不同模型在长程行检索任务上的比较。
    图3:不同模型在长程行检索任务上的比较。

A7 补充细节

相关工作分类与对比:
- 内存高效的精确注意力: 本文的工作属于这一类别,建立在分块技术【索引24,Online normalizer calculation for softmax,2018】、内存高效注意力【索引30,Self-attention does not need o(n2) memory,2021】、【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】以及分块并行Transformer(BPT)【索引23,Blockwise parallel transformer for large context models,2023,NeurIPS】等进展之上,通过计算和通信的重叠进一步扩展了上下文处理能力。
- 近似注意力机制: 论文提到,近似注意力的研究虽然存在,但这些方法在扩大规模时常遇到困难或导致次优结果。
- 并行化方法: 现有并行方法包括数据并行【索引10,Large scale distributed deep networks,2012,NeurIPS】、张量并行【索引34,Megatron-lm: Training multi-billion parameter language models using model parallelism,2019】、流水线并行【索引27,Pipedream: Generalized pipeline parallelism for dnn training,2019,SOSP】【索引15,Gpipe: Efficient training of giant neural networks using pipeline parallelism,2019,NeurIPS】【索引28,Memory-efficient pipeline-parallel dnn training,2021,ICML】、序列并行【索引21,Sequence parallelism: Long sequence training from system perspective,2023,ACL】【索引18,Reducing activation recomputation in large transformer models,2022】【索引17,Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models,2023】和FSDP【索引11,Fully Sharded Data Parallel: faster AI training with fewer GPUs — http://engineering.fb.com,2023】【索引31,Zero: Memory optimizations toward training trillion parameter models,2020,SC20】。
- 与序列并行的区别: 论文特别指出,张量并行只能减少部分激活内存,而现有的序列并行方法会引入显著的、无法与计算完全重叠的通信开销。例如,先前有工作【索引17,Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models,2023】通过优化的all-to-all拓扑来分片序列和注意力头,但受限于注意力头数且需要在每个设备上收集完整序列。另一项使用环形拓扑的序列并行工作【索引21,Sequence parallelism: Long sequence training from system perspective,2023,ACL】虽然降低了通信成本,但由于算术强度的限制,重叠通信与计算仍然具有挑战性,导致在长上下文场景下不可行。相比之下,Ring Attention利用分块并行Transformer的计算模式,实现了通信与计算的完全重叠,从而解决了这些问题,实现了零开销扩展。
- 计算与通信重叠: 该思想在高性能计算(HPC)领域已有研究【索引7,Transformations to parallel codes for communication-computation overlap,2005,SC’05】【索引38,Overlap communication with dependent computation via decomposition in large deep learning models,2022,ASPLOS】【索引8,Mpi-aware compiler optimizations for improving communication-computation overlap,2009,ICS】,环形通信也在其他并行计算场景中得到应用【索引2,Parallel computing: Architectures, algorithms, and applications,2008】【索引16,Building a fault tolerant mpi application: A ring communication example,2011,IPDPSW】【索引14,Bringing hpc techniques to deep learning,2017】【索引33,Horovod: fast and easy distributed deep learning in tensorflow,2018】。但本文是首次证明该技术可以高效地应用于Transformer中的自注意力,并通过重叠分块计算与通信,使其在不增加显著开销的情况下,无缝融入Transformer的训练和推理。

A5 结论

本文提出了一种名为Ring Attention的内存高效方法,旨在解决Transformer模型在处理长序列时遇到的内存瓶颈。该方法通过将分块计算与环形通信相结合,并完全重叠计算与通信过程,成功地使模型的上下文长度能够随设备数量线性增长。实验证明,Ring Attention能够训练比现有SOTA内存高效Transformer长设备数量倍的序列,上下文长度超过1亿个token,且无需对注意力机制进行任何近似。

未来展望: “近乎无限”的上下文能力为AI领域开辟了广阔的新机遇,包括:
- 开发大规模的视频-音频-语言多模态模型。
- 通过学习更长的反馈和试错序列来提升强化学习性能。
- 理解和生成复杂的代码库。
- 使AI模型能够理解科学数据,如基因序列。
- 通过从互联的链接数据中收集信息,发展出更强的推理能力。

A6 附录

A. 代码实现: Ring Attention的Jax实现代码如Figure 4所示。该实现使用defvjp函数定义前向和后向传播,并利用jax.lax.ppermute集合操作来促进键值块在主机环中的交换。对于大规模端到端训练,建议使用FSDP分片大模型,并结合Ring Attention实现长上下文。如果总批量大小过大,可以增加张量并行来减小全局批量大小。并行度可以通过代码库中的mesh_dim参数调整。例如,在512个A100设备上,对于30B模型,可使用8路FSDP分片模型,剩下的64个设备用于Ring Attention,使上下文扩大64倍。对于7B或3B这样较小的模型,则无需FSDP,所有512个设备都可用于Ring Attention,将上下文扩大512倍。基于8个A100 GPU可达256K上下文的结果,推断512个A100 GPU可将上下文扩展至1600万。

1  def _ring_attention_fwd(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
2      if float32_logits:
3          q, k = q.astype(jnp.float32), k.astype(jnp.float32)
4      batch, q_len, num_heads, dim_per_head = q.shape
5      batch, kv_len, num_heads, dim_per_head = k.shape
6      numerator = jnp.zeros((batch, q_len, num_heads, dim_per_head)).astype(q.dtype)
7      denominator = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
8      axis_size = lax.psum(1, axis_name)
9      block_size = q_len # assumes this function is pre-sharded inside shard_map
10     query_chunk_size = blockwise_kwargs["query_chunk_size"]
11     key_chunk_size = blockwise_kwargs["key_chunk_size"]
12     def scan_kv_block(carry, idx):
13         prev_max_score, numerator, denominator, k, v = carry
14         attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
15             (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
16         q_block_idx = lax.axis_index(axis_name)
17         k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
18         q_chunk_idx_start = q_block_idx * (block_size // query_chunk_size)
19         k_chunk_idx_start = k_block_idx * (block_size // key_chunk_size)
20         numerator, denominator, max_score = _blockwise_attention_fwd(q, k, v,
21             (numerator, denominator, prev_max_score), q_chunk_idx_start, k_chunk_idx_start,
22             bias=attn_bias_slice, **blockwise_kwargs)
23         k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size)
24             for i in range(axis_size)]), (k, v))
25         return (max_score, numerator, denominator, k, v), None
26     prev_max_score = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
27     (max_score, numerator, denominator, _, _), _ = lax.scan(scan_kv_block,
28         init=(prev_max_score, numerator, denominator, k, v), xs=jnp.arange(0, axis_size))
29     output = numerator / rearrange(denominator, 'b h q -> b q h')[..., None]
30     return output.astype(v.dtype), (output, q, k, v, attn_bias, denominator, max_score)
31 
32 def _ring_attention_bwd(axis_name, float32_logits, blockwise_kwargs, res, g):
33     output, q, k, v, attn_bias, denominator, max_score = res
34     batch, kv_len, num_heads, dim_per_head = k.shape
35     axis_size = lax.psum(1, axis_name)
36     dq = jnp.zeros_like(q, dtype=jnp.float32)
37     dk = jnp.zeros_like(k, dtype=jnp.float32)
38     dv = jnp.zeros_like(v, dtype=jnp.float32)
39     query_chunk_size = blockwise_kwargs["query_chunk_size"]
40     key_chunk_size = blockwise_kwargs["key_chunk_size"]
41     block_size = q.shape[1] # assumes this function is pre-sharded inside shard_map
42     def scan_kv_block(carry, idx):
43         dq, dk, dv, k, v = carry
44         attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
45             (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
46         q_block_idx = lax.axis_index(axis_name)
47         k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
48         q_chunk_idx_start = q_block_idx * (block_size // query_chunk_size)
49         k_chunk_idx_start = k_block_idx * (block_size // key_chunk_size)
50         dq, dk, dv = _blockwise_attention_bwd(q, k, v, g, (dq, dk, dv, output, denominator, max_score),
51             q_chunk_idx_start, k_chunk_idx_start, bias=attn_bias_slice, **blockwise_kwargs)
52         k, v, dk, dv = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
53             (i + 1) % axis_size) for i in range(axis_size)]), (k, v, dk, dv))
54         return (dq, dk, dv, k, v), None
55     (dq, dk, dv, k, v), _ = lax.scan(scan_kv_block, init=(dq, dk, dv, k, v), xs=jnp.arange(0, axis_size))
56     dq, dk, dv = dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)
57     return dq, dk, dv, None
58 
59 @partial(jax.custom_vjp, nondiff_argnums=[4, 5, 6])
60 def ring_attention(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
61     y, _ = _ring_attention_fwd(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs)
62     return y
63 
64 ring_attention.defvjp(_ring_attention_fwd, _ring_attention_bwd)

B. 实验细节:
- B.1 上下文长度评估: 在5.1节的实验中,我们使用FSDP将模型划分到GPU或TPU设备上。评估重点是确定在常用FSDP训练场景下的最大可实现序列长度。TPU使用默认训练配置,即以bfloat16格式进行矩阵乘法,以float32累积权重。GPU则采用默认设置,所有运算均以float32进行。
- B.2 MFU评估: 在5.2节的评估中,GPU上的批量大小为每批200万个token,TPU上为每批400万个token。训练使用FSDP和Jax SPMD进行。对于梯度检查点,我们对注意力和前馈网络(FFN)使用了nothing_saveable作为检查点策略。
- B.3 行检索评估: 在5.4节的评估中,我们微调了LLaMA-13B模型,由于云端计算预算限制,将上下文长度限制在512K。训练在32个A100 80GB云GPU上进行。我们使用从ShareGPT.com公共API收集的用户共享对话进行微调,遵循了先前工作的方法。ShareGPT是一个用户分享ChatGPT对话的网站。为确保数据质量,我们将HTML转换回markdown并过滤掉一些不当或低质量的样本,数据清洗后得到125K个对话。

C. 推理需求: 表2中提供了训练期间重叠通信与计算所需的最小序列长度。Ring Attention使得上下文大小能够随设备数量线性扩展,轻松进行训练。虽然我们主要介绍训练,因为它比自回归推理(查询token数为1)对内存的要求更高,但Ring Attention同样适用于推理。例如,在32个TPUv5e上部署一个LLaMa 7B模型,传统方法是沿注意力头维度分布模型,每个设备计算一个注意力头。假设批量大小为1,由于键值缓存激活大小的限制,最多可服务256K的上下文长度。Ring Attention可以通过在设备环之间循环键值缓存,允许32倍大的上下文。为了重叠通信与计算,需要满足$d^2/F \ge 2d/B$,即$B/F \ge 2/d$。对于带宽为186 GB/s,算力为196 TFLOPs的设备,并假设这种长上下文的MFU高达不合理的40%,那么$B/F \approx 2.4$,这意味着Ring Attention可以在不增加开销的情况下,为推理提供32倍大的上下文。

D. 训练FLOPS随上下文大小的扩展: 鉴于我们提出的方法开启了训练超过1亿token上下文的可能性,并允许上下文大小随设备数量线性扩展,理解每个数据集的训练FLOPS如何随上下文大小扩展至关重要。虽然更大的上下文大小会导致更高的FLOPS,但由于token数量固定,增加的比例并非二次方。我们在图5中展示了这些结果,图中展示了不同模型大小和上下文长度(代表不同计算预算)下,相对于一个短的4K上下文大小的模型的FLOPs比率。我们使用公式$(24bsh^2 + 4bs^2h)n$计算每个序列的FLOPs,其中$h$是模型隐藏维度,$b$是批量大小,$s$是总序列长度,$n$是层数。每个数据集的FLOPs比率则由$((24bs_2h^2 + 4bs_2^2h)/(24bs_1h^2 + 4bs_1^2h))/(s_2/s_1) = (6h + s_2)/(6h + s_1)$给出,其中$s_2$和$s_1$分别是新的和旧的上下文长度。模型大小及其隐藏维度如下:LLaMA-7B (4096), LLaMA-13B (5140), LLaMA-33B (7168), LLaMA-65B (8192), GPT3-175B (12288), 和 1TB (36864)。这些模型配置来自LLaMA和GPT-3的论文,除了1TB模型的大小和维度是我们自己定义的。

图5:相对于4k上下文大小的每个数据集训练FLOPs成本比率,考虑了不同的模型维度。x轴是上下文长度,例如,32x(128k)表示上下文长度为128k,是同一模型4k上下文长度的32倍。
图5:相对于4k上下文大小的每个数据集训练FLOPs成本比率,考虑了不同的模型维度。x轴是上下文长度,例如,32x(128k)表示上下文长度为128k,是同一模型4k上下文长度的32倍。

如图5所示,将小模型扩展到1M上下文大小会导致大约20-40倍的FLOPs,对于10M和100M token上下文大小则更多。然而,随着模型大小的增加,成本比率下降。例如,将170B模型从4K扩展到10M,每个数据集的FLOPs增加了162.6倍,尽管上下文大小长了3072倍。