Fast Inference from Transformers via Speculative Decoding

文章标题:通过推测解码实现Transformer的快速推理
作者:Yaniv Leviathan, Matan Kalman, Yossi Matias
机构:Google

A1 主要贡献

本文旨在解决大型自回归模型(如Transformer)推理速度慢的问题。这类模型的解码过程是串行的,即解码K个token需要K次模型的串行运行。现有加速方法通常需要改变模型架构、重新训练,并且无法保证输出与原模型完全一致。

核心问题与研究目标
本文的核心观察是:(1)困难的语言建模任务通常包含更容易的子任务,这些子任务可以被更高效的小模型很好地近似;(2-)推理过程的瓶颈往往是内存带宽而非算术运算,这意味着额外的计算资源可能是可用的。基于此,本文提出了一种名为推测解码(speculative decoding)的算法,旨在不改变模型架构、无需重新训练、且不改变模型输出分布的前提下,通过增加并发性来加速推理。

创新点与核心贡献
1. 随机环境下的推测执行:本文将计算机体系结构中常见的推测执行(speculative execution)技术推广到了随机设置中。为此,作者提出了一种新颖的采样方法,称为推测采样(speculative sampling)
2. 推测解码机制:该机制利用一个高效的近似模型(approximation model)生成若干个候选token(即推测性前缀),然后让速度较慢的目标模型(target model)并行地对这些候选token进行验证。推测采样方法能够最大化这些推测任务被接受的概率,同时严格保证最终的输出分布与单独使用目标模型完全相同。通过这种方式,目标模型的一次并行运行可能一次性生成多个token,从而显著减少所需的串行解码步骤。
3. 实用性与效果:该方法无需修改现有模型或重新训练,可以直接应用于现成的(off-the-shelf)模型。实验证明,在T5-XXL(11B参数)模型上,该方法相比标准的T5X实现取得了2到3倍的加速,且输出结果完全相同。下图展示了该技术在一个无条件语言建模任务中的应用实例,其中38个token的生成仅需9次目标模型的串行运行。

图1. 本文技术在无条件语言建模案例中的图示。每一行代表算法的一次迭代。绿色token是近似模型(一个6M参数的GPT类Transformer)提出的建议,并被目标模型(一个97M参数的GPT类Transformer)接受。红色和蓝色token分别是-被拒绝的建议及其修正。例如,在第一行中,目标模型仅运行一次,就生成了5个token。
图1. 本文技术在无条件语言建模案例中的图示。每一行代表算法的一次迭代。绿色token是近似模型(一个6M参数的GPT类Transformer)提出的建议,并被目标模型(一个97M参数的GPT类Transformer)接受。红色和蓝色token分别是-被拒绝的建议及其修正。例如,在第一行中,目标模型仅运行一次,就生成了5个token。

A2 方法细节

2.1 概述

核心思想:本文的核心思想是利用一个高效的近似模型和一个较慢但更强大的目标模型协同工作来加速解码。设Mp为我们试图加速其推理过程的目标模型,其在给定前缀x<t的情况下输出的分布为p(xt|x<t)。设Mq为一个针对相同任务的更高效的近似模型,其输出的分布为q(xt|x<t)。推测解码的步骤如下:(1) 使用更高效的模型Mq生成γ个候选补全(γ是一个正整数,其最优选择将在3.5节讨论);(2) 使用目标模型Mp并行地评估所有这些来自Mq的猜测及其各自的概率,并接受所有能够导向相同分布的猜测;(3) 从一个调整后的分布中采样一个额外的token,用于修正第一个被拒绝的猜测,或者在所有猜测都被接受的情况下添加一个新的token。通过这种方式,目标模型Mp的每次并行运行至少会产生一个新token(因此,即使在最坏的情况下,目标模型的串行运行次数也不会超过标准的自回归方法),但它有可能生成多达γ + 1个新token,具体数量取决于MqMp的近似程度。

2.2 标准化采样

采样方法的统一:尽管存在多种采样方法和参数,如argmax、top-k、nucleus以及温度(temperature)设置,并且流行的实现通常在logits层面以不同方式处理它们,但所有这些方法都可以被轻松地转换为从一个调整后的概率分布中进行标准采样。例如,argmax采样等价于将分布中非最大值的元素置零然后进行归一化。因此,我们可以只处理从概率分布中进行标准采样的问题,并将所有其他类型的采样都归入该框架下。在下文中,我们假设p(x)q(x)分别是经过采样方法调整后,来自模型MpMq的分布。

2.3 推测采样

采样机制:为了从p(x)分布中采样一个x,我们首先从q(x)分布中采样一个x。如果满足q(x) ≤ p(x),我们就保留这个样本。如果q(x) > p(x),我们以1 - p(x)/q(x)的概率拒绝该样本,并从一个调整后的分布p'(x) = norm(max(0, p(x) - q(x)))中重新采样x。可以证明(见附录A.1),对于任何分布p(x)q(x),通过这种方式采样的x确实服从p(x)分布。

并行化与算法:给定从Mq模型在某个条件前缀上运行得到的分布q(x),我们可以从中采样一个token x1。然后,我们通过在prefix上运行Mp来计算分布p(x),同时并行地、推测性地在prefix + [x1]上运行Mp来计算下一个token x2的分布。一旦这两个计算完成,我们便按照上述逻辑进行:如果x1被拒绝,我们就丢弃对x2的计算,并从调整后的分布中重新采样x1;如果x1被接受,我们则保留这两个token。算法1将这个思想推广到一次性采样1到γ + 1个token。

算法1 推测解码步骤

算法1 推测解码步骤
算法1 推测解码步骤

A3 分析

3.1 生成的Token数量

期望Token数分析:我们分析该算法在目标模型串行调用次数上的减少因子,这等价于分析算法1单次运行所产生的期望token数量。首先定义接受率βx:给定前缀x<tβx是根据2.3节的推测采样方法接受从q(xt|x<t)采样的xt的概率。E(β)是衡量Mq近似Mp程度的一个自然指标。如果我们做一个简化假设,即各个β是独立同分布的(i.i.d.),并记α = E(β),那么算法1单次运行产生的token数量是一个有上限的几何变量,其成功概率为1 - α,上限为γ + 1。因此,算法1生成的期望token数满足以下公式(1),如图2所示。

公式(1)
公式(1)

图2. 算法1生成的期望token数量作为α的函数,展示了不同γ值下的情况。
图2. 算法1生成的期望token数量作为α的函数,展示了不同γ值下的情况。

3.2 计算α

α的计算公式:现在我们为α推导一个简单的计算公式。首先定义一个自然散度DLKDLK(p, q) = Σx |p(x) - M(x)|,其中M(x) = (p(x) + q(x))/2。基于此,我们有以下引理:

  • 引理3.3: DLK(p, q) = 1 - Σx min(p(x), q(x))
    • 证明: DLK(p, q) = Σx |p(x) - M(x)| = Σx |(p-q)/2| = (1/2) * (Σx |p-q|) = (1/2) * Σx (p+q - |p-q|) = 1 - Σx min(p(x), q(x))

从引理3.3可以立即得到以下推论:

  • 推论3.4: DLK(p, q)是一个在[0, 1]区间内的对称散度。DLK(p, q) = 0当且仅当p = qDLK(p, q) = 1当且仅当pq的支撑集不相交。

接受率βDLK散度之间的关系由以下定理给出:

  • 定理3.5: β = 1 - DLK(p, q)

最后,我们得到α的计算公式:

  • α = E(β) = E(Σx min(p(x), q(x)))

根据定义,β是接受一个从q(x)中采样的token的概率,即Ex∼q(x) min(1, p(x)/q(x))。展开后可以得到:
Ex∼q(x) min(1, p(x)/q(x)) = Σx q(x) min(1, p(x)/q(x)) = Σx min(q(x), p(x))
因此,α的最终计算公式为α = E[Σx min(p(x), q(x))]。表3中展示了在我们的实验中凭经验观察到的α值。

3.3 壁钟时间(Walltime)改进

改进因子分析:我们已经证明,在独立同分布的假设下,我们的算法将目标模型的调用次数减少了(1 - α^(γ+1)) / (1 - α)倍。推测执行通常假设有足够的计算资源来支持增加的并发性(见3.4节)。在进行壁钟时间分析时,我们假设可以并行运行γ + 1Mp的评估而不会增加壁钟时间。为了得到总的壁钟时间改进,我们现在需要考虑运行近似模型Mq的成本。我们定义成本系数cMq单次运行时间与Mp单次运行时间之比。c的值取决于硬件配置和软件实现细节,在我们的实验中,c通常小于0.05。

  • 定理3.8: 算法1在总壁钟时间上的期望改进因子为 (1 - α^(γ+1)) / ((1 - α)(γc + 1))

    • 证明: 记Mp单步运行成本为T。算法1每次运行的成本是T * c * γ + T(运行Mq γ次和Mp一次),平均产生 (1 - α^(γ+1)) / (1 - α) 个token。因此,使用算法1生成一个token的期望成本为 T * (cγ + 1)(1 - α) / (1 - α^(γ+1))。由于标准解码算法生成一个token的成本是T,两相比较即可得到结果。
  • 推论3.9: 如果α > c,那么存在一个γ值可以带来改进,并且改进因子至少为 (1 + α) / (1 + c)

    • 证明: 如果对于某个γ有改进,那么对于任何0 < γ* < γ也一定有改进。因此,要判断我们的方法是否能带来改进,可以评估当γ = 1时定理3.8的公式,得到 (1 - α^2) / ((1 - α)(c + 1)) = (1 + α) / (1 + c)。如果这个值大于1,即α > c,则方法有效。

3.4 算术运算数量

运算量分析:算法1并行运行γ + 1Mp,因此并发算术运算的数量增加了γ + 1倍。由于算法1每次运行最多产生γ + 1个token,总的算术运算量可能高于标准解码算法。当我们接受来自Mq的样本时,增加的并发性是“免费的”,总运算量没有增加。但是,当我们拒绝一个猜测时,计算就被浪费了。我们定义ĉ为近似模型Mq每个token的算术运算量与目标模型Mp的比值。

  • 定理3.11: 算法1导致的总运算量期望增加因子为 (γĉ + γ + 1) / ( (1 - α^(γ+1)) / (1 - α) )
    • 证明: 记为标准解码基线每个token的算术运算量(即Mp单次运行的运算量)。那么算法1单次迭代的成本是T̂ * ĉ * γ + T̂ * (γ + 1)γMq运行和γ + 1次并行的Mp运行)。将此成本除以算法1产生的期望token数(公式1),再除以,即可得到结果。

如果α值较低,算术运算量的增加会很高,反之亦然。值得注意的是,对于Transformer解码器,算法1的总算术运算量(不计Mq的运行)的上限可以被同等大小的Transformer编码器单次运行的运算量所界定。与总算术运算量不同,总内存访问次数可能通过我们的方法减少。具体来说,目标模型的权重和KV缓存可以在算法1的每次执行中只读取一次,因此读取它们的内存访问次数减少了(1 - α^(γ+1)) / (1 - α)倍。

3.5 选择γ

最优γ的选择:在给定cα并假设有足够计算资源的情况下,最优的γ是使壁钟时间改进公式(定理3.8)最大化的那个值:(1 - α^(γ+1)) / ((1 - α)(γc + 1))。由于γ是整数,可以通过数值方法轻松找到,如图3所示。表1和图4展示了在不同αγ值下,推理速度与总算术运算量之间的权衡(假设c = ĉ = 0)。图5展示了一个简化的执行追踪图。

图3. 最优γ作为α的函数,展示了不同c值下的情况。
图3. 最优γ作为α的函数,展示了不同c值下的情况。

动态γ的可能性:除了根据α选择一个固定的γ值外,由于β不是恒定的,我们还可以通过预测β的值并相应地在算法1的运行期间动态改变γ来获得进一步的改进。为了得到额外改进因子的上界,假设我们有一个γ的神谕(oracle)。那么我们期望的生成token数为1 / (1 - α)。对于典型的cα值,并假设计算资源无限,增强后的壁钟时间改进因子可以比固定γ的改进因子高出约60%。这部分工作留待未来探索。

表1. 在不同γ和α值下,总算术运算量和推理速度与基线的对比,假设c = ĉ = 0。
表1. 在不同γ和α值下,总算术运算量和推理速度与基线的对比,假设c = ĉ = 0。

图4. 在不同γ值下,加速因子和算术运算量增加量作为α的函数。
图4. 在不同γ值下,加速因子和算术运算量增加量作为α的函数。

图5. 一个完整的编码器-解码器Transformer堆栈的简化追踪图。顶行显示了γ = 7的推测解码,因此每次调用Mp(紫色块)之前都有7次调用Mq(蓝色块)。左侧的黄色块是Mp编码器的调用,橙色块是Mq编码器的调用。同样,中间行显示了γ = 3的推测解码,底行显示了标准解码。
图5. 一个完整的编码器-解码器Transformer堆栈的简化追踪图。顶行显示了γ = 7的推测解码,因此每次调用Mp(紫色块)之前都有7次调用Mq(蓝色块)。左侧的黄色块是Mp编码器的调用,橙色块是Mq编码器的调用。同样,中间行显示了γ = 3的推测解码,底行显示了标准解码。

3.6 近似模型

近似模型的选择:推测采样和推测解码保证了对于任何近似模型Mq的选择,输出分布都保持一致。在实验中,我们主要测试了现成的、较小的Transformer作为近似模型。我们选择的近似模型与目标模型Mp具有相同的架构,并使用相同的概率标准化方法。在这种设置下,选择比Mp小大约两个数量级的Mq通常表现最好,因为它在αc之间取得了良好的平衡。

其他类型的近似模型
* 可忽略成本模型:这类模型的c ≈ 0。在这种情况下,期望的壁钟时间改进为(1 - α^(γ+1)) / (1 - α),其上界为1 / (1 - α)。一个有趣的例子是n-gram模型,其评估仅相当于查表。在经验测试中,即使是这些简单的n-gram模型也能获得非零的α。例如,在英德翻译任务中,当Mp为T5-XXL 11B而Mq为一个简单的二元模型(bigram model)时,我们得到α ≈ 0.2,这在γ = 3时能带来1.25倍的推理速度提升。
* 启发式模型:其他简单的启发式方法也可以用作可忽略成本的近似模型。例如,在长序列很可能重复的情况下(如摘要任务或聊天界面),一个简单的近似模型可以从上下文中找到匹配的前缀并直接复制token,这可能会产生很高的α值。
* 非自回归模型:推测解码也可以使用非自回归模型【Stern et al., 2018】作为近似模型。这样,我们只需调用一次非自回归模型,而不是算法1中的自回归循环。
* 随机模型:从理论角度看,一个随机选择token的近似模型也能保证对所有Mp模型带来一些(尽管非常小)的改进。

A4 实验

实验环境

  • 模型架构与参数
    • T5: 目标模型Mp为T5-XXL v1.1(11B参数)。近似模型Mq测试了T5-large(800M)、T5-base(250M)和T5-small(77M)。所有模型均使用预训练检查点。
    • GPT-like: 目标模型Mp是一个97M参数的类GPT解码器。近似模型Mq是一个6M参数的Transformer解码器,以及简单的unigram和bigram模型。
    • LaMDA: 目标模型Mp为LaMDA(137B参数)。近似模型Mq测试了LaMDA 8B、LaMDA 2B和LaMDA 100M。
  • 数据集与任务
    • T5: 任务(1)在WMT EnDe上微调的英德翻译;任务(2)在CNN/DM上微调的文本摘要。
    • GPT-like: 在lm1b数据集上训练的无条件语言生成。
    • LaMDA: 一个对话任务。
  • 硬件配置
    • 所有壁钟时间测量均在单块TPU-v4上进行。
  • 软件配置
    • 算法实现在T5X代码库中完成,并与其中的基线实现进行比较。
    • 批量大小(batch size)为1。
    • 测试了两种采样设置:argmax采样(temperature=0)和标准采样(temperature=1)。

实验结果

A4.1 实际壁钟时间改进

我们实现了推测解码算法,并将其与T5X代码库中用于加速T5-XXL的实现进行了比较。实验结果如表2所示。
* 核心发现: 实验表明,T5-small(77M)在成本系数c和接受率α之间取得了最佳平衡,为T5-XXL(11B)提供了最高的加速比。
* 具体数据:
* 在英德翻译任务中,当温度为1时获得2.6倍加速,温度为0(argmax采样)时获得3.4倍加速。
* 在文本摘要任务中,加速比略低,温度为1时为2.3倍,温度为0时为3.1倍
* 结论: 实验结果与理论预测吻合良好,存在的微小差异主要源于实现细节的优化差异(见附录A.3)。argmax采样(temp=0)由于其生成的概率分布更尖锐,通常能获得更高的α值,从而带来更大的加速。

表2. 加速T5-XXL 11B模型推理的实证结果。
表2. 加速T5-XXL 11B模型推理的实证结果。

A4.2 经验α值

为了更广泛地验证方法的有效性,我们在多种任务、模型和采样设置下测量了α值(基于10K个由Mp生成的token进行评估)。
* 核心发现:
* 当近似模型Mq比目标模型Mp小大约两个数量级时,通常能产生0.5到0.9之间的α值。
* 调整后的概率分布越“尖锐”(例如argmax采样),α值越高。
* 即使是简单的n-gram近似模型也能产生不可忽略的α值。例如,在英德翻译任务中,使用bigram模型作为Mqα值约为0.2。由于c在这种情况下几乎为0,这能带来1.25倍的速度提升,对于如此简单的近似模型来说,这个结果令人惊讶。
* 具体数据: 表3总结了在不同模型和任务下的α值。例如,在LaMDA 137B模型上,使用LaMDA 2B作为近似模型,在标准采样(T=1)下α值为0.82,在argmax采样(T=0)下α值高达0.9。

表3. 各种目标模型Mp、近似模型Mq和采样设置下的经验α值。T=0和T=1分别表示argmax采样和标准采样。
表3. 各种目标模型Mp、近似模型Mq和采样设置下的经验α值。T=0和T=1分别表示argmax采样和标准采样。

A5 结论与讨论

本文介绍了推测采样,它使得在随机设置下进行高效的推测执行成为可能。我们分析了它通过推测解码对自回归模型(如Transformer)解码过程的影响,并证明在有足够计算资源的情况下,与流行的优化实现T5X相比,我们在实践中获得了显著的2到3倍的加速。

局限性:推测执行的一个普遍局限性是通过增加并发性来换取延迟的降低,但这会以增加算术运算总量为代价。因此,对于没有额外可用计算资源的配置,我们的方法没有帮助。

优势与应用场景:然而,在内存带宽是瓶颈而计算资源有富余的常见情况下,我们的方法提供了显著的优势:
1. 无需修改模型架构
2. 无需重新训练
3. 最重要的是,输出分布保证与原模型完全相同
该方法易于实现,可用于加速现成的模型,而无需开发和评估定制方案。

未来工作展望
1. 与集束搜索(Beam Search)的兼容性:进一步研究将推测解码与集束搜索相结合的可能性(见附录A.4)。
2. 定制近似模型:虽然使用现成的近似模型已取得显著加速,但通过定制近似模型(如定制架构、非自回归模型或各种启发式方法)或定制训练过程(如从Mp进行知识蒸馏或直接优化Mq以提高α)可能会获得更大的改进。
3. 分层推测解码:探索算法的分层版本,即近似模型本身也由一个更快的模型来加速,这可能允许使用能力更强的近似模型。
4. 动态调整:在推理过程中动态改变近似模型和猜测数量γ可能会带来额外的性能提升。
5. 跨领域应用:将推测解码应用于文本以外的其他领域(如图像生成)。
6. 更广泛的应用:推测采样本身可以用于自回归模型解码之外的场景。例如,在物理模拟或强化学习中,当一个慢函数f(x)生成一个分布,而另一个慢函数g(y)的输入从该分布中采样时,我们的方法可以用来并行运行fg

A6 附录

A.1 推测采样的正确性证明

证明过程:我们将证明对于任何分布p(x)q(x),通过推测采样得到的token分布与直接从p(x)采样的分布是相同的。设β为接受概率。调整后的分布p'(x)的归一化常数为1 - β

公式
公式

其中:
x被接受的概率由以下公式给出:

公式
公式

x被拒绝后从调整分布中采样的概率为:

公式
公式

将两者相加,我们得到最终的概率:

公式
公式

这证明了最终的采样概率等于p(x),符合我们的要求。

A.2 推测采样与拒绝采样的比较

拒绝采样的过程:拒绝采样是一个表面上与我们的方法相似的迭代采样过程:
1. 从q(x)中采样x,从均匀分布U(0, 1)中采样r
2. 如果r < p(x) / (M * q(x)),则返回x。其中M = max_x(p(x)/q(x))
3. 回到步骤1。

效率对比:如果我们采用拒绝采样的非迭代版本(即步骤1和2失败后直接从p(x)采样),其效率会远低于我们的方法。具体来说,拒绝采样的期望接受概率为 Ex∼q(x)[p(x) / (M * q(x))] = Σx p(x) / M ≤ Σx min(p(x), q(x)) = α。这个接受概率可能远低于我们方法中的接受概率α

A.3 理论预测与实际运行时间的对比

理论与实践的差异:表4比较了基于定理3.8的期望运行时间改进与表2中凭经验测量的运行时间。我们根据性能分析器(profiler)的追踪数据估算了各种模型的c值。可以看出,理论预测在很大程度上与测量的运行时间相符。较大的差异主要由两个原因造成:(1)我们的实现与基线实现之间的优化差异;(2)β值是独立同分布的这一简化假设只是一个近似。

表4. 期望改进因子(EXP)与凭经验测量的改进因子(EMP)的对比。
表4. 期望改进因子(EXP)与凭经验测量的改进因子(EMP)的对比。

应用思路:我们的方法可以在付出一些性能代价的情况下应用于集束搜索采样。给定原始的集束宽度w,我们可以使用近似模型Mq和更宽的集束u ≥ w执行γ步集束搜索。然后,我们可以使用Mp并行地检查所有候选序列(计算成本为(w + uγ)Mp的运行)。最后,对于每一步,只要满足top_w(Mp) ⊆ top_u(Mq),我们就可以接受Mq的猜测,从而获得与单独使用Mp进行常规集束搜索相同的结果。对这种设置下我们方法的分析更为复杂,我们将其留作未来的工作。

A.5 宽容度(Lenience)

引入宽容度参数:算法1的一个强属性是输出分布保证不变。但是,如果我们愿意允许一些有保证的变化,就可以获得进一步的推理速度提升。我们可以引入一个宽容度参数l ∈ [0, 1],并在算法1中将q(x)p(x)比较前,先将q(x)乘以l。这样做仍然能保持一个很好的保证,即任何token的采样概率都不会超过p(x)/l。例如,当l=0.1时,任何token的采样概率不会超过其真实概率的10倍,从而保证极度罕见的token仍然极度罕见。

对α的影响与结果:引入宽容度因子l后,α的计算变为 α = Σx min(p(x)/l, q(x))。表5显示了当Mp为T5-XXL(11B),Mq为T5-small(77M)时,不同l值对应的α值。假设c = 0.015,使用宽容度值1、0.5、0.3和0.1时,我们分别获得了2.5倍、3.1倍、3.6倍和5倍的改进因子。对于argmax采样(temperature=0),也可以采用类似的宽容度策略,例如当p(x)在一个l因子范围内接近最大概率时接受Mq的采样,这同样能带来α值的显著提升。

表5. 在标准采样下,当Mp为T5-XXL(11B)用于EnDe翻译任务时,不同l值对应的α值。
表5. 在标准采样下,当Mp为T5-XXL(11B)用于EnDe翻译任务时,不同l值对应的α值。