作者/机构: Vasudev Shyam (Zyphra), Jonathan Pilault (Zyphra), Emily Shepperd (EleutherAI), Quentin Anthony (Zyphra), Beren Millidge (Zyphra)

A1 主要贡献

本文旨在解决现代 Transformer 架构核心操作——自注意力机制的计算瓶颈问题。自注意力机制的计算复杂度随序列长度呈二次方增长,这使得处理长上下文的 Transformer 模型(尤其是大型语言模型 LLM)在训练和生成时需要消耗大量资源。尽管近期模型已能处理百万级别的超长上下文,并展现出新的能力,但其高昂的推理成本限制了实际应用。

为了应对这一挑战,本文做出了以下核心贡献:
1. 理论创新:首次推导出了一个标量能量函数,其梯度精确地计算了自注意力模块。这一理论阐明了自注意力的理论基础,并揭示了其内在的可并行化结构。
2. 算法设计:基于能量函数理论,提出了一种名为“树状注意力”(Tree Attention)的新算法。该算法利用 logsumexpmax 操作的结合律特性,将沿序列轴的归约(reduction)操作构建成一棵归约树,从而高效地并行化精确注意力计算。
3. 性能优势
* 渐进式加速:与当前最先进的序列并行方法(如环形注意力 Ring Attention)相比,Tree Attention 在多 GPU 解码中实现了渐进式的性能提升。其通信步数随设备数量呈对数增长,而非线性增长。实验表明,在128个GPU上处理512万长度的序列时,速度提升高达8倍。
* 资源节约:Tree Attention 需要的通信量显著减少,且峰值内存占用仅为 Ring Attention 的一半。
* 端到端加速:在 Llama 3.1-8B 模型上的实际应用中,Tree Attention 将解码速度提升了高达4倍。
4. 普适性:该算法可广泛应用于各种硬件和网络设置,包括 H100 DGX 节点、AMD MI300x 节点以及通过 PCIe 连接的 NVIDIA RTX 4090,证明了其通用性和实用性。

图1:环形与树状注意力拓扑结构
图1:环形与树状注意力拓扑结构

图1:由于树状注意力(图1(a))中logsumexp和max操作的结合律特性,可以将跨序列的归约构造为一棵树,这比环形注意力(图1(b))需要更少的通信步骤,并且内存和通信量也更少。

A3 背景知识与理论基础

相关工作

自注意力机制的计算挑战与现有解决方案。自注意力机制由【76, Ashish Vaswani, Noam Shazeer, Niki Parmar, et al. Attention is all you need. Advances in neural information processing systems, 2017】引入,其对序列长度的二次方依赖性($O(N^2 \cdot d)$)在处理长序列时带来了巨大挑战。为解决此问题,研究者提出了多种注意力近似机制,如使用低秩投影和核化近似的 Linformer【78, Sinong Wang, Belinda Z. Li, Madian Khabsa, et al. Linformer: Self-attention with linear complexity. CoRR, 2020】和 Performer【15, Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, et al. Rethinking attention with performers. CoRR, 2020a】,它们将复杂度降低到线性 $O(N)$。此外,稀疏模型如 Longformer【8, Iz Beltagy, Matthew E. Peters, and Arman Cohan. Longformer: The long-document transformer. CoRR, 2020】和 BigBird【81, Manzil Zaheer, Guru Guruganesh, Avinava Dubey, et al. Big bird: Transformers for longer sequences. CoRR, 2020】通过将注意力限制在局部窗口或稀疏模式上,显著减少了资源需求。然而,这些方法都是对注意力机制的近似,而本文旨在并行化精确的注意力计算。

理论进展与精确注意力并行化。理论研究方面,基于核的方法【75, Yao-Hung Hubert Tsai, Shaojie Bai, Makoto Yamada, et al. Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel. CoRR, 2019】为自注意力提供了计算上更高效的替代公式。同时,【22, Feyza Duman Keles, Pruthuvi Mahesakya Wijewardena, and Chinmay Hegde. On The Computational Complexity of Self-Attention. arXiv e-prints, 2022】的研究表明,除非强指数时间假说(SETH)为假,否则实现亚二次时间复杂度的自注意力是不太可能的。在精确注意力计算的并行化方面,FlashAttention【20, Tri Dao, Dan Fu, Stefano Ermon, et al. Flashattention: Fast and memoryefficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 2022】通过重组注意力计算,利用GPU内存层级结构,实现了单设备上更快、内存效率更高的精确注意力计算,将内存复杂度从二次降为线性。其他技术则通过优化的矩阵运算和分块策略在核心或线程间分发计算【67, Zhuoran Shen, Mingyuan Zhang, Haiyu Zhao, et al. Efficient attention: Attention with linear complexities. In WACV, 2021】。本文的工作与 Ring Attention【52, Hao Liu, Matei Zaharia, and Pieter Abbeel. Ring Attention with Blockwise Transformers for NearInfinite Context. arXiv e-prints, 2023】最为相关,因为它同样旨在跨多设备并行化精确注意力。据我们所知,目前尚无其他技术像本文一样探索多设备并行解码。

自注意力机制

自注意力的核心计算。自注意力操作可以看作是查询(queries)和键(keys)之间的一系列点积相似性搜索。这些相似性得分随后沿序列轴进行归约和 softmax 操作,从而为每个查询生成一个关于所有键的相似性概率分布。最终,根据这个分布计算值(values)向量的期望。对于长度为 $N$ 的序列,我们表示查询为 $\{q_i, i = 1, \dots, N\}$,其中每个查询 $q_i \in \mathbb{R}^d$ 是一个大小为 $d$ 的向量($d$ 代表隐藏维度),同样地,键和值为 $\{(k_j, v_j), j = 1, \dots, N\}$。注意力可以写作:

内存高效的注意力计算。直接计算注意力需要实例化一个 $N \times N$ 的矩阵,其计算和内存成本在序列长度上是二次的。内存高效的注意力【62, Markus N. Rabe and Charles Staats. Self-attention Does Not Need O(N^2) Memory. arXiv e-prints, 2021】是一种迭代计算 softmax 相似性的方法,无需实例化完整的注意力矩阵。它一次处理一个查询(或一小块查询),执行以下操作:




一旦计算出值 $o$ 和 softmax 分母 $l$,就可以通过相除得到最终的 softmax 得分 $A(q_i) = o(q_i) / l(q_i)$。这种迭代方式显著降低了所需内存。

Flash Attention。【20, Tri Dao, Dan Fu, Stefano Ermon, et al. Flashattention: Fast and memoryefficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 2022】采用类似方法来降低注意力的内存和计算成本,但其算法未针对多GPU计算进行调整。Flash Attention 以分块方式执行【62, Markus N. Rabe and Charles Staats. Self-attention Does Not Need O(N^2) Memory. arXiv e-prints, 2021】的迭代算法,利用了单个 GPU 张量核心内可用的块并行计算原语。此外,它精确地调整块的大小,使其能够容纳在 GPU 的 SRAM 中完成整个注意力计算,从而有效地实现了核函数融合,避免了许多不必要的 IO 操作。

作为能量函数梯度的自注意力

将注意力与能量模型联系起来。在 Transformer 架构取得巨大成功后,学术界投入大量精力从数学上理解注意力操作的本质,并将其与能量模型(如 Hopfield 网络)联系起来【46, Dmitry Krotov and John J Hopfield. Dense associative memory for pattern recognition. Advances in neural information processing systems, 2016】、【45, Dmitry Krotov. Hierarchical Associative Memory. arXiv e-prints, 2021】、【53, Beren Millidge, Tommaso Salvatori, Yuhang Song, et al. Universal hopfield networks: A general framework for single-shot associative memory models. In International Conference on Machine Learning, 2022】、【33, Benjamin Hoover, Duen Horng Chau, Hendrik Strobelt, et al. Dense associative memory through the lens of random features. In The Thirty-eighth Annual Conference on Neural Information Processing Systems, 2024】。【63, Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, et al. Hopfield Networks is All You Need. arXiv e-prints, 2020】开创性地将自注意力与现代 Hopfield 网络关联,提供了新颖的视角,认为自注意力是使用高能非线性相似性函数执行异构联想记忆查找。这项工作后被【32, Benjamin Hoover, Yuchen Liang, Bao Pham, et al. Energy Transformer. arXiv e-prints, 2023】扩展,他们基于能量函数推导出了一个修改版的 Transformer。然而,尽管 softmax 操作可以从标量函数 log-sum-exp 的梯度推导出来是众所周知的:


但自注意力模块的等效函数尚未被推导出来。本文通过引入一个辅助的源向量 $s$ 来建立注意力与能量函数之间的联系,该向量代表系统的“外部贡献”【34, J J Hopfield. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences, 1982】。源向量 $s$ 是我们计算标量能量函数梯度以获得自注意力操作的参数。我们需要这个源向量来写出分布矩的生成函数,因为对 $s$ 求导可以得到精确的自注意力操作。

核心观察。这一洞察使我们得出以下观察:
观察 1. 注意力可以表示为标量能量函数 $E(s)$ 相对于源 $s$ 的梯度,即:


其中,矩生成函数(即能量函数)$E(s)$ 定义为:

观察 1 的证明见附录 C.1。此公式也为附录 C.2 中的注意力贝叶斯解释提供了基础,并启发了下一节的树状注意力算法。

A2 方法细节

本节将展示如何将注意力操作表述为能量函数的梯度,从而提出一种高效的并行计算策略。关键思想是利用一种高效的算法来计算能量,然后对其进行微分,以获得计算注意力的高效算法。

5.1 高效的能量函数计算

解码场景下的能量函数。我们关注在因果语言模型中使用 KV 缓存进行解码的场景,此时我们有一个查询(query)和 $N$ 个键(keys)与值(values)。在这种情况下,能量函数为:

结合律的应用。一个关键事实是 logsumexpmax 都是结合运算:



我们可以证明,只要有足够多的并行工作单元,这种结合律性质就允许这些归约操作以对数时间复杂度高效地并行执行:

定理 1。使用 $P$ 个并行处理器对大小为 $N$ 的数组执行涉及结合函数(如 logsumexpmax)的归约操作,其时间复杂度为 $O(N/P + \log P)$。当处理器数量 $P$ 等于 $N$ 时,时间复杂度降为 $O(\log N)$。
定理 1 的证明在附录 E 中。

并行能量计算算法(算法1)。综合以上结果,对于块内索引 $\hat{i}, \hat{j} \in \{1, \dots, C\}$,我们得到以下高度并行的算法1:

5.2 高效的并行解码

自动微分与梯度计算。自动微分的核心思想之一是,计算函数梯度 $\nabla_x f(x)$ 的时间复杂度与计算 $f(x)$ 本身相同【77, Tim Vieira. Evaluating ∇ f(x) is as fast as f(x), 2016】。不过,如果函数的计算图很深,计算梯度的内存占用会随着深度增加而增长,因为反向传播需要存储中间张量的值。在我们的情况中,计算能量所涉及的计算图很浅,因此内存开销可以忽略不计。这意味着如果我们能高效地计算能量,就能自动获得一个计算其梯度(即自注意力操作)的高效算法。

手动实现梯度计算。在我们的场景中,我们希望计算能量函数相对于 $s_i$ 的梯度,然后将其设为零。这可以通过自动微分引擎实现,从一开始就将 $s$ 设置为零张量。不过,我们也可以手动实现上述算法1中对 $s$ 的梯度传递,如下面的算法2所示,该算法不实例化 $s$。特别要注意,当我们设置 $s_{i,j}=0$ 对所有 $j \in \{1, \dots, d_h\}$ 时,$lse$ 只涉及查询和键之间点积的 logsumexp

合并算法。通过为反向传播存储 $lse, m$,唯一需要执行的归约操作是上述算法第5行中的求和。这个单一归约在每个设备上计算局部和需要 $O(N/P)$ 时间,通信和合并部分结果需要 $\log P$ 时间,因此我们得到了与 logsumexp 计算相同的渐进复杂度。在实践中,我们将前向和对 $s$ 的梯度计算实现在一个函数中,该函数同时返回能量函数的值和梯度。因此,我们可以将算法1和算法2合并为以下高效的并行解码算法3:

5.3 使用拓扑感知的高效集合操作

通信开销。虽然上述理论分析表明,使用基于树的归约应该能看到加速,但在实践中由于各种潜在开销,这并不一定能得到保证。特别是,我们对所提出的树状解码算法时间复杂度的论证假设了部分结果的通信是瞬时的,这在实践中永远不会发生。事实上,随着我们扩展序列长度或GPU数量,特别是在多节点设置中,通信所花费的时间是总执行时间的主要贡献。然而,重要的是,除了其渐进优势外,树状注意力还受益于利用现代GPU集群中标准的双层拓扑结构。

与Ring Attention的对比。我们将我们的算法与之前提出的序列并行注意力算法 Ring Attention 进行了基准测试。与我们的算法一样,Ring Attention 假设序列在GPU之间分片,并在不将所有序列收集到单个设备上的情况下执行注意力计算。相反,它在逻辑上排列成环形拓扑的相邻GPU之间以点对点方式通信键和值的分片。这种通信与本地分片输出的计算重叠。与此策略相反,我们的算法在执行 AllReduce 操作时将查询分散并跨所有GPU通信部分结果,但不在GPU之间移动键和值的分片。因此,在解码情况下,我们的方法受益于较低的通信量,并且比 Ring Attention 更少地受到通信成本开销的影响。

NCCL Send/Recv 延迟
NCCL Send/Recv 延迟

图2:两个H100 GPU在节点内和节点间进行NCCL Send/Recv的延迟。GPU集群提供了双层拓扑,其中节点内带宽显著高于节点间带宽。像树状注意力这样的算法通过减少节点间通信需求来利用这种拓扑,从而实现通信与计算更好的重叠。

网络带宽层次结构的影响。Ring Attention 本质上不是拓扑感知的,只能在同构带宽的网络中扩展。然而,这与现代GPU集群的双层网络拓扑相冲突,后者在节点内使用高带宽互连(NVLINK或PCIe),而在节点间使用相对较低带宽的互连(InfiniBand或Ethernet)。这些互连在带宽和延迟上差异很大(见图2)。因此,Ring Attention 受限于最慢的互连,并且不能总是将注意力计算与通信重叠。我们在6.3节进一步讨论这一点。树状注意力通过使用网络拓扑感知的通信模式来改进 Ring Attention,以增加计算和通信的重叠,并减少分布式注意力计算中通信的可伸缩性瓶颈。

实践中的集合通信库。在实践中,像 NCCL 这样的集合通信库会尝试根据数据量和网络拓扑等因素自动检测正确的通信策略。在 DGX 集群中,对于节点内的集合操作,执行环形归约,而跨节点则执行树状归约。因此我们看到,使用内置的集合操作(如 Allreduce)比强制执行 Ring Attention 的点对点通信模式,在跨多个GPU进行长上下文解码时能获得更好的性能。我们展示了以下策略在跨多个GPU从非常长的上下文中解码时如何优于 Ring Attention。在我们的实证实验中,我们在每个设备内部都使用 Flash Attention 2【18, Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023】,这既适用于我们的算法,也适用于 Ring Attention。我们在附录D中提供了一个简单的JAX实现。请注意,我们的方法与 Flash Decoding【2, Flash-decoding for long-context inference, 2024】类似,不同之处在于后者的并行化发生在GPU内的不同流式多处理器(SM)级别,而我们在不同GPU之间进行并行化。所有计算均以 BF16 格式执行。

A4 实验

实验环境

  • 硬件配置
    • 主集群:一个 DGX H100 集群,包含16个节点,每个节点有8个 H100 GPU。节点内 GPU 通过全连接的 NVLINK 4.0(900GB/s)连接。节点间通过每个节点8个 InfiniBand NDR 互连(每个GPU一个),每个提供400 Gbps带宽。
    • 其他平台:8个 AMD MI300X GPU(节点内使用 AMD Infinity Fabric,节点间使用 RoCE)以及2个通过 PCIe 连接的 RTX 4090 GPU。
  • 软件配置
    • 实现:使用 JAX 实现了 Tree Attention 算法。
    • 依赖库:在设备内部,Tree Attention 和 Ring Attention 的实现都使用了 Flash Attention 2 进行注意力计算。集合通信依赖于 NCCL 库。
    • 精度:所有计算均使用 BF16 浮点格式。
  • 模型架构
    • 微基准测试:使用一个标准的注意力块,包含16个头,每个头的维度为128。
    • 端到端测试:使用 Llama 3.1 8B 模型和 Llama 3.2 1B 模型。
  • 数据集与任务
    • 微基准测试:评估不同序列长度(从80k到5.12M)下的解码延迟、峰值内存和通信量。
    • 端到端测试:在不同长度(32k, 64k, 128k, 256k)的提示序列上,进行带预填充(prefill)的10个 token 解码,测量端到端吞吐量。

实验结果

与 Ring Attention 类似,Tree Attention 是对注意力的精确计算,因此训练和评估指标与标准注意力相同。实验结果主要关注延迟、峰值内存使用和通信量。

6.1 延迟

实验内容:通过测量在不同序列长度和不同数量 H100 节点上执行解码所需的时间来评估 Tree Attention 的延迟,并与 Ring Attention 进行比较。
实验结果:如图3所示,随着序列长度的增加(图3(a))或 GPU 数量的增加(图3(b)),Tree Attention 相对于 Ring Attention 的性能优势愈发明显。Tree Attention 的执行时间随着 GPU 数量的增加趋于平缓,而 Ring Attention 的执行时间则持续增长。当使用128个 GPU 处理5.12M 长度的序列时,Tree Attention 实现了接近8倍的速度提升,显示出其卓越的渐进可扩展性。

图3:16头 Tree Attention 与 Ring Attention 在不同规模GPU集群(1到16个H100 DGX节点)上的执行时间
图3:16头 Tree Attention 与 Ring Attention 在不同规模GPU集群(1到16个H100 DGX节点)上的执行时间

图3:相对执行时间以80k tokens序列长度下的Ring Attention时间为基准。

6.2 内存成本

实验内容:理论分析和实证测量 Tree Attention 和 Ring Attention 的峰值内存使用情况。Ring Attention 需要额外存储来自邻居设备的 $k', v'$ 分片和部分输出 $o$,而 Tree Attention 只需存储部分分子、分母和最大值。
实验结果:如图4所示,Tree Attention 的内存使用显著低于 Ring Attention。理论预测,只要 $2 \cdot B \cdot d_h \le 2 \cdot C \cdot d_h$(在实际场景中几乎总是成立),Tree Attention 的峰值内存就更低。实证结果也证实,Ring Attention 的峰值内存增长速度大约是 Tree Attention 的2倍。


图4:在两个RTX 4090之间分片时,单个注意力块使用Tree Attention与Ring Attention的峰值内存使用情况。结果是使用JAX内存分析器在一个GPU上测得的。峰值内存的差异随隐藏大小和序列长度而变化。

6.3 通信量

实验内容:理论分析 Ring Attention 的点对点通信策略和 Tree Attention 的 Allreduce 策略所需的通信数据量。
实验结果:理论分析表明,Ring Attention 的通信量与序列长度 $N$ 成正比($V_{Ring} = (P-1) \frac{2N d_h}{P}$),而 Tree Attention 的通信量与序列长度无关,仅与设备数 $P$ 和头数 $n_h$、头维度 $d_h$ 相关($V_{Tree} = 2 \frac{P-1}{P} (3 B n_h d_h)$)。在解码场景中,由于单GPU上的注意力计算速度远快于设备间通信延迟,Ring Attention 难以通过计算与通信重叠来隐藏其高昂的通信成本,而 Tree Attention 的低通信量优势更为突出。

6.4 Llama Transformer 模型性能

实验内容:在 Llama 3.1 8B 模型上进行端到端吞吐量测试,比较 Tree Attention 和 Ring Attention 在带预填充阶段解码10个 token 时的性能。实验在8个 H100 GPU(NVLink连接)和4个 MI300X GPU(AMD Infinity Fabric连接)上进行。
实验结果:如表1所示,在所有测试配置下,Tree Attention 的解码延迟都显著低于 Ring Attention。在8个 H100 上速度提升高达4倍,在4个 MI300X 上提升高达3倍。这些结果表明 Tree Attention 能很好地泛化到不同类型的系统、GPU数量、通信协议和网络拓扑中,具有很强的实际应用价值。
表1:使用8B Llama 3.1模型,在不同序列长度和GPU类型下,Tree Attention(本文)和Ring Attention(SOTA)在带预填充阶段的平均解码时间(秒)比较。平均结果和标准误差(±)由10次试验计算得出。

A5 结论

本文成功地推导出了自注意力机制的能量函数,并基于此函数梯度的计算,提出了一种新颖高效的并行注意力计算方法——Tree Attention。该方法在多设备解码场景下表现尤为突出,其算法具有渐进优势,在128个GPU、5.12M序列长度的实验中,实现了比当前最先进的 Ring Attention 高达8倍的速度提升。此外,Tree Attention 使用的 AllReduce 操作通过传递部分归约结果,极大地减少了通信数据量和峰值内存需求。在Llama 3.1模型的实际应用中,Tree Attention 在带预填充的解码阶段相比 Ring Attention 获得了3-5倍的加速。通过在包括 AMD MI300x 在内的多种GPU集群上进行测试,本文证明了 Tree Attention 能够很好地泛化到各种通信协议和网络拓扑中,展现了其广泛的适用性和巨大的应用潜力。

A6 附录

A 更多相关工作

贝叶斯重构与Hopfield网络。近期有工作尝试对注意力进行贝叶斯重构,通过推导一个概率生成模型来匹配自注意力操作【68, Ryan Singh and Christopher L Buckley. Attention: Marginal probability is all you need? arXiv preprint arXiv:2304.04556, 2023】。这延续了将自注意力与 Hopfield 网络架构联系起来的研究路线【63, Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, et al. Hopfield Networks is All You Need. arXiv e-prints, 2020】。其思想是,虽然 Hopfield 网络是根据能量景观上的动力学定义的,但通过将能量函数等同于变分自由能泛函,可以将其转化为贝叶斯解释,从而推导出自注意力更新中隐含的生成模型。

与现有能量函数方法的区别。在【32, Benjamin Hoover, Yuchen Liang, Bao Pham, et al. Energy Transformer. arXiv e-prints, 2023】中提出的能量函数是 logsumexp,但由于其更新规则中的梯度是相对于块的输入计算的,因此得到的函数是自注意力操作的修改版本。类似地,【63, Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, et al. Hopfield Networks is All You Need. arXiv e-prints, 2020】中的更新规则要求绑定某些权重(K和V),这限制了 Hopfield 推导只能模拟自联想检索网络,而 Transformer 的注意力是异构联想的。

基于结合律的并行化思想。另一个值得注意的相关工作是【23, Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, et al. Attention as an RNN. arXiv e-prints, 2024】,作者们在我们第5节中也做出了类似的观察,即注意力计算中的结合律操作可以被高效地并行化,并以此为基础提出了一种基于注意力的改进 RNN 架构用于序列建模。

本文方法的独特性。尽管这个能量函数本身主要是一个数学和理论上的发现,但我们展示了当与自动微分结合时,我们的公式自然地导出了一种高效的并行算法,用于计算注意力和执行解码,尤其是在多设备上。

B 树状归约操作的更多背景

树状归约的定义与工作原理。树状归约(tree reduction)是一种分层策略,用于高效地对一组数据元素执行归约操作(如求和、乘积、最大值、最小值),尤其是在并行计算中。这种方法降低了整体计算复杂性,并能有效利用并行处理资源。其工作原理如下:
* 分解问题:将输入数据分成更小的块,并在这些块中的相邻元素之间成对地执行归约操作。
* 形成树状结构:第一级归约的结果在下一级中再次成对归约。这个过程持续进行,直到整个数据集被归约为单个结果。
* 迭代或递归聚合:聚合通常遵循二叉树模式,但也可以使用其他扇入数(例如,k-ary 树)。树中的每个节点代表一个部分归约结果,树的根节点持有最终结果。

树状归约的优势。由于树结构的总节点数与深度呈对数关系,当可以聚合部分结果时,树状归约可以从渐进上减少执行操作所需的总步数。此外,它也易于并行化,因为可以定义k-ary树来匹配可用的并行处理器数量。许多现有的网络拓扑,如Nvidia的NVLINK和Infiniband,由于树结构的天然优势,其设计本身就采用了这种拓扑结构,这意味着树状操作是自然且高效的。

C 作为能量函数梯度的注意力及其贝叶斯解释

C.1 观察1的证明

引入源向量s。我们展示自注意力操作如何可以写成一个能量函数的梯度。具体来说,我们定义一个标量函数,它依赖于键、查询、值,以及一个我们称之为源 $s$ 的辅助向量。源是我们计算该标量函数梯度以获得自注意力操作的参数。我们需要源来写出分布矩的生成函数。它也是我们可以用来对生成函数进行泰勒展开并提取矩作为 $s$ 的单项式系数的变量。我们想找到一个函数 $E(Q, K, V, s)$ 使得:


这个术语受到机器学习中能量模型工作的启发【7, Matthew James Beal. Variational algorithms for approximate Bayesian inference. University of London, 2003】、【48, Yann LeCun, Sumit Chopra, Raia Hadsell, et al. A tutorial on energy-based learning. Predicting structured data, 2006】、【69, Yang Song and Diederik P Kingma. How to train your energy-based models. arXiv preprint arXiv:2101.03288, 2021】。变量和索引的摘要在附录G中提供。

能量函数作为累积量生成函数。我们首先展示能量函数如何由与注意力分数给出的分布相关联的累积量生成函数给出。受统计力学的启发,其中一个类似的累积量生成函数定义了亥姆霍兹自由能【47, L. D. Landau and E. M. Lifshitz. Statistical Physics. Reading, MA: Addison-Wesley, 1958】,我们将我们的累积量生成函数称为自注意力的能量函数。

单查询情况的推导。让我们关注单个查询的情况。如上所述,我们利用注意力操作可以被看作是在由注意力分数 $A_j$ 设定的分布中计算向量 $v_j$ 的期望值:


概率密度由下式给出:

通常,分母或归一化因子被识别为所谓的配分函数:

现在,我们可以通过引入一个源 $s \in \mathbb{R}^{d_h}$ 来计算上述概率分布的一阶矩。在我们的例子中,有了 $s$,我们可以将配分函数扩展为:

现在,我们可以将分布的任意矩计算为 $Z(s)$ 的第 $n$ 个泰勒系数,对于所有的 $j_1, j_2, \dots \in \{1, \dots, d_h\}$:

换句话说,我们可以将 $Z(s)$ 写成:

因此,一阶矩可以写成:

这可以写成 $\log Z(s)$ 的梯度:

这个量就是生成函数,也就是自由能:

因果自注意力与多头注意力的推广。为了计算因果自注意力,我们引入 $N$ 个源 $s_i \in \mathbb{R}^{d_h}$ 并取:


内部和截断到索引 $i$ 是由于因果掩码。现在,为了计算因果自注意力的第 $i$ 个元素,我们对 $s_i$ 微分并将其设为零:

推广到多头注意力的情况是直接的。在这种情况下,每个头都有一个键、查询和值。对于总共 $n_h$ 个头,生成函数的形式为:

其中

输出投影权重包含在 $v_{\bar{j}}^{\bar{h}}$ 的定义中,意味着

其中 $W_O^{\bar{h}} \in \mathbb{R}^{d_h} \times \mathbb{R}^{d_{model}}$ 表示输出投影权重的一个头大小的切片,$\bar{j} \in \{1, \dots, d_h\}$ 跨越头内索引。在上面的索引表示法中,头索引带上划线,而嵌入空间索引不带。我们接下来将重点放在单头情况,因为这使得表述更简单,而多头的推广是直接的。请注意,我们证明了我们的能量函数方法也可以解释安全 softmax,见附录F。

C.2 贝叶斯解释

通过最大后验估计推导自注意力。能够将自注意力操作推导为能量函数的最小化,意味着可以通过识别一个似然函数,并证明我们可以从计算该似然的最大后验(MAP)估计中获得注意力块的前向传播,从而为自注意力提供一个贝叶斯解释。具体来说,我们为对数似然函数提出以下形式:


我们用 $x$ 表示自注意力块的输入,从中我们通过乘以权重 $W_Q, W_K, W_V$ 分别得到 $q, k, v$。让我们同时对 $o$ 和 $v$ 最小化上述函数:

这些条件明确写为:

将第一个条件代入第二个条件,得到注意力的前向传播:

总而言之,这意味着我们可以从以下似然的 MAP 估计中获得关于 $x$ 的梯度:

此外,这样的过程使我们能够识别与自注意力函数相关联的基于能量的模型。

C.3 Llama Transformer 模型更多性能结果

在消费级硬件上的性能。为了扩展我们在第6.4节的工作,并证明 Tree Attention 可以成功应用于一系列硬件设置,我们还实验在双 NVIDIA RTX 4090 配置上运行 Llama3.2-1B。这两块 4090 通过 PCIe 网络连接。即使在这种情况下,我们观察到 Tree Attention 在自回归解码方面相比 Ring Attention 有显著的4倍加速(在更长的序列长度下增长到5倍)。
表2:使用1B Llama 3.2模型,在4090s上不同序列长度下,Tree Attention(本文)和Ring Attention(SOTA)在带预填充阶段的平均解码时间(秒)比较。平均结果和标准误差(±)由10次试验计算得出。

D 附录:JAX 代码

以下是 tree_flash_decode 方法。我们的完整代码库可在此处获取:https://github.com/Zyphra/tree_attention


该函数使用 Flash Attention 2【18, Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023】计算局部 numerators 和 denominators,这两者都使用 Allreduce(即 psum 和 pmax 的调用)在设备之间累积。NCCL 决定了这些结果以何种模式进行通信。

E 定理1证明

我们在此证明定理1。
证明
顺序情况:在单个GPU上,对大小为 $N$ 的数组进行归约操作的时间复杂度为 $O(N)$,因为处理器必须顺序处理每个元素。
使用P个处理器的并行处理:将大小为 $N$ 的数组分成 $P$ 个块,每个块的大小为 $N/P$。每个处理器独立地对其块执行归约操作。每个处理器的时间复杂度为 $O(N/P)$。
合并部分结果:来自 $P$ 个处理器的部分结果需要被合并。使用树状模式进行归约,部分结果可以在 $O(\log P)$ 步内完成。每一步都涉及合并成对的结果,每一步结果数量减半,直到只剩下一个结果。
总时间复杂度:总时间复杂度是处理块和合并结果的时间复杂度之和:


这证明了当使用 $P$ 个并行处理器时,对大小为 $N$ 的数组执行涉及结合律操作的归约的时间复杂度为 $O(N/P + \log P)$,当处理器数量等于数组大小时,复杂度降为 $O(\log N)$。□

F 计算安全 Softmax

数值稳定性问题。虽然从数学上讲,注意力利用了 softmax 操作,但在实践中,使用相对较低精度的操作时,这通常在数值上是不稳定的。为了解决这个问题,通常使用一个数学上等价的函数,即“安全 softmax”,它从指数中的所有点积中减去最大值。这确保了所有被求幂的值都小于1,从而减少了爆炸和导致数值不稳定的可能性。

能量函数与安全Softmax的兼容性。在这里,我们证明了我们的能量函数方法也可以解释安全 softmax。
假设我们比较我们的生成函数:


和一个稍微修改过的版本:

当我们对这两个量求导时,我们看到我们得到了相同的结果:

具体来看:

通常,在在线计算 softmax 时,会执行这个过程,其中 $m_i$ 是 $Q \cdot K_i^T$ 的行最大值。这种平移使得指数和不会导致溢出。

G 方程符号说明

以下是将在前面部分中使用的各种变量和索引的摘要:
表一:变量名


表二:索引名称和范围