EFFICIENT LONG-CONTEXT LANGUAGE MODEL TRAINING BY Core Attention Disaggregation
EFFICIENT LONG-CONTEXT LANGUAGE MODEL TRAINING BY Core Attention Disaggregation
作者/机构: Yonghao Zhuang, Junda Chen, Bo Pang, Yi Gu, Yibo Zhu, Yimin Jiang, Ion Stoica, Eric Xing, Hao Zhang
A1 主要贡献
核心问题:在长上下文大语言模型(LLM)的训练中,由于文档长度可变,通过文档打包(document packing)方式处理数据会导致严重的负载不均衡。其根本原因在于,Transformer模型中的自注意力(self-attention)计算量随序列长度呈二次方增长,而模型其余部分的计算量仅呈近似线性增长。这种计算复杂度的不匹配导致在数据并行(DP)和流水线并行(PP)中出现“掉队者”(stragglers),即处理较长文档(或注意力计算量大的数据块)的设备会拖慢整个训练过程,降低系统吞吐量。
研究目标:本文旨在通过一种新的系统架构来解决长上下文LLM训练中的负载不均衡问题,从而提升端到端的训练吞吐量。
创新点与核心思想:
本文提出核心注意力解耦(Core Attention Disaggregation, CAD),其核心思想是将无参数的核心注意力(Core Attention, CA)计算(即softmax(QK⊤)V部分,如图1所示)从模型的其他部分(如线性层、前馈网络等)中分离出来,并将其调度到一个独立的资源池(称为注意力服务器 Attention Servers)上执行。
图 1. Transformer及其由核心注意力引起的工作负载不平衡。
这一方法的可行性基于两个关键观察:
1. 无状态性(Statelessness): 核心注意力(CA)没有可训练的参数,且只存储极少的瞬时状态(如每行的softmax统计数据)。因此,对其进行负载均衡简化为了一个计算密集型任务的调度问题,而无需考虑内存状态的平衡。
2. 可组合性(Composability): 核心注意力的计算可以在token级别上进行任意粒度的切分。不同文档(来自不同DP副本或PP阶段)的token级计算任务可以被重新组合(re-batch)成一个大的、高设备利用率的融合内核调用,而不会损失现代注意力内核(如FlashAttention)的效率。
基于CAD思想,本文实现了一个名为DistCA的系统,并包含三项关键优化:
- 原地注意力服务器(In-place attention server): 通过让GPU在计算上下文无关层和作为注意力服务器之间分时复用,实现了高计算和高内存利用率。
- 乒乓执行机制(Ping-pong execution): 将每个微批次(microbatch)分为两个更小的“Ping”和“Pong”纳米批次(nano-batches),通过交错执行来将数据通信与计算完全重叠,隐藏通信开销。
- 负载均衡调度器(Workload balanced scheduler): 开发了一个感知通信的贪心调度算法,该算法动态地将文档切分为token级任务,并在注意力服务器之间进行调度,以在最小化通信开销的同时实现近乎完美的计算负载均衡。
通过在多达512个H200 GPU上对高达512K上下文长度的负载进行评估,DistCA相比现有系统,端到端训练吞吐量提升高达1.35倍,并消除了DP/PP中的掉队者现象。
A3 背景知识、关键观察与设计原则
2.1 LLM架构
- 上下文无关层:这类层包括QKV投射、输出投射、前馈网络(FFN)、层归一化(layernorm)和位置嵌入。它们的操作是逐个token进行的,即每个token的输出仅依赖于其自身的隐藏状态。因此,其计算量和激活内存都与token数量近似线性相关。
- 上下文相关层:这类层仅包括核心注意力(Core Attention, CA)。本文特意将核心注意力与通常文献中提到的注意力(后者是QKV投射、输出投射、层归一化和核心注意力的组合)区分开来。核心注意力计算
O = softmax(QK⊤)V,它没有可训练的参数,也不需要其他token的CA中间输出。现代的IO感知注意力内核【索引2,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】通过在反向传播中重新计算来避免在前向传播中物化巨大的注意力分数矩阵P,这使得核心注意力的中间状态量可以忽略不计,因而是无状态的。
2.2 LLM训练并行化
- 数据并行(Data Parallelism, DP):各个副本处理不同的数据批次,并在梯度同步屏障处汇合。任何接收到注意力工作负载更大批次的副本都会成为掉队者,从而拖慢其他副本。
- 张量并行(Tensor Parallelism, TP):将模型层进行切分,但代价是每层都需要通信。当扩展到单个节点之外时(TP size > 8),这种通信开销是无法承受的。TP可以在TP分片间平衡内存和计算,因为所有设备处理完全相同的数据批次。
- 上下文并行(Context Parallelism, CP):沿序列维度在多个GPU上对数据进行切分。上下文无关层可以在每个分片上独立运行,但核心注意力需要一次all-gather操作来收集上下文的token状态。为解决因果掩码(causal mask)导致的负载不均衡,最近的工作【索引3,The llama 3 herd of models,2024】采用了头尾分片分配策略。当使用文档打包时,按文档进行CP(Per-document CP)能更好地平衡负载。
- 流水线并行(Pipeline Parallelism, PP):将模型的层划分为多个阶段,并将输入数据切分为微批次,微批次依次流过每个阶段。由于不同微批次可能包含不同长度文档打包而成的块,导致各阶段处理时间不均衡,产生流水线气泡,从而使下游所有阶段空闲。
3.1 计算与内存的负载不均衡
- 不均衡问题的形式化:一个长度为
l的文档的计算量可表示为 $FLOPs(l) = \alpha l^2 + \beta l$,其中$\alpha l^2$来自核心注意力,$\beta l$来自上下文无关层。激活内存为 $M(l) = \gamma l$。要使两个微批次(分别包含长度为 $\{l_i\}_{i=1}^n$ 和 $\{l'_j\}_{j=1}^m$ 的文档)在计算和内存上都达到平衡,必须同时满足 $\sum_{i=1}^n l_i = \sum_{j=1}^m l'_j$ (内存平衡)和 $\sum_{i=1}^n l_i^2 = \sum_{j=1}^m l'_j^2$ (计算平衡)两个条件,这在实践中极难实现。
3.2 现有方法的问题
-
可变长度数据块(Variable-length data chunk):该方法通过重新分配文档来平衡计算量(即 $\sum l_i^2$),但这会导致各微批次的总token数($\sum l_i$)不同,从而在某些设备上造成激活内存膨胀。随着序列长度增长,该方法会达到内存上限,无法再通过移动序列来完全平衡注意力计算。如图4所示,在512K长度的负载下,DP=8时GPU的平均空闲时间比例高达55%。
图 4. 在8B模型上处理512K token数据块时,不同并行策略下可变长度数据块的吞吐量和内存差异。 -
按文档上下文并行(Per-document CP):该方法为每个文档分配相等份额给每个CP rank,从而同时平衡了计算和内存。但它在规模化时面临三大瓶颈:
- 小分片效率低:即使是短文档也会被切分,产生的小分片可能无法充分利用注意力内核的计算单元,降低了算术强度。图5显示,当文档长度小于128个token时,内核吞吞吐量显著下降。
- 通信开销大:该方法虽然将每个rank的计算量减少为1/c,但需要all-gather KV状态,其成本与全局token数成正比。如图3a所示,随着CP规模扩大,all-gather延迟占比从2个节点上的3%急剧上升到32个节点上的近40%。
- 内存压力:最后一个CP rank必须存储整个文档聚合后的KV状态以用于反向传播,导致内存压力。如图3b所示,KV内存占比随CP规模扩大而增加,从2个节点上的3%增长到16个节点上的近30%。
图 3. Llama-8B模型下,上下文并行中all-gather的延迟和内存分解。文档长度均为32k。
图 5. 核心注意力的吞吐量。 -
两种方法的组合:结合使用这些技术会继承各自的缺点。图6显示,在一个64-GPU、512K-token的实验中,增加CP度可以减少不平衡,但会降低吞吐量并有OOM风险;而增加DP度会导致严重的负载不平衡和次优的吞吐量。
图 6. 应用可变长度数据块和按文档CP时的吞吐量。
3.3 核心注意力解耦的动机
- 解耦边界:一个自然的想法是,将核心注意力(CA)的二次项$\alpha \sum l_i^2$与上下文无关层的线性项$\beta \sum l_i$和$\gamma \sum l_i$分开。解耦后,可以独立地调度和均衡CA的计算(因其无状态),并独立地平衡上下文无关层的计算和内存(因其均为线性)。
- 可行性分析:
- 如何平衡CA计算? 关键在于可分性(Divisibility)和内核可组合性(kernel composability)。CA计算在token粒度上是可分的。现代注意力内核(如FlashAttention 2)可以在由不同长度分片组成的融合序列上保持高MFU,只要分片大小大于内核的tile大小(例如128个token)。这意味着文档可以被任意切分,然后重新组合成一个高利用率的CA内核调用,从而在不牺牲内核效率的情况下平衡计算。图5的性能测试验证了这一点。
- 如何隐藏通信开销? 解耦CA也需要通信KV状态,但其开销可以被有效隐藏,原因有四:
- 由于因果掩码,早期分片不需要后期分片的K、V。可以使用all-to-all操作只发送必要的分片。
- 通信密集型分片(即文档的后期分片)可以被调度到不同设备上,以避免all-to-all中的掉队者。
- 可以优化分片方案以最小化通信,而不受限于统一的序列切分或固定的CP度。
- 可以通过乒乓执行机制将一个批次的CA输入通信与另一个批次的计算重叠。
A2 方法细节
DistCA系统由一个运行时系统和一个工作负载调度器组成。运行时负责交替执行上下文无关层和核心注意力层,并插入必要的通信。调度器则决定如何对每个文档进行分片以及如何将分片放置到不同设备上。
CA:核心注意力层;Linear:FFN,qkvo-proj;MISC:layer norm, dropout, ...
图 2. DistCA架构。
4.1 运行时
-
核心注意力任务(CA-task)的定义:注意力服务器的工作负载被定义为核心注意力任务(CA-task),记为
t。一个CA-task是针对一个查询分片q(t)及其上下文的键值分片kv(t)的核心注意力计算。一个文档被分割成多个不重叠的分片q1, q2, ..., qn,其完整的核心注意力结果是对应任务t1, t2, ..., tn结果的集合。 -
系统工作流程:如图2所示,在一组处理多个文档的GPU中,每个GPU都可以作为一个独立的注意力服务器。对于一批文档,在经过上下文无关层处理后,它们被分割成CA-tasks。每个CA-task被分配给一个注意力服务器。服务器接收到其被分配的所有任务的输入张量后,利用内核的可组合性,将所有CA-task批处理并在一个单独的内核中执行(例如通过一次FlashAttention调用)。计算完成后,每个CA-task的输出被发送回处理后续上下文无关层的源GPU。
-
中央调度器:一个运行在CPU上的中央调度器负责确定分片策略。它在GPU处理当前批次时,预取下一批次的文档,并使用预先计算的性能分析数据来估计每个文档和潜在CA-task的计算成本,从而生成一个分片和分配计划。
-
原地注意力服务器(In-place attention server):为了避免为CA计算分配专用GPU而导致的内存严重未充分利用问题(因为CA是无状态的,而FFN等层内存消耗巨大),DistCA采用原地注意力服务器设计。每个GPU周期性地在计算上下文无关层和充当注意力服务器之间切换角色。这样既能实现高内存利用率,又能平衡GPU间的计算负载。
-
乒乓执行(Ping-Pong execution):为隐藏通信开销,系统采用乒乓执行调度。每个输入微批次被分为两个等token数的纳米批次——“Ping”和“Pong”。这两个纳米批次的执行是交错的,使得一个的通信可以与另一个的计算重叠。此外,系统还将张量并行所需的节点内通信(通常通过NVLink)与核心注意力解耦引起的节点间通信(通常通过InfiniBand)进行重叠。
-
流水线并行支持:CAD可以自然地与DP和TP集成,并替代CP。对于PP,CA-task因其无权重的特性,来自不同PP阶段的任务可以与来自不同微批次的任务一同被调度和平衡。对于上下文无关层,由于所有PP阶段的微批次包含相同数量的token,它们的计算负载是相同的,因此是平衡的。为了防止设备在切换角色时空闲,系统调整了调度,使所有阶段在同一个tick内执行相同的阶段(要么全做前向,要么全做后向),这是通过将部分后向微批次逻辑上延迟到调度末尾的流水线气泡中实现的(如图8所示)。此外,在流水线预热和排空阶段,部分空闲的GPU时间被重新用于作为注意力服务器运行CA-task。
图 8. 正常1F1B和解耦注意力下的流水线并行调度。
4.2 感知通信的贪心调度算法
- 调度目标:调度器旨在平衡两个相互竞争的目标:(1) 最小化注意力服务器之间的负载不平衡(以FLOPs衡量),同时 (2) 最小化通信量(以字节衡量)。
- 性能分析器(Profiler):为了估算一个分片的CA-task成本,系统构建了一个性能分析器,该分析器在一个查询长度和键值长度的网格上对CA进行基准测试。对于一个给定的CA-task,其执行时间通过在四个最近的网格点上进行双线性插值来预测。
- 调度单元(Scheduling units):调度算法中的基本单位是Item,它可以是一个完整的文档或一个文档的分片。每个Item已经驻留在计算其上下文无关层的设备上。一个Item的CA计算精确对应一个CA-task。
- 调度算法步骤:
- 确定目标负载:首先,调度器通过将所有Item的总FLOPs除以注意力服务器数量
n,计算出理想的每个服务器负载F¯。然后,将注意力服务器划分为负载有盈余(load >F¯)和有赤字(load <F¯)两类。 - 为赤字服务器迭代迁移:调度器按赤字降序遍历每个赤字目标服务器
d,并尝试从盈余源服务器迁移Item来弥补d的负载缺口。- 为了找到最有效的迁移Item,调度器对每个候选Item使用成本效益启发式进行评估。它首先确定一个Item可迁移的最大FLOPs量
∆Fmax = min(F_Item, S_source, D_destination)。 - 然后,选择一个FLOPs等于
∆Fmax且通信成本Vcomm最小的分片。 - 接着,计算每个候选Item的优先级分数,定义为单位计算量转移的通信成本:
E = ∆Fmax / Vcomm。E越高表示迁移效率越高。 - 调度器选择得分最高的Item。根据计算出的
∆Fmax,如果∆Fmax = F_Item,则整个Item被迁移;如果∆Fmax < F_Item,则该Item被拆分为两个子Item,新创建的具有∆FmaxFLOPs的子Item被分派到目标服务器。
- 为了找到最有效的迁移Item,调度器对每个候选Item使用成本效益启发式进行评估。它首先确定一个Item可迁移的最大FLOPs量
- 终止条件:当每个服务器上的负载都在
ϵF¯(一个容忍度epsilon)范围内,或者当剩余的迁移无法使E值超过一个很小的阈值时,调度器停止。这样,调度器在确保系统负载平衡的同时,避免了因微不足道的迁移而产生不必要的通信。
- 确定目标负载:首先,调度器通过将所有Item的总FLOPs除以注意力服务器数量
5 实现
- 代码实现:DistCA用2000行Python代码实现。为了高效地分派注意力服务器的输入和输出,实现了一个遵循【索引11,FLASH: fast all-to-all communication in GPU clusters,2025】思想的All-to-All通信核,额外使用了1000行CUDA和C++代码,通信利用了NVSHMEM【索引17,NVSHMEM: GPU-Centric OpenSHMEM,2025】。
- 系统集成:由于DistCA只改变了注意力计算的逻辑,它被集成到Megatron-LM【索引21,Megatron-lm: Training multibillion parameter language models using model parallelism,2019】中,以重用其对token无关层、模型架构、4D并行化和端到端训练流水线的高效实现。集成工作用了1000行Python代码。
A4 实验环境
-
模型架构:实验使用了LLaMA 8B和LLaMA 34B模型。具体配置如表2所示。
表 2. 实验模型配置。“Hidden”是隐藏维度大小,“#Head”是注意力头数,“Head Size”是每头维度。
-
硬件配置:所有实验均在NVIDIA DGX H200节点上运行。每个节点包含8块140GB H200 GPU。
- 软件配置:DistCA由Python、CUDA和C++实现,并集成到Megatron-LM框架中。
- 数据集:实验使用了两种合成的数据分布:
- Pretrain:模拟预训练场景,其中长文档在预训练数据分布中被上采样。
- ProLong:一个公开的长上下文训练专用数据集【索引5,How to train long-context language models (effectively),2025】,包含更高比例的长文档。
- 基线系统:主要与WLB-LLM【索引25,WLB-LLM: workload-balanced 4d parallelism for large language model training,2025c】(在本文中称为“WLB-ideal”)进行比较。由于没有官方实现,作者基于其方法重新实现了该系统。
A4 实验结果
6.2 端到端实验
-
3D并行(无PP):
- 实验设置:如表3所示。
- 结果:如图9所示,DistCA在所有配置下均优于“WLB-ideal”。在Pretrain数据集上实现了1.07-1.20倍的加速,在ProLong数据集上实现了1.05-1.12倍的加速,并表现出更好的扩展性。
- 分析:“WLB-ideal”要么受困于高CP度带来的通信开销,要么受困于高DP度带来的调度难题。而DistCA通过乒乓执行隐藏了通信开销,并通过灵活的token级调度实现了更好的负载均衡。
表 3. 3D训练配置。
图 9. 3D并行(无PP)实验。加速比定义为WLB-LLM的平均运行时间除以DistCA的平均运行时间。 -
4D并行(含PP):
- 实验设置:如表4所示。
- 结果:如图10所示,DistCA同样普遍优于“WLB-ideal”。对于8B模型,在Pretrain数据集上加速比为1.15-1.30倍,在ProLong数据集上为1.10-1.35倍。对于34B模型,在多数配置下也实现了高达1.15倍(Pretrain)和1.25倍(ProLong)的加速。
- 分析:除了3D并行中的优势外,DistCA还能平衡流水线各阶段的计算,并重新利用预热和排空阶段的空闲GPU。相比之下,WLB-LLM难以找到有效的配置,常因高CP或DP度而内存溢出,并且流水线并行会加剧其负载不平衡。
- 性能限制:在34B模型的4D并行实验中,由于每个微批次处理的张量形状不同,导致内存分配器频繁创建和释放不同大小的内存块,造成内存碎片化和PyTorch垃圾回收,这反过来延迟了GPU核函数的启动,限制了DistCA的性能。
表 4. 4D并行训练配置。
图 10. 4D并行(含PP)实验。加速比定义为WLB-LLM的平均运行时间除以DistCA的平均运行时间。
6.3 消融研究
-
系统开销:
- 实验设计:比较了DistCA与两个基线:1) “Signal”(通信量降为1字节,仅反映计算不平衡开销);2) “Single Stream”(移除乒乓执行,通信与计算在同一流上串行)。
- 结果与分析:如图11所示,DistCA的延迟几乎与“Signal”相同,表明通信几乎被完全重叠。与“Single Stream”相比,DistCA消除了10-17%的延迟。唯一的例外是8B模型在8个节点上,计算工作量太小,无法完全隐藏通信。
图 11. 不同通信模式下的吞吐量。 -
调度器中的超参数:
- 实验设计:研究了调度器中的容忍度因子
ϵ如何权衡CA负载平衡与通信量。 - 结果与分析:如图12所示,对于34B模型,当容忍度因子低于0.10时,通信量增加导致延迟升高;当容忍度因子过大时,负载不平衡导致延迟线性增加。在多数情况下,将容忍度因子从0调整到0.15,可以在平均耗时几乎不变甚至改善的情况下,将内存需求降低20-25%。
图 12. 计算不平衡容忍度因子的影响。 - 实验设计:研究了调度器中的容忍度因子
A7 补充细节
7 相关工作
- 长上下文训练中的负载不平衡:
- FlexSP【索引24,Flexsp: Accelerating large language model training via flexible sequence parallelism,2025b】引入了动态上下文并行,但其使用的整数线性规划(ILP)求解器是NP难问题,限制了可扩展性。
- WLB-LLM【索引25,WLB-LLM: workload-balanced 4d parallelism for large language model training,2025c】引入了可变大小的数据块,用不平衡的MLP来补偿注意力的工作量不平衡。但这导致了内存不平衡,并且随着上下文长度增加,MLP计算占比缩小,使其难以抵消注意力的不平衡。
- 与这些CP的变体相比,CAD通过token粒度的可变长度分片实现了更精确的负载平衡,并通过感知通信效率的分片减少了规模化时的通信。
- 模型解耦:
- 模型解耦在LLM推理中被广泛采用,最常见的是预填充-解码解耦。
- MegaScale-Infer【索引28,Megascale-infer: Serving mixtureof-experts at scale with disaggregated expert parallelism,2025】等工作在推理中进一步将注意力和FFN解耦到不同的物理设备上,主要针对MoE模型,因为解耦引入的层间传输可以与MoE固有的token路由通信合并。
- 相比之下,本文的工作专注于以吞吐量为导向的语言模型训练,并证明了其优势与模型架构无关。
8 局限性
- 原地注意力服务器:DistCA使用原地注意力服务器来维持内存利用率。如果内存需求得到满足,为注意力分配更多专用GPU(而不扩展其他部分的GPU)可以在保持负载平衡和低通信开销的同时进一步减少计算时间。
- 调度器限制:
- 当前调度器限制每个CA-task使用一个Q分片和完整的K、V上下文,这限制了灵活性。
- 在估算通信量时,当前模型悲观地假设所有token都被传输,忽略了目标设备上已存在的K、V,这可能高估字节数并导致非最优的传输。
A5 结论
本文提出了核心注意力解耦(CAD),这是一种用于大语言模型训练的新架构,它将核心注意力模块与模型的其余部分分离开来,以实现独立的扩展和调度。基于核心注意力是无状态且在token粒度上可组合的观察,本文实现了DistCA系统,该系统具有一个感知工作负载的调度器以平衡计算同时最小化通信,以及一个乒乓执行方案来隐藏分派延迟。端到端评估表明,与最先进的训练系统相比,DistCA的吞吐量提升高达1.35倍,并且随着规模的扩大,其优势愈发明显。
A6 附录
A 核心注意力服务器最大分区大小的上限
- 通信量分析:假设一个长度为
l的文档被均匀分为s个分片。查询状态的总通信量为$l \cdot h_q$。键值状态的总通信量为$h_{kv} \cdot (l/s \cdot s + l/s \cdot (s - 1) + \dots) = (s + 1)lh_{kv}/2$。 - 上限推导:当通信与计算完全重叠时,应满足 $t \cdot l \geq l \cdot (h_q + h_{kv}(s + 1)/2)/B$,其中
t是计算一个token的上下文无关层的时间,B是网络带宽。整理后得到分片数量的上限为:$s \leq 2(tB - h_q)/h_{kv} - 1$。 -
Llama-34B示例:以Llama-34B为例,其配置如表5所示。假设InfiniBand带宽为50GB/s,H200节点的MFU为50%(FP16下990TFLOPs)。
表 5. Llama-34B配置hidden (h) key-value hidden (hkv) intermediate (i) 8192 2048 22016 -
计算时间
t:一个token的上下文无关层总FLOPs为:
计算出的时间t代入上限公式后得到:
- 结论:对于Llama-34B,一个文档最多可以被分成约31个分片而不会产生通信开销。对于更大的模型,由于
t随隐藏大小$h_q$二次方增长,这个上限s甚至会增加。
B 通信开销函数
- 通信成本函数
v(·)的构建:对于一个有nq个查询token和nkv个键值token的分片,通信成本为:
- 优化问题:在迁移
∆FmaxFLOPs时,选择最优的分片大小nq可以最小化通信量v(·)。最优解的推导如下:
- 实践中的简化:在实践中,不同分片的MFU存在差异,直接用FLOPs估算时间并不精确。然而,性能分析结果显示,如果继续采用头尾上下文并行的做法(即一个Item同时包含文档开头和结尾的token),用FLOPs进行估算仍然是准确的。因此,系统保留了头尾上下文并行的策略,并将更精确的建模留作未来工作。在这种方式下,通信成本被修改为:
由于i + nq = j = nkv,通信量简化为:
因此,最优的nq对应于可能的最小值,具体由以下约束决定:
💬 评论讨论
欢迎在这里分享您的想法和见解!