Linear Attention Sequence Parallelism

文章标题: 线性注意力序列并行 (Linear Attention Sequence Parallelism)
作者/机构: Weigao Sun (上海人工智能实验室), Zhen Qin (TapTap), Dong Li (上海人工智能实验室), Xuyang Shen (上海人工智能实验室), Yu Qiao (上海人工智能实验室), Yiran Zhong (上海人工智能实验室)

A1 主要贡献

本文针对线性序列建模方法(如线性注意力)提出了一种名为线性注意力序列并行(Linear Attention Sequence Parallelism, LASP)的高效序列并行(SP)方法,旨在解决现有SP方法未能利用线性注意力“右乘优先”特性而导致的通信效率和可用性欠佳的问题。

核心问题与研究目标
随着大型语言模型(LLM)的规模和序列长度不断增加,单个GPU的内存容量成为一个显著瓶颈。序列并行(SP)技术通过将长序列切分到不同设备上处理来解决此问题。然而,对于线性注意力这类具有线性复杂度的机制,现有的SP方法没有充分利用其右乘(right-product)优势,导致并行效率和可用性不理想。本文的目标是设计一种专门针对线性注意力模型的SP方法,以大幅提升通信效率,并支持在大型集群上进行超长序列的分布式训练。

创新点与主要贡献
1. 新型序列并行方法LASP:提出了一种专为线性序列建模方法设计的SP方法LASP。该方法能够在分布式训练中处理比现有SP方法长8倍的序列,同时速度显著更快。
2. 序列长度无关的通信开销:设计了一种点对点(P2P)环式通信策略,该策略利用了线性注意力的右乘核技巧(right-product kernel trick)。这使得中间状态的交换量与序列长度无关,从而大幅降低了通信开销。
3. GPU友好的实现:通过精细的系统工程优化,包括核函数融合(kernel fusion)KV状态缓存(KV State caching),提升了LASP在GPU硬件上的执行效率。
4. 数据并行兼容性:LASP与所有批次级别(batch-level)的分布式数据并行(DDP)方法兼容,包括PyTorch/Legacy DDP、FSDP和ZeRO系列优化器,这对于在大型集群上进行超长序列的训练至关重要。

A3 背景知识/关键Observation/设计原则

线性注意力 (Linear Attention)

线性Transformer模型通过采用各种近似方法绕过Softmax注意力,例如【11, Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】 使用1 + elu激活函数,【30, Qin et al., cosformer: Rethinking softmax in attention, 2022, International Conference on Learning Representations】利用余弦函数模仿Softmax特性。其核心思想是使用“核技巧”,通过先计算键(keys)和值(values)的乘积来加速注意力矩阵的计算,避免了计算量巨大的n x n矩阵乘法。近期工作如TransNormerLLM 【34, Qin et al., Transnormerllm: A faster and better large language model with improved transnormer, 2024】提出了Lightning Attention以优化IO操作来加速线性注意力,而Lightning Attention-2 【36, Qin et al., Lightning attention-2: A free lunch for handling unlimited sequence lengths in large language models, 2024】通过分别处理块内和块间的计算来提高效率。其他相关模型包括RetNet 【53, Sun et al., Retentive network: A successor to transformer for large language models, 2023】、GLA 【64, Yang et al., Gated linear attention transformers with hardware-efficient training, 2023】、DeltaNet 【46, Schlag et al., Linear transformers are secretly fast weight programmers, 2021】及其并行版本【66, Yang et al., Parallelizing linear transformers with the delta rule over sequence length, 2024】和GSA 【69, Zhang et al., Gated slot attention for efficient linear-time sequence modeling, 2024】。尽管研究取得了显著进展,线性注意力在商业大规模模型中的应用仍然有限,但已有公司开始探索,例如Minimax-01【17, Li et al., Minimax-01: Scaling foundation models with lightning attention, 2025】、腾讯的混元Turbo-S【55, Tencent, Tencent hunyuan turbo-s, 2025】和Together.AI的StripedHyena【27, Poli et al., StripedHyena: Moving Beyond Transformers with Hybrid Signal Processing Models, 2023】。

内存高效的注意力机制 (Memory-Efficient Attention)

为了解决注意力机制的内存瓶颈,研究者们提出了多种优化方法。【42, Rabe & Staats, Self-attention does not need o(n2) memory, 2021】首次采用在线Softmax技术,以线性内存复杂度顺序计算数值稳定的注意力分数,但时间复杂度仍为二次。FlashAttention【4, Dao et al., Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022】和FlashAttention-2【2, Dao, Flashattention-2: Faster attention with better parallelism and work partitioning, 2023】通过分块(tiling)技术最小化GPU高带宽内存(HBM)和片上SRAM之间的读写次数,从而减少训练时间和内存占用。PagedAttention【15, Kwon et al., Efficient memory management for large language model serving with pagedattention, 2023】则通过减少浪费并允许在批处理请求中自适应共享,优化了KV缓存的内存利用率。Ring Attention【21, Liu et al., Ring attention with blockwise transformers for near-infinite context, 2023】通过将序列分布在多个设备上,并将键值块的通信与分块注意力计算重叠,来降低处理长序列时Transformer模型的内存需求。

序列并行 (Sequence Parallelism)

序列并行(SP)作为训练长序列的常用方法,已被集成到许多大型模型训练框架中,如Megatron-LM、DeepSpeed和Colossal-AI。Megatron-LM【48, Shoeybi et al., Megatron-lm: Training multi-billion parameter language models using model parallelism, 2019】将SP与模型(张量)并行(MP)结合,以在GPU上执行大规模矩阵乘法。然而,MP需要划分注意力头,这使得最大并行度受限于注意力头的数量。DeepSpeed-Ulysses【10, Jacobs et al., Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models, 2023】使用all-to-all通信原语来减少通信量,但也划分注意力头,面临与Megatron-LM类似的问题。

A2 方法细节

3.1 预备知识

Softmax注意力。标准的Transformer架构中的因果掩码注意力计算公式为:


其中,$d$表示隐藏维度。矩阵$Q, K, V \in \mathbb{R}^{N \times d}$分别代表查询、键和值矩阵,它们是输入$X \in \mathbb{R}^{N \times d}$的线性投影,即$Q = XW_Q, K = XW_K, V = XW_V$。输出矩阵表示为$O \in \mathbb{R}^{N \times d}$,$M \in \mathbb{R}^{N \times N}$代表因果掩码矩阵。Softmax(·)操作引入了相对于输入序列长度$N$的二次时间复杂度,这限制了原生Transformer扩展到长输入序列的能力。

线性注意力。线性注意力最初在【11, Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】中提出,它消除了Softmax操作。后续工作【29, Qin et al., The devil in linear transformer, 2022, EMNLP】、【34, Qin et al., Transnormerllm: A faster and better large language model with improved transnormer, 2024】提出用归一化操作Norm(·)替代Softmax,公式变为:


当考虑双向任务时,上述公式可以简化为:

然后,通过利用矩阵乘积的结合律,它可以被数学上等价地转换为右乘版本:

这种线性注意力公式使得循环预测的计算复杂度达到$O(N d^2)$。并且,其循环更新$K^\top V$的方式无需计算整个注意力矩阵,使得推理过程非常高效。

线性注意力的挑战。尽管线性复杂度在线性注意力计算效率和内存优化方面提供了显著优势,但随着序列长度$N$的增长,单个GPU上的计算和内存利用率仍然成比例增加。这可能导致单个GPU(如NVIDIA A100的80GB限制)在处理极长序列时出现内存瓶颈。在GPU集群上使用基于线性注意力的LLM实现序列级别的零冗余训练仍然是一个开放性问题。此外,在因果设置中解决此问题的复杂性进一步加剧了挑战。为解决此问题,我们提出了LASP,作为一种在序列级别上并行化线性注意力训练的解决方案,即使在因果设置中也同样适用。

图1:LASP的可视化。左图:带因果掩码的块级线性注意力计算可分为两类:块内计算和块间计算。块内计算对应于掩码矩阵的对角线元素(橙色对角框),使用传统的左乘方法。而块间计算对应于下三角框,采用高效的右乘方法进行计算。右图:展示了LASP采用的点对点通信机制。输入序列X被划分为多个子序列块{· · · , Xi, Xi+1, · · · },每个块由不同设备上的不同模型实例处理。对于每个设备i,从其各自的输入块Xi计算出Qi, Ki和Vi。值得注意的是,设备间的通信操作在前向和后向传播中是互补的。具体来说,在前向传播中,KV矩阵从设备i发送到设备(i + 1),在后向传播中,dKV矩阵从设备(i + 1)发送回设备i。
图1:LASP的可视化。左图:带因果掩码的块级线性注意力计算可分为两类:块内计算和块间计算。块内计算对应于掩码矩阵的对角线元素(橙色对角框),使用传统的左乘方法。而块间计算对应于下三角框,采用高效的右乘方法进行计算。右图:展示了LASP采用的点对点通信机制。输入序列X被划分为多个子序列块{· · · , Xi, Xi+1, · · · },每个块由不同设备上的不同模型实例处理。对于每个设备i,从其各自的输入块Xi计算出Qi, Ki和Vi。值得注意的是,设备间的通信操作在前向和后向传播中是互补的。具体来说,在前向传播中,KV矩阵从设备i发送到设备(i + 1),在后向传播中,dKV矩阵从设备(i + 1)发送回设备i。

3.2 LASP

LASP的分块思想。LASP在集群上对序列进行分块处理。遵循分块的思想,LASP将输入序列划分为多个子序列块,并将这些块单独分布到不同的GPU上。对于因果设置下的线性注意力,为了充分利用其右乘优势,我们将块的注意力计算分为两种不同类型:块内(intra-chunks)和块间(inter-chunks)。块内计算涉及常规的注意力计算,而块间计算则利用了与线性注意力右乘相关的核技巧。关于LASP在数据分发、前向传播和后向传播中的复杂机制的更多细节将在下文阐述。LASP的可视化展示在图1中。

数据分发机制。LASP专为在分布式环境中训练基于线性Transformer的长序列而设计,通过沿序列维度对输入数据进行分区来实现。在这种情况下,分布式环境中的每个GPU负责训练一部分子序列,这有助于减少训练长序列时激活值带来的巨大内存占用。GPU之间引入了通信操作来传输中间状态。最终训练好的模型将吸收来自整个长序列的知识。

数据分发过程。对于一个长度为$N$的输入序列,我们将其在嵌入空间中的表示记为$X \in \mathbb{R}^{N \times d}$,特征维度为$d$。在LASP框架中,$X$被均匀地划分为$T$个块,其中$T$称为序列并行大小,并且必须能被分布式世界大小$W$整除。这些分段的数据块随后被分配给各自的GPU。需要注意的是,不同的序列并行组接收不同批次的数据。然而,在同一组内,所有数据块都来自同一批次的数据。LASP中数据分发过程的完整描述见算法1。

图2:LASP数据分发。左图:一个包含两个输入序列和八个GPU的数据分发示例。右图:完整的数据分发算法。
图2:LASP数据分发。左图:一个包含两个输入序列和八个GPU的数据分发示例。右图:完整的数据分发算法。

算法1 LASP数据分发

1: 输入: 嵌入空间中的输入序列 X ∈ R^{N×d},序列长度为 N,隐藏维度为 d,分布式世界大小为 W,序列并行大小为 T。
2: 获取序列并行组的数量 G = W/T。
3: 获取子序列长度(或块大小)C = N/T。
4: 获取全局排名列表 R = get_global_rank()。
5: 获取序列并行源排名列表 R_src = ⌊R/T⌋ * T。
6: 沿序列维度,将 X 分割成 T 个块 {X1, X2, ..., XT},每个块的大小为 C × d。
7: 将数据块 {X1, X2, · · ·, XT} 的副本传输到排名索引在 R_src 中的 GPU。
8: 将 {X1, X2, · · ·, XT} 从 R_src 分散到各自序列并行组中的所有排名。

数据分发示例。此外,图2中展示了LASP数据分发的一个示例,其中分布式世界大小$W=8$,序列并行大小$T=4$,序列并行组数量$G=2$,序列并行源排名列表$R_{src} = [0, 4]$。对于第一个批次Seq0,输入序列$X$沿着序列维度被划分为$T$个块$\{X_1, X_2, ..., X_T\}$,随后传输到SP-Group0中的第一个排名(即全局排名0)。全局排名0上的数据块接着被分散到SP-Group0内的全局排名{0, 1, 2, 3},每个排名只保留一个块。后续批次Seq1以类似的方式,被分配到SP-Group1内的全局排名{4, 5, 6, 7}。

前向传播的简化设定。为了简化推导,我们暂时省略了公式(2)中的Norm(·)算子。此外,我们考虑一个常规情况,即$W=T$,这意味着$G = W/T = 1$。在这种情况下,排名为0的GPU整合一个批次中所有分割的子序列,随后将它们分发到整个分布式世界中的所有GPU。值得注意的是,序列并行大小不等于世界大小的情况将在第3.5节中讨论。

前向传播的递归形式。我们首先将kvKV分别定义为中间内存状态向量和矩阵。不失一般性,我们在带因果掩码的线性注意力中加入衰减率$\lambda$,选择$\lambda=1$即得到普通的线性注意力【34, Qin et al., Transnormerllm: A faster and better large language model with improved transnormer, 2024】、【53, Sun et al., Retentive network: A successor to transformer for large language models, 2023】。在带因果掩码的线性注意力计算的前向传播中,第$s$个输出可以计算为:


以递归形式重写,我们得到:

其中

是前向传播中第$s$个输入的激活内存状态。

分块计算。在序列并行(SP)中,给定排名$i$上的数据块$X_t$,对应的查询、键和值为$Q_t = X_tW_Q, K_t = X_tW_K, V_t = X_tW_V$。注意,这里我们假设$T=W$,因此它们的索引是等价的,即$t=i$。第$t$个块内的输出可以计算为:


块内计算与其他GPU上的其他块没有依赖关系,因此可以在分布式世界中的所有排名上并行计算。然而,这个结果没有考虑到前$1 \sim (t-1)$个块对第$t$个块的影响,这被称为块间计算。为了计算块间部分,我们重新排列公式(4)为:

公式(8)中的第一部分对应于对先前块的计算,第二部分对应于对当前块的计算。在SP中,公式(8)可以被重写为块的形式:

其中$KV_t = kv_{tC}$。请注意,第$t$个块的块间计算依赖于前一个$(t-1)$块的激活状态,即$KV_{t-1}$,它是在排名$(i-1)$上计算的。因此,需要执行一个P2P通信操作Recv,将$KV_{t-1}$从排名$(i-1)$拉取到排名$i$。然后,应该更新激活状态$KV_t$,以用于后续第$(t+1)$个块的块间注意力计算。第$t$个块的$KV_t$更新规则是:

与前面的Recv操作相对应,执行另一个P2P通信操作Send,将公式(10)中获得的$KV_t$传输到后续的排名$(i+1)$,用于其块间计算。

KV状态缓存。值得注意的是,在反向传播中,第$t$个块需要$KV_{t-1}$作为激活值来计算梯度。为了最小化通信操作,我们将$KV_{t-1}$缓存在高带宽内存(HBM)中以加速计算。

最终前向输出。结合块内和块间部分,最终的前向输出如下:


我们在算法2中展示了$W=T$时LASP的完整前向传播过程。

算法2 LASP前向传播

1: 输入: 嵌入空间中的输入序列 X ∈ R^{N×d},序列长度 N,隐藏维度 d,分布式世界大小 W,序列并行大小 T = W,衰减率 λ ∈ R+。
2: 根据算法1分发输入序列 X。
3: 获取子序列长度(或块大小)C = N/T。
4: 初始化掩码 M ∈ R^{C×C},其中 M_{ij} = λ^{i-j} (如果 i ≥ j),否则 M_{ij} = 0。
5: 初始化 Λ = diag{λ, λ², · · ·, λ^C} ∈ R^{C×C}。
6: 初始化激活状态 KV = 0 ∈ R^{d×d}。
7: for 块 t ∈ {1, · · ·, T} 在排名 i ∈ {1, · · ·, W} 上并行 do
8:   根据自己的数据块计算 Qt = XtWQ, Kt = XtWK, Vt = XtWV,每个大小为 C × d。
9:   计算 O_{t,intra} = [(Qt K_t^⊤) ⊙ M]Vt。
10: end for
11: for 块 t ∈ {1, · · ·, T} 在排名 i ∈ {1, · · ·, W} do
12:   从排名 (i - 1) 接收激活 KVt-1。
13:   保存 KVt-1 为 KVi 以用于反向计算。
14:   计算 O_{t,inter} = ΛQtKVt-1。
15:   计算 Ot = O_{t,intra} + O_{t,inter}。
16:   更新 KVt = λ^C KVt-1 + (λ^C Λ^{-1} Kt)^⊤ Vt。
17:   将激活 KVt 发送到排名 (i + 1)。
18: end for
19: return O = [Ot], 其中 t ∈ {1, · · ·, T}。

反向传播的梯度公式。对于反向传播,给定$do_s$,我们有(引自【11, Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】):


将$dkv_s$写成递归形式,我们有:

反向传播的分块设定。在SP中,我们有$\{Q_t, K_t, V_t, O_t, dO_t\}$,对应于排名$i$上的第$t$个子序列块,其中$t \in \{1, \cdots, T\}$,$i \in \{1, \cdots, W\}$。与前向传播相同,以下推导假设$t=i, T=W$。

dQt的计算。我们首先计算关于第$t$个数据块的$dQ$,得到:


由于$dQ_{t,intra}$的计算是独立的,其计算可以在所有GPU上并行化。而$dQ_{t,inter}$的计算反映了块1到$t-1$对块$t$的相互依赖。为了计算块间部分,我们将公式(12)转换为:

公式(15)的第一部分对应于块内计算,而第二部分对应于块间计算。在SP中,我们可以计算$dQ_{t,inter}$为:

注意,$KV_t$已在前向传播中计算并缓存,因此这里不需要通信来获取$KV_t$。得益于KV状态缓存, $dQ_{t,inter}$的计算也可以并行执行。

dKt的计算。接下来,第$t$个块内的$dK$可以并行计算为:


然后我们转换公式(12)为:

其中加号前的项对应于块内计算,加号后的项对应于块间计算。上述方程可以按块的形式重写如下:

这里需要一个Recv操作来从第$(t+1)$个块拉取$dKV_{t+1}$。然后为了计算第$(t-1)$个块的$dKV$,$dKV$应更新为:

接着执行一个Send操作将$dKV_t$推送到排名$(i-1)$。

dVt的计算。最后,对于$dV$,其块内部分可以计算为$dV_{t,intra} = [(Q_t K_t^\top) \odot M]^\top dO_t$。我们再次转换公式(12)为:


第一项和第二项分别对应于块内和块间计算。在SP中,$dV_{t,inter}$可以计算为:

结合块内和块间部分,我们得到$d_Qt, d_Kt$和$d_Vt$的最终结果。完整的LASP反向传播过程见附录A.1中的算法3。

3.3 对比

通信量对比。在LASP中,需要注意的是,前向传播中每个线性注意力模块层都需要为$KV \in \mathbb{R}^{d \times d}$状态进行通信。通信量由$Bd^2/h$决定,其中$B$是批量大小,$h$是头数。相比之下,Ring Attention也采用P2P环式通信,但作用于状态$K, V \in \mathbb{R}^{N \times d}$,导致通信量为$BNd/h$。Megatron-LM中的SP在每个Transformer层内的两个层归一化后使用两次all-gather操作,并在注意力和前馈神经网络(FFN)层后使用一次reduce-scatter操作。这导致通信量为$2BNd + 4BNd/T$。DeepSpeed对每个注意力模块层的输入$Q, K, V$和输出$O$使用all-to-all集体通信【56, Thakur et al., Optimization of collective communication operations in mpich, 2005】,导致通信量为$4BNd/T$。

表1:通信量对比。简化公式:为便于比较,我们消除了共同因子Bd。

表1:通信量对比。简化公式:为便于比较,我们消除了共同因子Bd。
表1:通信量对比。简化公式:为便于比较,我们消除了共同因子Bd。

与Ring Attention的详细对比。在实际应用中,当$N/T \geq 32$时,LASP能够实现理论上最低的通信量。此外,LASP的通信量不受序列长度$N$或子序列长度$C$变化的影响,这对于在大型集群上进行超长序列的SP来说是一个巨大优势。值得注意的是,尽管Ring Attention和LASP都使用P2P环式通信,但它们在通信和计算方面都存在差异。通信:在前后向传播中,Ring Attention涉及通信两个状态$K, V \in \mathbb{R}^{N \times d}$。相比之下,LASP只通信一个单一状态$KV \in \mathbb{R}^{d \times d}$,这个状态不依赖于序列长度。LASP具有更低的理论通信复杂度,这使得它更高效,特别是在互连速度较慢、通信-计算重叠可能不是最优的环境中。计算:Ring Attention专为标准注意力设计,采用左乘方式,即$((QK^\top)V)$。而LASP专为类线性注意力的序列建模方法量身定制,它利用右乘核技巧$(Q(K^\top V))$来实现线性时间复杂度。

3.4 系统工程优化

核函数融合。为了提高LASP在GPU上的效率,我们在块内和块间计算中都执行了核函数融合,并且还将$KV$和$dKV$的更新融合到块内和块间计算中。

KV状态缓存。为了避免在反向传播期间重新计算激活值$KV$,我们选择在前向传播计算后立即将其存储在GPU的HBM中。在随后的反向传播中,LASP直接访问$KV$使用。需要注意的是,缓存在HBM中的$KV$激活值大小为$d \times d$,不受序列长度$N$的影响。当输入序列长度$N$非常大时,$KV$的内存使用变得可以忽略不计。

3.5 混合并行

数据-序列混合并行。如图2所示,LASP允许指定一个能被分布式世界大小整除的较小序列并行大小。这种配置导致输入数据同时在批次和序列维度上进行分割,这是一种称为数据-序列混合并行的混合并行类型。DeepSpeed中的ZeRO系列优化器【43, Rajbhandari et al., Zero: Memory optimizations toward training trillion parameter models, 2020】和PyTorch中的FSDP【70, Zhao et al., Pytorch fsdp: experiences on scaling fully sharded data parallel, 2023】提出将模型状态(包括优化器状态、梯度和模型参数)分布到分布式环境中的所有GPU上。作为数据并行的变体,这些技术与LASP无缝对齐。此外,它们专注于最小化模型状态的内存,这与LASP旨在减少每个GPU上激活内存的目标相辅相成。通过整合这些技术,使得处理长序列长度的大模型训练变得更加实用。

与张量并行和流水线并行的兼容性。LASP支持张量并行(TP)和流水线并行(PP)。在PP中,如GPipe【13, Kim et al., torchgpipe: On-the-fly pipeline parallelism for training giant models, 2020】调度方法所示,模型首先被划分到多个设备上,每个设备持有一部分模型。一个小批量内的数据被分成微批量,并顺序送入包含第一部分的设备。每个设备处理其微批量并将输出转发到序列中的下一个设备,同时准备接收和处理来自前一个设备的后续微批量。这种流水线输入的方法有效地最小化了设备空闲时间。当LASP与PP集成时,微批量被替换为来自一个小批量的子序列。与标准PP不同,每个设备在本地保留中间状态(前向的$KV$和后向的$dKV$),而不是像单独使用LASP时那样将它们传输到下一个设备。对于TP,与LASP的集成是流畅的。线性注意力层利用TP在块内和块间计算中分割矩阵操作。

层间混合模型上的混合SP。对于层间混合模型,混合SP方法在Softmax注意力Transformer层应用已有的Ring Attention SP,同时在线性注意力Transformer层使用LASP。由于这两种策略在各自的层类型中独立运作,它们不会相互干扰。这种直接的方法主要是将LASP实际应用于混合模型。我们在附录A.5.1中进行了实验,以证明混合SP的可行性。

A4 实验环境

  • 模型架构: 实验在两种代表性的线性注意力模型上进行评估:TransNormerLLM (TNL) 【33, Qin et al., Scaling transnormer to 175 billion parameters, 2023】、【34, Qin et al., Transnormerllm: A faster and better large language model with improved transnormer, 2024】和Linear Transformer 【11, Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, 2020, International Conference on Machine Learning】。TNL是最新纯基于线性注意力的大语言模型,而Linear Transformer是社区公认的经典线性Transformer模型。
  • 数据集: 收敛性实验使用的数据集是The Pile 【6, Gao et al., The pile: An 800gb dataset of diverse text for language modeling, 2020】。
  • 硬件配置:
    • GPU: 最多16台DGX-A100服务器,每台配备8块A100 80G GPU,共计128块GPU。
    • 互联: 节点内GPU通过NVSwitch连接,带宽600GBps。节点间使用RoCE(RDMA over Converged Ethernet)技术,每台服务器配备8个RoCE RDMA适配器,通信带宽800Gbps。
  • 软件配置:
    • 框架与库: PyTorch 2.1.1, Triton 2.0.0, CUDA 11.7, cuDNN 8.0, NCCL 2.14.3。
    • 代码实现: 基于Metaseq 【68, Zhang et al., Opt: Open pre-trained transformer language models, 2022】(集成了FairScale【5, FairScale authors, Fairscale: A general purpose modular pytorch library for high performance and large scale training, 2021】)和DeepSpeed开发。
  • 训练配置:
    • 超参数: 学习率0.0005,总更新步数50,000,预热步数2,000,权重衰减0.01。
    • 优化器: Adam,beta值为0.9和0.999。
    • 并行后端: 实验中使用了多种DDP后端,包括PyTorch DDP、Legacy DDP、FSDP、ZeRO系列,以交叉验证与LASP的兼容性。

A4 实验结果

可扩展性与速度比较

  • 可扩展性: 实验结果(图3)表明,使用LASP,在128个GPU上,TNL-1B模型能够成功将序列长度扩展到4096K(使用FSDP后端)和2048K(使用DDP后端)。最大序列长度与GPU数量成线性关系(例如,16个GPU可训练512K序列,64个GPU可训练2048K序列)。即使增加GPU数量,LASP也能保持高吞吐量。该可扩展性在FSDP和DDP后端下均表现一致。更详细的定量结果见附录A.5的表10。
    图3:LASP在吞吐量(tokens/sec)和内存使用上的可扩展性评估。左:LASP与FSDP后端集成;右:LASP与DDP后端集成。使用TNL-1B模型,批大小为1,在多达128个A100 80GB GPU上进行。带虚线的"×"表示发生内存不足(OOM)。
    图3:LASP在吞吐量(tokens/sec)和内存使用上的可扩展性评估。左:LASP与FSDP后端集成;右:LASP与DDP后端集成。使用TNL-1B模型,批大小为1,在多达128个A100 80GB GPU上进行。带虚线的"×"表示发生内存不足(OOM)。
  • 速度比较: 在64个GPU上,将LASP与现有的SP方法(Ring Attention, DeepSpeed-Ulysses, Megatron-SP)在TNL 1B和7B模型上进行了比较(图4)。由于其高效的通信设计,LASP在吞吐量上表现出显著提升。
    图4:LASP与Ring Attention、DeepSpeed-Ulysses和Megatron-SP的速度比较(tokens/sec)。带虚线的"×"表示发生内存不足(OOM)。评估使用TNL-1B和7B模型,批大小为1,在64个A100 80GB GPU上进行。这三种方法的并行大小均配置为64。
    图4:LASP与Ring Attention、DeepSpeed-Ulysses和Megatron-SP的速度比较(tokens/sec)。带虚线的"×"表示发生内存不足(OOM)。评估使用TNL-1B和7B模型,批大小为1,在64个A100 80GB GPU上进行。这三种方法的并行大小均配置为64。

收敛性

  • 收敛性测试: 实验在TNL (0.4B)、Linear Transformer (0.4B)以及一个使用Softmax注意力的LLaMA模型上进行,训练5万步。结果(表2)显示,无论是否使用LASP,以及在不同的DDP后端下,两个线性模型的损失值都保持一致,证明LASP不会对模型收敛产生负面影响。
    表2:LASP的收敛性能。所有实验使用8个A100 80G GPU,序列长度16K,批大小1。结果涵盖了与LASP结合的各种DDP后端。我们探讨了两种线性注意力模型:TransNormerLLM(TNL)和Linear Transformer,以及一个带Softmax注意力的Transformer模型(LLaMA),参数均为0.4B,进行了50K次更新。
    表2:LASP的收敛性能。所有实验使用8个A100 80G GPU,序列长度16K,批大小1。结果涵盖了与LASP结合的各种DDP后端。我们探讨了两种线性注意力模型:TransNormerLLM(TNL)和Linear Transformer,以及一个带Softmax注意力的Transformer模型(LLaMA),参数均为0.4B,进行了50K次更新。

消融研究

  • 系统工程优化消融: 对核函数融合(Kernel Fusion)和KV状态缓存(KV State Caching)进行了消融研究(表3)。在2个A100 GPU上训练一个1B参数的TNL模型(批大小2,序列长度8K),结果表明这两种技术都显著提高了训练吞吐量,而对内存使用的影响极小。
    表3:对系统工程优化技术“核函数融合”和“KV状态缓存”的消融研究。实验在TNL-1B模型上进行,批大小为2,序列长度为8K,使用2个A100 GPU。
    表3:对系统工程优化技术“核函数融合”和“KV状态缓存”的消融研究。实验在TNL-1B模型上进行,批大小为2,序列长度为8K,使用2个A100 GPU。
  • 激活值减少方法消融: 比较了激活检查点(AC)、LASP及其组合的效果(表4)。在单个节点(8个GPU)上,仅使用DDP和FSDP可训练的最大序列长度分别为12K和16K。AC和LASP都能显著扩展最大序列长度。LASP与GPU数量成比例扩展,而AC则不具备此特性。通过将LASP与AC结合,使用DDP和FSDP后端分别实现了496K和768K的最大序列长度,这得益于线性注意力、AC和LASP三者互补的优势。
    表4:对激活值减少方法的消融研究。测试了DDP和FSDP后端。使用一个配备8个A100 80G GPU的单节点训练一个TNL-1B模型。
    表4:对激活值减少方法的消融研究。测试了DDP和FSDP后端。使用一个配备8个A100 80G GPU的单节点训练一个TNL-1B模型。

A7 补充细节

5 讨论

线性复杂度序列模型的兴起。线性复杂度序列建模方法正成为下一代基础模型中传统Transformer(使用Softmax注意力)的重要替代品,因为它们具有显著更快的训练和推理速度,并且性能与传统方法相当。最近,AI社区见证了新型线性复杂度模型的快速发展,并获得了相当大的关注。例子包括线性注意力模型如TransNormerLLM,状态空间模型(SSM)如Mamba和Jamba,以及线性RNN模型包括RWKV、HGRN和Griffin。我们认为LASP的设计可以无缝集成到大多数线性复杂度模型中。为了强调LASP的泛化能力,我们在附录A.4中使用了一种广义形式的线性注意力【35, Qin et al., Unlocking the secrets of linear complexity sequence model from a unified perspective, 2024】,证明了其他线性复杂度模型也可以被纳入LASP框架中。

与Softmax注意力的兼容性。此外,探索线性注意力及LASP与广泛采用的Softmax注意力的兼容性也很重要。这两种机制在不同场景下各有优势。Softmax注意力在建模短序列方面非常有效,性能强大,但受限于序列长度的二次复杂度,这限制了其在长上下文任务中的可扩展性。另一方面,线性注意力及类似的线性序列建模方法为长序列提供了显著更高的效率,但在捕捉复杂依赖关系方面可能效果稍差。一个利用两者优势的实用解决方案是采用混合架构,在Transformer模型中交替使用Softmax注意力和线性注意力层。这一策略已在大型商业模型中实现,如Minimax-01【17, Li et al., Minimax-01: Scaling foundation models with lightning attention, 2025】和腾讯混元Turbo-S【55, Tencent, Tencent hunyuan turbo-s, 2025】,以及在较小的混合模型中,如Samba【44, Ren et al., Samba: Simple hybrid state space models for efficient unlimited context language modeling, 2024】、Jamba【20, Lieber et al., Jamba: A hybrid transformer-mamba language model, 2024】和GatedDeltaNet【65, Yang et al., Gated delta networks: Improving mamba2 with delta rule, 2024】。

A5 结论

本文提出了LASP,通过利用线性复杂度序列建模方法的右乘特性,有效解决了现有SP方法的局限性,显著提升了通信和并行效率。通过设计高效的P2P环式通信机制和精细的工程优化(包括核函数融合和KV状态缓存),LASP实现了通信流量的显著减少和GPU集群硬件利用率的提高。与所有类型的批处理级DDP方法的兼容性确保了LASP在超长序列大规模分布式训练中的实用性。我们的实验突出了LASP在可扩展性、速度、内存使用和收敛性能方面的优势。在特定的实验设置下,LASP在处理最大长8倍的序列时,实现了比现有开箱即用SP方法快得多的序列级分布式训练速度。

A6 附录

A.1 反向传播算法

算法3 LASP反向传播

1: 输入: 序列长度 N, 分布式世界大小 W, 序列并行大小 T, 衰减率 λ ∈ R+
   Qt, Kt, Vt, Ot, dOt ∈ R^{C×d} for t ∈ {1, 2, · · ·, T}。
2: 获取子序列长度(或块大小)C = N/T。
3: 初始化掩码 M ∈ R^{C×C},其中 M_{ij} = λ^{i-j} (如果 i ≥ j),否则 M_{ij} = 0。
4: 初始化 Λ = diag{λ, λ², · · ·, λ^C} ∈ R^{C×C}。
5: 初始化 dKV = 0 ∈ R^{d×d}。
6: for t ∈ {1, 2, · · ·, T} 在排名 i ∈ {1, 2, · · ·, W} 上并行 do
7:   计算 dQ_{t,intra} = [(dOt V_t^⊤) ⊙ M]Kt。
8:   计算 dQ_{t,inter} = ΛdOt KV_{t-1}^⊤。
9:   计算 dK_{t,intra} = [(dOt V_t^⊤) ⊙ M]^⊤ Qt。
10:  计算 dV_{t,intra} = [(Qt K_t^⊤) ⊙ M]^⊤ dOt。
11: end for
12: for t ∈ {T, · · ·, 2, 1} 在排名 i ∈ {W, · · ·, 2, 1} do
13:   从排名 (i + 1) 接收激活 dKV_{t+1}。
14:   计算 dK_{t,inter} = (λ^C Λ^{-1} Vt) dKV_{t+1}^⊤。
15:   计算 dV_{t,inter} = (λ^C Λ^{-1} Kt) dKV_{t+1}。
16:   在排名 i 上加载 KVi 作为 KVt。
17:   结合 dQ, dKt, dVt 的块内和块间部分。
18:   计算 dKVt = λ^C dKV_{t+1} + (ΛQt)^⊤ dOt。
19:   将激活 dKVt 发送到排名 i。
20: end for
21: return dQ = [dQt], dK = [dKt], dV = [dVt], with t ∈ {1, 2, · · ·, T}。

A.2 硬件和软件

硬件。我们的实验配置最多涉及16台DGX-A100服务器,每台服务器配备8个A100 GPU,这些GPU通过NVSwitch互连,确保GPU间带宽为600GBps。对于节点间通信,我们采用RoCE(RDMA over Converged Ethernet)技术,每台服务器使用8个RoCE RDMA适配器。此设置有助于实现高效的服务器间通信,带宽容量为800Gbps。

软件。实验在PyTorch 2.1.1和Triton 2.0.0中实现,使用CUDA 11.7、cuDNN 8.0和NCCL 2.14.3。我们的算法基于Metaseq和DeepSpeed开发。

A.3 实验设置

训练配置。训练配置设定了特定的超参数:学习率为0.0005以控制优化步长,更新次数上限为50,000次以定义训练时长,以及2,000次更新的预热期以通过逐渐调整学习率来稳定早期训练。此外,使用0.01的权重衰减率进行正则化以避免过拟合【49, Sun et al., Co2: Efficient distributed training with full communication-computation overlap, 2024】。选择Adam优化器,其beta值为0.9和0.999,用于管理梯度的动量和缩放,有助于有效和稳定的训练收敛【71, Zhou et al., pbsgd: Powered stochastic gradient descent methods for accelerated non-convex optimization, 2020】。实验中选择了不同的DDP后端,包括PyTorch DDP(简称DDP)、Legacy DDP、FSDP、ZeRO系列,用于交叉验证与LASP的兼容性。

A.4 LASP的泛化性

LASP的泛化性。虽然LASP最初受到线性注意力机制的启发,但我们的目标是展示其对各种线性序列建模方法的更广泛适用性。本节通过理论分析和实证验证来研究LASP的泛化能力。

递归记忆的通用形式。在理论方面,我们首先定义以下术语:记忆状态$m_t \in \mathbb{R}^{k \times d}$,输入状态$i_t \in \mathbb{R}^d$,扩展状态$e_t \in \mathbb{R}^k$,振荡状态$o_t \in \mathbb{R}^{k \times m}$,收缩状态$s_t \in \mathbb{R}^k$,并写出递归记忆的通用形式(引自【35, Qin et al., Unlocking the secrets of linear complexity sequence model from a unified perspective, 2024】):


这是公式(5)中线性注意力递归形式的通用形式,其中指定了$o_t$和$e_t$:

适用模型列表。LASP的设计可以无缝应用于能够用公式(24)普遍表示的模型。这些模型包括:S4【8, Gu et al., Efficiently modeling long sequences with structured state spaces, 2022】、S5【49, Smith et al., Simplified state space layers for sequence modeling, 2022】、DSS【9, Gupta et al., Diagonal state spaces are as effective as structured state spaces, 2022】、TNN【31, Qin et al., Toeplitz neural network for sequence modeling, 2023】、线性注意力【11, Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, 2020】、TNL【34, Qin et al., Transnormerllm: A faster and better large language model with improved transnormer, 2024】、RetNet【53, Sun et al., Retentive network: A successor to transformer for large language models, 2023】、Mamba【7, Gu & Dao, Mamba: Linear-time sequence modeling with selective state spaces, 2023】、RWKV-4【25, Peng et al., RWKV: Reinventing RNNs for the transformer era, 2023】、Cosformer【30, Qin et al., cosformer: Rethinking softmax in attention, 2022】、Lrpe【33, Qin et al., Linearized relative positional encoding, 2023】、GLA【64, Yang et al., Gated linear attention transformers with hardware-efficient training, 2023】、GateLoop【12, Katsch, Gateloop: Fully data-controlled linear recurrence for sequence modeling, 2023】、DUR【22, Mao, Fine-tuning pre-trained transformers into decaying fast weights, 2022】、GFW【45, Schlag & Schmidhuber, Gated fast weights for associative retrieval, 2018】、HGRN【39, Qin et al., Hierarchically gated recurrent neural network for sequence modeling, 2024】、【38, Qin et al., Hgrn2: Gated linear rnns with state expansion, 2024】和LRN【23, Martin & Cundy, Parallelizing linear recurrent neural nets over sequence length, 2018】。我们在表5中列出了所有这些模型及其对应的元素。

表5:在定义的通用形式内典型线性复杂度序列建模方法的清单。对于每种方法,列出了以下状态:输入状态、扩展状态、振荡状态、收缩状态和记忆状态。如果状态与输入序列直接相关,则下标i被强调。注意,我们使用1(k) ∈ Rk,其中1(k)j = 1对于j = 1, . . . , k,以及J(kd) = 1(k)1(d)⊤ ∈ Rk×d。
表5:在定义的通用形式内典型线性复杂度序列建模方法的清单。对于每种方法,列出了以下状态:输入状态、扩展状态、振荡状态、收缩状态和记忆状态。如果状态与输入序列直接相关,则下标i被强调。注意,我们使用1(k) ∈ Rk,其中1(k)j = 1对于j = 1, . . . , k,以及J(kd) = 1(k)1(d)⊤ ∈ Rk×d。

我们还为每种建模方法提供了完整的解释如下:
* S4: 在S4中,我们通过线性投影从输入$x_t$得到$u_t \in \mathbb{R}^d$,并通过SSM参数化得到$A \in \mathbb{R}^{k \times k}, B, C \in \mathbb{R}^{k \times 1}$。计算如下:


注意,S4的原始定义是作为通道级别的映射$f_i, i = 1, \dots, d$ of $\mathbb{R}^{n \times 1} \to \mathbb{R}^{n \times 1}$。
* S5: S5的递归方程与S4相同,唯一的区别是直接定义了映射$\mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d}$以及$B, C \in \mathbb{R}^{k \times d}$。
* DSS: DSS的递归方程与S4/S5相同,唯一的区别是直接定义了映射$\mathbb{R}^{n \times d} \to \mathbb{R}^{n \times d}$以及$B, C \in \mathbb{R}^{k \times d}, A = \text{Diag } a \in \mathbb{R}^{k \times k}$。
* TNN: 根据【28, Qin & Zhong, Accelerating toeplitz neural network with constant-time inference complexity, 2023】,TNN可以无损转换为SSM,其中$C = J(kd) \in \mathbb{R}^{k \times d}, B \in \mathbb{R}^{k \times d}, A = \text{Diag} \lambda_1, \dots, \lambda_k \in \mathbb{R}^{k \times k}$,通过线性投影从$x_t$得到$u_t$,并且可以表示为递归公式:

* Linear Attention: 在线性注意力中,我们通过线性投影从输入$x_t \in \mathbb{R}^d$得到查询$q_t \in \mathbb{R}^k$,键$k_t \in \mathbb{R}^k$,值$v_t \in \mathbb{R}^d$,递归计算如下:

* TNL/RetNet: TNL/RetNet是带有指数衰减的线性注意力形式,获取$q_t, k_t, v_t$的方法与线性注意力相同,$\lambda$是预定义的不可学习参数。其递归计算为:

* Mamba: Mamba可以看作是数据依赖的S4。它使用类似的方法获取$u_t, A, B, C$,其中$A_t, B_t, C_t$通过$x_t$和$A, B, C$计算。其递归方程定义为:

* RWKV-4: 在RWKV-4中,我们通过线性投影从输入$x_t$和可学习权重$w$得到$r_t, k_t, v_t$。忽略RWKV-4的分母,递归方程可以简化为:

与S4类似,RWKV-4使用通道级别的映射$f_i, i = 1, \dots, d$ of $\mathbb{R}^{n \times 1} \to \mathbb{R}^{n \times 1}$。
* Cosformer: 在Cosformer中,我们从输入$x_t \in \mathbb{R}^d$和预定义的$\theta$(不可学习)得到查询$q_t \in \mathbb{R}^k$,键$k_t \in \mathbb{R}^k$,值$v_t \in \mathbb{R}^d$。然后递归计算如下:

* Lrpe: 在Lrpe中,我们从输入$x_t \in \mathbb{R}^d$和可学习权重$\theta$得到查询$q_t \in \mathbb{R}^k$,键$k_t \in \mathbb{R}^k$,值$v_t \in \mathbb{R}^d$,并递归计算如下:

* GLA/GateLoop: 在GLA/GateLoop中,我们从输入$x_t \in \mathbb{R}^d$得到查询$q_t \in \mathbb{R}^k$,键$k_t \in \mathbb{R}^k$,值$v_t \in \mathbb{R}^d$,衰减$g_t \in \mathbb{R}^k$,并递归计算如下:

* DUR/GFW: 在DUR/GFW中,我们从输入$x_t \in \mathbb{R}^d$得到查询$q_t \in \mathbb{R}^k$,键$k_t \in \mathbb{R}^k$,值$v_t \in \mathbb{R}^d$,衰减$g_t \in \mathbb{R}^k$,$\bar{g}_t \in \mathbb{R}^d$,并递归计算如下:

* HGRN/LRN: 在HGRN/LRN中,我们从输入$x_t \in \mathbb{R}^1$得到输出门$o_t \in \mathbb{R}^1$,遗忘门$f_t \in \mathbb{R}^1$,输入状态$i_t \in \mathbb{R}^1$,并递归计算如下:

与S4类似,HGRN/LRN使用通道级别的映射$f_i, i = 1, \dots, d$ of $\mathbb{R}^{n \times 1} \to \mathbb{R}^{n \times 1}$。

泛化性的实验验证。为经验性地验证LASP的泛化性,我们采用表2的实验设置,并将LASP应用于表5中列出的另外三种线性序列建模方法,即Cosformer、RetNet和Mamba。收敛结果见表6,表明LASP不会对收敛产生负面影响,并达到了与基线相当的性能。

表6:LASP在Cosformer、RetNet和Mamba上的收敛结果。测试了0.4B参数的TNL模型,批大小为2,序列长度为16K。
表6:LASP在Cosformer、RetNet和Mamba上的收敛结果。测试了0.4B参数的TNL模型,批大小为2,序列长度为16K。

A.5 额外的实验结果

A.5.1 混合SP在混合模型上的结果

混合SP在混合模型上的结果。我们进行了一个小规模实验(8× A100,1B参数,DDP后端)来评估混合SP方法在混合模型上的可行性。在此设置中,“1/4混合”模型表示每四层中有一层是Softmax注意力Transformer层。用“S”表示Softmax注意力,“L”表示线性注意力,一个16层的“1/4混合”模型遵循“LLLSLLLSLLLSLLLS”的模式。表7中的结果表明,混合SP有效地扩展了TNL和Linear Transformer的最大可训练序列长度,而只导致训练速度轻微下降。

表7:混合SP在层间混合模型上的结果。“1/4混合”指每四层中有一层是Softmax注意力Transformer层的模型。报告了最大序列长度和吞吐量(tokens/sec)。
表7:混合SP在层间混合模型上的结果。“1/4混合”指每四层中有一层是Softmax注意力Transformer层的模型。报告了最大序列长度和吞吐量(tokens/sec)。

A.5.2 下游任务评估结果

下游任务评估结果。我们进行了一个延长训练时长至30万步(消耗400亿个token)的实验,以评估LASP的性能及其在下游任务上的评估结果。研究了0.4B参数的TNL和Linear Transformer。我们在多个下游基准上评估了训练后模型的性能,包括PIQA、HellaSwag (HS)、WinoGrande (WG)、ARC-E、ARC-C、OBQA和CSR-AVG。结果呈现在表8和表9中。LASP不会对下游任务性能产生负面影响。

表8:LASP在延长30万步训练后的收敛结果。测试了0.4B参数的TNL和Linear Transformer,批大小为2,序列长度为16K。
表8:LASP在延长30万步训练后的收敛结果。测试了0.4B参数的TNL和Linear Transformer,批大小为2,序列长度为16K。

表9:下游任务评估结果。HS: HellaSwag, WG: WinoGrande。分数越高表示性能越好。
表9:下游任务评估结果。HS: HellaSwag, WG: WinoGrande。分数越高表示性能越好。

A.5.3 定量的可扩展性结果

定量的可扩展性结果。见下页表10。

表10:LASP在吞吐量(tokens/sec)和每GPU内存使用(GB)上的定量可扩展性结果。实验在TNL-1B上进行,序列长度从2K扩展到4096K,批大小为1。测试了DDP和FSDP后端。
表10:LASP在吞吐量(tokens/sec)和每GPU内存使用(GB)上的定量可扩展性结果。实验在TNL-1B上进行,序列长度从2K扩展到4096K,批大小为1。测试了DDP和FSDP后端。