ZeCO: Zero Communication Overhead Sequence Parallelism for Linear Attention
ZeCO: Zero Communication Overhead Sequence Parallelism for Linear Attention
作者/机构: Yuhong Chou (香港理工大学), Zehao Liu (香港理工大学), Ruijie Zhu (加州大学圣克鲁兹分校), Xinyi Wan (新加坡国立大学), Tianjian Li (TikTok), Congying Chu (中国科学院自动化研究所), Qian Liu (TikTok), Jibin Wu (香港理工大学), Zejun Ma (TikTok)
A1 主要贡献
核心问题: 线性注意力机制通过提供线性计算复杂度,为处理超长序列(例如1M上下文)的大型语言模型(LLM)带来了显著优势。然而,现有的序列并行(SP)方法是跨设备分配这些工作负载的关键,却因巨大的通信开销而成为主要瓶颈。
研究目标: 提出一种新的序列并行方法,旨在克服现有方法的局限性,并为长序列训练实现端到端的近线性可扩展性。例如,使用ZeCO在64个设备上训练一个序列长度为1M的模型,其耗时与在单个设备上训练16k序列大致相同。
创新点 (ZeCO):
本文介绍了ZeCO,一种专为线性注意力模型设计的零通信开销序列并行方法。其核心是一种名为All-Scan的新型集体通信原语。All-Scan能够为每个SP rank精确提供其所需的初始算子状态,同时保持最小的通信足迹,从而有效消除通信开销。
主要贡献概括:
1. 提出ZeCO方法: 这是一种新颖的线性注意力模型序列并行方法。ZeCO通过利用我们提出的All-Scan集体通信重新构建了序列并行。All-Scan采用流水线通信方式,实现了理论上最小的通信量。这种集成方法能够有效重叠通信和计算,同时只产生极小的额外计算和I/O开销。
2. 理论最优性证明: 从理论上证明了ZeCO的最优性。通过对不同序列并行策略的时间成本分析,表明ZeCO构成了最低的必要成本,从而确立了其效率。
3. 全面的多层次实验验证: 在集体通信、算子和模型三个层面上进行了全面的实验,结果表明ZeCO带来了显著的性能提升。如图1所示,All-Scan集体通信实现了高达3.9倍的通信加速,是现有最快的序列并行方法;而ZeCO序列并行算子则带来了高达9.3倍的整体加速。在模型层面,ZeCO将吞吐量提高了60%以上,并在8到256个设备上展示了近线性的可扩展性,即使上下文长度高达8M token。
图 1: (左) ZeCO在序列长度与设备数量成比例扩展时,表现出近线性的扩展效率,接近理论上限。(右) ZeCO在三个性能指标上显著优于SOTA SP方法:通信时间、运行时额外成本和单GPU吞吐量。通信/吞吐量指标使用256个GPU(8M序列)获得,运行时额外成本使用128个GPU(4M序列)获得。
A3 背景知识
本文中,我们使用粗体大写字母(如Q)表示矩阵,并用相同字母表示矩阵的行,因此$Q_t$指的是Q的第$t$行。除非另有说明,$p$表示设备数量,$L$表示每个设备的序列长度,$d$表示隐藏维度,$h$表示注意力头的数量,$C$表示块(chunk)的长度。
2.1. 门控线性注意力(GLA)的循环和分块形式
线性注意力机制。线性注意力机制利用核技巧(kernel trick)消除了全注意力中的softmax计算,并通过交换计算顺序将注意力计算复杂度从二次降至线性【索引13,Katharopoulos et al.,Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention,2020,arXiv】、【索引4,Choromanski et al.,Rethinking Attention with Performers,2022,arXiv】。实现此机制有多种方式。多项工作【索引41,S. Yang, Bailin Wang, Shen, et al.,Gated Linear Attention Transformers with HardwareEfficient Training,2024,arXiv】、【索引21,Qin, S. Yang, et al.,Hgrn2: Gated linear rnns with state expansion,2024,arXiv】、【索引5,Chou et al.,MetaLA: Unified Optimal Linear Approximation to Softmax Attention Map,2024,arXiv】已经总结出一种统一的(对角衰减)线性模型形式。本文使用门控线性注意力(GLA)算子【索引41,S. Yang, Bailin Wang, Shen, et al.,Gated Linear Attention Transformers with HardwareEfficient Training,2024,arXiv】,作为线性模型的一种泛化形式来演示我们的算法。注意力状态按以下方式循环更新:
$$S_t = (\alpha_t^\top 1) \odot S_{t-1} + K_t^\top V_t, O_t = Q_t S_t,$$其中$\alpha_t^T \in (0, 1)^{d_k}$是衰减因子。
分块形式。为了在训练期间实现高效并行,序列被划分为$N$个长度为$C$的块(chunk),并且循环以分块方式重新表述。设块$i$包含从$iC$到$(i+1)C-1$的token,其衰减向量为$\alpha_{iC+j}$。设$S[i] \in \mathbb{R}^{d \times d}$是处理完$i$个块后的块级隐藏状态,即$S[i] := S_{iC}$。GLA定义了块的累积衰减:$\gamma[i] = \prod_{j=1}^{C} \alpha_{iC+j}$,以及token级别的缩放因子:$\Gamma_{iC+j} = \frac{b_{(i+1)C}}{b_{iC+j}}$,$\Lambda_{iC+j} = \frac{b_{iC+j}}{b_{iC}}$,其中$b_t = \prod_{s=1}^{t} \alpha_s$。块级的GLA状态和输出计算如下:
$$ \boldsymbol{S}_{[i]}=\left(\boldsymbol{\gamma}_{[i]}^{\top} \mathbf{1}\right) \odot \boldsymbol{S}_{[i-1]}+\left(\boldsymbol{K}_{[i]} \odot \boldsymbol{\Gamma}_{[i]}\right)^{\top} \boldsymbol{V}_{[i]}, $$每个块的输出应计算为:
$$ \boldsymbol{O}_{[i]} = \underbrace{(\boldsymbol{Q}_{[i]} \odot \boldsymbol{\Lambda}_{[i]}) \cdot S_{[i-1]}}_{O_{[i]}^{\text{inter}}} + \underbrace{\left[\left((\boldsymbol{Q}_{[i]} \odot \boldsymbol{\Lambda}_{[i]}) \cdot (\boldsymbol{K}_{[i]} \odot \boldsymbol{\Gamma}_{[i]})^{\top}\right) \odot M\right] \cdot \boldsymbol{V}_{[i]}}_{O_{[i]}^{\text{intra}}}, $$其中,跨块(inter-chunk)部分循环地更新全局状态,而块内(intra-chunk)部分则处理对角线上的掩码注意力计算。
2.2. 线性注意力模型的序列并行
LASP1。LASP1【索引28,Weigao Sun, Qin, et al.,Linear Attention Sequence Parallelism,2025,arXiv】采用分块并行策略,将输入序列划分为多个连续的块,并均匀分配到各个设备上。每个设备基于线性注意力公式串行地计算其输出。在通信方面,每个设备从前一个块接收状态,更新后再传递给下一个设备。尽管这种方式避免了冗余的通信量,但它强制设备间严格按串行顺序执行,导致总计算时间随设备数量线性增长,严重限制了并行效率和吞吐量。
LASP2。LASP2【索引27,Weigao Sun, Lan, et al.,LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid,2025,arXiv】、【索引14,A. Li et al.,Minimax-01: Scaling foundation models with lightning attention,2025,arXiv】遵循类似的分块计算结构,但用All-Gather通信取代了串行状态传递。每个设备必须首先从所有其他设备收集本地状态,然后对相同的数据执行相同的扫描操作。这使得设备可以进行计算并行,但引入了巨大的通信开销。总通信量随设备数量线性增长,因为每个设备都必须从所有其他设备收集状态张量。
2.3. 全注意力的序列并行
现有方法。Megatron-LM中的序列并行(也称为上下文并行)【索引25,Shoeybi et al.,Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism,2020,arXiv】和Ring Attention【索引18,H. Liu, Zaharia, and Abbeel,Ring attention with blockwise transformers for nearinfinite context,2023,arXiv】、【索引2,Brandon et al.,Striped attention: Faster ring attention for causal transformers,2023,arXiv】分别基于All-Gather和P2P通信,以一次性或流水线的方式实现KV块的全局同步,并已被广泛采用【索引10,Grattafiori, Dubey, Jauhri, Pandey, Kadian, Al-Dahle, Letman, Mathur, Schelten, Vaughan, et al.,The Llama 3 Herd of Models,2024b,arXiv】、【索引27,Weigao Sun, Lan, et al.,LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid,2025,arXiv】。Ulysses【索引12,Jacobs et al.,DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models,2023,arXiv】将不同自注意力头的计算分布在设备上,这种方法易于实现,但与张量并行(TP)不兼容,且受限于头的数量。Ring Self Attention【索引15,S. Li et al.,Sequence parallelism: Long sequence training from system perspective,2021,arXiv】是最早提出全注意力序列并行的方法,但它没有利用自注意力的I/O高效优化【索引7,Dao, Fu, et al.,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,Advances in neural information processing systems】、【索引6,Dao,Flashattention-2: Faster attention with better parallelism and work partitioning,2023,arXiv】、【索引24,Shah et al.,Flashattention-3: Fast and accurate attention with asynchrony and low-precision,2024,Advances in Neural Information Processing Systems】、【索引23,Rabe and Staats,Self-attention does not need O(n2) memory,2021,arXiv】,这限制了其适用性。
根本局限。尽管如此,全注意力的序列并行从根本上受限于自注意力算法本身:即使不考虑通信,由于算法固有的复杂性,对于超长序列,计算成本也会变得极其昂贵。
A2 方法细节
3.1. ZeCO序列并行方法与通信需求
符号定义。假设每个设备被分配一个长度为$L$的序列,该序列被划分为$N = L/C$个块。输入$X \in R^{L \times d}$经过投影后,我们得到$N$个不重叠的块$Q[n]$, $K[n]$, $G[n]$,其中$n \in N$。我们用$S_i$, $i \in PL$表示全局状态,用$S[n]$, $n \in N$表示每个设备上的本地块状态。
图 2: ZeCO图示。ZeCO在三个维度上突显其优势:(1) 并行可扩展性:实现与DP相当的效率(子图:算法比较);(2) 算子级计算:实现通信与本地计算的重叠,以最大化资源利用率(子图:ZeCO计算);以及(3) 通信模式:利用定制的流水线式All-Scan通信模式,大幅减少设备间的同步延迟。(子图:All-Scan通信)
本地状态计算。根据公式(2),在每个设备内部,我们从初始状态$S[0]=0$开始,顺序地计算本地状态:
$$ \boldsymbol{S}_{[n]} = \left( \boldsymbol{\gamma}_{[n]}^{\top} \mathbf{1} \right) \odot \boldsymbol{S}_{[n-1]} + \tilde{\boldsymbol{K}}_{[n]}^{\top} \boldsymbol{V}_{[n]}, \quad \text{for} \quad n = 1, \ldots, N. $$与GLA相比,我们额外维护了累积衰减向量$\tilde{\gamma}[n]$,它保存了从第一个块到第n个块的总乘积衰减:
$$\tilde{\gamma}_{[n]}=\prod_{i=0}^{n} \gamma_{[i]}, \quad \text { for } \quad n=1, \ldots, N.$$在本地循环的末尾,我们得到一个本地块状态列表{$S[0], S[1]...S[N]$}。
全局状态更新。全局状态的递归由公式(1)定义。为了获得设备$p$的一系列全局状态,设备$p$必须从设备$p-1$获取最后一个全局状态$S_{(p-1)L}$。然后,我们可以通过以下方式更新当前设备的全局状态:
$$\boldsymbol{S}_{(p-1)L+nC} = (\tilde{\gamma}_{[n]}^{\top} \mathbf{1}) \odot \boldsymbol{S}_{(p-1)L} + \boldsymbol{S}_{[n] \cdot}$$我们在附录A.1中给出了上述全局更新计算的证明。由于将每个本地状态$S[n]$更新为其全局状态$S_{(p-1)L+nC}$在各个块之间是独立的,我们可以首先将$S[N]$更新为设备$p$的最后一个全局状态$S_{pL}$,并将其发送出去,以启动设备$p+1$中的全局状态更新。
All-Scan通信与计算重叠。为了满足这一通信需求,我们在3.2节中提出了All-Scan集体通信算子。All-Scan通信与不依赖通信的本地计算重叠;在实践中,我们同时计算对角注意力分数,如图2所示。All-Scan使得ZeCO能够并行化设备间通信和设备内计算。在实现上,ZeCO以最小的额外计算和I/O成本重新安排了GLA的标准形式,从而实现了高效的序列并行训练(见算法1)。
3.2. All-Scan集体通信
解决设备间依赖。为了将每个设备中的本地块状态转换为全局一致的值,每个设备$p$需要其前驱设备的最终状态$S_{(p-1)L}$。这在设备之间呈现出一个依赖链,会导致与设备数量相关的通信延迟。为了高效地解决这个问题,我们提出了一种All-Scan集体通信策略,该策略集接收、更新和发送于一体。具体来说,All-Scan将大的状态张量分割成可以顺序传输和处理的较小片段。
流水线状态扫描。我们不是接收完整的状态$S_{(p-1)L}$,而是将其沿$d_k$维度划分为$K$个连续的块从设备$p-1$发送:
$$ \boldsymbol{S}_{(p-1)L} = \left[ \boldsymbol{S}_{(p-1)L}^{(1)}, \boldsymbol{S}_{(p-1)L}^{(2)}, \dots, \boldsymbol{S}_{(p-1)L}^{(K)} \right], \quad \boldsymbol{S}_{(p-1)L}^{(k)} \in \mathbb{R}^{\frac{d_k}{K} \times d_v}. $$相应地,衰减因子也被分割成对齐的片段$\tilde{\gamma}^{(j)}[N] \in \mathbb{R}^{1 \times \frac{d_k}{K}}$。每个状态块都以流水线方式从$p-1$传输到$p$,并立即应用更新并发送:
$$S_{pL}^{(k)} = (\tilde{\gamma}_{[N]}^{(k)\top} \mathbf{1}) \odot S_{(p-1)L}^{(k)} + S_{[N]}^{(k)} \quad \text{for} \quad k=0, ..., K$$这种设计使得设备$p+1$在接收到$S_{pL}$的第一个块时,就能立即开始在All-Scan中更新其最后一个全局状态$S_{(p+1)L}$,如算法2所示。作为一个通信原语,All-Scan可以与其他CUDA流独立运行,实现了细粒度的通信-计算重叠,从而在长上下文训练中最大化设备利用率和吞吐量。
3.3. 最优性分析
最优性条件。我们现在正式确立,使用All-Scan集体通信算法的序列并行策略可以达到线性注意力SP的理论最优性。我们确定了最优性的两个必要和充分条件,并证明ZeCO同时满足这两个条件:
1. 零通信开销:每个设备只传输和接收最小必要尺寸的信息(数据),没有冗余通信。
算法1:使用All-Scan通信的ZeCO前向传播
1: 注意: 高亮部分代表ZeCO算法对GLA算法的SP适配,红色部分的低成本代表SP的较低额外成本。
2: 输入: Q, K, ∈ R^(L×d_k), V ∈ R^(L×d_v), G = [α₁...α_L] ∈ R^(L×d_k), 块大小C, 设备数P, 设备排名p ∈ {0, 1, . . . , P − 1}
3: 将Q, K, G划分为N = L/C个块 {Q[1]...Q[N]}, {K[1]...K[N]}, {G[1]...G[N]},每个大小为C × d_k。将V划分为N个块 {V[1]...V[N]},每个大小为C × d_v。
4: 在SRAM上初始化 S = 0 ∈ R^(d_k×d_v), γ̃ = 1 ∈ R^(d_k)
5: 将γ̃, S写入HBM,记为γ̃[0], S[0]。
6: for n ← 0 to N do
7: 从HBM加载K[n], G[n], V[n]到SRAM。
8: 在芯片上,计算γ[n] ∈ R^(d_k), Γ[n] ∈ R^(C×d_k) 和 K̄[n] = K[n] ⊙ Γ[n], γ̃ = γ̃ ⊙ γ[n]
9: 将γ̃写入HBM,记为γ̃[n]。
10: 在芯片上,计算 S = (γᵀ[n]1) ⊙ S + K̃ᵀ[n]V[n]。
11: 将S写入HBM,记为S[n]。
12: end for
13: 并行执行:
14: 并行流 1:
15: S_((p−1)L), S_(pL) ← All-Scan(S[N], γ̃[N])
16: 并行流 2:
17: parfor n ← 1 to N do
18: 从HBM加载Q[n], K[n], G[n] ∈ R^(C×d_k)到SRAM。
19: 在芯片上,计算Λ[n] ∈ R^(C×d_k), Q̃[n] = Q[n] ⊙ Λ[n], K̄[n] = K[n]/Λ[n]
20: 在芯片上,构建因果掩码 M ∈ R^(C×C)
21: 在芯片上,计算 P = (Q̃[n]K̄ᵀ[n]) ⊙ M ∈ R^(C×C)
22: 将P作为P[i]写入HBM。
23: end parfor
24: 流屏障
25: for n ← 1 to N do
26: 从HBM加载Q[n], G[n], V[n], S_((p−1)L), S[n], γ̃[n−1], P到SRAM。
27: 在芯片上,计算Λ[n]
28: 在芯片上,计算Q̃[n] = Q[n] ⊙ Λ[n]
29: 在芯片上,计算 O_inter[n] = Q̃[n](S[n−1]+(γ̃ᵀ[n-1]1) ⊙ S_((p−1)L)), O_intra = PV[n] ∈ R^(C×d_v)
30: 在芯片上,计算O[n] = O_inter + O_intra
31: 将O[n]存储到HBM。
32: end for
33: 返回 O = {O[1]...O[N]}, S = {S_((p−1)L), S[1]...S[N], S_(pL)}。
算法2:All-Scan算法
零通信开销分析。令$S \in \mathbb{R}^{d_k \times d_v}$表示累积的状态张量。根据线性注意力的输出公式(3),这个状态代表了必须在块之间通信的最小信息。对于一个分布在$P$个设备上的序列,任何SP算法都必须至少在设备边界上传输状态信息。令$V^{(p)}_{ZeCO}$表示ZeCO(All-Scan)中第$p$个设备的通信量。每个设备精确地向下一个设备发送一次最后的全局状态,因此:
$$V_{\text{ZeCO}}^{(p)} = |S| = d_{k} \times d_{v}.$$这代表了理论下限。然而,像LASP-2这样的现有方法依赖于all-gather操作,从所有其他$P-1$个设备接收本地状态,导致通信量为$(P-1) \times d_k \times d_v$,该值随设备数量的增加而增加。如图3所示,ZeCO在SP场景中实现了可能的最小通信量。
最优SP策略分析。令$T_P^{SP}(L)$表示使用$P$个设备通过序列并行处理长度为$L$的序列的总运行时间,令$T_1^{SP}(PL)$表示在单个设备上处理总长度为$PL$的序列的运行时间。令$T_{P, \text{ideal-sp}}(PL)$表示理想条件下的运行时间,假设是零额外开销的完美并行。对于理想的SP,应满足以下属性:
$$T_{\text {ideal-sp }}^{P}(P L)=T_{\text {ideal-SP }}^{1}(L)=\frac{T_{\text {ideal-SP }}^{1}(P L)}{P}$$公式(10)的含义是,在理想情况下,序列并行的吞吐量应随设备数量线性扩展(即处理时间与吞吐量成反比)。然而,在实践中,序列并行会引入额外开销。因此,我们接下来分析实际场景下的延迟。在实践中,设备间传输和同步数据所需的前提通信延迟、计算和I/O会引入额外延迟。对于ZeCO(以及其他序列并行),关系变为:理想序列并行成本 + 额外成本,可以形式化为:
$$T_{\mathrm{SP}}^{P}(P L)=T_{\text {ideal-SP }}^{P}(P L)+T_{\text {extra\_comp\&I/O }}+\left(T_{\text {All\_Scan }}-T_{\text {overlaped\_comp }}\right).$$对于ZeCO,前两个部分与通信无关。后两个部分代表通信延迟,计入了All-Scan算子中无法被本地对角线注意力重叠的部分(我们假设最坏情况)。这也可以看作是与理想SP的差距。从公式(11)可以看出,$T_{P, \text{ideal-sp}}(PL)$和$T_{\text{overlaped_comp}}$是算法固有的。因此,SP的关键问题在于,线性注意力的额外时间成本,即$T_{\text{extra_comp\&I/O}}$和$T_{\text{All_Scan}}$,可以减少到何种程度。接下来将通过分析All-Scan通信的最优性,以及证明$T_{\text{extra_comp\&I/O}}$是一个可忽略的系统成本,来展示ZeCO如何成功地达到最优。
All-Scan时间成本分析。在All-Scan中,我们将$S$划分为$K$个块,每个块的大小为$\mathbb{R}^{\frac{d_k}{K} \times d_v}$,并以流水线方式传输这些块,如图2所示。通过将状态$S$划分为$K$个块并以流水线方式更新它们,有效的通信延迟可以计算为:
$$T_{\text {All\_Scan }}=\tau\left(d_{k} \times d_{v}\right)+\frac{(P-1) \tau\left(d_{k} \times d_{v}\right)}{K},$$其中$\tau(\cdot)$表示通信给定大小张量所需的时间。公式12显示了成本$T_{All\_Scan}$的两个组成部分。第一项代表可以通过流水线方法并行化的开销,这是必要的,并对应于最小通信要求。第二项是边界处的开销。随着$K$的增加,边界开销减少,重叠程度提高。因此,当$K$变得足够大时,边界开销趋近于零。因此,带有All-Scan的ZeCO实现了通信的最小时间成本。
额外计算和I/O成本分析。$T_{\text{extra_comp\&I/O}}$项包括两部分:少量的额外浮点运算,以及几个辅助张量的HBM加载和存储操作。在算法1中,第9行和第26行对$\tilde{\gamma}[n]$的加载和存储操作是向量级别的,仅占状态张量的$\frac{1}{d_v}$。所需的额外状态可以被重用$N$次,仅产生$\frac{1}{N}$的开销。对于序列长度为8192和块大小为64($N=128$),这与典型的$d_v$相当,增加的开销小于1%。元素级乘法的成本可以忽略不计。因此,$T_{\text{extra_comp\&I/O}}$在实践中可以安全地忽略。这证明了ZeCO的时间公式(11)应为:
$$\begin{aligned} \begin{aligned} T_{\mathrm{ZeCO}}^{P}(P L) & =T_{\text {ideal-SP }}^{1}(L)-T_{\text {overlaped\_comp }}+\tau\left(d_{k} \times d_{v}\right)+\epsilon \\ & \approx T_{\text {ideal-SP }}^{1}(L)-T_{\text {overlaped\_comp }}+\tau\left(d_{k} \times d_{v}\right), \end{aligned} \end{aligned}$$
图 3: ZeCO在满足最低通信量的前提下,通信时间最短。左侧两图显示了算法计算速度和通信量的理论值,右图显示了实际通信时间。
其中$\epsilon$代表可忽略的计算和I/O成本。
与其他方法的对比。相比之下,像LASP这样的现有方法在设备间具有严格的串行依赖性,导致(我们假设其他方法中的$T_{\text{extra_comp\&I/O}}$项也可以优化到可忽略的水平。即便如此,这些方法仍然是次优的):
$$T_{\mathrm{LASP}}^{P}(P L)=P \times\left(T_{\text {ideal-SP }}^{1}(L)+\tau\left(d_{k} \times d_{v}\right)\right)>T_{\mathrm{ZeCO}}^{P}(P L).$$而LASP-2虽然通过并行计算改进了LASP,但遭受了更高的通信成本:
$$T_{\text{LASP-2}}^{P}(PL) = T_{\text{ideal-SP}}^{1}(L) + P \times \tau(d_{k} \times d_{v}) > T_{\text{ZeCO}}^{P}(PL).$$因此,ZeCO以最小的额外成本实现了零通信开销和最优的SP策略。这种最优性直接转化为卓越的性能。在图3中,我们展示了不同SP算法的通信成本和计算开销的理论值,以及通信时间的实际值。我们还在附录A.2中对现有SP算法进行了统一的通信和运行时分析。
A4 实验环境
- 硬件配置: GPU集群,配备256块H100 80GB GPU。
- 软件配置: 模型在Lingua【索引35,Videau et al.,Meta Lingua: A minimal PyTorch LLM training library,2024,https://github.com/facebookresearch/lingua】中进行训练,这是一个基于PyTorch的分布式训练框架 。
- 模型架构: 1B参数的GLA模型。
- 实现细节: 为了与基线序列并行(SP)方法如LASP1【索引28,Weigao Sun, Qin, et al.,Linear Attention Sequence Parallelism,2025,arXiv】和LASP2【索引27,Weigao Sun, Lan, et al.,LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid,2025,arXiv】进行公平比较,我们从Flash Linear Attention【索引43,S. Yang and Y. Zhang,FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism,2024,https://github.com/fla-org/flash-linear-attention】仓库中适配了分块门控线性注意力算子用于我们的实现。完整的实验设置和数据在附录A.3中提供 。
A4 实验结果
我们评估了所提出的ZeCO SP算法和All-Scan通信算子在1B-GLA模型上的效率和可扩展性。评估重点关注两个方面:(1)不同集体通信算子的通信速度;(2)在增加GPU数量时,算法级和模型级的可扩展性。
4.1. 通信速度
实验内容: 我们评估了不同通信算子在各自正确训练所需通信负载下的通信运行时。实验在8到256个GPU上进行,每个GPU分配8K序列长度。每个通信核预热5轮,然后报告50次运行的平均值。
实验结果: 如图3所示,All-gather (Megatron)在128和256个GPU的实验中发生了内存溢出。对于其他方法,为了便于展示,Y轴的上半部分是对数尺度的渲染结果。All-Scan在不同规模的集群中都显著优于其他方法。值得注意的是,在256个GPU上,All-Gather (LASP2)比All-Scan慢4倍。
4.2. SP算法运行时和模型吞吐量
实验内容: 我们评估了线性注意力SP方法(包括LASP1、LASP2和ZeCO)在微观层面(算法)和宏观层面(模型训练)的性能。
- SP算法运行时: 我们测量了每个SP算子在相同设置(L=16K或32K,H=16)下的前向和后向传播时间,并与理想情况下的DP(数据并行)算子进行比较。DP算子的时间作为理论下限。
- 模型吞吐量: 我们使用不同的序列并行方法测试了1B-GLA模型的训练吞吐量,设置相同(L=16K或32K,H=16),并与使用DP的理想模型进行比较。使用DP训练的GLA模型的吞吐量作为理论上限。
实验结果:
- SP算法运行时: 图4显示,在128个GPU的实验中(2M和4M序列长度),ZeCO的单次前向和后向传播仅比理论下限慢3毫秒,这符合我们在3.3节中的分析,并证明了我们算法的最优性。
- 模型吞吐量: 如图4所示,随着GPU数量的增加,ZeCO实现了总吞吐量的线性增长,这满足了序列并行的初衷,而其他方法则出现了严重的性能下降。
图 4: LASP在SP算子运行时(上半部分)和吞吐量(下半部分)的可扩展性评估。在每个GPU 16k和32k序列长度的对比测试中,ZeCO算法表现出与DP算法同样稳定的时间。在16k和32k两种情况下,ZeCO都展现出接近DP的线性扩展吞吐量增长曲线,而其他方法则出现性能退化。
A5 结论
工作总结: 在这项工作中,我们提出了用于线性注意力的ZeCO序列并行方法,在理论和实证结果上均达到了SOTA水平。更重要的是,我们的方法充分释放了线性模型的算法效率,并首次实现了序列并行的近线性吞吐量扩展。在系统层面,我们的方法引入了新颖的All-Scan集体通信原语,它不仅是ZeCO效率的基石,也为推动线性模型社区的分布式计算提供了基础性创新。
未来工作: 未来,我们计划从三个主要方向展开工作。
1. 算法实现改进: 我们将进一步改进All-Scan集体通信原语的算法实现,例如采用树状实现。
2. 算法泛化: 我们旨在将线性注意力的序列并行算法推广到对角衰减之外,使其支持包括矩阵变换结构在内的多种形式。
3. 并行拓扑研究: 我们将研究大规模模型中序列并行的高效并行拓扑。
A6 附录
A.1. 全局块更新证明
证明目标。我们证明主文公式(6)中使用的全局校正公式的正确性,该公式将位置$(p-1)L+nC$处的全局状态表示为:
$$ \boldsymbol{S}_{(p-1)L+n\text{C}} = (\tilde{\gamma}_{[n]} \mathbf{1}) \odot \boldsymbol{S}_{(p-1)L} + \boldsymbol{S}_{[n]} $$证明过程。我们从每个设备内门控线性注意力(GLA)状态更新的分块递归开始。对于任何块$n$,递归关系为:
$$ \boldsymbol{S}_{[n]}=\left(\boldsymbol{\gamma}_{[n]}^{\top} \mathbf{1}\right) \odot \boldsymbol{S}_{[n-1]}+\tilde{\boldsymbol{K}}_{[n]}^{\top} \boldsymbol{V}_{[n]}, $$初始状态为$S[0]=0$。展开递归,我们得到最终本地状态$S[n]$的闭式表达式:
$$ \boldsymbol{S}_{[n]} = \sum_{i=1}^{n} \left( \prod_{j=i+1}^{n} \boldsymbol{\gamma}_{[j]}^{\top} \mathbf{1} \right) \odot ( \tilde{\boldsymbol{K}}_{[i]}^{\top} \boldsymbol{V}_{[i]} ) + \left( \prod_{j=1}^{n} \boldsymbol{\gamma}_{[j]}^{\top} \mathbf{1} \right) \odot \boldsymbol{S}_{[0]} . $$公式19表示本地计算的结果,它捕获了从零初始状态开始并仅考虑当前块$n$内本地贡献所获得的最终本地状态。关键的观察是,由于初始条件$S[0]=0$,第二项$(\prod_{j=1}^{n} \gamma^T[j]) \odot S[0]$消失了,使得$S[n]$完全由本地信息决定。
线性分解特性。我们现在展示全局状态更新的线性分解特性。关键的见解是,当存在非零初始状态时,最终的全局状态可以分解为两个独立的线性贡献:全局计算的衰减传播和当前块的本地贡献。现在假设我们从一个非零初始状态$S_{(p-1)L}$(即前一个设备的最终全局状态)开始执行相同的递归。在全局索引$(p-1)L+nC$处的更新状态变为:
$$\begin{aligned} \begin{aligned} \boldsymbol{S}_{(p-1) L+n C} & =\sum_{i=1}^{n}\left(\prod_{j=i+1}^{n} \boldsymbol{\gamma}_{[j]}^{\top} \mathbf{1}\right) \odot\left(\tilde{\boldsymbol{K}}_{[i]}^{\top} \boldsymbol{V}_{[i]}\right)+\left(\prod_{j=1}^{n} \boldsymbol{\gamma}_{[j]}^{\top} \mathbf{1}\right) \odot\left(\mathbf{0}+\boldsymbol{S}_{(p-1) L}\right) \\ & =\left(\tilde{\boldsymbol{\gamma}}_{[n]} \mathbf{1}\right) \odot \boldsymbol{S}_{(p-1) L}+\boldsymbol{S}_{[n]} . \end{aligned} \end{aligned}$$这种线性特性允许本地计算$S[n]$只存储来自块$n$的残余贡献,而将传入的全局状态$S_{(p-1)L}$乘以累积衰减$\tilde{\gamma}[n]$,就能精确地重建完整的全局状态。
A.2. 序列并行方法的统一分析
分析框架。在本节中,我们对全注意力和线性注意力模型中的几种代表性序列并行(SP)方法进行了统一分析,使用多头注意力,其中头数$H=32$,$d_k = d_v = D = e$。我们主要从以下三个角度进行比较:
- 通信量:SP执行期间每个设备传输的数据总量。
- 计算成本:并行处理长度为$PL$的序列所需的总计算时间。
- 额外计算开销:由于SP特定逻辑引入的额外操作。
全注意力模型。
- Ulysses: 使用All-to-All通信来交换Q、K、V和输出张量。每个设备的通信量为$4LD$。由于全注意力的二次复杂度,计算成本为$L^2DP$。
- Megatron CP: 利用All-Gather从所有设备收集Q和K。通信量为$2PLD$。计算成本与Ulysses相同,为$L^2DP$。
线性注意力模型。在线性注意力中,由于每个设备处理长度为$L$的序列,每个设备的固有计算量为$LDe$(这里我们忽略了分块算法引入的额外低阶项),这与$P$无关。
- LASP-1: 采用串行P2P通信。每个设备传输一个状态张量$S \in \mathbb{R}^{H \times e \times e}$,通信量为$De$。然而,设备是顺序执行的,导致等效的时间开销(包括通信和计算)如同每个设备执行了$P$倍的工作量。
- LASP-2: 使用All-Gather收集所有设备上的中间状态张量。每个设备处理所有$P$个全局状态,导致通信量为$PDe$,以及额外的计算成本$\log(P)De + NDe$用于求和归约和状态更新。
- ZeCO (我们的方法): 通过All-Scan实现流水线通信。每个设备仅发送/接收一个状态$S$,通信量为$De$。它额外维护$N$个累积衰减向量$\tilde{\gamma}$,并使用全局递归更新$N$个中间状态。因此,存在额外的计算成本$NDe + Nd$。
结论。总而言之,对于全注意力的序列并行,计算成本和设备参数数量都强烈依赖于设备数量$P$,这成为一个主要的效率瓶颈。在线性注意力的情况下,尽管LASP-1和LASP-2的通信和计算成本都随$P$扩展,但计算成本仅占总开销的一小部分。因此,通信成本对$P$的依赖性成为主要瓶颈。相比之下,我们的ZeCO算法实现了通信和计算成本均与设备数量$P$无关。
表 1: 序列并行方法比较:通信量和计算成本(对于LASP-1,我们考虑了顺序执行顺序)
A.3. 实验设置和补充数据
4.1节实验设置。在4.1节的实验中,$H$为32,每个分块的张量大小为16384,隐藏维度$d$为4096,每个设备的序列长度$L$为8192。实验设置包括5轮预热和50轮实验取平均值,见表2。
4.2节实验设置。在4.2节的算法运行时实验中,我们测试了配备不同SP方法的GLA-attention算法,记录了1次迭代FWD和BWD的时间。$H$为16,每个分块的张量大小为16384,隐藏维度$d$为2048,每个设备的序列长度$L$分别为16384和32768。实验设置包括5轮预热,并报告了50轮实验的平均值,见表3和表4。在模型吞吐量实验中,我们测试了配备不同SP方法的GLA-1B模型,并记录了训练阶段的吞吐量。$H$为32,每个分块的张量大小为16384,模型层数为20,隐藏维度$d$为2048,每个设备的序列长度$L$分别为16384和32768。实验设置包括5轮预热,并报告了100步实验的平均值,见表5和表6。
表 2: 通信运行时
表 3: 算法运行时 (16k序列长度)
表 4: 算法运行时 (32k序列长度)
表 5: GLA模型SP吞吐量/GPU (tokens/sec) on 1B-16k
表 6: GLA模型SP吞吐量/GPU (tokens/sec) on 1B-32k
A.4. 使用All-Scan通信的ZeCO反向传播
反向传播过程。在ZeCO算法的反向传播中,大部分过程与前向传播相似。需要注意的是,这里的符号有所不同:$\tilde{\gamma}[n]$表示反向累积乘积的衰减因子。此外,在门控线性注意力的官方实现中,$S[n]$需要在反向传播过程中重新计算。然而,由于全局初始状态已在前向传播过程中获得,因此在重新计算$S[n]$时不需要进行all-scan通信。
算法3:使用All-Scan通信的ZeCO反向传播
输入: Q, K, G ∈ R^(L×d_k), V, dO ∈ R^(L×d_v), 块大小 C, 设备数 P, 设备排名 p ∈ {0, 1, . . . , P − 1}
在SRAM上初始化 dS = 0 ∈ R^(d_k×d_v)
1: for n ← N to 0 do
2: 从HBM加载 G[n] ∈ R^(C×d_k), Q[n] ∈ R^(C×d_k), dO[n] ∈ R^(C×d_v) 到SRAM
3: 在芯片上,计算 γ[n], Γ[n] 和 Q̃[n] = Q[n] ⊙ G[n], γ̃ = γ̃ ⊙ γ[n]
4: 将 γ̃ 存储在HBM中,记为 γ̃[n]
5: 在芯片上,计算 dS = (γᵀ[n]1) ⊙ dS + Q̃ᵀ[n]dO[n]
6: 将 dS 存储在HBM中,记为 dS[n]
7: end for
8: 并行执行:
9: 并行流 1:
10: dS_((p−1)L), dS_(pL) ← All-Scan(dS[0], γ̃[0])
11: 并行流 2:
12: 从HBM加载 S_((p−1)L) 到SRAM
13: 在芯片上,用 S[0] = S_((p−1)L) 重新计算 S[n], n = {0, 1, 2, . . . , N − 1}
14: 存储 {S[n], n ∈ {0, 1, 2, . . . , N − 1}}
15: for n ← 1 to N 并行 do
16: 从HBM加载 Q[n], K[n], G[n], V[n], dO[n] 到SRAM
17: 从HBM加载 , ∈ R^(d_k×d_v) 到SRAM
18: 在芯片上,构建因果掩码 M ∈ R^(B×B)
19: 在芯片上,计算 Λ[n], Γ[n] ∈ R^(C×d_k)
20: 在芯片上,计算 Q̃[n] = Q[n] ⊙ Λ[n], K̃[n] = K[n] ⊙ Γ[n]
21: 在芯片上,计算 P[n] = (Q̃[n]K̃ᵀ[n]) ⊙ M ∈ R^(C×C)
22: 在芯片上,计算 dP[n] = (dO[n]Vᵀ[n]) ⊙ M
23: 在芯片上,计算 dK̄[n] = Q̃ᵀ[n]dP
24: 在芯片上,计算 dK[n] = dK̄[n]/Λ[n]
25: 在芯片上,计算 dQ̃[n] = dPK̄[n]
26: 在芯片上,计算 dQ[n] = dQ̃[n] ⊙ Λ[n]
27: 将 P[n], dQ[n], dK[n] 存储在HBM中
28: end for
29: 流屏障
30: for n ← 1 to N 并行 do
31: 从HBM加载 P[n], dQ[n], dK[n], dO[n], Q[n], K[n], G[n], γ̃[n-1], dS_(pL), S[n−1] 到SRAM
32: 在芯片上,计算 Λ[n], Γ[n] ∈ R^(C×d_k)
33: 在芯片上,计算 K̃[n] = K[n] ⊙ Γ[n]
34: 在芯片上,计算 dK̃[n] = V[n](dSᵀ[n−1]+(γ̃ᵀ[n-1]1) ⊙ dSᵀ_(pL))
35: 在芯片上,计算 dK[n] = dK[n] + dK̃[n] ⊙ Γ[n]
36: 在芯片上,计算 dQ̃[n] = dO[n]Sᵀ[n−1]
37: 在芯片上,计算 dQ[n] = dQ[n] + dQ̃[n] ⊙ Λ[n]
38: 在芯片上,计算 dV[n] = Pᵀ[n]dO[n] + K̃[n](dSᵀ[n−1]+(γ̃ᵀ[n-1]1) ⊙ dSᵀ_(pL))
39: 将 dK[n], dV[n] 存储在HBM中
40: end for
41: 令 dQ = {dQ[1], . . . , dQ[N]}, dK = {dK[1], . . . , dK[N]}, dV = {dV[1], . . . , dV[N]}
42: 计算 dA = Q ⊙ dQ − K ⊙ dK, dG = revcum(dA)
43: 返回 dQ, dK, dV, dG
💬 评论讨论
欢迎在这里分享您的想法和见解!