Mesh-Attention: A New Communication-Efficient Distributed Attention with Improved Data Locality
Mesh-Attention: A New Communication-Efficient Distributed Attention with Improved Data Locality
文章标题:Mesh-Attention: 一种具有改进数据局部性的新型通信高效分布式注意力机制
作者/机构:Sirui Chen (普渡大学), Jingji Chen (字节跳动), Siqi Zhu (伊利诺伊大学厄巴纳-香槟分校/清华大学), Ziheng Jiang (字节跳动), Yanghua Peng (字节跳动/清华大学), Xuehai Qian (清华大学)
A1 主要贡献
核心问题:大型语言模型(LLMs)要扩展上下文窗口(context window)面临着一个基本问题,即分布式注意力的计算和内存需求会随着上下文窗口大小的增加而急剧增长。现有的主流方法,如Ring-Attention,虽然通过重叠通信和计算实现了良好的并行化,但其通信开销过大,随着GPU数量和序列长度的增加,通信时间占比可高达91.5%,限制了其可扩展性。
研究目标:本文旨在通过一个新的基于矩阵的模型来重新思考分布式注意力的设计空间,从而提出一种新的分布式注意力算法,以解决现有方法的通信瓶颈和可扩展性问题。
创新点:
-
提出基于矩阵的新模型和Mesh-Attention算法:本文引入了一个“分配矩阵”(Assignment Matrix, AM)的抽象模型来描述和分析分布式注意力中的计算分配和通信。基于此模型,提出了Mesh-Attention算法。与Ring-Attention将一整行或一整列计算分配给一个GPU不同,Mesh-Attention将一个二维的计算块(tile)分配给每个GPU。这种方式允许在Q和KV两个维度上微调数据局部性,从而实现更低的通信-计算比(CommCom ratio)。
-
保证局部Q-KV属性的索引旋转方法:为了让每个GPU都能高效地执行其本地Q和KV块之间的计算,本文提出了一种通过巧妙旋转KV(或Q)维度索引的方法,使得每个GPU分配到的计算块都包含其本地Q-KV对的计算,这进一步减少了通信量。
-
高效的通信计算重叠调度算法:Mesh-Attention的通信模式比Ring-Attention更复杂,涉及Q、KV和部分输出的传输。为了最大化计算和通信的重叠,本文将整体过程分解为多个步骤,并提出了一个贪心算法来自动搜索计算和通信的最优调度方案。该算法通过引入两个限制条件来约束搜索空间,并基于“解锁”计算块的收益来优先选择通信操作,同时调度“恰到好”数量的计算任务以确保最佳重叠。
下图展示了从Ring-Attention到Mesh-Attention的演进过程,以及调度搜索空间和剪枝策略。
A3 背景知识
2.1 LLMs 和注意力机制
LLMs与Transformer架构:大型语言模型(LLMs)【【索引26,A survey of large language models,2025】】已成为现代自然语言处理的基石。大多数先进的LLM都基于Transformer架构【【索引22,Attention is all you need,2023】】,其简化形式如图2a所示。在每个层内,输入嵌入首先通过一个自注意力块,然后是一个由多层感知机(MLP)组成的位置前馈网络(FFN),两个块周围都应用了残差连接和层归一化。注意力机制是每个Transformer层的关键组件,它通过学习到的查询(query)、键(key)和值(value)投影来计算token之间的成对亲和度。
注意力机制的计算与开销:在注意力层中,当序列长度很长时,涉及WQ、WK、WV和WO的点积操作并非主要开销来源。相反,投影后的查询(Q)、键(K)和值(V)激活值构成了主要的内存开销,而对它们执行的核心注意力操作是主要的计算成本来源:
$$\begin{aligned} \begin{aligned} P & = \mathbf{softmax}(\frac{QK^T}{\sqrt{d}}) \\ O & = PV \end{aligned} \end{aligned}$$Q和KT之间的矩阵乘法计算了每对查询-键之间的相似度分数。这些分数通过$1/\sqrt{d}$进行缩放并应用softmax,生成一个归一化的概率矩阵P。最后,将P乘以V,根据学习到的注意力权重聚合整个序列的值向量,从而为每个token生成输出表示。此操作的优化是加速LLMs最关键的方面之一。FlashAttention【【索引6,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】;【索引7,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022】】通过重排序和融合矩阵乘法与softmax操作,显著减少了内存带宽和中间存储。然而,将注意力扩展到数千甚至数百万token会给模型的内存和计算资源带来巨大负担,特别是在自注意力块中,因为其计算和内存复杂度与序列长度成二次方关系。
序列并行:为了缓解这些瓶颈,研究者们提出了各种分布式和分层注意力方案。除了数据并行【【索引10,Accurate, large minibatch sgd: Training imagenet in 1 hour,2018】】、张量并行【【索引20,Megatron-lm: Training multi-billion parameter language models using model parallelism,2020】】和流水线并行【【索引11,Pipedream: Fast and efficient pipeline parallel dnn training,2018】;【索引12,Gpipe: Efficient training of giant neural networks using pipeline parallelism,2019】】等传统并行优化方法外,还提出了一种沿序列应用的新并行维度,即序列并行,以应对长上下文任务带来的挑战。如图2b所示,在序列并行中,输入嵌入x按序列分片并分发给所有GPU。经过投影步骤后,每个GPU只持有其本地的Q、K和V张量块。因此,需要跨GPU通信来执行核心的分布式注意力操作,其中所有token在Q-KV维度上进行交互。最终,每个处理器只保留与其分配的输入区域相对应的O块。
2.2 Ring-Attention
Ring-Attention工作原理:Ring-Attention【【索引14,Ring attention with blockwise transformers for near-infinite context,2023】】是一种序列并行技术,它将Q、K和V张量在所有GPU上切分,并进行循环旋转来传递KV(K和V)张量块。每个GPU使用接收到的KV张量块与其本地的Q块计算注意力,并累积部分输出,直到组合成完整的结果。
分块计算与通信重叠:通过利用在线softmax归一化技巧【【索引18,Online normalizer calculation for softmax,2018】】,注意力计算可以沿着Q和KV维度被划分为多个小块。每个分块注意力计算产生一个部分输出,该输出被增量地累加到最终的O(输出)张量中。Ring-Attention应用这种分块方案,在聚合每个GPU的本地O张量时,将通信与计算重叠。
Ring-Attention的计算流程:再次参考图1(a),Q和KV张量沿序列维度被平均分割成9个块{$Q_i$}$_{i=0}^8$和{$KV_i$}$_{i=0}^8$,分布在9个GPU上。最初,GPU i持有块$Q_i$和$KV_i$。所有9个KV块通过逻辑环传递,直到每个GPU i都收集到所有KV块,并在每个传递步骤中与其本地的$Q_i$执行分块注意力。分块注意力计算可描述为:
$O_{i,j}, lse_{i,j} = \text{BlockAttention}(Q_i, KV_j)$
其中$O_{i,j}$表示使用$Q_i$和$KV_j$计算的块输出,而$lse_{i,j}$表示用于在线softmax的log-sum-exp(lse)系数块。当GPU i遍历j(从0到8)应用上述公式时,所有的输出块和lse块通过在线softmax被累积归约到$O_i$和$lse_i$中。$O_i$将与$Q_i$具有相同的维度,并传递给模型的后续层。由于lse块是一个向量,其大小与O块相比可以忽略不计。
Ring-Attention的瓶颈:尽管Ring-Attention将通信与分块注意力计算重叠,但在低带宽、大规模部署中,大型KV块在所有GPU之间的循环交换可能成为通信瓶颈。在128个GPU上对长度为1M的输入进行推理的预填充(prefill)阶段实验中,我们发现Ring-Attention有91.5%的时间在等待通信而没有与计算重叠。
2.3 其他相关工作
相关工作的演进:Yang等人【【索引23,Context parallelism for scalable million-token inference,2025】】对Ring-Attention做了微小修改,设计了一种模式,沿着环传递Q块而非KV块,并根据上下文选择更优的选项。传递Q块需要在最后进行额外的reduce-scatter通信来聚合部分结果。DeepSpeed-Ulysses (DS-Ulysses)【【索引13,Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models,2023】】是另一种与Ring-Attention正交的序列并行策略。它通过两次all-to-all交换在序列和头(head)维度之间转置数据块,使得每个GPU能本地进行完整的头级注意力计算。然而,DS-Ulysses的可扩展性受限于多头注意力机制中的头的数量。为了解决此限制,USP(Unified Sequence Parallelism)【【索引8,Usp: A unified sequence parallelism approach for long context generative ai,2024】】提出了一种融合DS-Ulysses和Ring-Attention的混合方法。Startrail【【索引16,Startrail: Concentric ring sequence parallelism for efficient near-infinite-context transformer model training,2025】】通过引入额外的并行维度来缓解Ring-Attention的通信瓶颈。它将GPU分组并通过All-Gather在组间分发KV张量分片,其3D拓扑降低了通信复杂性。然而,其复杂的分组策略和组间集合通信会产生冗余数据传输,且这些集合通信与本地计算的重叠能力有限,限制了整体性能。
A2 方法细节
3.1 定义与分析
核心概念定义:本节以通用形式定义引言中解释的概念,并强调Ring-Attention的缺点。
* 分配矩阵(Assignment Matrix, AM):考虑一个分布式注意力操作,其中Q被划分为n个块,KV也被划分为n个块。分配矩阵是一个$n \times n$的矩阵,其中AM[i][j]的值是负责计算Q[i]和KV[j]之间Q-KV对的GPU的ID。
* 通信-计算比(CommCom):指整个分布式注意力操作中所有GPU的通信量与计算量之比,或单个GPU的该比率。
* 局部Q-KV属性(Local Q-KV Property):一个GPU负责计算其本地Q块和KV块之间的Q-KV对。
* 步数分配(Step Number Assignment):给定k个步骤和一个GPU的($a \times b$)-tile,将一个从0到k-1的步数分配给AM中的每个计算元素,以及(a+b)个输入的Q和KV块和a个输出的O块用于通信。对于本地输入块,步数为0;对于本地输出块,GPU只接收部分输出而不发送。
* Ring-Attention:是一种分布式注意力算法,其中一个GPU ID被分配给AM的一行,反映了一个GPU总是计算其本地Q块与所有本地和远程KV块之间的Q-KV对的原则。Ring-Attention的局部性不平衡,Q实现了完美的局部性,而KV则是最差的局部性,导致通信量过大。
3.2 基于分块的工作负载分布
工作负载分配:在n个GPU的任意设置中,n的任何因式分解$n = a \times b$都提供了一个可以映射到AM上的有效分块(tiling)大小。我们将GPU 0到GPU n-1以行优先的方式排列到AM的tile中。左侧列的Q块索引保持为0到n-1,而顶部行的KV块索引按以下方式修改:对于每连续的b个索引,每个都被替换为它们下方tile中的b个GPU的索引,如图1所示。这完成了Mesh-Attention的工作负载分配。
分组策略与对称性:由于AM上KV索引的重新排序,每个GPU不仅计算本地的Q-KV注意力块,而且还具有对称属性,这体现在以下分组策略中:
* Q组:每个Q组i由一组a个连续的GPU形成,索引为{$ai + x$}$_{x=0}^{a-1}$,$i \in \{0, . . . , b − 1\}$。
* KV组:每个KV组i由模a同余的GPU形成,索引为{$i+ax$}$_{x=0}^{b-1}$,$i \in \{0, . . . , a − 1\}$。
每个GPU i恰好是一个Q组和一个KV组的成员。索引为{$a\lfloor i/a\rfloor + x$}$_{x=0}^{a-1}$的Q块通过Q组获取;索引为{$i \pmod a +ax$}$_{x=0}^{b-1}$的KV块通过KV组获取。因为$Q_i$和$KV_i$最初由GPU i持有,所以通信量为a-1个Q块,加上b-1个KV块,再加上a-1个O块。
网状通信结构:如果我们将每个组内的GPU连接成一个环,并让所有通信沿着这些环进行,整个系统就形成了一个环绕式网格(wrap-around mesh)。我们观察到一个有趣的特性:所有GPU可以同步地执行通信和计算。每个GPU只需知道其在Q组或KV组中的前驱和后继,即可执行相同的、锁步的通信操作,从而消除了由不平衡引起的“气泡”延迟。
3.3 计算流程
算法流程:基于分块的工作负载分布,我们指定了每个GPU的计算流程。图3中的分区意味着所有GPU的工作是对称的,即每个GPU的计算流程是相同的,因此我们只需要指定一次。每个GPU的计算流程在算法1中指定。每个GPU必须在执行计算之前从指定的Q组和KV组中获取远程块。所有计算完成后,会生成相应的O块,其中a-1个必须发送给Q组中“拥有”相应本地Q块的其他a-1个GPU。总的来说,所有O块生成后,它们必须被重排:每个GPU将其a-1个非本地块发送给适当的对等方,并从它们那里接收回自己块的部分输出。
具体步骤描述:算法1描述了没有任何重叠和优化的功能性计算流程。远程Q和KV块通过两次all-gather操作获取,注意力计算之后,输出的重排和归约通过一次reduce-scatter操作实现。具体来说,在第1行和第2行之后,任何GPU i获取了由AM中分配的工作负载决定的Q块{$a\lfloor i/a\rfloor + x$}$_{x=0}^{a-1}$和KV块{$i \pmod a + ax$}$_{x=0}^{b-1}$。因此,在第3行,GPU i能够计算分配给它的所有Q-KV块。在第4行,当GPU i发送其索引为({$a\lfloor i/a\rfloor + x$}$_{x=0}^{a-1}$ - {$i$})的$O^{(i)}$块时,它同时从同一Q组中的其他(a-1)个GPU接收相应的$O^{(...)}$块,并将它们归约为最终输出$O_i$。
高效调度的挑战:一个高效的分布式注意力实现需要精细的计算和通信调度,以便两者尽可能地重叠。上述计算流程带来了三个问题。首先,该算法仅指定了操作,但没有提供如何有效重叠通信和计算的直观方法。其次,如引言所述,当我们将问题表述为将计算操作放入每个通信时隙时,搜索空间是巨大的。第三,每个通信时隙的延迟不是固定的,它取决于硬件平台和数据的源/目的地,这妨碍了准确估计要放入时隙中的计算操作数量。
3.4 tile内通信调度
核心观察与策略:为了应对这些挑战,我们利用了Mesh-Attention通信模式中的一个关键观察:虽然AM中的计算在GPU之间按二维tile划分,但我们之前定义的每个Q/KV组是一维的。因此,我们可以在每个Q/KV组内应用Ring-Attention中成熟的计算和通信重叠技术,其中组内的GPU形成一个逻辑环。这个想法同时解决了后两个问题:(1)Q/KV组内更受限的通信和计算调度因附加假设而缩小了搜索空间;(2)将all-gather和reduce-scatter操作分解为具有固定延迟的小通信步骤,使得能够更准确地选择要重叠的计算操作。
P2P通信分解:具体来说,Q/KV组内的集合通信可以分解为细粒度的点对点(P2P)通信步骤:在all-gather的一步中,每个GPU将其本地数据发送给其后继,并转发从其前驱接收的数据,直到所有GPU都收集到完整的数据集。在reduce-scatter的一步中,每个GPU将一个块发送给其后继,并将接收到的块与适当的本地项进行归约,直到每个GPU获得其分配块的最终归约结果。
同步操作与索引映射:Q/KV组内的细粒度点对点通信步骤在组内GPU之间同步执行,但传输消息的源和目的地不同。为了为所有GPU推导出一个简洁且相同的操作规范,我们引入了一个如表1所示的从本地块索引(表示为Q#u、KV#u或O#u)到全局块索引的映射。我们假设tile形状为$a \times b$,总共有$n = a \times b$个tile。该映射指定了全局块索引作为GPU索引i、a、b和n的函数。每个GPU都有自己的输出块$O\#u^{(i)}$,因为每个GPU产生的部分输出根据其在组内逻辑环中的位置而不同。
表1:GPUi上的本地到全局块索引映射
P2P操作定义:根据分解方法和洞察,我们定义了每个GPU的P2P操作。每个GPU的后继和前驱是根据3.2节中描述的程序推导出的AM明确定义的。
* 引入三个初始化为零的变量$i_q$, $i_{kv}$和$i_o$来跟踪P2P操作的进展。最初,每个GPU持有其本地的Q和KV块,记为Q#0和KV#0。
* Recv Q/KV: 1) 将Q#$i_q$/KV#$i_{kv}$发送到Q/KV组中的后继,2) 从前驱接收Q#($i_q$+1)/KV#($i_{kv}$+1),3) 将$i_q$/$i_{kv}$增加1。
* Send O: 1) 将O#($i_o$+1)发送到Q组中的后继,2) 从前驱接收O#(($i_o$+2) mod a)并使用在线softmax归约到本地块中,3) 将$i_o$增加1。
调度约束与搜索空间缩减:值得注意的是,组内的所有GPU会同步执行Recv Q、Recv KV或Send O操作,并且除了第一步外,GPU发送的是上一步接收到的Q/KV块。图4阐述了基于图3中AM的GPU 4在Q和KV组中的两个操作步骤。总的来说,Q块上的all-gather被分解为Q组中每个GPU上的a-1次Recv Q操作;KV块上的all-gather被分解为KV组中每个GPU上的b-1次Recv KV操作。O块的reduce-scatter被分解为Q组中每个GPU上的a-1次Send O操作。这些操作由每个GPU调用,而无需显式指定要发送的数据,传输的数据由表1中的映射和特定操作被调用的次数隐式确定。这限制了计算的调度并减小了搜索空间。通常,对Q#i-KV#j的计算在分配的GPU接收到Q#i和KV#j后变为可执行状态,但在基于环的通信中,它在i-1次Recv Q操作和j-1次Recv KV操作完成后才变为可执行状态。这自然地禁止了如图1(e)上部所示的任意重排序,从而减小了搜索空间。此外,第i次Send O操作只能在对{$Q\#(i)-KV\#(j)$}$_{j=0}^{b-1}$的注意力计算完成后执行。
3.5 自动调度生成
调度生成问题:基于Q/KV组内基于环的分解通信,我们提出了一个贪心算法来推导高效的调度,旨在最大化计算和通信的重叠。参考图1(d),问题可以被构造成将一个步数分配给一个tile的AM中的每个元素以及相关的Q/KV块通信。我们在选择步数分配时遵循三个原则。
调度三原则:
* 快速且平衡地使计算就绪:每个远程Q/KV块的传输会使某些计算块变为可执行状态。若调度不当,这个转换可能太慢(如图5(a)的Ring-Attention)或不平衡(如图5(b)的Mesh-Attention行优先调度),导致用于重叠的计算量不足或过多。为解决此问题,我们静态地分析$c_Q, c_{KV}, c_O$(分别表示完全隐藏一次Q, KV, O块传输时间所需的最少计算块数),动态地获取$n_Q, n_{KV}$(表示因一次Recv Q/KV而变为可执行的计算块数)。因此,选择Q或KV通信的收益可视为$n_Q/c_Q$或$n_{KV}/c_{KV}$。在每一步,我们选择收益更高的通信。
-
为重叠触发恰到好处的计算:即使选择了好的Q和KV通信,也可能触发过多的可执行计算块。为解决此问题,我们利用分析得到的$c_Q, c_{KV}, c_O$值,仅调度不超过这些值的计算块。计算块应按行优先顺序触发,因为一旦整行完成,输出O就准备好可以发送。
-
降低非关键路径计算的优先级:对于每个GPU,其本地Q对应的行(本地块索引为0)是其负责的最终输出O的一部分,它只需接收归约后的O而无需向外发送部分O。换言之,该行的计算不在其他GPU等待的关键路径上,因此在必要时应降低其优先级。
完整算法描述:基于这三个原则,完整的算法如算法2所述。ComputeBlocks函数按照行优先顺序最多触发x个计算块。整个注意力操作大致分为三部分:(1)调度所有Recv Q/KV操作,旨在最大化“收益”,使所有计算块就绪(第4-9行);(2)如果仍有部分输出未计算,则触发相应行中的计算块以生成输出,将其发送给其他GPU,并调度适当数量的计算块进行重叠(第10-13行);(3)如果本地输出尚未生成,则完成它(第14-15行)。图5(e)为9-GPU示例提供了该算法的运行实例。
// 算法2:前向传播的自动调度生成
输入: cQ, cKV, cO,分别代表完全隐藏一次Q, KV, O块传输所需的最少计算块数
输出: steps
1 函数 ComputeBlocks(x):
2 按行优先顺序计算x个可执行的块(但第一行优先级最低)。如果可执行块少于x,则全部计算。
3 steps ← [], step_num ← 0
4 while 并非所有 Recv Q/KV 操作都已执行 do
5 设 Recv Q 或 Recv KV 解锁的可计算块数为 nQ 或 nKV
6 if nQ/cQ > nKV/cKV then
7 steps[step_num++] ← “执行 Recv Q,调用 ComputeBlocks(cQ) 进行重叠”
8 else
9 steps[step_num++] ← “执行 Recv KV,调用 ComputeBlocks(cKV) 进行重叠”
10 while 并非所有 Send O 操作都已执行 do
11 while Send O 无效 do
12 steps[step_num++] ← “调用 ComputeBlocks(1)”
13 steps[step_num++] ← “执行 Send O,调用 ComputeBlocks(cO) 进行重叠”
14 while 并非所有块都已计算 do
15 steps[step_num++] ← "ComputeBlocks(1)"
最优形状选择流程:该贪心算法为给定的a和b生成特定形状tile的调度。但对于给定的$n = a \times b$,a和b有多种选择。我们需要对每种设置分析$c_Q, c_{KV}, c_O$,根据贪心算法生成相应的最佳调度,估计运行时间,并选择具有最小运行时间的形状和调度。流程如图6所示。
3.6 Mesh-Attention的反向传播
反向传播的修改:为实现反向传播,Mesh-Attention需稍作修改。遵循FlashAttention【【索引6,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】】的重计算策略,每个Q组内传输的块包括前向输出O、其梯度dO、查询张量Q和log-sum-exp系数lse(数学符号记为OdOQ)。同时,KV组内传输的块仍为KV张量。计算后,非本地KV张量的梯度dK和dV(统称为dKV)必须沿着KV组传回,以便在其对应的GPU上进行归约;而Q张量的梯度dQ则沿着Q组进行归约。因此,我们需要四种类型的P2P通信:Recv OdOQ/KV和Send dQ/dKV。
调度约束:对OdOQ#i-KVj的计算在执行了i-1次Recv Q操作和j-1次Recv KV操作后变为可执行。至于梯度,第i次Send dQ操作只能在{OdOQ#i-KV#j}$_{j=0}^{b-1}$整行的注意力计算完成后执行。类似地,第i次Send dKV操作只能在{OdOQ#i-KV#j}$_{i=0}^{a-1}$整列的注意力计算完成后执行。
反向传播调度策略:我们分析$c_{OdOQ}, c_{KV}, c_{dQ}, c_{dKV}$,它们分别代表完全隐藏传输一个OdOQ, KV, dQ或dKV块所需的最少计算块数,并遵循与前向传播类似的调度原则。然而,简单的行优先计算顺序不再适用,因为它会将所有dKV块的传输推迟到最后,妨碍了与计算的充分重叠;同样,列优先顺序会推迟所有dQ块的传输。这一挑战源于Q(行)和KV(列)维度都涉及必须发送的梯度。
交替行列计算方案:为解决此问题,我们需要一个交替完成行和列的方案,这可以通过贪心方法实现。选择下一个要计算的块时,有两个候选:行优先的块和列优先的块。我们假设沿着行再计算$n_{dQ}$个块,或沿着列再计算$n_{dKV}$个块后,可以发送一个dQ/dKV块。由于总计算时间是固定的,通信应尽早发起以最大化重叠。因此,我们通过比较$c_{dQ}/n_{dQ}$和$c_{dKV}/n_{dKV}$的收益来做决策:值较大者表明一个梯度块可以更早地被发送(按发送成本加权)。此外,由于选择行优先或列优先的块会使$n_{dQ}$或$n_{dKV}$减一,这种贪心选择会产生正反馈效应:只要可能,它会自然地继续直到完成整行或整列的计算,满足了交替完成行和列的启发式策略。算法3展示了具体细节。
3.7 Causal Mask 机制
因果掩码支持:因果掩码确保每个token只能关注自身及之前的token,这强制了自回归生成所需的从左到右的依赖约束。Ring-Attention可以通过Striped Attention【【索引4,Striped attention: Faster ring attention for causal transformers,2023】】或Zig-Zag Attention【【索引27,Feature request: Balancing computation with zigzag blocking,2024】URL: https://github.com/zhuzilin/ring-flash-attention/issues/2】来支持因果注意力 。
Mesh-Attention的实现:Mesh-Attention可以自然地通过应用Striped Attention来支持因果掩码。首先,每个块收集所有索引与n取模后余数相同的token的数据。例如,$Q_i$/$KV_i$包含token {$i + nx$}$_{x=0}^{N/n-1}$的查询/键值投影。然后,我们调整KV块到GPU的映射,以保留之前讨论的局部性属性。图7展示了这一过程的示例,其中有16个token和4个块,第一步后,KV1和KV2被交换,以便每个GPU可以在其本地Q和KV块上执行计算。可以看到,每个计算块都严格遵循因果掩码,并且计算在GPU之间是平衡的。Zig-Zag Attention也可以适配到Mesh-Attention,但因过程复杂,本文未提供细节。
3.8 理论通信复杂度分析
分析设置:我们对Ring-Attention、DS-Ulysses、StarTrail和Mesh-Attention进行了通信复杂度分析。假设批量大小为1,序列长度为N,并行度(即GPU数量)为n,隐藏大小为d。我们只讨论前向传播。
- Ring-Attention【【索引14,Ring attention with blockwise transformers for near-infinite context,2023】】:每个GPU向环中的后继发送n-1个KV块,每个块大小为$2Nd/n$,导致通信量为$(2 - 2/n)Nd$。
- DS-Ulysses【【索引13,Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models,2023】】:前向传播包含4次Q、K、V和O块的all-to-all操作,每次计为$(n-1)/n^2Nd$。因此,每个GPU的通信量为$4(n-1)/n^2Nd$。
- StarTrail【【索引16,Startrail: Concentric ring sequence parallelism for efficient near-infinite-context transformer model training,2025】】:C是StarTrail的可配置超参数。通信分为四个阶段,每个GPU的总通信量高达$((4C-1)/n + 2/C)Nd$个元素。
- Mesh-Attention:每个GPU的通信量包括接收a-1个Q块、接收n/a-1个KV块和发送a-1个O块。因此,每个GPU的通信量由以下方程计算:
根据算术-几何平均不等式(AM-GM),当$a \to \sqrt{n}$时,每个GPU的通信量最小化,得到最优通信量为$(4\sqrt{n}/n - 4/n)Nd$。实际上,Mesh-Attention的性能由多种因素决定,因此选择的因式分解可能不是理论上通信复杂度最低的那个。
复杂度对比:在表2中,我们比较了这四种序列并行方法的理论通信量。StarTrail的通信复杂度比Mesh-Attention差,因为它存在冗余通信。相比之下,Mesh-Attention的通信结构确保每个GPU接收到的Q或KV块都被充分利用于计算,消除了冗余通信。此外,在计算和通信重叠方面,StarTrail的多阶段工作流存在缺点,其all-gather操作只能与QKV投影重叠,而其他阶段缺乏相应的计算来重叠,从而影响整体性能。除了通信复杂度,DS-Ulysses受限于头的数量,导致并行性有限。
表2:不同序列并行方法中的理论通信量(前向)
A4 实验环境
- 硬件配置:实验在一个配备了256个GPU的集群上进行。
- 模型架构参数:注意力配置方面,头(head)的数量为32,每个头的维度为128,因此总的隐藏层大小(hidden size)为4096。
A5 实验结果
4.2 整体性能
性能对比:如表3和表4所示,在不同的序列长度和GPU数量下,Mesh-Attention始终优于Ring-Attention。Mesh-Attention实现了高达3.4倍(平均2.9倍)的加速。在模型FLOPs利用率(MFU)方面,Mesh-Attention也始终表现更优,平均比Ring-Attention高2.5倍(最高达3.4倍)。两种方法的MFU都随着序列长度的增加而提高,因为注意力计算量随序列长度二次方增长,而通信量仅线性增长。带有因果掩码时的MFU通常较低,因为因果掩码将计算量减半,使得通信成为更严重的瓶颈。
表3:注意力前向+后向吞吐量(单位 10⁻² iter/s)
表4:注意力模型FLOPs利用率(MFU)(%)
4.3 可扩展性
强可扩展性:图8(a)展示了强可扩展性结果。我们固定序列长度为100万token,并改变用于并行化注意力算子的GPU数量。Mesh-Attention表现出比Ring-Attention更好的可扩展性:Ring-Attention只能扩展到64个GPU,使用更多GPU时性能显著下降。相比之下,得益于其更低的通信复杂度,Mesh-Attention可以扩展到128个GPU。Ring-Attention的最佳执行时间在64个GPU上实现(37.5秒),比Mesh-Attention在128个GPU上实现的最佳时间(11.9秒)慢3.15倍。
弱可扩展性:图8(b)展示了弱可扩展性分析。我们改变GPU数量,并相应地扩展序列长度,以保持每个GPU的计算量不变。由于注意力计算量与序列长度成二次方关系,当GPU数量翻倍时,我们将序列长度增加$\sqrt{2}$倍。我们观察到,由于通信开销更高,Ring-Attention的运行时间在扩展到更大规模时增长得更快。对于Ring-Attention,256个GPU的执行时间是32个GPU的3.74倍。相比之下,Mesh-Attention仅为2.83倍。
4.4 执行时间分解分析
性能优势来源:我们在图9a中展示了执行时间分解分析。结果证实,Mesh-Attention的性能优势主要来自于通信开销的减少:Mesh-Attention和Ring-Attention的计算时间几乎相同,而Mesh-Attention的通信等待时间比Ring-Attention少高达74.9%(平均74.0%)。值得注意的是,尽管显著减少,Mesh-Attention的通信开销仍然不可忽略。在256个GPU下,通信开销仍占总执行时间的86.6%。
4.5 通信分析
通信量对比:图9b展示了每个GPU的通信量。Mesh-Attention能够将通信量减少高达85.5%(平均78.2%)。值得注意的是,当扩展GPU数量时,Ring-Attention的通信量几乎保持不变,而Mesh-Attention的通信量显著下降。例如,256个GPU的Mesh-Attention的单GPU通信量比32个GPU时减少了51.8%。这一观察结果与我们在3.8节的理论分析一致,证明了Mesh-Attention更好的横向扩展通信特性。
4.6 峰值内存分析
内存消耗:我们在表5中分析了Mesh-Attention的峰值内存消耗。总的来说,Mesh-Attention比Ring-Attention消耗更多的峰值内存,因为它需要在整个注意力计算过程中缓存多个KV/Q块以进行数据重用。相比之下,Ring-Attention在内存中最多保留2个KV块和1个Q块。然而,值得注意的是,Mesh-Attention的高峰值内存消耗是暂时的:大部分GPU内存会在当前层的注意力前向/后向计算完成后迅速释放。Mesh-Attention不会增加反向传播所需的、需要暂存的前向激活量。
表5:注意力前向/后向峰值内存(GB)
4.7 GQA评估
GQA性能:除了标准多头注意力(MHA),我们还评估了Mesh-Attention在分组查询注意力(GQA)【【索引2,Gqa: Training generalized multi-query transformer models from multi-head checkpoints,2023】】上的性能。GQA通过一个因子g减少KV头的数量,并允许连续的g个Q头共享同一个KV头。随着g的增加,KV张量的大小减小,因此Ring-Attention的通信瓶颈得到缓解。我们在图10中比较了不同g值下Mesh-Attention和Ring-Attention的性能。在所有设置中,Mesh-Attention始终优于Ring-Attention。当g较大时(如4/8),Mesh-Attention仍然大大减少了通信开销。然而,由于Ring-Attention的通信问题不那么严重,Mesh-Attention的整体性能增益不那么显著。
A6 结论
本文提出了Mesh-Attention,一种新的分布式注意力算法,它为每个GPU分配一个二维的计算块(tile)而非一维的行或列,通过更低的通信-计算比(CommCom ratio)实现更高的效率。该通用方法将Ring-Attention作为一个特例包含在内,并允许通过不同的tile形状来调整CommCom比。重要的是,我们提出了一个贪心算法,该算法可以在有约束的条件下高效地搜索tile内的调度空间,以确保GPU之间的高效通信。理论分析表明,与其他现有算法相比,Mesh-Attention的通信复杂度要低得多,并表现出良好的可扩展性。大量的实验结果令人信服地证实了Mesh-Attention的优势。未来的工作将探索更先进的技术以进一步减少和隐藏通信成本,并研究内存效率更高的调度方案。
💬 评论讨论
欢迎在这里分享您的想法和见解!