作者/机构: Joshua Ainslie∗, James Lee-Thorp∗, Michiel de Jong∗ Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai (Google Research)

A1 主要贡献

自回归解码器的推理过程是Transformer模型的一个严重瓶颈,主要因为在每个解码步骤都需要加载解码器权重以及所有的注意力键(keys)和值(values),这带来了巨大的内存带宽开销【Shazeer, 2019, Fast transformer decoding: One write-head is all you need, arXiv preprint】。多查询注意力(Multi-query attention, MQA)通过使用多个查询头(query heads)但共享单一的键和值头,可以显著减少加载键和值所带来的内存带宽。然而,MQA可能会导致模型质量下降和训练不稳定,并且为追求更快的推理速度而专门训练一个独立的模型在实践中可能并不可行。此外,尽管像PaLM【Chowdhery et al., 2022, Palm: Scaling language modeling with pathways】这样的一些语言模型已经采用了MQA,但许多公开可用的模型,如T5【Raffel et al., 2020, Exploring the limits of transfer learning with a unified text-to-text transformer, JMLR】和LLaMA【Touvron et al., 2023, Llama: Open and efficient foundation language models】,仍然使用多头注意力(Multi-head attention, MHA)。

针对以上问题,本文提出了两大贡献以加速大型语言模型的推理速度:
1. 提出了一种增量训练(uptraining)方法:该方法能够将已有的多头注意力(MHA)语言模型检查点,以原始预训练计算成本的一小部分(例如5%),转换为使用多查询注意力(MQA)的模型。这为同时获得高质量的MHA检查点和快速推理的MQA模型提供了一种经济高效的途径。
2. 提出了分组查询注意力(Grouped-query attention, GQA):GQA是MHA和MQA之间的一种插值方案。它将查询头分组,每组共享一个键和值头,从而使用的键值头数量介于MQA(1个)和MHA(与查询头数量相等)之间。实验证明,经过增量训练的GQA模型在质量上接近MHA,同时在推理速度上与MQA相当。

A2 方法细节

2.1 增量训练(Uptraining)

从多头模型生成多查询模型的两步法。这个过程分为两个步骤:首先,转换检查点;其次,进行额外的预训练,让模型适应其新结构。图1展示了将多头检查点转换为多查询检查点的过程。具体操作是将所有头的键(key)和值(value)的投影矩阵进行均值池化(mean pooled),融合成单一的投影矩阵。我们发现,这种方法比从多个头中选择单一的键和值头,或从头开始随机初始化新的键和值头效果更好。

图1:从多头注意力到多查询注意力的转换概览。所有头的键和值投影矩阵被均值池化成一个单头。
图1:从多头注意力到多查询注意力的转换概览。所有头的键和值投影矩阵被均值池化成一个单头。

增量预训练。转换后的检查点会使用与原始训练相同的预训练方案,在其原始训练步数的α比例上进行进一步的预训练。

2.2 分组查询注意力(Grouped-query attention)

GQA的定义与转换。分组查询注意力将查询头(query heads)分为G个组,每个组共享一个单一的键头(key head)和值头(value head)。GQA-G指的是有G个分组的分组查询。其中,GQA-1(只有一个组,因此只有一个键和值头)等同于MQA;而GQA-H(分组数等于头数)等同于MHA。图2展示了分组查询注意力与多头/多查询注意力的比较。当将一个多头检查点转换为GQA检查点时,我们通过对该组内所有原始头进行均值池化来构建每个组的键和值头。

图2:分组查询方法概览。多头注意力有H个查询、键和值头。多查询注意力在所有查询头之间共享单个键和值头。分组查询注意力则为每组查询头共享单个键和值头,从而在多头和多查询注意力之间进行插值。
图2:分组查询方法概览。多头注意力有H个查询、键和值头。多查询注意力在所有查询头之间共享单个键和值头。分组查询注意力则为每组查询头共享单个键和值头,从而在多头和多查询注意力之间进行插值。

GQA作为MHA和MQA之间的权衡。一个中间的分组数量可以得到一个插值模型,其质量高于MQA但速度快于MHA,我们后续将证明这代表了一个有利的权衡。从MHA到MQA将H个键和值头减少到单个键和值头,从而将键值缓存(key-value cache)的大小以及需要加载的数据量减少了H倍。然而,更大的模型通常会扩展头的数量,因此MQA在内存带宽和模型容量上都代表了更激进的削减。GQA则允许我们随着模型尺寸的增加,保持带宽和容量的同比例减少。

GQA对大模型的额外优势。此外,对于更大的模型,来自注意力的内存带宽开销相对较小,因为KV缓存随模型维度线性扩展,而模型的FLOPs和参数则随模型维度的平方扩展。最后,针对大型模型的标准分片技术(standard sharding)会根据模型分区的数量复制单个键和值头【Pope et al., 2022, Efficiently scaling transformer inference, arXiv preprint】;GQA通过分组避免了这种分区带来的浪费。因此,我们预期GQA对于更大的模型会是一个特别好的权衡方案。

GQA不适用于编码器。我们注意到,GQA并未应用于编码器的自注意力层;因为编码器的表示是并行计算的,所以内存带宽通常不是其主要瓶颈。

A4 实验环境

  • 模型架构:所有模型均基于T5.1.1架构【Raffel et al., 2020, Exploring the limits of transfer learning with a unified text-to-text transformer, JMLR】。实验主要考察了使用多头注意力的T5 Large和T5 XXL模型,以及经过增量训练的使用多查询和分组查询注意力的T5 XXL模型。MQA和GQA应用于解码器的自注意力和交叉注意力层,不用于编码器自注意力层。
  • 软件配置:使用JAX【Bradbury et al., 2018, JAX: composable transformations of Python+NumPy programs】、Flax【Heek et al., 2020, Flax: A neural network library and ecosystem for JAX】和Flaxformer实现。优化器采用Adafactor,其超参数和学习率调度与原始T5模型相同。
  • 硬件配置:增量训练在TPUv3上进行,大约花费了600个TPUv3芯片日。推理时间测试在8个TPUv4芯片上进行,使用xprof工具【Google, 2020, Profile your model with cloud tpu tools】测量每个样本在单个TPUv4芯片上的耗时。
  • 数据集
    • 摘要任务:CNN/Daily Mail【Nallapati et al., 2016, Abstractive text summarization using sequence-tosequence rnns and beyond, CoNLL】、arXiv和PubMed【Cohan et al., 2018, A discourse-aware attention model for abstractive summarization of long documents, NAACL】、MediaSum【Zhu et al., 2021, Mediasum: A large-scale media interview dataset for dialogue summarization, NAACL-HLT】和Multi-News【Fabbri et al., 2019, Multi-news: A large-scale multi-document summarization dataset and abstractive hierarchical model, ACL】。
    • 翻译任务:WMT 2014英译德。
    • 问答任务:TriviaQA【Joshi et al., 2017, Triviaqa: A large scale distantly supervised challenge dataset for reading comprehension, ACL】。
  • 训练设置
    • 增量训练:从公开的T5.1.1检查点初始化,键和值头通过均值池化转换为MQA或GQA结构,然后在原始预训练数据集上继续预训练原始步数的5%(α = 0.05)。
    • 微调:所有任务均使用0.001的恒定学习率、128的批量大小和0.1的dropout率。输入/输出长度根据任务不同而变化(例如,CNN/Daily Mail为512/256,TriviaQA为2048/32)。采用贪心解码进行推理。

A4 实验结果

主要结果

表1:T5 Large和XXL模型(多头注意力)与5%增量训练的T5-XXL模型(多查询和分组查询注意力)在摘要、翻译和问答数据集上的推理时间和平均开发集性能比较。
表1:T5 Large和XXL模型(多头注意力)与5%增量训练的T5-XXL模型(多查询和分组查询注意力)在摘要、翻译和问答数据集上的推理时间和平均开发集性能比较。

图3展示了MHA T5-Large、MHA T5-XXL以及增量训练比例α=0.05的MQA和GQA-8 XXL模型在所有数据集上的平均性能与平均推理时间的关系。实验结果表明,一个更大的、经过增量训练的MQA模型相比于MHA模型提供了一个更有利的权衡,其质量和推理速度均优于MHA-Large模型。更重要的是,GQA模型在此基础上实现了显著的质量提升,其性能接近MHA-XXL,而速度则接近MQA。表1中包含了所有数据集的完整结果。

图3:增量训练的MQA模型相比MHA-Large模型,在质量和速度上都更有优势;GQA模型则以相似的速度增益实现了与MHA-XXL相当的质量。图中展示了T5-Large和T5-XXL(多头注意力)以及5%增量训练的T5-XXL(MQA和GQA-8)在所有任务上的平均性能与平均推理时间的关系。
图3:增量训练的MQA模型相比MHA-Large模型,在质量和速度上都更有优势;GQA模型则以相似的速度增益实现了与MHA-XXL相当的质量。图中展示了T5-Large和T5-XXL(多头注意力)以及5%增量训练的T5-XXL(MQA和GQA-8)在所有任务上的平均性能与平均推理时间的关系。

消融实验

本节在三个有代表性的任务子集上进行实验:CNN/Daily Mail(短篇摘要)、MultiNews(长篇摘要)和TriviaQA(问答),以研究不同建模选择的影响。

检查点转换方法。图4比较了不同检查点转换方法的性能。结果显示,均值池化(Mean)的效果最好,其次是选择第一个头(First),最差的是随机初始化(Random)。从直观上看,结果的排序与从预训练模型中保留信息的程度成正比。

图4:T5-Large模型在增量训练为MQA(α=0.05)时不同检查点转换方法的性能比较。“Mean”指均值池化键和值头,“First”指选择第一个头,“Random”指从头随机初始化。
图4:T5-Large模型在增量训练为MQA(α=0.05)时不同检查点转换方法的性能比较。“Mean”指均值池化键和值头,“First”指选择第一个头,“Random”指从头随机初始化。

增量训练步数。图5展示了T5 XXL模型在使用MQA和GQA时,性能如何随增量训练比例的变化而变化。首先,GQA在转换后(比例为0)就已经达到了合理的性能,而MQA需要经过增量训练才能变得有效。MQA和GQA都从5%的增量训练中获益,而增加到10%时收益递减。

图5:T5 XXL模型在使用MQA和GQA-8时,性能随增量训练比例的变化情况。
图5:T5 XXL模型在使用MQA和GQA-8时,性能随增量训练比例的变化情况。

分组数量。图6展示了GQA分组数量对推理速度的影响。对于更大的模型,KV缓存带来的内存带宽开销约束较小,同时由于头数增加,键值大小的缩减更为明显。因此,将分组数从1(MQA)增加时,最初只会导致适度的速度下降,但随着分组数接近MHA,成本会越来越高。我们选择8个分组作为有利的折中点。

图6:GQA-XXL模型在输入长度2048和输出长度512时,每个样本的推理时间随GQA分组数量的变化。从1个组(MQA)增加到8个组只会带来适度的推理开销,而继续增加分组的成本会越来越高。
图6:GQA-XXL模型在输入长度2048和输出长度512时,每个样本的推理时间随GQA分组数量的变化。从1个组(MQA)增加到8个组只会带来适度的推理开销,而继续增加分组的成本会越来越高。

A5 结论

大型语言模型在推理时成本高昂,主要原因是加载键和值(keys and values)时产生的内存带宽开销。多查询注意力(MQA)通过减少这种开销来降低成本,但代价是模型容量和质量的下降。本文提出了一种方法,可以用原始预训练计算成本的一小部分,将多头注意力(MHA)模型转换为多查询模型。此外,我们引入了分组查询注意力(GQA),它是MQA和MHA的一种插值方法,能够在保持与MQA相当的推理速度的同时,实现接近MHA的模型质量。

局限性

  • 本文专注于缓解加载键和值所带来的内存带宽开销,这一开销在生成长序列时最为关键,而长序列的质量本身就难以评估。例如,我们使用ROUGE分数来评估摘要任务,但众所周知这是一种有缺陷的评估方法,无法完全反映真实情况,因此很难确定我们所做的权衡是否正确。
  • 由于计算资源有限,我们没有将增量训练的XXL GQA模型与一个从头开始训练的同类模型进行比较,因此我们不知道增量训练与从头训练相比的相对性能。
  • 我们仅在编码器-解码器模型上评估了增量训练和GQA的影响。近来,仅解码器模型非常流行,由于这些模型没有独立的自注意力和交叉注意力,我们预期GQA相比MQA在这些模型上会有更强的优势。

A6 附录

A 训练稳定性

MQA在微调中的不稳定性。我们发现多查询注意力(MQA)在微调期间可能导致训练不稳定,尤其是在与长输入任务结合时。我们从头开始训练了多个使用MQA的T5-Large模型。在每种情况下,预训练过程都遭受了频繁的损失尖峰(loss spikes),并且最终模型在对长输入任务进行微调时立即发散(diverged)。

增量训练的改善及GQA的稳定性。经过增量训练的MQA模型更为稳定,但仍然表现出高方差。因此,对于在不稳定任务上的MQA模型,我们报告了三次微调运行的平均性能。然而,经过增量训练的分组查询注意力(GQA)模型似乎是稳定的,所以我们没有进一步探究MQA不稳定的根本原因。