高效扩展Transformer推理

作者: Reiner Pope, Sholto Douglas, Aakanksha Chowdhery, Jacob Devlin, James Bradbury, Anselm Levskaya, Jonathan Heek, Kefan Xiao, Shivani Agrawal, Jeff Dean


主要贡献

核心问题与研究目标:随着Transformer模型参数扩展到千亿(如【6, Brown et al., Language models are few-shot learners, 2020, NeurIPS】、【23, Kaplan et al., Scaling laws for neural language models, 2020, arXiv】、【34, Rae et al., Scaling language models: Methods, analysis & insights from training Gopher, 2021, CoRR】、【21, Hoffmann et al., Training compute-optimal large language models, 2022, arXiv】)乃至五千亿以上(如【11, Chowdhery et al., PaLM: Scaling language modeling with Pathways, 2022, arXiv】、【41, Smith et al., Using deepspeed and megatron to train megatronturing nlg 530b, a large-scale generative language model, 2022, arXiv】),其在各类应用中的实用价值日益凸显。然而,这些模型的有效部署极具挑战性。与训练阶段的高度并行性不同,生成式推理是逐个token进行的,每个token的计算都依赖于前一个生成的token,这导致其并行性大大降低。本文旨在提供一套工程原则,用于在各种具挑战性的生产环境中高效地服务于大规模Transformer模型。

研究挑战
1. 内存占用:大模型的参数和解码过程中的瞬时状态(如KV缓存)内存占用巨大,通常无法装入单个加速器芯片。
2. 延迟要求:生成式推理的并行性较低,使得满足严格的延迟目标变得尤为困难。巨大的内存占用导致从高带宽内存(HBM)加载参数和KV缓存到计算核心需要大量的内存流量,这对满足延迟目标构成了挑战。
3. 计算成本:注意力机制的推理成本随输入序列长度呈二次方增长(如【43, Sukhbaatar et al., Adaptive attention span in transformers, 2019, arXiv】、【10, Choromanski et al., Rethinking attention with performers, 2020, arXiv】、【13, Dao et al., Flashattention: Fast and memoryefficient exact attention with io-awareness, 2022, arXiv】)。

核心创新点
1. 强大的分区框架:本文建立了一个强大且抽象的分区框架,以充分挖掘Transformer推理有限并行性下的模型并行扩展潜力。在该框架内,作者通过分析模型为给定模型大小和特定应用需求求解最佳分区策略。这使用户能够直观地理解各种权衡,并为其应用选择最佳的多轴张量分区策略、批量大小和芯片配置,而非依赖黑盒式的详尽搜索(如【54, Zheng et al., Alpa: Automating inter-and intra-operator parallelism for distributed deep learning, 2022, arXiv】、【51, Xu et al., GSPMD: general and scalable parallelization for ml computation graphs, 2021, arXiv】)。
2. 内存优化与多查询注意力:本文应用了内存优化技术,并充分利用PaLM的多查询注意力机制来减少不必要的张量开销,最大化在给定芯片数量上可容纳的批量大小,从而实现更高的吞吐量。
3. 性能突破
- 在一个540B参数的密集模型上,使用64个TPU v4芯片,本文实现了在生成过程中每个token 29ms的低批量大小延迟(使用int8权重)和在处理输入token时76%的模型FLOPS利用率(MFU),同时支持2048个token的长上下文。
- 对于一个类似聊天机器人的交互式应用,本文的实现可以在1.9秒内处理用户64个token的输入,查询1920个token的对话历史,并生成一个64个token的响应。
- 对于一个面向吞吐量的离线应用,本文的实现可以处理1984个输入token并生成64个输出token,整体FLOPS效率达到73%。

图1: PaLM模型的成本与延迟关系图。上下文长度为2048。每条线上的点代表效率与延迟的帕累托前沿。C为芯片数,B为批量大小。左图:生成64个token的每token延迟,假设上下文已处理。右图:处理2048个输入token的时间;不包括生成任何输出token的时间。表2和表3展示了帕累托前沿上针对低延迟或高吞吐量应用的几个具体场景的详细信息。
图1: PaLM模型的成本与延迟关系图。上下文长度为2048。每条线上的点代表效率与延迟的帕累托前沿。C为芯片数,B为批量大小。左图:生成64个token的每token延迟,假设上下文已处理。右图:处理2048个输入token的时间;不包括生成任何输出token的时间。表2和表3展示了帕累托前沿上针对低延迟或高吞吐量应用的几个具体场景的详细信息。

推理成本权衡

衡量推理成本的指标。本文使用延迟(latency)、吞吐量(throughput)和模型FLOPS利用率(model FLOPS utilization, MFU)来衡量推理成本。延迟分为处理初始输入token的时间(称为“prefill”)和自回归生成输出token的时间(称为“decode”)。解码延迟也可以用“每步”来衡量。吞吐量是每秒处理或生成的token数量。MFU是观测到的吞-吐量与硬件在峰值FLOPS下无内存或通信开销的理论最大吞吐量之比。

内存成本。权重和KV缓存等张量存储在设备上的高带宽内存(HBM)中,是内存占用的主要部分。在每次模型前向传播(prefill或decode步骤)中,这些张量需要从HBM传输到计算核心,这部分时间称为“内存时间”。在小批量和短序列下,加载权重的时间占主导;而在大批量和长序列下(如批量大小512+,序列长度2048+),加载KV缓存的时间占主导。

计算成本。一个N参数的仅解码器模型,每次前向传播每个token需要2N次矩阵乘法浮点运算(FLOPs)【23, Kaplan et al., Scaling laws for neural language models, 2020, arXiv】。如果所有芯片都以峰值FLOPS运行,这部分时间称为“计算时间”。对于大模型,注意力机制中的矩阵乘法FLOPs通常占比较小,但其KV缓存会占用大量内存容量和带宽。

预期权衡与挑战

通信瓶颈。加载权重和非注意力计算的时间与模型大小成正比,与芯片数量成反比。然而,对于给定的分区布局,芯片间通信所需的时间随着芯片数量的增加下降得更慢(甚至不变),因此随着芯片数量的增多,通信成为一个日益重要的瓶颈。

低延迟场景。若应用要求尽可能低的延迟,需要使用更多芯片并尽可能地对模型进行分区。小批量通常可以实现更低的延迟,但这也会导致更差的MFU,从而增加每个token的总成本。

长上下文场景。长注意力上下文会显著增加推理时间。对于一个超过500B参数且使用多头注意力的模型,当批量大小为512、上下文长度为2048时,KV缓存总计可达3TB,是模型参数大小的3倍。在生成每个token时,片上内存都需要从片外内存加载一次这个巨大的KV缓存,期间计算核心基本处于空闲状态。

高吞吐量场景。若应用是离线推理且不关心延迟,主要目标是最大化单芯片吞吐量。最有效的方法是增加批量大小,因为大批量通常能带来更好的MFU。某些对小批量效率不高的分区策略,在批量增大时会变得高效。

推理设置

基本设置与符号。本文考虑一个拥有n_params个参数的Transformer模型,部署在n_chips个芯片上。模型的模型维度(或嵌入维度)为d_model(或E),前馈网络中间维度为d_ff(或F),拥有n_heads(或H)个头。

推理阶段。一批B个序列中的每个样本都有L_input个输入token和生成L_gen个输出token。由于输入token在推理开始时全部存在,模型可以并行处理所有B × L_input个token,这个单次前向传播步骤称为prefill。输出token是自回归生成的,通过一个L_gen步的序列循环。每一步包含一次模型的前向传播,之后为批次中的B个样本每个采样一个新token。这个循环称为generationdecode。由于prefill可以并行处理L_input个token,而decode必须顺序处理L_gen个token,这两个阶段具有不同的性能特征,因此需要分开分析。

为推理效率进行分区

大型模型必须在多个芯片上进行分区,以容纳权重和激活张量,并使计算和内存时间满足延迟要求。模型分区会引入芯片间的通信,不同的分区策略涉及不同的通信模式和通信量。本节详细介绍了几种用于Transformer大语言模型的高级分区策略,旨在实现成本效益和延迟效益。

分区符号和通信集合操作

分区布局表示法。本文基于具有3D环面拓扑(X × Y × Z)的TPU v4系统来描述分区布局。遵循【51, Xu et al., GSPMD: general and scalable parallelization for ml computation graphs, 2021, arXiv】的表示法,下标用于指定被分区的张量维度。例如,BLE_xyz表示逻辑形状为BLE的张量,其最后一个维度E被分成了X × Y × Z个分区,其中x, y, z对应物理TPU v4的轴,每个芯片上的张量形状为 [B, L, E/(X × Y × Z)]。B、E和F分别代表批量、模型嵌入和MLP前馈维度。L代表序列长度。

通信集合操作。如果一个张量在某个轴(如x)上是复制的,该轴会从表示法中省略。后缀“partialsum-x”表示一个张量在每个芯片上本地进行了收缩(求和),但仍需在TPU的x轴上对芯片间的结果进行求和。本文使用源自MPI的几种通信集合操作【12, Clarke et al., The mpi message passing interface standard, 1994, Springer】:
- all-reduce(x):对x轴上的芯片集合中的partialsum张量进行求和,并将总和广播回所有参与的芯片。
- reduce-scatter(x)all-reduce的归约阶段。它对x轴上的张量求和,但产生一个在该轴上分片而非复制的输出。
- all-gather(x)all-reduce的广播阶段。它将分片张量广播并连接到x轴上的所有芯片,产生一个复制的、比输入大X倍的输出。
- all-to-all:将分片从一个张量维度转移到另一个维度,例如从BLH_xQB_xLHQ
图A.1展示了这些原语。

前馈层的分区

前馈层,1D权重固定布局 (1D weight-stationary layout)

概述。当模型无法放入单个芯片时,最简单的分区策略是1D权重固定,即每个E × F的权重矩阵沿着E或F轴在n_chips个芯片上进行分区。每个芯片上的权重分片与相应的激活分片相乘,结果通过all-gather和/或reduce-scatter在芯片间聚合。一个技巧是【40, Shoeybi et al., MegatronLM: Training multi-billion parameter language models using model parallelism, 2019, CoRR】,在计算两个连续的矩阵乘法(如Transformer的MLP块)时,如果第一个矩阵乘法按输出轴分区,那么每个芯片上产生的激活分片恰好是第二个按输入轴分区的矩阵乘法所需要的,从而避免了矩阵乘法之间的跨芯片通信。随着并行芯片数量的增加,内存延迟和计算延迟通常呈近线性下降,但通信延迟基本保持不变,因为整个激活矩阵都需要在芯片间聚合。因此,当芯片数量增多时,通信成为瓶셔颈。

细节。基线布局中,前馈层的权重和激活沿d_ff维度在n_chips个芯片上分区,如Megatron【40, Shoeybi et al., MegatronLM: Training multi-billion parameter language models using model parallelism, 2019, CoRR】。图2(a)展示了该布局。在TPU v4的3D环面拓扑上,权重的分区布局为EF_xyzF_xyzE,即在X, Y, Z三个物理轴上共分区为X × Y × Z = n_chips份。权重固定在每个芯片上,而激活在芯片间传输以匹配权重布局,需要一次all-gather和一次reduce-scatter。在这种1D权重固定分区策略中,每个芯片在reduce-scatterall-gather中的输入和输出形状均为BLE。附录A.1推导了这些操作的通信成本,其通信时间为:

前馈层,2D权重固定布局 (2D weight-stationary layout)

概述。对于更多的芯片,一个更经济的策略是同时沿E和F轴对每个E × F权重矩阵进行分区,使每个分片大致为方形,这被称为2D权重固定。例如,若E=1024, F=4096, n_chips=64,则E轴分4路,F轴分16路,每个芯片存储一个256x256的权重块。总计算成本与1D布局相同,但通信效率高得多。在连续的权重矩阵乘法中,可以交替在两个轴上执行激活聚合。通过正确的分区,每个芯片总能获得与其权重分片相乘所需的激活分片,而无需持有激活张量的完全复制副本。由于每个轴都在$O(\sqrt{n_{chips}})$上分区,通信时间缩放为$O(1/\sqrt{n_{chips}})$而非保持不变。这意味着即使2D布局在某个芯片数和批量大小下受通信限制,仍可通过增加更多芯片来继续降低延迟。然而,1D布局的技巧只需在d_model维度上聚合,而2D布局需要交替在d_modeld_ff维度上聚合。因此,当n_chips > d_ff/d_model时,2D权重固定才更具通信效率。由于通常d_ff = 4 * d_model,这发生在n_chips > 16时。

图2: 前馈层的分区布局。
图2: 前馈层的分区布局。

细节。图2(b)展示了这种分区布局。与1D权重固定布局中每个芯片处理BLE形状的all-gatherreduce-scatter不同,2D布局对d_model进行了分区,使得对d_ff分区的通信量从BLE减少到BLE/X。代价是引入了第二对reduce-scatterall-gather操作。权重的分区布局是E_xF_yz,即沿d_model维度分成X个分区,沿d_ff维度分成Y × Z个分区。激活的通信包括两次all-gatherreduce-scatter。附录A.2.1推导了最小化总通信时间的最优X, Y, Z值。假设d_ff = 4 × d_model,当$X = \sqrt{0.5 \times n_{chips}}$且$YZ = \sqrt{2 \times n_{chips}}$时,通信时间达到最小。总通信时间为:

前馈层,权重收集布局 (weight-gathered layout)

概述。在之前描述的权重固定策略中,每个芯片存储权重矩阵的一个分片,并负责将其与相应的激活分片相乘。然而,随着批量大小和序列长度的增加,输出激活的大小可能远超权重大小。此时,将激活固定在每个芯片上,而在芯片间传输权重可能更经济。对于非常大的批量大小,最好在连续的矩阵乘法之间完全保持激活固定,这要求在所有芯片间完全传输权重,我们称之为XYZ-weight-gathered。对于中等批量大小,采用一种“混合”方法,即权重和激活都沿不同轴部分传输,是更有益的。我们称这些方法为X-weight-gathered和XY-weight-gathered。

图3: 前馈层的通信量随批量大小的变化。随着批量大小(以token计)的增长,切换到在更多芯片上进行权重all-gather的布局可以最小化通信量。通信量估计基于X=Y=Z=4,d_model=16384,d_ff=65536。
图3: 前馈层的通信量随批量大小的变化。随着批量大小(以token计)的增长,切换到在更多芯片上进行权重all-gather的布局可以最小化通信量。通信量估计基于X=Y=Z=4,d_model=16384,d_ff=65536。

细节。图2(c)展示了XY-weight-gathered布局。我们选择的一个关键布局特性是,权重初始布局与2D权重固定布局中的E_xF_yz相同,这样我们可以在prefill阶段使用权重收集,在decode阶段使用权重固定。在einsum操作之前,权重张量在X和Y轴上进行all-gather,通信量为EF/Z。这相对于权重固定布局增加了通信,但作为回报,我们减少了激活的通信:激活的一对reduce-scatter/all-gather被跳过,另一对的通信量从BLE/X降至BLE/(XY)。通过调整X, Y, Z轴的相对大小,我们可以在权重通信和激活通信之间进行权衡。图3显示了随着批量大小的增长,通信最优配置如何在这些布局之间切换。图A.2展示了三种权重收集布局。现在我们展示权重收集布局的渐近扩展性。令N为权重all-gather的芯片数。总通信时间由$N = \sqrt{BL n_{chips} / F}$的选择最小化(推导见附录A.2.2),总通信时间为:


注意,BL对应于总批量大小(以token计)。权重固定布局的通信时间与BL成线性关系,而权重收集布局的通信时间与$\sqrt{BL}$成线性关系。因此,当批量大小和prefill序列长度足够大时,权重收集布局变得更优。

注意力层的分区

多查询注意力 (Multiquery Attention)。多头注意力可以以前馈层类似的方式进行并行化,用n_heads替换d_ff。但其推理会因存储和加载KV缓存而产生巨大的内存容量和带宽成本。多查询注意力【38, Shazeer, Fast transformer decoding: One write-head is all you need, 2019, arXiv】、【11, Chowdhery et al., PaLM: Scaling language modeling with Pathways, 2022, arXiv】是一个替代方案,它为查询张量生成n_heads个头,但键和值张量只有一个头,这个头在所有查询头之间共享。这使得KV缓存张量的大小减少了n_heads倍,从而减少了加载它们所花费的内存时间。

图4: 多查询注意力在按批次分片时具有更低的加载KV缓存的内存成本。
图4: 多查询注意力在按批次分片时具有更低的加载KV缓存的内存成本。

分区策略。关键设计考虑是最小化重复加载KV缓存的内存时间。我们提出一种多查询注意力的分区策略,其中Q, K, V矩阵沿批量B维度在n_chips个芯片上分区。图4(c)显示,这使得每个芯片加载KV缓存的内存成本降低了n_chips倍,从而内存时间也减少了相同倍数。与图5(a)中按头维度分区的策略相比,这种策略(图5(b))需要通过all-to-all集合操作对输入激活张量进行重分片,从而产生额外的通信成本。

图5: 注意力层分区布局比较:按头分片的多头注意力 vs. 按批次分片的多查询注意力。
图5: 注意力层分区布局比较:按头分片的多头注意力 vs. 按批次分片的多查询注意力。

不同阶段的策略选择。在自回归生成(decode)阶段,Q, K, V张量每个样本只有一个token,而KV缓存有很多token。由于KV缓存比Q, K, V张量大几个数量级,因此在小张量上花费all-to-all通信时间来节省大张量上的内存时间是非常划算的。在prefill阶段,按批次分片通常不划算。Q张量有很多token,它们都查询相同的K和V张量。K和V张量的内存加载在Q张量的所有token上被摊销,因此通常不是瓶颈。所以在prefill阶段我们使用按头分片的布局。

并行注意力/前馈层

并行Transformer块。本文讨论了PaLM【11, Chowdhery et al., PaLM: Scaling language modeling with Pathways, 2022, arXiv】中使用的“并行”Transformer块公式【47, Wang and Komatsuzaki, GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model, 2021, github】带来的推理延迟增益。与标准“串行”公式不同,并行公式中前馈层和注意力层从经过层归一化的输入并行计算,然后相加以获得输出。

并行公式的优点。首先,每层只有一个层归一化而不是两个,这在小批量下降低了延迟。其次,前馈层的输入矩阵可以与注意力层的查询投影矩阵WQ融合,键/值投影矩阵WK和WV可以融合,前馈层的输出矩阵可以与注意力层的输出投影矩阵WO融合。这种融合通过更大的矩阵乘法提高了FLOPS利用率。更重要的是,它消除了每个Transformer层中为d_ff/n_heads并行性所需的两个all-reduce操作中的一个,将该轴上的通信时间减少了一半。

底层优化

通信与计算重叠。我们使用【49, Wang et al., Overlap communication with dependent computation via decomposition in large deep learning models, 2023, ASPLOS】中的Looped CollectiveEinsum技术,使通信与计算并发进行,从而部分或完全隐藏了大部分reduce-scatterall-gather操作的通信时间。我们选择将reduce-scatter归约到隐藏维度(E或F)而非批量或序列维度(B或L),因为这为Looped CollectiveEinsum提供了更有效的机会。

最大化性能。我们开发了一套CollectiveEinsum的变体来优化不同场景,并明确地将通信集合操作与它们应该融合的矩阵乘法相匹配,以最大化重叠的可能性。通过这些优化,性能比最初的编译器分区和调度实现提高了约1.4倍。

其他优化。我们还包括了以下底层优化:更好的张量内存布局以最小化填充和复制、更快的top-k/top-p解码采样实现、更快的Softmax和Swish的log-base-2实现,以及在prefill期间对序列的增量处理支持。

量化

int8权重。我们使用AQT库【30, Lew et al., AQT: Accurate quantized training, 2022, github】将16位权重转换为int8,以减少内存成本,且没有明显的质量损失。这节省了加载权重时的内存时间,在低批量场景下尤其有帮助,并减少了权重收集布局中的通信量。

图6: 在64个芯片上,使用2D和1D权重固定布局对PaLM 540B进行文本生成的每token延迟。
图6: 在64个芯片上,使用2D和1D权重固定布局对PaLM 540B进行文本生成的每token延迟。

实验环境与结果

实验环境

  • 模型: PaLM系列模型(8B, 62B, 540B参数)。对于PaLM 540B模型,为了在64+芯片上更有效地分区,注意力头的数量从48个填充到64个,这增加了18B参数。
  • 硬件: 最多256个TPU v4芯片。每个TPU v4芯片提供275 TFLOPS的bfloat16矩阵运算能力,拥有32 GiB的高带宽内存(HBM),带宽为1200 GB/s,并具有270 GB/s的互连带宽,采用3D环面拓扑结构。
  • 软件: 推理框架基于JAX【4, Bradbury et al., JAX: Composable transformations of Python+NumPy programs, 2018, github】和XLA【50, XLA: Optimizing compiler for TensorFlow, 2019】。原始的高级实现基于T5X【1, T5x, 2021, github】。

实验结果

前馈层分区评估

  • 解码阶段 (Decode): 在批量大小为512时,随着芯片数量增加,2D权重固定布局的性能优于1D布局,因为它具有更好的渐近扩展性,避免了通信瓶颈(图6)。
  • 预填充阶段 (Prefill): 在序列长度为2048时,随着批量大小(以tokens计)从2048增加到100万,最优分区布局从2D权重固定布局切换到权重收集布局。权重收集布局在小批量时效率低下,但在大批量时成为最高效的,实现了76%的MFU,此时通信开销几乎可以忽略不计(图7)。
  • 结论: 最佳策略是在prefill阶段根据批量中的token数量从权重固定和权重收集布局中选择,在decode阶段则选择2D权重固定布局。

图7: 在64个芯片上,对序列长度为2048的PaLM 540B运行prefill的模型FLOPS利用率。随着批量大小(以token计)的增长,从2D权重固定切换到权重收集方法可以提高MFU。
图7: 在64个芯片上,对序列长度为2048的PaLM 540B运行prefill的模型FLOPS利用率。随着批量大小(以token计)的增长,从2D权重固定切换到权重收集方法可以提高MFU。

注意力层分区评估

  • 内存容量: 优化的多查询注意力(按批次分区)与多头注意力和基线多查询注意力(按头分区)相比,可以支持最多32倍的上下文长度(表1)。
  • 生成延迟: 在生成阶段,随着上下文长度增长,加载KV缓存的时间成为推理时间的主要部分。优化的多查询注意力布局显著提高了速度。在8192-32768 tokens的序列长度下,注意力只占总运行时间的8-31%(图8)。

表1: PaLM 540B不同注意力变体在64个芯片上支持的最大上下文长度。我们为KV缓存保留了30%的总内存。优化的多查询注意力支持最多32倍的上下文长度。
表1: PaLM 540B不同注意力变体在64个芯片上支持的最大上下文长度。我们为KV缓存保留了30%的总内存。优化的多查询注意力支持最多32倍的上下文长度。

图8: 在64个芯片上,批量大小为256的8层PaLM 540B版本的每生成token延迟与序列长度的关系。虚线表示在完整的118层模型和超过512的上下文长度下,使用多头注意力或基线多查询分区的KV缓存将无法装入内存。
图8: 在64个芯片上,批量大小为256的8层PaLM 540B版本的每生成token延迟与序列长度的关系。虚线表示在完整的118层模型和超过512的上下文长度下,使用多头注意力或基线多查询分区的KV缓存将无法装入内存。

并行注意力/前馈层评估

  • 在生成阶段(64芯片,批量512,2D权重固定布局),串行公式比并行版本的推理延迟高出14%,原因是激活通信时间增加。在prefill阶段,由于权重收集布局的激活通信较少,这一差异会缩小。

PaLM端到端结果

  • 延迟与成本: 图1(左)展示了生成阶段模型大小、延迟和成本(芯片-秒/token)之间的帕累托前沿。成本与参数数量成正比,降低批量大小可改善延迟但成本更高。
  • 量化效果: int8权重在低延迟目标下成本降低超过2倍,因为低批量成本主要由权重加载时间决定。例如,PaLM 540B上,批量大小64时,int8权重延迟为28.5ms/token,bfloat16为36.9ms/token。
  • Prefill vs Decode: Prefill阶段的成本比decode阶段低2倍,因为prefill阶段使用了效率更高的权重收集布局(图1右)。
  • 具体配置: 表2和表3展示了PaLM 540B和62B在低延迟和高吞吐量场景下的关键配置。低延迟场景下,prefill使用batch-1,decode使用batch 32-64。高吞吐量场景下使用更大的批次和bfloat16权重。
  • 模型大小与延迟: 随着模型规模增大,帕累托前沿上的低批量延迟呈次线性增长(近似平方根关系),因为更大的模型可以被分区到更多芯片上,从而缓解通信瓶颈。
表2: PaLM 540B示例配置
场景 权重 芯片数 批量大小 (Prefill/Decode) Prefill延迟 (ms) Decode延迟 (ms/tok) MFU (Prefill/Decode)
低延迟 int8 64 1/32 480 33.6 13%/22%
int8 128 1/64 338 20.1 9%/18%
高吞吐量 bfloat16 64 512/512 1004 58.7 73%/41%
bfloat16 128 1024/1024 1205 39.4 60%/31%
注:FFN布局为WS 2D和WG XYZ,Attention布局来自3.3节。
表3: PaLM 62B示例配置
场景 权重 芯片数 批量大小 (Prefill/Decode) Prefill延迟 (ms) Decode延迟 (ms/tok) MFU (Prefill/Decode)
低延迟 int8 16 1/32 163 11.2 13%/25%
int8 32 1/64 135 7.9 8%/21%
高吞吐量 bfloat16 16 512/512 316 17.0 76%/46%
bfloat16 32 1024/1024 362 11.1 66%/36%
注:FFN布局为WS 2D和WG XYZ,Attention布局来自3.3节。

与FasterTransformer的基准比较

  • 比较设置: 本文的实现(在64个TPU v4上)与FasterTransformer的基准(在16-32个NVIDIA A100上)进行了比较。为公平起见,使用MFU作为性能指标。
  • 性能优势: 如图9所示,本文的PaLM 540B实现达到了最佳的绝对延迟,并且在几乎所有延迟目标下,本文对Megatron模型的实现也提供了最佳的MFU。PaLM的并行注意力/FFN层是其性能优势的主要原因之一。
  • 可扩展性: FasterTransformer的32路张量并行在所有基准测试中最高达到33%的MFU,而本文的实现能够扩展到64路张量并行,同时仍能达到44%的MFU,这表明本文的2D权重固定分区策略在TPU v4的大型高速互连域上具有更优的可扩展性。

图9: 在不同批量大小下,运行60个输入token、20个输出token推理的MFU与总延迟关系图。
图9: 在不同批量大小下,运行60个输入token、20个输出token推理的MFU与总延迟关系图。

相关工作

并行化方法。先前的工作提出了多种高效分区策略来高效训练大模型,例如NeMo Megatron【26, Korthikanti et al., Reducing activation recomputation in large transformer models, 2022, arXiv】、GSPMD【51, Xu et al., GSPMD: general and scalable parallelization for ml computation graphs, 2021, arXiv】和Alpa【54, Zheng et al., Alpa: Automating inter-and intra-operator parallelism for distributed deep learning, 2022, arXiv】。FasterTransformer【16, FasterTransformer: Gpt guide, github】为多GPU多节点推理建立了一套基准测试。DeepSpeed Inference【3, Aminabadi et al., Deepspeed inference: Enabling efficient inference of transformer models at unprecedented scale, 2022, arXiv】进一步启用了ZeRO卸载以利用CPU和NVMe内存。与这些通过整数线性规划等方法搜索并行策略的工作相比,本文基于直观且有经验支持的分析性权衡来推导分区策略,以满足应用需求,并能很好地随模型大小、上下文长度和芯片数量扩展。

机器学习推理效率。许多方法致力于通过模型架构改进来提升Transformer模型的推理效率,例如高效的注意力层【36, Roy et al., Efficient content-based sparse attention with routing transformers, 2020, arXiv】、【10, Choromanski et al., Rethinking attention with performers, 2020, arXiv】、【24, Kitaev et al., Reformer: The efficient transformer, 2020, arXiv】,蒸馏【37, Sanh et al., Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter, 2019, arXiv】,以及剪枝【31, Li et al., Efficient transformer-based large scale language representations using hardware-friendly block structured pruning, 2020a, arXiv】或量化【14, Dettmers et al., Llm. int8 (): 8-bit matrix multiplication for transformers at scale, 2022, arXiv】等模型压缩技术。本文重用了模型量化的先前工作,并且本文描述的技术也可以与其他模型压缩方法相结合。

结论与展望

本文研究了Transformer推理工作负载的扩展特性,并提出了实用的分区方法,以满足严格的延迟目标等挑战性应用需求。研究表明,最佳延迟是通过将推理扩展到64个甚至更多芯片来实现的,这远超了传统的单服务器推理范式。更长的上下文长度会带来更高的内存成本,但带有适当分区的多查询注意力降低了这一成本,使长上下文推理变得可行。本文提出的分区策略可推广到多种拓扑结构,包括GPU系统中的单节点和多节点NVLink网络。尽管实现了扩展推理工作负载的目标,但FLOPs数量和通信量从根本上限制了密集Transformer模型的推理性能。未来的方向包括稀疏性技术(如基于任务的专家混合架构)和自适应计算技术,这些技术有望减少Transformer模型每个token的FLOPs,从而在成本和延迟方面取得进一步的提升。

附录

A. 分区策略:通信成本推导

A.1 all-gather/reduce-scatter的成本

通信模式。图A.1展示了我们使用的典型集合操作及其在三个设备上的通信模式。对于一个K个分区的all-gather,每个芯片产生大小为D的输出,通信模式要求大小为D/K的块在(K-1)个互连链路上被传输,以复制到(K-1)个芯片。all-gather的通信时间为:


reduce-scatter的通信时间相同,但D是输入缓冲区的大小。因此,all-reduce的总通信时间是all-gather的两倍。在大多数公式中,我们忽略(K-1)/K项,简化为1。

图A.1: 集合操作的通信模式:all-gather、reduce-scatter和all-to-all在三个设备上的示意图。
图A.1: 集合操作的通信模式:all-gather、reduce-scatter和all-to-all在三个设备上的示意图。

A.2 通信时间计算细节

A.2.1 前馈层,2D权重固定布局

布局与成本。如图2(b)所示,权重的分区布局为E_xF_yz。我们展示如何确定X, Y, Z轴的大小以最小化总通信时间。通信时间为:


假设d_ff = 4 × d_model,当$X = \sqrt{0.5 \times n_{chips}}$且$YZ = \sqrt{2 \times n_{chips}}$时,通信时间达到最小。总通信时间为:

A.2.2 前馈层,权重收集布局

布局与成本。图A.2展示了不同的权重收集布局。权重在einsum之前在X和Y轴上进行all-gather,通信量为EF/Z。令N为权重all-gather的芯片数。权重通信为:


激活通信为:

总通信通过选择$N = \sqrt{BS n_{chips} / F}$来最小化,得到总通信时间:

图3展示了最优配置如何随批量大小增长而在不同布局间切换。

图A.2: 前馈层的权重收集布局。
图A.2: 前馈层的权重收集布局。

B. 最低预填充延迟 (Prefill Latency)

成本与延迟关系。图B.1展示了在批量大小为1时,将序列长度从32扫到1024的prefill成本与延迟的帕累托前沿。

图B.1: PaLM模型在一系列序列长度S下的预填充成本与延迟关系。C表示芯片数量。
图B.1: PaLM模型在一系列序列长度S下的预填充成本与延迟关系。C表示芯片数量。

C. MFU与延迟的权衡

MFU与延迟关系。图C.1展示了MFU与延迟的帕累托前沿,其设置与图1相同。解码的MFU通常远低于预填充。在prefill阶段,MFU的“跳跃”显示了从2D权重固定布局到XYZ权重收集布局的转换点。在多数情况下,大模型能达到更高的MFU。

图C.1: PaLM模型的MFU与延迟关系。我们使用2048的上下文长度。每条线上的点代表效率与延迟的帕累托前沿。C为芯片数,B为批量大小。左图:生成64个token的每token延迟。右图:处理2048个输入token的时间。对应的成本与延迟数据见图1。
图C.1: PaLM模型的MFU与延迟关系。我们使用2048的上下文长度。每条线上的点代表效率与延迟的帕累托前沿。C为芯片数,B为批量大小。左图:生成64个token的每token延迟。右图:处理2048个输入token的时间。对应的成本与延迟数据见图1。

D. 与FasterTransformer的完整比较

模型架构差异。表D.1列出了PaLM和Megatron-Turing NLG在推理时的超参数差异。
完整结果。表D.2至D.4报告了在FasterTransformer基准测试的三种配置(20输入/8输出,60输入/20输出,128输入/8输出)下的完整比较结果。在这些表格中,加粗字体表示所有500B级模型结果中的帕累托前沿,下划线表示MT-NLG模型结果中的帕累托前沿。

表D.1: PaLM和Megatron-Turing NLG推理的超参数。
表D.1: PaLM和Megatron-Turing NLG推理的超参数。

表D.2: 20输入token,8输出token基准测试的结果。所有时间单位为毫秒。
表D.2: 20输入token,8输出token基准测试的结果。所有时间单位为毫秒。

表D.3: 60输入token,20输出token基准测试的结果。所有时间单位为毫秒。
表D.3: 60输入token,20输出token基准测试的结果。所有时间单位为毫秒。

表D.4: 128输入token,8输出token基准测试的结果。所有时间单位为毫秒。
表D.4: 128输入token,8输出token基准测试的结果。所有时间单位为毫秒。