MeshSlice: Efficient 2D Tensor Parallelism for Distributed DNN Training
- 作者: Hyoungwook Nam, Gerasimos Gerogiannis, Josep Torrellas
- 机构: University of Illinois at Urbana-Champaign
A1 主要贡献
本文旨在解决大规模DNN模型分布式训练中张量并行(TP)的通信瓶颈问题。现有的1D TP因通信成本高而可扩展性有限,而2D TP虽然能通过将矩阵分片到2D加速器网格中来减少通信,但其核心的通用矩阵乘法(GeMM)算法存在效率问题。具体来说,Cannon算法通信流量大;SUMMA算法同步开销高;而使用集体通信操作的2D GeMM无法将通信与计算重叠。此外,优化2D TP的众多参数(如数据流、网格形状、分片方式)非常困难,通常需要专家手动配置。
为应对这些挑战,本文做出了以下核心贡献:
- 提出新颖的MeshSlice算法:这是一种为分布式DNN训练中的2D TP设计的高效2D GeM M算法。MeshSlice通过将AllGather (AG) / ReduceScatter (RdS) 等集体通信操作切分为多个部分的集体操作,从而实现了通信与计算的重叠。这种方法有效隐藏了大部分通信延迟,解决了现有算法无法在行列两个维度上同时实现重叠的问题。
- 开发MeshSlice LLM自动调优器(Autotuner):该工具能够自动为大型语言模型(LLM)的训练找到最优的2D TP配置。它首先选择一个高效的2D GeMM数据流,然后利用分析性成本模型协同优化网格形状和通信粒度,从而替代了繁琐的人工调优过程。
- 全面的评估与实现:通过模拟训练GPT-3和Megatron-NLG模型的TPUv4集群,本文验证了MeshSlice的性能。结果显示,MeshSlice在高达256路的2D TP中仍保持高效率。在一个256个TPU的集群中,MeshSlice训练GPT-3和Megatron-NLG模型的速度分别比现有最先进的算法快12.0%和23.4%。此外,本文还在真实的Google TPUv4集群上实现了MeshSlice,验证了其切片操作的开销很小,且自动调优器的成本模型能准确估算通信和计算成本。
A3 背景知识/关键Observation/设计原则
2.1 分布式训练方法
-
分布式训练并行策略。大规模DNN(特别是LLM)通常在分布式计算设备系统上进行训练。本文将计算设备称为“芯片”,可以是ML加速器、GPU或CPU。在芯片间分配DNN计算主要有三种并行类型:数据并行(DP)、流水线并行(PP)和张量并行(TP)。为了最大化训练的可扩展性,现代LLM训练方法通常将这三种并行方式结合起来,形成3D芯片网络【索引5,Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning,2024,ASPLOS】、【索引22,Deepspeed: System optimizations enable training deep learning models with over 100 billion parameters,2020,KDD】、【索引27,Using DeepSpeed and Megatron to train MegatronTuring NLG 530B, a large-scale generative language model,2022,arXiv】、【索引29,Llama 2: Open foundation and fine-tuned chat models,2023,arXiv】。
-
数据并行(DP)。DP在不同芯片间划分输入数据【索引7,Project Adam: Building an efficient and scalable deep learning training system,2014,OSDI】。由于DNN参数(权重和偏置)在芯片间是复制的,唯一的通信发生在参数更新和同步时。DP的通信成本可以被有效隐藏,因为一层的参数更新通信可以与另一层的计算并行进行。
-
流水线并行(PP)。PP将不同的DNN层分配给不同的芯片【索引12,Gpipe: Efficient training of giant neural networks using pipeline parallelism,2019,NeurIPS】。通信只发生在流水线阶段的边界。PP的可扩展性受限于DNN模型的网络结构,且其开销随流水线阶段数量的增加而增加。
-
张量并行(TP)。TP在不同芯片间划分DNN层的所有矩阵(权重、输入和输出)。由于所有矩阵都被划分,TP的内存占用最小,但通信流量在三种并行方式中最大。TP在每次通用矩阵乘法(GeMM)计算中都会产生通信流量。在LLM训练中,这种通信流量由全连接(FC)层产生【索引16,Reducing activation recomputation in large transformer models,2023,MLSys】。
-
1D TP GeMM的通信模式。在1D TP GeMM中,权重矩阵按其输入或输出维度划分到一个1D芯片网络中。在前一种情况下,输入和权重分片相乘以计算部分输出,然后通过ReduceScatter(RdS)通信累加部分输出。在后一种情况下,通过AllGather(AG)通信从所有芯片收集输入分片,然后与权重分片相乘。
-
1D TP的可扩展性限制。1D TP的可扩展性受限,因为通信流量随芯片数量线性增长。这是因为输出分片必须在所有芯片间通过RdS进行累加,或者输入分片必须通过AG复制到所有芯片。因此,要实现1D TP的线性性能扩展,通信带宽必须随芯片数量呈二次方扩展。一个解决方案是NVIDIA的NVSwitch【索引19,DGX H100: AI for Enterprise,2023,NVIDIA】,它将8个GPU连接成一个全连接的ICI网络。这种方法的可扩展性有限,因为构建更大规模的全连接交换机的难度呈二次方增长。因此,大多数3D LLM训练集群将TP限制在8路【索引5,Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning,2024,ASPLOS】、【索引8,The Llama 3 herd of models,2024,arXiv】、【索引27,Using DeepSpeed and Megatron to train MegatronTuring NLG 530B, a large-scale generative language model,2022,arXiv】、【索引29,Llama 2: Open foundation and fine-tuned chat models,2023,arXiv】。
2.2 2D张量并行
-
2D TP作为更具扩展性的方案。一个更具扩展性的TP解决方案是2D TP,其中矩阵被划分到组织成2D网格(连接为2D环面)的芯片中。每个芯片本地持有一个矩阵的一个分片。然后,FC层使用一个2D分布式GeMM算法。
-
2D GeMM的通信优势。在2D GeMM中,每个分片只与同一行或同一列的芯片通信,而1D分布式GeMM中一个分片需要与所有其他芯片通信。这是因为计算输出矩阵的一个元素仅依赖于左输入矩阵的一行和右输入矩阵的一列。因此,2D TP产生的通信流量比1D TP少。所以,对于给定的通信成本,2D TP更具扩展性;对于给定的芯片数量,其通信成本更低。此外,构建大规模网格的硬件成本低于构建大规模全连接网络的成本。
-
利用2D TP扩展性的方式。2D TP的额外可扩展性可以有多种利用方式。例如,Llama 3在一个16K GPU集群上使用3D并行(DP+PP+TP)进行训练,其中仅使用了8路1D TP【索引8,The Llama 3 herd of models,2024,arXiv】。如果使用128路2D TP替代8路1D TP,我们可以构建一个16倍大的集群,拥有256K个芯片。这不仅实现了更大规模,还降低了DP的通信成本。因为现在每个芯片持有1/128的权重矩阵分片而不是1/8,每个芯片的DP流量比之前小了16倍。
-
利用2D TP优化现有集群。或者,我们可以在保持芯片总数不变的情况下,应用128路2D TP,并将DP和PP的并行度各降低4倍。在这个集群中,每个芯片的DP流量比使用8路1D TP的集群小64倍。由于通信成本降低和流水线阶段减少,我们可以预期计算利用率会比原始集群更高。
2.3 2D GeMM 算法
2.3.1 通用方面
-
2D GeMM的核心作用。2D GeMM是分布式DNN训练中2D TP的核心。尽管存在多种2D GeMM算法,但它们有许多共同之处。假设我们计算输出矩阵 $C=AB$,即左输入 $A$ 和右输入 $B$ 的乘积。所有三个矩阵都在行和列维度上被划分为分片,并分配给一个2D芯片网格。也就是说,在一个有 $N_r$ 行和 $N_c$ 列的网格中,$C$ 被划分为分片 $C_{00} \dots C_{(N_r-1)(N_c-1)}$。然后,$C_{ij}$ 存储在网格的第 $i$ 行和第 $j$ 列的芯片 $(i,j)$ 中。$A$ 和 $B$ 也同样处理。
-
2D GeMM的三种数据流。2D GeMM算法可以以三种可能的数据流【索引30,SUMMA: Scalable universal matrix multiplication algorithm,1997,Concurrency: Practice and Experience】、【索引36,The general matrix multiplyadd operation on 2D torus,2006,IPDPS】进行GeMM计算,如图1所示。在每种数据流中,三个矩阵($A$, $B$, $C$)中的一个的分片保持在其芯片中不动(stationary),而另外两个的分片则在垂直(行间)或水平(列间)方向上进行通信。
-
具体数据流介绍。在输出固定(OS)数据流(图1a)中,输出 $C$ 是固定的, $A$ 的每个分片被传输到同一行的芯片(列间),$B$ 的每个分片被传输到同一列的芯片(行间),结果是 $C=AB$。在左侧固定(LS)数据流(图1b)中,左输入 $A$ 是固定的,右输入 $B$ 的每个分片被传输到同一列的芯片,输出 $C$ 的每个分片被传输并累加到同一行的芯片中,结果是 $C=AB^T$。最后,右侧固定(RS)数据流(图1c)是LS数据流的对称版本,结果是 $C=A^TB$。LS和RS数据流分别等同于脉动阵列(systolic arrays)【索引24,A Systematic Methodology for Characterizing Scalability of DNN Accelerators using SCALE-Sim,2020,ISPASS】中的输入固定和权重固定数据流。
-
2D GeMM的通信流量成本。2D GeMM的通信流量取决于数据流和网格的形状。假设我们有一个 $N_r$ 行和 $N_c$ 列的网格,两个矩阵 $M_r$ 和 $M_c$ 分别在行间(垂直)和列间(水平)方向流动。每个矩阵分片必须被通信到其所在行或列的所有其他芯片。我们可以如下计算传输分片所需的时间(不包括同步和其他开销)。对于行间(垂直)传输,时间是 $(N_r - 1) \times size(M_r)/(N_r \times N_c)/BW_{row}$;对于列间(水平)传输,时间是 $(N_c - 1) \times size(M_c)/(N_r \times N_c)/BW_{col}$。这里,$BW_{row}$ 是行间链路带宽,$BW_{col}$ 是列间链路带宽。我们将这些时间称为流量成本。2D GeMM的流量成本是行间和列间流量成本的最大值,因为我们需要等待两个方向的通信都完成。
-
最小化流量成本的条件。如果 $BW_{row} = BW_{col}$,当 $(N_r - 1)/(N_c - 1) = size(M_c) / size(M_r)$ 时,流量成本最小化。然而,最小化流量成本的 $N_r$ 和 $N_c$ 值可能不会最小化总通信成本,因为还有同步和其他开销。
图1:2D GeMM算法的三种数据流。
2.3.2 Cannon算法
-
Cannon算法的特点和局限性。Cannon算法【索引4,A cellular computer to implement the Kalman filter algorithm,1969,PhD dissertation, Montana State University】是最早的2D GeMM算法之一。它只适用于方形网格。在计算开始前,矩阵分片以一种倾斜(skewed)的方式进行移位。然后,Cannon算法以脉动方式移位分片(使用SendRecv通信操作),同时计算部分乘积。Cannon算法是脉动阵列【索引14,In-datacenter performance analysis of a tensor processing unit,2017,ISCA】和3D/2.5D GeMM算法【索引1,A three-dimensional approach to parallel matrix multiplication,1995,IBM Journal of Research and Development】、【索引28,Communication-optimal parallel 2.5 D matrix multiplication and LU factorization algorithms,2011,Euro-Par】的基础。
-
Cannon算法的通信流量问题。Cannon算法的主要限制是其流量成本高于其他2D GeMM算法,原因有二。首先,开始时对矩阵分片进行的倾斜操作会产生额外的通信流量,这在其他算法中是不需要的。其次,虽然不同的网格形状会改变2D GeMM算法的流量成本,但Cannon算法只适用于方形网格。因此,当矩阵形状严重不平衡时,Cannon算法会比其他2D GeMM算法产生更高的流量成本。
2.3.3 SUMMA 算法
-
SUMMA算法的优势。SUMMA【索引30,SUMMA: Scalable universal matrix multiplication algorithm,1997,Concurrency: Practice and Experience】解决了Cannon的两个局限性,即它不需要倾斜矩阵分片,并且可以支持任何网格形状。图2a展示了SUMMA针对不同数据流的伪代码。SUMMA将矩阵划分为 $K \times K$ 个分片,其中 $K$ 是 $N_r$ 和 $N_c$ 的公倍数。然后,它在一个包含 $K$ 次迭代的循环中执行通信和计算。SUMMA在同一行或列的分片上使用广播(bcast)和归约(reduce)通信操作。由于2D环面的每一行或每一列都以环形拓扑连接,SUMMA运行环形广播和归约算法。
-
SUMMA LS算法示例。例如,假设我们在一个 $P \times Q$ 的网格上运行SUMMA LS算法(图2a,中)。$B_{ij}$ 是位于网格第 $i$ 行第 $j$ 列芯片上的 $B$ 的一个分片。在第 $k$ 次迭代中,发生以下情况。首先,对于所有的 $i$,芯片 $(i,k)$ 将它们的 $B_{ik}$ 分片广播到第 $k$ 列的芯片。然后,在所有芯片中,本地的 $A_{ij}$ 分片与 $B_{ik}$ 的转置相乘,产生部分结果 $C'$。最后,对于每一行 $i$,该行所有芯片中的部分结果 $C'$ 被归约到第 $k$ 列芯片 $(i,k)$ 的分片 $C_{ik}$ 中。
-
SUMMA算法的低效性。不幸的是,SUMMA的一对多广播和多对一归约通信在连接有高带宽网络链路的大型网格中是低效的。考虑一个广播操作,如图3左侧所示。为了在传输过程中利用环(一行或一列)中的所有链路,要广播的分片被分解成 $D$ 个数据包,以细粒度传输的方式在环上流式传输。流式传输分 $D+P-1$ 个流水线阶段完成。这里有两个开销来源:流水线中的气泡(bubbles)和同步。每个链路都会经历 $P-1$ 个气泡——一些在传输开始时,一些在结束时。此外,每个流水线阶段都需要一次同步,因此广播需要 $D+P-1$ 次同步。归约操作具有相同的通信模式,并遭受相同的开销。由于SUMMA中有 $K$ 次迭代,总同步开销以 $O(K^2)$ 的速度增长。
图2:SUMMA和Collective 2D GeMM算法针对三种数据流的伪代码。代码中,下标为row的表达式表示同一列内的行间通信,而下标为col的表达式表示同一行内的列间通信。A_ij是位于网格第i行第j列芯片上的A的分片。
2.3.4 Collective 2D GeMM
-
Collective 2D GeMM的引入。为了避免SUMMA的开销,一种流行的方法是使用AllGather (AG) 和 ReduceScatter (RdS) 【索引25,Parallel matrix multiplication: A systematic journey,2016,SIAM Journal on Scientific Computing】、【索引35,GSPMD: general and scalable parallelization for ML computation graphs,2021,arXiv】集体通信操作来执行2D GeMM。AG涉及所有按列或按行广播的并行执行。同样,RdS是所有按列或按行归约操作的并行执行。我们将这种方法称为Collective 2D GeMM。
-
算法对应关系。对于SUMMA算法的每种数据流,都有一个对应的Collective 2D GeMM算法【索引25,Parallel matrix multiplication: A systematic journey,2016,SIAM Journal on Scientific Computing】。图2b展示了三种数据流的算法。例如,在LS数据流中(图2b,中),SUMMA LS(图2a,中)中的多个
bcast_col
操作被合并成一个单一的allgather_col
操作,SUMMA LS中的多个reduce_row
操作被合并成一个单一的reducescatter_row
操作。计算也是在一步中完成,而不是像SUMMA那样在 $K$ 次迭代中完成。 -
Collective通信的效率优势。使用集体AG/RdS通信解决了bcast/reduce操作的低效问题。考虑一个在 $P$ 个芯片组成的环(一行或一列)中的AG操作,如图3右侧所示。在 $P-1$ 个步骤中的每一步,每个链路都向邻居传输一个完整的分片。因此,与bcast相比,AG消除了流水线气泡,传输了更大的数据包,并调用了更少的同步。RdS与reduce相比也是如此。此外,由于Collective 2D GeMM在每个方向上只调用一次AG或RdS,其总同步开销以 $O(P)$ 的速度增长。因此,AG/RdS比bcast/reduce更高效,并能获得更高的带宽利用率。
-
Collective 2D GeMM的主要局限性。Collective 2D GeMM的主要局限性是它不能将通信与计算重叠。由于没有循环迭代,并且计算与通信操作之间存在真实依赖关系,因此无法应用软件流水线来重叠它们。
-
Wang等人提出的部分解决方案。Wang等人【索引34,Overlap communication with dependent computation via decomposition in large deep learning models,2022,ASPLOS】通过将一个方向上的集体通信拆分为多个SendRecv通信,提出了一个部分解决方案。然后,通过应用软件流水线,将SendRecv通信与部分GeMM计算重叠。这种2D GeMM等同于FSDP【索引37,Pytorch FSDP: experiences on scaling fully sharded data parallel,2023,arXiv】和1D TP的组合。然而,这个解决方案只能划分和重叠一个方向上的通信操作;它不能重叠另一个方向上的通信操作。要在两个方向上都将AG/RdS操作划分为多个SendRecv操作,就需要使用Cannon算法,其局限性已在2.3.2节讨论。
图3:在一个P芯片环中,广播(Broadcast)和全收集(AllGather)操作的通信模式。
A2 方法细节
本文为2D TP做出了两项贡献。首先,提出了一种新的2D GeMM算法,解决了现有2D GeMM算法的局限性。其次,设计了一个LLM自动调优器,为LLM训练找到一个高效的2D TP配置。该LLM自动调优器优化了数据流、网格形状和通信粒度的配置。
我们提出的2D GeMM算法称为MeshSlice。图4可视化了先前算法的时间线,并与MeshSlice进行了比较。该图显示了计算、行间通信和列间通信的时间进展。Cannon需要进行倾斜操作且只支持方形网格形状,因此其流量高于其他算法,增加了总执行时间。SUMMA使用低效的bcast/reduce通信操作,由于细粒度的数据包而产生流水线气泡和同步开销。Collective算法不将集体通信与计算重叠。Wang的算法只划分了一个方向上的集体通信,因此另一个方向上的通信没有被重叠。最后,MeshSlice能够在两个方向上都将通信与计算重叠,从而实现最快的执行速度。
3.1 MeshSlice 2D GeMM 算法
-
MeshSlice算法的三个特性。MeshSlice算法具有三个特点。首先,它在两个方向上都划分并重叠通信操作。其次,它使用高效的AG和RdS通信操作,而不是bcast、reduce或SendRecv操作。最后,它支持任何网格形状和灵活的通信粒度。
-
MeshSlice的核心思想。为了将通信与计算重叠,我们需要将集体通信划分为更小的通信,并应用软件流水线。划分集体AG/RdS操作有两种现有方法。第一种是将其分解为多个SendRecv通信。Wang的算法【索引34,Overlap communication with dependent computation via decomposition in large deep learning models,2022,ASPLOS】将此方法应用于单个方向,而将其应用于两个方向则需要使用Cannon算法【索引4,A cellular computer to implement the Kalman filter algorithm,1969,PhD dissertation, Montana State University】。第二种方法是将AG和RdS分别划分为多个bcast和reduce操作。这种方法导致了SUMMA【索引30,SUMMA: Scalable universal matrix multiplication algorithm,1997,Concurrency: Practice and Experience】。
-
MeshSlice的新型划分方法。相反,MeshSlice引入了一种新的划分方法:将AG和RdS操作划分为部分的AG和RdS操作。该算法的核心是将矩阵分片切成 $K$ 个子分片。在一个 $K$ 次迭代的循环中,我们对一个子分片应用一个部分的AG或RdS,并计算一个部分的GeMM。当所有 $K$ 个子分片处理完毕后,算法完成。
-
MeshSlice算法伪代码概述。图5显示了在网格的每个芯片 $(i, j)$ 上运行的MeshSlice算法针对三种数据流的伪代码。每个算法都在一个有 $K$ 次迭代的循环中执行。这里我们给出算法的高层直觉;在后续部分,我们将详细解释这些操作。在OS算法中(图5,左),对于每个 $k=0 \dots K-1$,发生以下情况。首先,每个芯片使用
slice_col
沿列维度切分其本地矩阵分片 $A_{ij}$ 以获取其本地的第 $k$ 个子分片 $A_k$。类似地,它使用slice_row
沿行维度切分其本地的 $B_{ij}$ 以获取其本地的第 $k$ 个子分片 $B_k$。然后,每个芯片分别使用allgather_row
和allgather_col
操作,从同一行的所有芯片收集 $A_k$ 子分片,从同一列的所有芯片收集 $B_k$ 子分片。最后,每个芯片用收集到的子分片计算部分GeMM,并将结果累加到其本地输出分片 $C_{ij}$ 中。关键是,通过循环的软件流水线,一次迭代中的部分GeMM操作与另一次迭代中的AG和切片操作重叠。 -
LS和RS数据流的切片方法。我们可以对LS和RS数据流应用类似的切片方法。在LS数据流的每次迭代中(图5,中),$A_{ij}$ 和 $B_{ij}$ 分别沿其行和列维度进行切片。然后,同一列芯片中的第 $k$ 个子分片 $B_k$ 被
allgather
到 $B'$。接着,计算部分乘积结果 $C' = A_{ij}(B')^T$。最后,$C'$ 被reduce-scatter
到同一行芯片中的第 $k$ 个子分片 $C_k$。RS数据流(图5,右)遵循类似的流程。 -
MeshSlice的灵活性和挑战。与Collective和SUMMA一样,MeshSlice可以应用于任何形状的网格。此外,我们可以控制切片数 $K$ 来调整通信的粒度。一个小的 $K$(粗粒度)会导致软件流水线期间产生较大的未重叠的前导(prologue)和收尾(epilogue)。一个大的 $K$(细粒度)会减小前导和收尾的大小,但会因执行更多通信操作而增加总同步开销。考虑到这种权衡,对于不同的2D GeMM配置和硬件架构,存在不同的最优 $K$ 值。
-
MeshSlice的主要挑战。MeshSlice的主要挑战是设计一个正确且高效的切片机制。这不是一个简单的问题:大多数任意的切片方式都会导致计算错误。在下文中,我们描述了如何实现正确的切片机制。
图5:MeshSlice 2D GeMM算法在三种数据流中的伪代码。
3.1.1 MeshSlice算法的数学描述
- 2D GeMM计算的数学表示。假设我们在一个形状为 $N_r \times N_c$ 的网格上,以OS数据流计算一个2D GeMM,其中 $A \in R^{M \times K}$ 和 $B \in R^{K \times N}$。在每个芯片 $(i, j)$ 中,计算结果将是分片 $C_{ij} \in R^{M/N_r \times N/N_c}$,它是 $A_{i*} = \text{AllGather}(A_{ij}) \in R^{M/N_r \times K}$ 和 $B_{*j} = \text{AllGather}(B_{ij}) \in R^{K \times N/N_c}$ 的乘积。图6描绘了对于给定的 $i$ 和 $j$ 的 $A_{i*}$ 和 $B_{*j}$。然后,$A_{i*}$ 被分解为 $K$ 个列向量 $a_{i0}, \dots, a_{i(K-1)} \in R^{M/N_r}$,$B_{*j}$ 被分解为 $K$ 个行向量 $b_{0j}, \dots, b_{(K-1)j} \in R^{1 \times N/N_c}$。图中高亮并标记了其中的四个向量:$a_{i0}$, $a_{i(K-1)}$, $b_{0j}$, 和 $b_{(K-1)j}$。最后,$C_{ij} = A_{i*}B_{*j}$ 等价于 $K$ 个列向量和行向量的外积之和,如下所示。
图6:在2D GeMM中生成一个输出分片。
- MeshSlice的计算切分。我们的算法在一个包含 $S$ 次迭代的循环中对这个计算进行切片,其中一次迭代计算每 $S$ 个外积中的一个。对于 $S=3$,图7显示了循环的第一次迭代(紫色)和第二次迭代(黄色)访问的向量。算法1展示了这个算法。例如,第一次迭代将 {$a_{i0}b_{0j} + a_{iS}b_{Sj} + a_{i2S}b_{2Sj} + \dots$} 累加到输出分片 $C_{ij}$ 中。这次迭代访问的列向量 {$a_{i0}, a_{iS}, \dots$} 是通过AG从第 $i$ 行芯片的分片 {$A_{i0}, A_{i1}, \dots$} 中收集的。同样,行向量 {$b_{0j}, b_{Sj}, \dots$} 是从第 $j$ 列芯片的分片 {$B_{0j}, B_{1j}, \dots$} 中收集的。
图7:在MeshSlice中使用S=3对分片进行切片。
3.1.2 MeshSlice算法的详细实现
- 分片与向量的对应关系。在本节中,我们详细描述图5中呈现的MeshSlice算法,并证明其部分GeMM计算与算法1中的切片GeMM计算是相同的。首先,请注意 $A_{ij}$ 包含 $K/N_c$ 个 $A_{i*}$ 的列向量,而 $B_{ij}$ 包含 $K/N_r$ 个 $B_{*j}$ 的行向量,如下所示:
- 子分片的生成。在MeshSlice OS算法(图5,左)的第 $k$ 次迭代中,对 $A_{ij}$ 应用
slice_col
会收集 $A_{ij}$ 中的每 $S$ 个列向量,对 $B_{ij}$ 应用slice_row
会收集 $B_{ij}$ 中的每 $S$ 个行向量。我们称 $A_k$ 和 $B_k$ 分别为 $A_{ij}$ 和 $B_{ij}$ 的第 $k$ 个子分片。
- 子分片的全收集与计算。如果我们从网格同一行的所有芯片中AllGather(
allgather_row
)$A_k$ 子分片,并从同一列的所有芯片中AllGather(allgather_col
)$B_k$ 子分片,我们将得到以下 $A'$ 和 $B'$ 矩阵。
-
计算等价性与优化。这些 $A'$ 和 $B'$ 向量就是图7中所示的向量。然后,计算 $C_{ij} += A'B'$ 在数学上等同于计算算法1的第 $k$ 次迭代。
-
内存访问优化。我们的切片操作可能导致非连续的内存访问。例如,
slice_col
操作访问的列向量 $a_{i(jK/N_c+k)}, a_{i(jK/N_c+k+S)}, \dots$ 在内存中不是连续的。这在大多数内存子系统中是低效的。因此,我们进一步优化切片操作(slice_col
和slice_row
),使得 $v_c \in R^{M/N_r}$ 的列向量变成矩阵 $M_c \in R^{M/N_r \times B}$,其中 $B$ 是一个依赖于架构的块大小(例如,缓存行大小)。同时,$v_r \in R^{1 \times N/N_c}$ 的行向量变成矩阵 $M_r \in R^{B \times N/N_c}$。这种设计确保了连续的内存访问。 -
分块列切片算法示例。例如,分块列切片算法
blocked_slice_col
如算法2所示,其中rows
和cols
是本地分片的维度。块大小 $B$ 由硬件架构决定。例如,由于TPU通过2D的128x8块访问其内存【索引10,Cloud TPU performance guide,2024,Google】,我们为TPU设置 $B=8$。用户随后可以从cols/B
的因子中选择任何切片数 $S$。
3.2 MeshSlice LLM 自动调优器
-
优化参数的挑战。MeshSlice 2D GeMM算法有几个决定其通信成本和效率的参数。首先,网格的形状决定了流量成本,如2.3.1节所讨论。其次,将矩阵划分为网格中芯片的分片(即分片方式)有多种方法,不同的划分方式会改变2D GeMM的数据流。最后,我们需要确定MeshSlice的切片数 $K$,它影响同步开销和通信重叠。
-
手动优化的困难。找到这些参数的最优配置是一个难题,通常依赖于人工输入【索引35,GSPMD: general and scalable parallelization for ML computation graphs,2021,arXiv】。手动寻找最优参数需要系统架构的专业知识和昂贵的反复试验。
-
MeshSlice LLM自动调优器的设计。为了解决这个问题,我们设计了MeshSlice LLM自动调优器,它可以为LLM训练找到高效的参数配置。自动调优器的输入是LLM架构、训练超参数(例如,批量大小和输入序列长度)以及芯片集群可能的2D网格形状。自动调优器分两个阶段运行。首先,它为FC层确定高效的数据流和LLM张量的分片方式。其次,它使用分析成本模型联合优化集群的网格形状和每个FC层的切片数 $K$。
3.2.1 阶段1:数据流和分片
-
分片优化的复杂性。2D TP的第一个优化问题是选择正确的分片方式。给定一个网格,张量的分片方式是将网格维度映射到要划分的张量维度的映射。例如,给定一个2D网格和一个4D张量,有 $4P2=12$ 种可能的分片方式,因为我们需要选择两个张量维度分别在芯片的行和列之间进行切分。
-
自动寻找最优分片的难点。自动寻找最优分片方式很困难,原因有二。首先,有很多选择。例如,考虑两个4D张量的2D GeMM。由于我们有两个输入张量和一个输出张量,因此有 $(4P2)^3 = 1728$ 种可能的分片组合。其次,估计每种分片组合的性能很困难。一种分片方式的改变可能会导致不同的数据流,需要可能的重新分片(由于转置),并导致不同的计算效率。更糟糕的是,我们不仅要考虑前向计算的分片,还要考虑训练期间反向传播计算的分片。
-
先选择数据流再确定分片。在这项工作中,我们反向处理这个问题:自动调优器首先选择数据流,然后确定每个张量的分片方式。以2D GeMM $C=AB$ 为例,它将输入 $A$ 和权重 $B$ 相乘以计算输出 $C$。前向传播计算有三种2D GeMM数据流,如表1的Forward列所示。自动调优器选择使三个矩阵中最大的那个保持固定的数据流。例如,如果输出 $C$ 是最大的矩阵,则选择 $C$-stn。
-
处理前向和反向传播。为了训练一个DNN层,有两个计算过程:前向传播计算输出,和反向传播计算梯度。一个前向传播 $C=AB$ 会产生两个反向传播计算,即反向数据和反向权重。反向数据计算计算输入梯度 $A' = C'B^T$,作为输出梯度 $C'$ 和 $B$ 的转置的乘积。反向权重计算计算权重梯度 $B' = A^TC'$,通过将 $A$ 的转置与 $C'$ 相乘。每个反向数据和反向权重计算的计算和通信需求几乎与前向传播相同。
-
为反向传播选择数据流。给定前向传播的数据流,自动调优器从表1的同一行中为反向传播计算选择数据流。这确保了以下属性。首先,最大的矩阵在所有三个计算(前向和两个反向)中都保持固定。其次,每个矩阵及其梯度矩阵在所有三个计算中都沿相同方向流动。最后,没有矩阵需要被转置来计算反向传播。
-
数据流选择示例。举个例子,假设我们为前向传播选择 $B$-stn 数据流($C = OS(A, B^T)$)。这里 $B^T$ 在其初始化期间被静态转置。在前向计算中,$A$ 水平流动,$B^T$ 垂直流动。反向数据计算的数据流是 $A' = OS(C', B^T)$。与前向传播一样,$B^T$ 是固定的,$C'$ 水平流动,$B^T$ 垂直流动。反向权重计算的数据流是 $B'^T = OS(C', A)$,这使得 $A$ 固定,$C'$ 水平流动,$B'^T$ 垂直流动。总的来说,在所有计算中,$B$ 和 $B'$ 保持固定,因为它们是最大的矩阵,$A$ 和 $A'$ 水平流动,$C^T$ 和 $C'^T$ 垂直流动。
-
转置数据流和启发式选择。对于表1的每一行,我们可以转置所有矩阵并翻转两个非固定矩阵的数据流方向,以获得相应的转置数据流。例如,$C = OS(A, B)$ 的转置版本是 $C^T = OS(B^T, A^T)$。因此,对于每个层,有两种数据流选择(非转置和转置)可以使最大矩阵保持固定。为一个 $L$ 层的神经网络寻找绝对最优的数据流选择是一个有 $2^L$ 种可能性的搜索问题。因此,MeshSlice LLM自动调优器使用一个简单的启发式方法:对于每个层,默认选择非转置数据流,除非该层的输入 $A$ 需要被转置以保持非转置数据流。在大多数LLM中,这种启发式方法消除了层与层之间的转置。
表1:2D GeMM C=AB的三种数据流。每种数据流使A固定(A-stn)、B固定(B-stn)或C固定(C-stn)。 -
分片方式的确定。一旦自动调优器选择了数据流,三个矩阵的分片方式就自动确定了。矩阵的行在芯片的行之间分片,矩阵的列在芯片的列之间分片。
-
LLM中4D张量的分片。在实践中,LLM使用形状为 $(B, S, H, D)$ 的4D张量,其中 $B$ 是批量大小,$S$ 是序列长度,$H$ 是注意力头的数量,$D$ 是每个头的隐藏维度。在FC层中,4D张量被重塑为 $(B \times S, H \times D)$ 维度的2D矩阵。MeshSlice遵循划分2D矩阵两个最外层维度的简单原则。因此,4D张量的 $(B \times S)$ 维度在网格行之间分片,$(H \times D)$ 维度在网格列之间分片。
-
其他操作的分片。除了FC层,LLM网络还有许多其他操作。这些其他操作的分片对性能影响最小,因为它们不产生通信流量【索引16,Reducing activation recomputation in large transformer models,2023,MLSys】。因此,一旦FC层张量的分片确定,我们让其他张量遵循相同的分片方式以避免重新分片的流量。
3.2.2 阶段2:网格形状和切片数
-
优化目标。此阶段配置MeshSlice的两个剩余参数:网格形状和切片数。为此,我们设计了分析成本模型,为每个配置估算GeMM执行时间。然后,我们使用这些成本模型来协同优化这两个参数。
-
通信成本模型。我们通过离线测量小型ML加速器集群中的同步延迟、网络带宽和通信操作启动开销来构建通信的分析成本模型。集体通信操作的成本定义如下:
公式5 -
模型解释。这里 $T_{launch}$ 是操作启动的开销,$P$ 是行或列中的芯片数量,$L_{sync}$ 是同步延迟,$size(shard)$ 是要传输的分片大小,$BW$ 是测得的链路带宽。这个线性模型非常适合在行或列环上的AG/RdS通信。这是因为在环上的AG或RdS(图3,右)中,分片传输是同步的,没有网络竞争。
-
计算成本模型。我们的分析计算成本模型将本地GeMM的总FLOP数除以ML加速器的有效FLOPS吞吐量。有效FLOPS是通过在单个加速器芯片上对几个GeMM操作进行基准测试来测量的。这个计算模型对于LLM训练足够准确,因为LLM训练中的大多数GeMM都足够大,可以完全饱和ML加速器的计算吞吐量。为了提高准确性,可以在单个加速器芯片上测量计算执行时间,而不是使用分析模型。
-
执行时间估算。对于每个FC层,自动调优器将MeshSlice算法的计算加通信执行时间分解为三部分:前导(prologue)、稳态(steady-state)和收尾(epilogue)。前导和收尾分别是第一个和最后一个循环迭代中无法通过软件流水线重叠的操作。例如,在我们的图5的OS算法中,第一次迭代中的两个all-gather操作构成了前导,最后一次迭代中的部分GeMM计算构成了收尾。我们假设一个方向上的通信可以与另一个方向上的通信并行执行。因此,前导时间是两个AG操作中最长的时间,每次迭代的稳态时间是三个操作中最长的时间:两个AG操作和部分GeMM。$K$ 次迭代的总估算执行时间是:前导 + $(K-1) \times$ 每次迭代的稳态时间 + 收尾。
-
协同优化过程。利用成本模型,自动调优器使用穷举搜索来协同优化集群的网格形状和每个FC层 $l$ 的切片数 $K_l$。对于每个可能的网格形状,自动调优器通过搜索所有可能的 $K$ 值来为每个FC层 $l$ 调整 $K_l$。由于不同FC层的最优 $K_l$ 值相互独立,自动调优器一次只优化一个层的 $K_l$ 值。最后,它选择执行时间最短的配置。网格形状和切片数的搜索空间很小,因为它们只有少数可能的整数选择。因此,由于搜索空间小和分析成本模型简单,自动调优器在几秒钟内就能完成。
A4 实验环境
-
模型架构:
- GPT-3: 175B参数,最流行的LLM之一。
- Megatron-NLG: 530B参数,需要更大规模的分布式训练。
- 两种模型均基于Transformer架构,由多个相同的神经网络块堆叠而成,每个块包含多头注意力和前馈网络两个子网络,共四个全连接(FC)层。实验中只对FC层进行详细模拟,其他层在真实TPUv4上进行基准测试,因为它们不产生通信成本。
-
硬件配置:
- 模拟环境:
- 加速器: 模拟Google TPUv4。每个TPU节点包含一个TPU和一个主机。TPU有两个核心和一个NIC共享HBM内存。每个核心有64MB暂存内存和四个128x128脉动阵列。
- 网络: NIC连接到一个带有四个ICI链路的路由器,将TPU连接成2D环面网络。
- 真实硬件环境:
- 使用Google Cloud上的4x4 TPUv4集群进行真实硬件测试和模拟器校准。
- 模拟环境:
-
软件与模拟器配置:
- 模拟器: 使用定制化的SST模拟器。修改了SST的rdmaNic模拟器以模拟NIC,并使用DRAMSim3模拟HBM内存。
- 实现: 在Jax中实现了MeshSlice,使用了
shard_map
进行计算划分,dynamic_slice
实现分块切片算法。 - 基线算法:
- 2D基线: Cannon, SUMMA, Collective 2D GeMM, Wang's algorithm。
- 1D基线: 1D TP (Sequence Parallelism), Fully-Sharded Data Parallelism (FSDP)。
- 模拟器校准: 通过在4个真实TPUv4芯片组成的环上进行基准测试来校准模拟器参数(频率、带宽、延迟等),确保模拟与真实硬件的执行时间差异小于10%。
-
成本模型构建:
- 通信成本模型: 在2芯片和4芯片的TPUv4集群上,通过对不同分片大小(8KB到512MB)的集体通信操作进行基准测试,使用线性回归拟合出启动开销、同步延迟和带宽等参数。
- 计算成本模型: 通过在单个TPUv4上分析几个GeMM操作的性能来测量有效计算吞吐量。
A5 实验结果
5.1 分布式GeMM算法性能
-
弱扩展性能 (Fig 9):
- 实验内容: 比较MeshSlice与四种2D基线和两种1D基线在不同集群规模(16到256个芯片)下的弱扩展性能。批量大小随芯片数量增加而增加(批量大小 = 芯片数 / 2),序列长度为2048。
- 实验结果: MeshSlice在所有情况下都是最快的算法,并随着芯片数量增加保持了良好的效率。在256路并行下,对于GPT-3和Megatron模型,MeshSlice分别比最先进的Wang算法快13.8%和26.0%。计入非FC层的性能后,端到端训练速度分别提升了12.0%和23.4%。从16路扩展到256路,MeshSlice的效率仅分别下降了16.8%(GPT-3)和5.8%(Megatron-NLG)。
- 分析结论: 其他算法(SUMMA、Cannon、1D TP、FSDP)在大规模集群中效率低下,甚至慢于不进行通信计算重叠的Collective算法。
图9:在弱扩展情况下,不同分布式GeMM算法在FC层上的FLOP利用率。上图为训练GPT-3,下图为训练Megatron。
-
通信成本分解 (Fig 10):
- 实验内容: 在256芯片集群上,分析各算法在FC层训练中的总通信时间(相对于其计算时间)的构成,包括启动(launch)、传输(transfer)和同步(sync)开销。
- 实验结果:
- Cannon: 由于需要方形网格和额外的倾斜操作,流量成本高。
- SUMMA: 同步开销巨大,且随网格规模二次方增长,导致其在大规模网格中效率极低。
- 1D方法 (1D TP/FSDP): 固有的高流量和只能利用环形拓扑(2个链路而非4个)导致通信时间长。
- Collective: 总通信时间最短,但无法与计算重叠。
- MeshSlice/Wang: 通信时间略高于Collective(因多次启动或同步),但MeshSlice能有效隐藏大部分通信时间。
- 分析结论: 该分析揭示了各基线算法的性能瓶颈,并证实了MeshSlice通过重叠通信来隐藏大部分通信开销的有效性。
图10:在256芯片集群中,不同算法在FC层中的通信时间分解(相对于各自的计算时间)。上图为GPT-3,下图为Megatron。
-
强扩展性能 (Fig 12):
- 实验内容: 比较各算法在固定批量大小(32)下的强扩展性能。
- 实验结果: 在16芯片的计算密集型场景下,所有算法效率都很高,差异不大。在256芯片的通信密集型场景下,通信成本占主导,MeshSlice通过重叠通信带来的增益减小,其利用率接近Collective和Wang,但仍优于1D TP和SUMMA。
- 分析结论: MeshSlice无论在计算密集型还是通信密集型场景下都是一个安全且高效的选择。
图12:在强扩展情况下,不同分布式GeMM算法在FC层上的FLOP利用率。上图为训练GPT-3,下图为训练Megatron。
-
不同矩阵形状下的性能 (Fig 11):
- 实验内容: 评估各2D GeMM算法在GPT-3和Megatron-NLG训练过程中遇到的16种不同形状的GeMM操作上的性能。
- 实验结果: MeshSlice在所有16种GeMM变体上都表现出一致的性能优势,平均比Collective和Wang分别快27.8%和19.1%。
- 分析结论: MeshSlice的优势在执行时间较长的大型GeMM操作上更为明显。
图11:在256芯片集群中,不同2D GeMM算法在不同矩阵形状(M,N,K)下的FLOP利用率。
5.2 LLM自动调优器和成本模型
-
数据流优化的有效性 (Table 2):
- 实验内容: 比较使用默认数据流(B-stn)和经过MeshSlice自动调优器优化后的数据流的FC层训练FLOP利用率。
- 实验结果: 数据流优化为GPT-3和Megatron分别带来了21.2%和5.1%的性能提升。
- 分析结论: 自动调优器选择的数据流能显著减少不必要的通信开销,尤其是在计算量相对较小的模型(如GPT-3)上,效果更为显著。
表2:在256芯片集群中,有无MeshSlice数据流优化的FC层训练FLOP利用率对比。
-
成本模型的准确性 (Fig 13, 14):
- 实验内容: 比较自动调优器的分析成本模型预测的FLOP利用率与模拟器得到的实际FLOP利用率,评估其在选择最优网格形状和切片数(K)方面的准确性。
- 实验结果:
- 网格形状 (Fig 13): 成本模型准确地预测出了能带来最高性能的最优网格形状(相比非最优形状,性能可提升2.4倍)。
- 切片数K (Fig 14): 成本模型找到的最优切片数与通过模拟找到的最优值完全一致。
- 分析结论: MeshSlice LLM自动调优器使用的分析成本模型虽然简单,但足够准确,能够可靠地找到高性能的MeshSlice配置。
图13:自动调优器成本模型估算的FLOP利用率与模拟得到的利用率对比,针对256芯片集群的不同网格形状。
图14:自动调优器成本模型估算的FLOP利用率与模拟得到的利用率对比,针对32x8网格中的不同切片数K。
5.3 在真实硬件上的MeshSlice性能
-
在4x4 TPU网格上的性能比较 (Table 3):
- 实验内容: 在一个真实的4x4 TPUv4集群上运行Collective、Wang和MeshSlice算法。当前TPUv4不支持AG/RdS操作与计算的重叠。
- 实验结果: 在无法重叠的情况下,MeshSlice比Collective慢约4.5%,这部分开销主要来自更细粒度的部分GeMM和部分AG/RdS操作,而切片操作本身的开销仅为1.3%。同时,预估如果支持重叠,MeshSlice将比Collective在GPT-3和Megatron上分别快38.6%和32.8%。
- 分析结论: MeshSlice的切片机制本身非常高效,开销极小。一旦硬件支持其所需的异步通信,将能释放巨大的性能潜力。
表3:在真实的4x4 TPUv4集群中,2D GeMM算法在FC层上的FLOP利用率。MeshSlice-Overlap显示了如果AG/RdS与计算重叠时的估算值。
-
通信成本模型的验证 (Fig 15):
- 实验内容: 将通信成本模型估算的8个不同FC层的总通信时间与在真实硬件上测量的通信时间进行比较。
- 实验结果: 模型的估算值与实测值非常吻合,平均误差仅为5.1%。
- 分析结论: 验证了通信成本模型的高准确性,这得益于行/列通信中无网络竞争的特性,为自动调优器的可靠性提供了有力支撑。
图15:MeshSlice中8个不同FC层的估算和实测总通信时间对比。
A6 结论
本文提出了MeshSlice算法,一种为分布式DNN训练设计的高效2D张量并行方法。MeshSlice通过将通信操作切分为多个部分,并利用软件流水线在行列两个维度上都实现了通信与计算的高效重叠,从而解决了现有2D GeMM算法(如Cannon、SUMMA、Collective GeMM)存在的流量大、同步开销高或无法重叠等问题。此外,本文还设计了MeshSlice LLM自动调优器,该工具能够通过选择高效的数据流,并利用精确的成本模型协同优化加速器网格形状和通信粒度,从而自动化了复杂的性能调优过程。
在模拟的256个TPUv4集群上的评估表明,MeshSlice在训练GPT-3和Megatron-NLG模型时,端到端性能分别比当前最先进的算法快12.0%和23.4%。
未来的工作方向包括:
1. 扩展到GPU集群:通过在GPU集群的物理网络上构建逻辑网格,将MeshSlice应用于更广泛的硬件平台,并相应调整自动调优器以考虑网络竞争。
2. 应用于推理场景:调整MeshSlice及其自动调优器以适应推理任务中更可能出现的内存瓶颈。
3. 支持其他DNN层:将MeshSlice应用于可转换为GeMM操作的其他层,如卷积层,或用于优化GNN中的2D分布式稀疏GeMM。
4. 结合专家混合(MoE)模型:将MeshSlice的2D TP与MoE的专家并行(EP)相结合,以支持更大规模模型的训练。
💬 评论讨论
欢迎在这里分享您的想法和见解!