LASP-2: Rethinking Sequence Parallelism for Linear Attention and its Hybrid

作者/机构: Weigao Sun 1, Disen Lan 1 2, Yiran Zhong 1, Xiaoye Qu 1, Yu Cheng 3

A1 主要贡献

核心问题: Transformer模型的核心自注意力机制具有关于序列长度的二次复杂度,导致训练成本高昂。虽然线性注意力等方法将复杂度降至线性,但现有的序列并行(Sequence Parallelism, SP)方法并未充分利用线性注意力的“右乘优先”特性,或者采用环式(ring-style)通信策略,这导致了较低的计算并行度,限制了其在分布式系统中处理更长序列时的可扩展性。

研究目标: 本文旨在提出一种新的SP方法LASP-2,以提升训练具有超长输入序列的线性注意力Transformer模型时的通信和计算并行性。同时,将该方法扩展到包含线性和标准注意力层的混合模型中。

创新点:
* 重构通信与计算工作流: 本文重新思考了线性注意力上SP的最小通信需求,并重组了整个通信与计算工作流。LASP-2采用单一的AllGather集体通信操作来处理中间内存状态,这些状态的大小与序列长度无关。这种设计显著提升了SP在线性注意力上的通信和计算并行性。
* 提出LASP-2H以支持混合模型: 将LASP-2的设计理念扩展,提出了LASP-2H,为同时包含线性和标准注意力层的混合模型提供了一个高效的SP解决方案。该方案采用统一的基于AllGather的通信设计。
* 构建并验证Linear-Llama3模型: 本文构建了一系列Linear-Llama3模型,包括纯线性和混合版本。通过在序列长度高达2048K的实验中,验证了LASP-2和LASP-2H的效率提升和性能。实验结果表明,在64个GPU上处理2048K序列长度时,LASP-2的训练速度比LASP提升了15.2%,比Ring Attention提升了36.6%。

A3 背景知识

符号表示
本文确保使用一致的符号以提高清晰度。表1提供了全文使用的所有符号的完整列表,包括索引、常量、向量和矩阵。向量和矩阵以粗体表示。为简化起见,我们省略了张量形状中与批大小和头数相关的维度。

Table 1: 符号表。论文中使用的索引、操作、常量、向量和矩阵。

线性注意力
“注意力”一词通常指一种计算过程,它为序列中的位置对分配分数,使每个元素能够“关注”其他元素。该机制最广泛使用且最重要的变体是softmax自注意力,它是标准Transformer模型的核心【53, Vaswani, A. et al., Attention is all you need, 2017, Advances in neural information processing systems】。在训练期间,为简单起见假设只有一个注意力头,softmax自注意力计算如下:

$$\begin{aligned} \begin{aligned} \mathbf{Q}, \mathbf{K}, \mathbf{V} & =\mathbf{X} \mathbf{W}_Q, \mathbf{X} \mathbf{W}_K, \mathbf{X} \mathbf{W}_V, \\ \mathbf{O} & =\operatorname{Softmax}\left(\mathbf{Q K}^{\top}\right) \mathbf{V} . \end{aligned} \end{aligned}$$

由物化$QK^\top$引起的成对比较机制导致了softmax自注意力特有的二次训练成本。最近,线性注意力【15, Katharopoulos, A. et al., Transformers are RNNs: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】、【39, Shen, X. et al., Scaling laws for linear complexity language models, 2024】、【28, Qin, Z. et al., Unlocking the secrets of linear complexity sequence model from a unified perspective, 2024a】作为softmax自注意力的潜在替代方案而受到关注,它有两个关键区别。首先,它移除了Softmax(·)操作,将其整合到一个核特征图中。其次,它利用矩阵乘法的结合律将$(QK^\top)V$重构为$Q(K^\top V)$。

这些调整将注意力计算的计算和内存复杂度从$O(N^2d)$降低到$O(Nd^2)$。这种技术通常被称为右乘核技巧,因为它优先执行右侧的乘法。

在推理过程中,softmax自注意力和线性注意力在每次迭代中都处理单个token。给定第$s$个token $x_s \in R^{1 \times d}$,softmax自注意力的计算需要存储一组不断扩大的键$\{k_1, \dots, k_s\}$和值$\{v_1, \dots, v_s\}$,即“KV缓存”,这在处理长输入序列时带来了巨大的内存负担。在线性注意力中,研究人员尝试使用各种非线性核来替代公式2中的$exp(\cdot)$函数。

$$\begin{aligned} \begin{aligned} \mathbf{q}_s, \mathbf{k}_s, \mathbf{v}_s &= \mathbf{x}_s \mathbf{W}_Q, \mathbf{x}_s \mathbf{W}_K, \mathbf{x}_s \mathbf{W}_V, \\ \mathbf{o}_s &= \frac{\sum_{i=1}^s \exp(\mathbf{q}_s \mathbf{k}_i^\top) \mathbf{v}_i}{\sum_{i=1}^s \exp(\mathbf{q}_s \mathbf{k}_i^\top)}. \end{aligned} \end{aligned}$$

然而,最近的研究【48, Sun, Y. et al., Retentive network: A successor to transformer for large language models, 2023】、【56, Yang, S. et al., Gated linear attention transformers with hardware-efficient training, 2023】、【30, Qin, Z. et al., Various lengths, constant speed: Efficient language modeling with lightning attention, 2024c】发现,使用线性核(即使用恒等函数)且不带归一化分母在实践中效果很好。这产生了一种非归一化的线性注意力形式,如下所示:

$$\mathbf{o}_{s}=\sum_{i=1}^{s} \mathbf{q}_{s}(\mathbf{k}_{i}{ }^{\top} \mathbf{v}_{i})=\mathbf{q}_{s} \sum_{i=1}^{s}(\mathbf{k}_{i}{ }^{\top} \mathbf{v}_{i})=\mathbf{q}_{s} \mathbf{M}_{s},$$

其中$M_s = \sum_{i=1}^{s} k_i^\top v_i$是$k_i^\top v_i$从$i=1$到$s$的前缀和,这在线性注意力中也被称为内存状态。这种重构使得线性注意力具有一种循环结构,类似于RNN的行为:

$$\mathbf{M}_s = \mathbf{M}_{s-1} + \mathbf{k}_s^\top \mathbf{v}_s, \quad \mathbf{o}_s = \mathbf{q}_s \mathbf{M}_s.$$

A2 方法细节

3.1. 无掩码的LASP-2

序列并行方法概述
SP方法通过将长输入序列划分为几个较小的块,然后将这些块分发到多个计算设备上。每个设备独立地并行处理其分配到的块的查询、键和值。为了完成整个序列的注意力计算,需要执行必要的通信步骤,以收集所有设备的结果或在它们之间交换信息。LASP【46, Sun, W. et al., Linear attention sequence parallelism, 2024a】是一种专为线性注意力模块设计的序列并行技术。

分布式计算设置
让我们考虑一个分布式计算设置,其中有$W$个设备,输入序列被划分为$T$个块,称为序列并行大小。通常情况下,$T$可以被$W$整除,我们经常假设$W=T$。这意味着每个块被分配给一个设备,确保每个块在分布式系统中并行处理。这种情况是纯粹的序列并行。此外,在附录A.4.1节,我们将探讨$W \neq T$的情况,这代表了一种结合序列并行和数据并行的混合方法。

LASP-2并行计算流程
在LASP-2中,输入序列$X$被划分为$T$个较小的块,表示为$[X_t]_1^T$,每个块被分发到分布式系统中的设备上。对于每个块$X_t$,其对应的查询、键、值和线性注意力内存状态可以在所有块上并行计算。这种并行计算如下进行:

$$\begin{aligned} \begin{aligned} \mathbf{Q}_{t}, \mathbf{K}_{t}, \mathbf{V}_{t} & =\mathbf{X}_{t} \mathbf{W}_{Q}, \mathbf{X}_{t} \mathbf{W}_{K}, \mathbf{X}_{t} \mathbf{W}_{V}, \\ \mathbf{M}_{t} & =\mathbf{K}_{t}^{\top} \mathbf{V}_{t} . \end{aligned} \end{aligned}$$

通过对每个块执行这种并发计算,LASP-2在分布式环境中高效地处理长输入序列。查询$Q_t$、键$K_t$、值$V_t$和内存状态$M_t$都是为序列的每个块单独计算的,确保没有单个设备因处理整个序列而负担过重。这种分布式方法有助于更好地进行内存管理和提高计算效率,尤其是在处理极长序列时。因此,LASP-2利用序列分区的能力来优化分布式框架中线性注意力的计算。

通信与计算细节
值得注意的是,在LASP-2中,前向传递期间只需要一个单一的all-gather集体通信操作。这个all-gather操作作用于与每个序列块相关联的内存状态$[M_t]_1^T$,确保系统中的每个设备都能访问到完整的内存状态集$[M_t]_1^T$。一旦所有块的内存状态被收集起来,它们将在所有设备上并发累加以计算对应于整个输入序列的内存状态。这个过程表示如下:

Algorithm 1 LASP-2 w/o Masking

1: Input: input sequence X, distributed world size W , sequence parallel size T = W .
2: Distribute X = [Xt]T1 .
3: for chunk t ∈ {1, · · · , T } on ranks {1, · · · , W } in parallel do
4: Calculate Qt = XtWQ, Kt = XtWK, Vt = XtWV .
5: Compute Mt = K⊤t Vt.
6: Communicate [Mt]T1 = AllGather([Mt]T1 ).
7: Compute M1:T = Sum([Mt]T1 ).
8: Compute Ot = QtM1:T .
9: end for
10: return O = [Ot]T1 .

$$\begin{aligned} \begin{aligned} {[\mathbf{M}_t]_1^T} &= \text{AllGather}([\mathbf{M}_t]_1^T), \\ \mathbf{M}_{1:T} &= \text{Sum}([\mathbf{M}_t]_1^T). \end{aligned} \end{aligned}$$

本地输出计算
最后,与本地查询$Q_t$相对应的线性注意力输出可以计算为:

$$\mathbf{O}_t = \mathbf{Q}_t \mathbf{M}_{1:T}.$$

优化策略
重要的是,累加步骤$Sum([M_t]_1^T)$可以通过递归方式高效执行,即依次添加每个内存状态$M_{1:t-1} + M_t$。这避免了重复计算早期块内存状态总和的需求,提高了计算效率。为了进一步优化性能,我们将累加结果$M_{1:T}$缓存在高带宽内存(HBM)中。这种缓存策略通过避免在计算梯度时重新计算$M_{1:T}$,从而加速了反向传播过程。这种方法类似于激活检查点(activation checkpointing)的概念,即保存中间激活以避免重新计算。

通信成本分析
需要指出的是,每个内存状态$M_t$的维度是$d \times d$,这意味着all-gather操作的通信成本与序列或块的长度无关。相反,该成本与参与SP通信组的设备数量成线性关系。为清晰起见,我们在算法1中总结了不考虑注意力掩码的LASP-2方法。在反向传播过程中,需要对内存状态的梯度$dM_t$执行类似的all-gather通信操作。无掩码的反向传播细节可在附录A.1的算法3中找到以供参考。

3.2. 带掩码的LASP-2

自回归任务中的挑战
在自回归任务中,掩码矩阵$\Psi \in \{-\infty, 1\}^{N \times N}$通常是一个下三角矩阵,其中当$i \ge j$时$\Psi_{ij} = 1$,当$i < j$时$\Psi_{ij} = -\infty$。这种结构在计算中强制执行了因果约束。具体来说,在计算$O = Softmax(QK^\top \odot \Psi)V$时,无法利用矩阵乘法的结合律将计算复杂度从二次降为线性的并行形式。

计算分解方法
为了应对带因果掩码的线性注意力所面临的这一挑战,我们采用了计算分解的方法,该方法在早期工作中被提出【56, Yang, S. et al., Gated linear attention transformers with hardware-efficient training, 2023】、【46, Sun, W. et al., Linear attention sequence parallelism, 2024a】。图1展示了线性注意力中块内(intra-chunk)和块间(inter-chunk)计算的区别。块间计算不依赖于跨设备的其他块,可以被视为没有因果掩码。因此,这些计算可以在分布式设置中的所有设备上并行化。相比之下,块内计算考虑了先前块(1到t-1)对第t个块的影响。这些块内操作受到掩码矩阵的影响,因此需要专门处理以遵守因果约束。


图1:带掩码的LASP-2中的计算分解。彩色块表示块间(inter-chunks)。

块内(intra-chunk)计算
对于块内的线性注意力计算,给定与块$X_t$对应的查询、键和值矩阵$Q_t$、$K_t$和$V_t$,输出计算如下:

$$\mathbf{O}_{t, \text { intra }}=\left[\left(\mathbf{Q}_t \mathbf{K}_t^{\top}\right) \odot \boldsymbol{\Psi}\right] \mathbf{V}_t,$$

该公式遵循标准的左乘矩阵乘法。尽管计算可以在设备间并行执行,但在训练期间它仍然保留了传统注意力机制通常具有的二次复杂度。这个限制源于逐元素的掩码操作($\odot \Psi$),它在块内强制执行因果约束,从而阻止了使用能将计算成本降至线性的优化方法。

块间(inter-chunk)计算
对于跨块的线性注意力计算,我们遵循与无掩码的LASP-2相似的流程。首先,每个块的内存状态在不同设备上并发计算为$M_t = K_t^\top V_t$。这些对应于每个块的内存状态最初分布在不同的设备上。为了同步结果,执行一个AllGather集体通信操作。这一步确保所有设备都持有所有块的内存状态,从而能够进行进一步的并行处理。

前缀和与缓存
一旦内存状态被收集,我们将在所有设备上进行并发的PrefixSum操作。该操作累加从第1个块到第$(t-1)$个块的内存状态,有效地构建了必要的中间状态。这可以表示为:

$$\begin{aligned} \begin{aligned} {[\mathbf{M}_t]_1^T} & = \text{AllGather}([\mathbf{M}_t]_1^T), \\ \mathbf{M}_{1:t-1} & = \text{PrefixSum}([\mathbf{M}_t]_1^{t-1}). \end{aligned} \end{aligned}$$

PrefixSum操作可以通过递归实现来优化,利用存储在HBM上的缓存内存状态。具体来说,内存状态的累积计算如下:

$$\mathbf{M}_{1:t-1} = \mathbf{M}_{1:t-2} + \mathbf{M}_{t-1}.$$

通过缓存$M_{1:t-1}$,反向传播的计算得以简化,因为这个缓存值是计算梯度所必需的激活值。这种方法不仅加快了反向传播的速度,还减少了计算负载,因为缓存的内存状态消除了重复计算的需要。

最终输出计算
在计算完内存状态之后,可以轻松地推导出对应于块间部分的输出以及第$t$个token的最终输出。第$t$个token的整体输出是通过将块内和块间的输出相加得到的。

算法与优化
我们在算法2中提供了带掩码的LASP-2的完整算法,其反向传播过程在附录A.1的算法4中。需要注意的是,在算法2中,第7行的通信操作(洋红色)和第8行的$O_{t,intra}$计算(青色)可以通过在不同线程上执行来实现重叠。这种并发执行有助于提高整体效率,因为它允许通信和计算的重叠。

Algorithm 2 LASP-2 w/ Masking

1: Input: input sequence X, distributed world size W , sequence
parallel size T = W .
2: Distribute X = [Xt]T1 .
3: Initialize mask matrix Ψ, where Ψij = 1 if i ≥ j and Ψij =
−∞ if i < j.
4: for chunk t ∈ {1, · · · , T } on ranks {1, · · · , W } in parallel
do
5: Calculate Qt = XtWQ, Kt = XtWK, Vt = XtWV .
6: Compute Mt = (Kt)⊤Vt.
7: Communicate [Mt]T1 = AllGather([Mt]T1 ).
8: Compute Ot,intra = [(QtK⊤t ) ⊙ Ψ]Vt.
9: Compute prefix sum M1:t−1 = PrefixSum([Mt]t−11 ).
10: Compute Ot,inter = QtM1:t−1.
11: Compute Ot = Ot,intra + Ot,inter.
12: end for
13: return O = [Ot]T1 .

3.3. LASP-1 vs LASP-2

核心区别
LASP-2及其先前版本LASP-1都旨在实现线性注意力上的高效SP。尽管理论上,LASP-1和LASP-2在通信KV激活(大小为$d \times d$,与序列或块长度无关)方面有相似之处,但它们在通信方式和计算顺序重组方面存在根本区别,具体阐述如下:

LASP-1的通信方式
LASP-1利用环式(ring-style)点对点(P2P)通信,需要在设备之间启动许多发送和接收操作,以在设备间逐一顺序传输KV激活。这使得通信过程相对缓慢,并且难以与块内计算充分重叠。

LASP-2的通信方式
而LASP-2使用单一的AllGather集体通信操作来在所有设备间并发交换KV激活。这提供了实际优势:(1)只需启动一个经过优化的集体通信操作,所有设备上的KV激活交换可以一次性并发完成;(2)集体通信更容易与计算重叠。例如,在带掩码的LASP-2中,AllGather通信能够与块内输出计算重叠。此外,LASP-2重组了整个计算顺序,使得基于AllGather的通信策略变得可行和高效。

算法对比
为了方便与LASP-2在算法上的差异进行比较,我们还在附录A.2中用相同的数学符号写下了LASP-1(带掩码和不带掩码)的算法。

3.4. 理论成本分析

分析背景
为了更好地理解LASP-2的优势,我们对LASP-1和LASP-2进行了理论成本分析。我们考虑纯SP场景,即分布式世界大小为$W$,长度为$N$的输入序列被划分为$T=W$个块,因此该世界中的所有设备都需要参与通信。下面$B$表示批大小,$H$表示头数。


图2:LASP-2H在线性注意力和标准注意力混合模型上的可视化。我们以一个同时使用TP和SP(维度均为2)的线性和标准注意力混合层为例展示LASP-2H。黄色和绿色的通信操作分别用于TP和SP。AG/RS:前向为all-gather,后向为reduce-scatter,反之亦然。AG/No:前向为all-gather,后向无操作,反之亦然。请注意,线性注意力的SP通信操作作用于内存状态$M_t \in R^{d \times d}$,而标准注意力的SP通信操作作用于状态$K_t, V_t \in R^{C \times d}$。

每步通信流量
LASP-1:$BHd^2$,LASP-2:$BHd^2$。这是因为LASP-1和LASP-2都在设备间传输线性注意力的内存状态(而非键和值)。对应于每个块(位于每个设备上)的内存状态张量形状为$[B, H, d, d]$。因此,在每个通信步骤中,它们的通信流量均为$BHd^2$。对于一个B=16, H=16, d=2048的Linear-Llama3-1B模型,每个内存状态约有10.7亿个参数,在FP16下占用约2.14GB内存。对于一个B=16, H=32, d=4096的Linear-Llama3-8B模型,每个内存状态约有85.9亿个参数,在FP16下占用约17.18GB内存。

每次迭代的通信步数
LASP-1:$2(W-1)$,LASP-2:$2$。这取决于这两种算法的不同通信方式。在一次迭代的前向传播中,LASP-2启动一个单一的all-gather操作,将所有内存状态$M_t$收集到所有设备上,即$[M_t]_1^T = AllGather([M_t]_1^T)$。这个集体操作在所有设备上并发执行。而在反向传播中,对$M_t$的梯度执行另一次all-gather,即$[dM_t]_1^T = AllGather([dM_t]_1^T)$。因此,在每次迭代中,LASP-2有2个通信步骤。而LASP-1使用一对发送和接收操作来顺序地将内存状态从一个设备交换到另一个设备。在前向传播中,设备i将其内存状态发送给设备i+1,设备i+1接收来自设备i的内存状态,依此类推。$O_{t,inter}$、$O_t$的计算和$M_t$的更新都在该设备上的每次接收操作之后进行。因此,在前向传播过程中,LASP-1有$W-1$个通信步骤。在反向传播中,这个过程从最后一个设备反向重复到设备0。因此,在每次迭代中,LASP-1总共有$2(W-1)$个通信步骤。

总通信成本模型
假设LASP-1和LASP-2都执行总共$I$次迭代,它们的通信流量模型可以表示为:LASP-1:$2(W-1)IBHd^2$和LASP-2:$2IBHd^2$。理想情况下,LASP-2的通信流量将比LASP-1减少$W-1$倍。然而,实际的通信成本取决于实际因素,如通信带宽(通常在节点内更快,跨节点更慢)和通信稳定性。因此,LASP-2的优势在互连较慢的集群中更为明显,反之亦然。需要注意的是,该成本模型只考虑了通信,不包括计算或数据加载。在实践中,通信占总成本的比例较小,因此LASP-2实现的整体训练加速小于$W-1$倍。LASP-2在涉及长序列、大集群、慢通信链路以及高效数据加载和计算的场景中表现最佳。

3.5. 混合模型序列并行

LASP-2H概述
混合模型结合了线性的Transformer层和使用softmax自注意力的标准Transformer层,已被证明能有效增强长上下文能力,特别是在回忆和检索等任务中。为了优化这类混合模型中的SP,我们提出了LASP-2的扩展版本,称为LASP-2H。该方法通过将SP整合到线性和标准注意力模块中,提供了一个全面的解决方案。LASP-2H的结构如图2所示。

在线性注意力模块上的应用
如算法1和算法2所述,LASP-2H通过对内存状态$M_t \in R^{d \times d}$执行单一的all-gather通信操作来处理线性注意力模块。通信复杂度仍然与序列或块长度无关,仅与SP大小$T$线性相关,使得该方法在分布式集群中非常高效。

在标准注意力模块上的应用
上下文并行(Context Parallelism, CP)是Megatron-LM中的一种SP技术,它沿着序列维度划分网络输入和所有激活。该方法专为标准softmax注意力量身定制。虽然Megatron-LM中传统的CP实现依赖于环状结构中的通信和计算重叠【21, Liu, H. et al., Ring attention with blockwise transformers for near-infinite context, 2023】,我们的LASP-2H采用了不同的方法,遵循了Llama3的最佳实践【10, Dubey, A. et al., The Llama 3 herd of models, 2024】。LASP-2H没有使用环式策略,而是在标准注意力上采用基于AllGather的通信,其中$K_t$和$V_t$张量首先在设备间收集,之后在本地为$Q_t$张量块计算注意力输出。尽管all-gather通信的延迟比基于环的方法更高,但它在处理各种类型的注意力掩码(如文档级掩码)时提供了更大的便利性和灵活性。这种灵活性在需要不同注意力模式的场景中尤其有益。此外,由于$K_t$和$V_t$张量远小于$Q_t$张量,特别是在使用分组查询注意力(GQA)【2, Ainslie, J. et al., GQA: Training generalized multi-query transformer models from multi-head checkpoints, 2023】时,all-gather的延迟被最小化。因此,计算注意力输出的时间复杂度远超过all-gather操作的复杂度。我们在附录A.3的算法7中介绍了基于AllGather的上下文并行。

A7 补充细节

4.5. 相关工作

4.5.1. 线性序列建模

线性注意力
原始的线性注意力【15, Katharopoulos, A. et al., Transformers are RNNs: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】引入了使用核方法替代Softmax注意力【53, Vaswani, A. et al., Attention is all you need, 2017, Advances in neural information processing systems】,从而将计算复杂度降低到与序列长度呈线性关系。此后,提出了多种线性注意力的变体。TransNormerLLM【26, Qin, Z. et al., TransNormerLLM: A faster and better large language model with improved transnormer, 2023a】、【27, Qin, Z. et al., Scaling transnormer to 175 billion parameters, 2023b】提出了Lightning Attention,这是一种通过优化IO交互来加速处理的精炼线性注意力机制。Lightning Attention-2【29, Qin, Z. et al., Lightning Attention-2: A free lunch for handling unlimited sequence lengths in large language models, 2024b】通过分别处理块内和块间计算,进一步实现了线性注意力的理论优势。RetNet【48, Sun, Y. et al., Retentive network: A successor to transformer for large language models, 2023】引入了一种结合了循环和注意力的retention机制,兼具并行训练和线性推理的优点。门控线性注意力(GLA)【56, Yang, S. et al., Gated linear attention transformers with hardware-efficient training, 2023】在线性注意力框架中加入了与数据无关的门控机制,并提出了一种高效的训练算法。DeltaNet【37, Schlag, I. et al., Linear transformers are secretly fast weight programmers, 2021】及其并行化版本【57, Yang, S. et al., Parallelizing linear transformers with the delta rule over sequence length, 2024】使用类似delta规则的更新来增强线性注意力在长上下文场景中的性能。最后,受GLA启发,门控槽注意力(GSA)【61, Zhang, Y. et al., Gated slot attention for efficient linear-time sequence modeling, 2024】引入了一种带有限定内存槽控制的门控线性注意力机制,以进一步提高效率。

状态空间建模
状态空间模型(SSM)是一个强大的框架,用于表示动态系统中序列的行为,并且它在线性序列建模领域显示出巨大的潜力。Mamba【11, Gu, A. and Dao, T., Mamba: Linear-time sequence modeling with selective state spaces, 2023】引入了一种状态选择机制,从而促进了线性序列长度的扩展。该架构在Mamba-2【7, Dao, T. and Gu, A., Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality, 2024】中得到进一步增强,其中引入了状态空间对偶(SSD)框架来优化其性能。

线性RNN
传统的RNN在处理长上下文序列建模方面面临重大挑战,主要原因是其在训练过程中的固有序列依赖性,这使得它们无法充分利用规模法则【48, Sun, Y. et al., Retentive network: A successor to transformer for large language models, 2023】。为了解决这些限制,RWKV【22, Peng, B. et al., RWKV: Reinventing RNNs for the transformer era, 2023】、【23, Peng, B. et al., Eagle and Finch: RWKV with matrixvalued states and dynamic recurrence, 2024】被提出,它是一个基于线性RNN的大型语言模型,旨在有效管理长期依赖。此外,HGRN【32, Qin, Z. et al., Hierarchically gated recurrent neural network for sequence modeling, 2024e】强调了数据依赖的衰减机制在提升线性RNN性能中的关键作用,展示了调整衰减参数如何能改善长上下文任务中的学习。其增强版本HGRN2【31, Qin, Z. et al., HGRN2: Gated linear rnns with state expansion, 2024d】通过引入一个利用外积操作的状态扩展机制来扩展此方法,从而在更长的序列上实现更大的可扩展性和更强的建模能力。RWKV和HGRN系列都试图克服RNN在高效长序列建模方面的弱点。

4.5.2. 序列并行

SP【19, Li, S. et al., Sequence parallelism: Long sequence training from system perspective, 2022】是一种为更高效地训练语言模型而设计的分布式技术,它通过将长序列划分为多个较短的子序列,并在多个计算设备上并行处理这些子序列来实现。现有的SP方法【17, Korthikanti, V. et al., Reducing activation recomputation in large transformer models, 2022】、【13, Jacobs, S. A. et al., Deepspeed Ulysses: System optimizations for enabling training of extreme long sequence transformer models, 2023】的并行度不能超过注意力头的数量,这限制了它们的可扩展性。Ring Attention【21, Liu, H. et al., Ring attention with blockwise transformers for near-infinite context, 2023】被提出来解决长序列建模中的高内存成本问题,它通过在不同设备上分布子序列并重叠KV块的通信来实现。LASP【46, Sun, W. et al., Linear attention sequence parallelism, 2024a】提出了一种新的、为线性注意力量身定制的SP策略,该策略基于GPU友好的实现,利用了P2P环式通信策略,但仍然缺乏对混合模型架构的优化。

A4 实验环境

模型架构:
* 基础模型: Linear-Llama3,一个基于Llama3【10, Dubey, A. et al., The Llama 3 herd of models, 2024】的模型,其中标准softmax注意力被多种线性注意力模块替代(包括原始线性注意力、Lightning Attention、Retention、GLA、Based、Rebased)。模型包含16个线性Transformer层,总计10亿参数。
* 混合模型: 在Linear-Llama3的每第四层保留标准softmax注意力的Transformer层,构成一个1/4混合架构。

数据集:
* 名称: SlimPajama【45, Soboleva, D. et al., SlimPajama: A 627B token cleaned and deduplicated version of RedPajama, 2023】。
* 规模: 完整数据集包含6270亿个token,实验中使用的是从训练集第一个块派生的500亿token的子集。
* 用途: 用于GPT风格的自回归语言模型任务。

硬件配置:
* 平台: 多达16台DGX-A100服务器。
* GPU: 每台服务器配备8个A100 GPU,通过NVSwitch连接。
* 互联带宽: GPU间带宽为600 GBps。

软件配置:
* 框架与库: PyTorch 2.3.1, CUDA 12.1, cuDNN 8.9.2, NCCL 2.20.5。
* 基础代码: 算法在NVIDIA的Megatron-Core 0.9.0【40, Shoeybi, M. et al., Megatron-LM: Training multibillion parameter language models using model parallelism, 2019】之上开发。
* 加速库: 使用Triton 2.3.1【52, Tillet, P. et al., Triton: an intermediate language and compiler for tiled neural network computations, 2019】加速GPU上的线性注意力计算,使用FlashAttention2【6, Dao, T., Flashattention-2: Faster attention with better parallelism and work partitioning, 2023】作为标准注意力的实现。

A4 实验结果

速度

  • 实验内容: 将LASP-2与现有的SP方法(Megatron-SP、Ring Attention和LASP-1)进行速度性能比较。实验在一个纯Linear-Llama3-1B模型上进行,使用64个A100 GPU,SP大小也设为64。
  • 实验结果: 如图3所示,LASP-2表现出优越的吞吐量,特别是在序列长度超过64K时。当序列长度为512K时,LASP-2比Ring Attention快17.8%,比LASP-1快7.3%。当序列长度达到2048K时,LASP-2的吞吐量增益分别比Ring Attention和LASP-1高出36.6%和15.2%。
  • 分析结论: LASP-2在处理超长序列时具有显著的速度优势,且优势随序列长度增加而愈发明显。


图3:速度比较(tokens/s)。实验在一个纯Linear-Llama3-1B模型上进行,使用基础线性注意力模块。共使用64个A100 GPU,SP大小T也设为64。为了适应超长序列长度(如2048K),本实验中批大小固定为1。

可扩展性

  • 实验内容: 通过调整序列长度和GPU数量,评估LASP-2在GPU内存使用和吞吐量方面的可扩展性。
  • 实验结果: 如图4所示,LASP-2能够通过增加GPU数量来线性扩展输入序列长度。例如,在保持每个GPU内存成本相同的情况下,8个GPU可训练长度达128K的序列,而128个GPU可训练长达2048K的序列。同时,增加序列长度和设备数量会导致更高的吞吐量。
  • 分析结论: LASP-2展示了高效的通信和线性可扩展性。更详细的量化可扩展性结果见附录A.5中的表6。


图4:可扩展性结果。实验在一个纯Linear-Llama3-1B模型上进行,使用基础线性注意力模块。SP大小T始终等于GPU数量。批大小固定为1以适应超长序列长度,例如2048K。带有虚线的“×”符号表示发生内存不足(OOM)。

收敛性能

  • 实验内容: 在SlimPajama语料库上,使用500亿个token、16K序列长度和全局批大小为8的设置,在8个A100 GPU上评估LASP-2在多种Llama-3变体(包括标准softmax注意力、多种线性注意力以及它们的1/4混合模型)上的预训练收敛性能。
  • 实验结果: 如表2所示,对于使用不同线性注意力模块的纯Linear-Llama3模型,LASP-2在保持更高吞吐量的同时,取得了可比但略高的损失值。在1/4混合模型上,损失结果普遍优于纯线性模型,其中Lightning Attention、Retention和GLA甚至达到了与基线相当或更低的损失值。Based注意力模块因其混合了线性和滑动窗口注意力的设计,表现出强大的吞吐量和损失性能。
  • 分析结论: 1/4混合模型在吞吐量和收敛性能之间取得了良好的平衡,与基线及其纯线性版本相比具有竞争力。

Table 2: 收敛性能结果。所有实验均使用8个A100 GPU,序列长度16K,批大小为8,在SlimPajama语料库的500亿个token上训练。
Linear-Llama3-1B

A5 结论

本文提出了LASP-2,一种新的SP方法,解决了现有SP方法在线性序列建模中的效率低下问题。通过重新设计整个算法工作流并利用单一的all-gather通信策略,与先前的工作LASP-1相比,LASP-2显著增强了通信和计算的并行性,并使得通信-计算重叠更容易实现。我们的结果表明,LASP-2在速度和可扩展性方面提供了显著的改进,尤其是在处理超长序列长度的情况下。此外,扩展到LASP-2H使得混合模型中(集成了线性和标准注意力模块)的SP变得高效,两者都利用了统一的基于all-gather的通信原语。在Linear-Llama3模型上的实验评估验证了这些进步,LASP-2在性能上超过了之前的方法,如LASP-1和Ring Attention,特别是在极端序列长度下,差距显著。这些发现证实了LASP-2在大型分布式系统中的实用性,使其成为未来长序列线性Transformer模型应用的一个有前途的方法。

A6 附录

A.1. LASP-2 算法(反向传播)

见算法3和算法4。

Algorithm 3 LASP-2 w/o Masking (Backward Pass)

1: Input: distributed world size W , sequence parallel size T = W , Qt, Kt, Vt, Ot, dOt ∈ RC×d for chunk t ∈ {1, · · · , T }.
2: for chunk t ∈ {1, · · · , T } on ranks {1, · · · , W } in parallel do
3: Compute dMt = (Qt)⊤dOt.
4: Communicate [dM]T1 = AllGather([dM]T1 ).
5: Compute dM1:T = Sum([dM]Tt+1).
6: Compute dQt = dOtM⊤1:T .
7: Compute dKt = VtdM⊤1:T .
8: Compute dVt = KtdM1:T .
9: end for
10: return dQ = [dQt]T1 , dK = [dKt]T1 , dV = [dVt]T1 .
Algorithm 4 LASP-2 w/ Masking (Backward Pass)

1: Input: distributed world size W , sequence parallel size T = W , Qt, Kt, Vt, Ot, dOt ∈ RC×d for chunk t ∈ {1, · · · , T }.
2: for chunk t ∈ {1, · · · , T } on ranks {1, · · · , W } in parallel do
3: Compute dMt = (Qt)⊤dOt.
4: Communicate [dM]T1 = AllGather([dM]T1 ).
5: Compute dQt,intra = [(dOtV⊤t ) ⊙ Ψ]Kt.
6: Compute dKt,intra = [(dOtV⊤t ) ⊙ Ψ]⊤Qt.
7: Compute dVt,intra = [(QtK⊤t ) ⊙ Ψ]⊤dOt.
8: Compute dQt,inter = dOtM⊤1:t−1.
9: Compute suffix sum dMt+1:T = SuffixSum([dM]Tt+1).
10: Compute dKt,inter = VtdM⊤t+1:T .
11: Compute dVt,inter = KtdMt+1:T .
12: Combine intra- and inter-chunk parts of dQt, dKt, dVt

13: end for
14: return dQ = [dQt]T1 , dK = [dKt]T1 , dV = [dVt]T1 .

A.2. LASP-1 算法

见算法5和算法6。

Algorithm 5 LASP-1 w/o Masking

1: Input: input sequence X, distributed world size W , sequence parallel size T = W .
2: Distribute input X = [Xt]T1 .
3: for chunk t ∈ {1, · · · , T } at rank i ∈ {1, · · · , W } in parallel do
4: Compute Qt = XtWQ, Kt = XtWK, Vt = XtWV .
5: Compute Mt = K⊤t Vt.
6: end for
7: for chunk t ∈ {1, · · · , T } at rank i ∈ {1, · · · , W } sequentially do
8: Recv activation Mt−1 from rank (i − 1). Save Mt−1 in memory for backward computation.
9: Compute Ot = QtMt−1.
10: Update Mt = Mt−1 + K⊤t Vt.
11: Send activation Mt to rank (i + 1).
12: end for
13: return O = [Ot] with t ∈ {1, · · · , T }.
Algorithm 6 LASP-1 w/ Masking

1: Input: input sequence X, distributed world size W , sequence parallel size T = W .
2: Distribute input X = [Xt]T1 .
3: Initialize mask matrix Ψ, where Ψij = 1 if i ≥ j, and Ψij = −∞ if i < j.
4: for chunk t ∈ {1, · · · , T } at rank i ∈ {1, · · · , W } in parallel do
5: Compute Qt = XtWQ, Kt = XtWK, Vt = XtWV .
6: Compute Mt = (Kt)⊤Vt.
7: Compute Ot,intra = [(QtK⊤t ) ⊙ Ψ]Vt.
8: end for
9: for chunk t ∈ {1, · · · , T } at rank i ∈ {1, · · · , W } sequentially do
10: Recv activation Mt−1 from rank (i − 1). Save Mt−1 in memory for backward computation.
11: Compute Ot,inter = QtMt−1.
12: Compute Ot = Ot,intra + Ot,inter.
13: Update Mt = Mt−1 + K⊤t Vt.
14: Send activation Mt to rank (i + 1).
15: end for
16: return O = [Ot] with t ∈ {1, · · · , T }.

A.3. 基于AllGather的上下文并行

见算法7。

Algorithm 7 AllGather-based Context Parallelism

1: Input: input sequence X, distributed world size W , sequence parallel size T = W .
2: Distribute X = [Xt]T1 .
3: for chunk t ∈ {1, · · · , T } on ranks {1, · · · , W } in parallel do
4: Calculate Qt = XtWQ, Kt = XtWK, Vt = XtWV .
5: Communicate [Kt]T1 = AllGather([Kt]T1 ) and [Vt]T1 = AllGather([Vt]T1 ).
6: Concatenate K = Concat([Kt]T1 ) and V = Concat([Vt]T1 ).√
7: Compute Ot = Softmax(QtK⊤/ d)V.
8: end for
9: return O = [Ot]T1 .

A.4. 兼容性

A.4.1. 混合并行

数据-序列混合并行
LASP-2允许选择一个小于分布式世界大小且可被其整除的序列并行大小。这种设置沿批次和序列两个维度分割输入数据,这是一种被称为数据-序列混合并行的策略。ZeRO系列优化器【34, Rajbhandari, S. et al., Zero: Memory optimizations toward training trillion parameter models, 2020】和FSDP【62, Zhao, Y. et al., Pytorch FSDP: experiences on scaling fully sharded data parallel, 2023】提供了将模型状态(如优化器状态、梯度和模型参数)分布到分布式系统中所有GPU的方法。由于这些技术是数据并行的变体,它们可以与LASP无缝集成。它们最小化模型状态内存占用的主要目标与LASP-2减少每个GPU上激活内存的特定重点相辅相成,使得训练处理长序列长度的大规模模型变得更加易于管理。

张量与流水线并行
LASP-2也支持张量并行(TP)和流水线并行(PP)。在TP的情况下,其与LASP-2的集成是直接且高效的。线性注意力层应用TP来分解块内和块间计算中的矩阵运算。同时,MLP层在TP下照常处理,无需任何修改。当LASP-2与PP配对时,它使用从迷你批次中提取的子序列来替代传统的微批次。与标准PP的一个关键区别是,每个设备在本地专门存储中间状态,即前向传播中的$M_t$和反向传播中的$dM_t$,而无需将这些状态通信给其他设备。

A.4.2. 可变长度

处理可变长度序列
在预训练期间,批次通常包含长度统一的序列。然而,在微调或推理时,模型可能会遇到不同长度的输入序列。一个直接的解决方案是将批次中的所有序列向右填充,以匹配最长序列的长度。不幸的是,当序列长度差异很大时,这种方法可能效率低下。对于标准Transformer,已经开发了更复杂的方法来应对这一挑战。这些技术包括在不填充的情况下跨GPU进行负载均衡【58, Zeng, J. et al., Boosting distributed training performance of the unpadded bert model, 2022】、【59, Zhai, Y. et al., ByteTransformer: A highperformance transformer boosted for variable-length inputs, 2023】,或将多个序列打包到一个批次中并相应地调整注意力掩码【9, Ding, H. et al., Fewer truncations improve language modeling, 2024】、【24, Pouransari, H. et al., Dataset decomposition: Faster llm training with variable sequence length curriculum, 2024】。LASP-2可以通过将整个批次视为一个单一的长序列来高效地管理可变序列长度,从而简化了过程,无需填充。

A.5. 附加实验结果

A.5.1. 双向语言建模任务

实验设置与结果
为了在双向语言建模任务上进行评估,我们以RoBERTa为基础模型,并将其标准注意力模块替换为基础线性注意力,在4个A100 GPU上使用总输入序列长度为2048进行5万次迭代训练。如表3所示,使用基础线性注意力的LASP-2能够达到与基于标准注意力的Ring Attention模型相近的收敛性能。

Table 3: 双向语言建模任务的收敛性能。报告了训练和验证损失值。

A.5.2. 混合比例的消融研究

实验设置与结果
我们提供了关于混合模型混合比例的消融结果。设“L”表示线性Transformer层,“N”表示普通Transformer层。评估的混合模型架构有:0 Hybrid: "LLLL LLLL LLLL LLLL";1/8 Hybrid: "LLLL LLLN LLLL LLLN";1/4 Hybrid: "LLLN LLLN LLLN LLLN";1/2 Hybrid: "LNLN LNLN LNLN LNLN"。与使用标准注意力的Llama3-1B基线(损失值为2.759)相比,结果表明更高的混合比例倾向于带来更好的收敛性能,但有时适中的混合比例可能达到更好的结果。

Table 4: 混合模型中混合比例的消融研究。表中报告了损失值。请注意,纯线性模型使用LASP-2,而混合模型使用LASP-2H。

A.5.3. 不同收集大小的消融研究

实验设置与结果
我们对不同大小的内存状态收集进行了消融研究。考虑批大小为1,在Linear-Llama3-1B模型(16个头,隐藏维度2048)中,每个内存状态的张量形状为[1, 16, 2048, 2048]。我们使用64个GPU和1024K的序列长度,每个测试重复10次并报告其平均值。我们改变了收集内存状态的分割大小,并在表5中展示了LASP-2的吞吐量结果。可以看出,较小的分割大小(即更多的分割次数)倾向于导致轻微较慢的吞吐量。结果表明,all-gather操作的利用并非效率提升的唯一原因。通信方式以及计算工作流的重组起着重要作用。

Table 5: 不同收集分割大小下的吞吐量结果(tokens/sec)。使用Linear-Llama3-1B模型(16个头,隐藏维度2048)。

A.5.4. 定量可扩展性结果

见下页表6。

Table 6: LASP-2在吞吐量(tokens/sec)和每GPU内存使用(GB)方面的定量可扩展性结果。实验在Linear-Llama3-1B上进行,序列长度从2K扩展到4096K。