Faster Cascades via Speculative Decoding

文章标题:通过推测解码实现更快的级联
作者/机构:Harikrishna Narasimhan¹, Wittawat Jitkrittum¹, Ankit Singh Rawat¹, Seungyeon Kim², Neha Gupta³,†, Aditya Krishna Menon¹, Sanjiv Kumar¹
¹Google Research, ²Google DeepMind, ³Mistral AI


A1 主要贡献

大型语言模型(LLMs)在各种自然语言处理任务上取得了显著的质量提升,但代价是推理延迟的增加。为了在不过度牺牲质量的情况下降低LLMs的推理成本,研究者们提出了多种方法,其中两种典型的策略是模型级联(model cascading)和推测解码(speculative decoding)。这两种方法都通过交错使用不同大小的模型(例如一个小型和一个大型模型)来进行预测,但它们的机制根本不同。

  • 模型级联:采用一种“延迟规则”(deferral rule)来识别“困难”的输入,并且只在这些输入上调用较大的模型。
  • 推测解码:使用小型模型通过标准的自回归解码来起草一个token块,然后由大型模型并行验证这些草稿。

这两种方法各有优势:级联提供了引人注目的成本-质量权衡,有时甚至能超越大型模型本身的性能;而推测解码在保证质量不变的前提下,提供了显著的速度提升。鉴于它们的互补性,本文的核心问题是:我们能否结合这两种技术的优点?

本文通过设计新的推测级联(speculative cascading)技术来回答这个问题。该技术通过推测执行(speculative execution)的方式来实现级联的延迟规则:让小型模型自回归地生成草稿,大型模型并行地对这些草稿进行处理,以决定是否延迟。这种方法结合了两种策略的优点,能够产生比标准级联和推测解码更好的成本-质量权衡。

图1:通过延迟规则在小型和大型LM之间进行推测级联推理。
图1:通过延迟规则在小型和大型LM之间进行推测级联推理。

本文的主要贡献如下:

(i) 提出了一种通用的推测执行方案:该方案旨在模仿一个通用的目标分布,该分布交织了草稿模型(drafter)和验证模型(verifier)的分布。有损推测采样(Lossy speculative sampling)是该方案在特定目标分布下的一个特例(第4.1节)。
(ii) 将常见级联规则推测化:展示了如何通过将常见级联规则(如Chow规则和置信度差异阈值)的目标分布插入到我们的框架中,以推测方式实现它们。我们将这些方法称为推测级联(第4.2节)。
(iii) 刻画并实现了最优推测级联规则:从理论上刻画了推测级联的最优延迟规则,并设计了一种实现该最优规则的插件估计(plug-in estimate)的推测级联技术(第4.3节,引理4,表1)。此外,还提出了这些延迟规则的特定于token的变体(第5节)。
(iv) 实验验证:在Gemma和T5模型上,针对一系列语言基准任务(包括摘要、翻译、推理、编码和问答)进行了实验,结果表明,推测级联比其顺序级联和推测解码的基线方法能够提供更好的成本-质量权衡(第6节)。


A3 背景知识与关键观察

2. 两种高效的语言模型推理策略

基础设定。设 $V$ 是一个有限的词汇表, $V^*$ 是由该词汇表生成的所有有限长度序列的集合。$\Delta_V$ 表示 $V$ 中所有token上的概率分布集合。给定任意长度的序列 $x = x_1x_2...x_L \in V^*$ 和索引 $i \leq L$,我们用 $x_{<i} = x_1x_2...x_{i-1}$ 表示前缀。语言模型(LM)是 $V^*$ 上的一个概率分布。设 $P$ 表示数据生成概率分布。我们将根据LM模仿 $P$ 的接近程度来衡量其质量。假设我们有两个LM,$q$ 和 $p$,其中 $p$ 是更大(更昂贵)的模型。我们的目标是设计一种推理策略,选择性地调用 $q$ 和 $p$ 以权衡质量和延迟。我们将 $q(x_t|x_{<t})$ 表示为模型 $q$ 在给定前缀 $x_{<t}$ 的情况下,为token $x_t$ 分配的概率, $p(x_t|x_{<t})$ 同理。为方便起见,我们使用简写 $q_t(·)$ 表示 $q(·|x_{<t})$ ,$p_t(·)$ 表示 $p(·|x_{<t})$ 。</p>

级联(Cascades)。级联是一种有效的策略,通过让较小的模型 $q$ 处理“简单”的样本,而让较大的模型 $p$ 处理“困难”的样本来权衡成本和质量。一种常见的级联方法是置信度阈值法或Chow规则【12, C Chow. On optimum recognition error and reject tradeoff. 1970. IEEE Transactions on information theory】,即首先在输入上运行 $q$ ,当 $q$ 对其生成响应的置信度足够低时,则延迟给 $p$ 。这种策略通常在序列级别实现,对于给定的前缀 $x_{<m}$ ,我们调用 $q$ ,评估其在所有响应上的最大条件概率,并检查其是否低于阈值 $\alpha \in [0, 1]$:</p>

$$\max_{x_m \dots x_{m+n}} q(x_m \dots x_{m+n} \mid x_{<m}) < 1 - \alpha.$$ <p>如果该条件成立,我们延迟给 $p$ 生成新的响应;否则,我们用 $q$ 生成响应。可以通过调整 $\alpha$ 来实现期望的成本-质量权衡。

表1:与不同推理算法相关的目标分布,其中α是自由参数,β ≥ 1 − α取决于α、q和p。最后一列表示执行是顺序的(算法2)、通过预言机(算法3)还是推测性的(算法5)。有关δ的详细信息,请参见(6)。第三行提出了Kim等人(2023)的BiLD算法的一个变体,其中D(q, p)是q和p之间差异的度量;原始算法的不同之处在于使用具有动态草稿窗口的确定性推测解码过程(参见§B)。
表1:与不同推理算法相关的目标分布,其中α是自由参数,β ≥ 1 − α取决于α、q和p。最后一列表示执行是顺序的(算法2)、通过预言机(算法3)还是推测性的(算法5)。有关δ的详细信息,请参见(6)。第三行提出了Kim等人(2023)的BiLD算法的一个变体,其中D(q, p)是q和p之间差异的度量;原始算法的不同之处在于使用具有动态草稿窗口的确定性推测解码过程(参见§B)。

推测解码(Speculative decoding)。推测解码是另一种策略,它在 $q$ 和 $p$ 之间进行token级别的交错,旨在以降低的推理成本可证明地匹配大模型的质量。给定前缀 $x_{<t}$ ,我们通过从 $q$ 进行自回归采样来草拟 $\gamma$ 个草稿token $x_t, ..., x_{t+\gamma-1}$ ,并通过在 $\gamma$ 个前缀 $x_{<t}, ..., x_{<t+\gamma-1}$ 上并行运行 $p$ 来验证这些token是否可以被接受。然后我们回滚到第一个被拒绝的token $t+j^*$(其中 $j^* \in \{0, 1, ..., \gamma-1\}$ ),用一个新的token替换 $x_{t+j^*}$ ,并用前缀 $x_{<t+j^*+1}$ 重复该过程。在验证阶段,由 $q$ 生成的草稿token $x_{t+j}$ 以概率 $\min(1, \frac{p_{t+j}(x_{t+j})}{q_{t+j}(x_{t+j})})$ 被接受,否则被拒绝。然后,一个被拒绝的token会被一个新的token替换,该token从修改后的分布 $norm(\max\{0, p_{t+j}(·) - q_{t+j}(·)\})$ 中采样,其中 $norm(·)$ 表示归一化使其总和为1。这个采样过程被证明等同于从 $p$ 中为前缀 $x_{<t}$ 自回归地采样 $\gamma$ 个token。</p>

有损推测解码(Lossy speculative decoding)。在实践中,可以使用上述采样的有损变体【55, Vivien Tran-Thien. An optimal lossy variant of speculative decoding. 2023. Unsupervised Thoughts (Blog)】,它允许与验证器分布 $p$ 有一些偏差。在这种情况下,草稿token $x_{t+j}$ 以概率 $\min(1, \frac{p_{t+j}(x_{t+j})}{(1-\alpha) \cdot q_{t+j}(x_{t+j})})$ 被接受,其中 $\alpha \in [0, 1)$ 是一个严格性参数,值越高表示与 $p$ 的偏差越大。被拒绝的token可以被从残差分布 $norm(\max\{0, \frac{1}{\beta} \cdot p_{t+j}(·) - q_{t+j}(·)\})$ 中采样的token替换,其中 $\beta \geq 1 - \alpha$ 是一个取决于 $\alpha$、$q$ 和 $p$ 的参数。一个常见的启发式方法是简单地设置 $\beta=1$ 。

3. 级联与推测解码的结合

3.1 热身:Token级级联

Token级Chow规则。将序列级的Chow规则扩展到token级级联是直接的。对于前缀 $x_{<t}$ ,我们首先计算小模型的分布 $q_t(·)$ ,并检查 $\max_{v \in V} q_t(v)$ 是否低于预选的阈值。如果是,我们评估 $p_t(·)$ 并采样 $x_t \sim p_t(·)$ ;否则,我们采样 $x_t \sim q_t(·)$ 。更一般地,我们可以设计一个token级别的延迟规则 $r: V^{t-1} \rightarrow \{0, 1\}$ ,它以 $x_{<t}$ 为输入并输出一个二元决策,其中 $r(x_{<t}) = 1$ 表示我们延迟到 $p$ 。Token级别的Chow规则可以写成:</p>

$$r_{\text{Chow}}(x_{<t}) = 1 \iff \max_{v \in \mathcal{V}} q_t(v) < 1 - \alpha,$$ <p>其中 $\alpha$ 是一个阈值参数。

3.2 最优的Token级级联延迟规则

最优延迟规则的目标。Chow规则虽然易于实现,但可能不是最优的。一个自然的问题是,一个token级级联的最优延迟规则 $r$ 是什么样的。为此,我们必须首先指定在每个步骤 $t$ 最小化的目标。一个合理的目标是最小化延迟规则相对于数据生成分布 $P$ 的预期损失,并为延迟到大模型增加一个成本。对于固定的前缀 $x_{<t}$,该目标如下:</p>

$$L_{\text{def}}(r; x_{<t}) = \mathbb{E}_{v \sim \mathbb{P}(\cdot|x_{<t})} \left[ (1 - r(x_{<t})) \cdot \ell(v, q_t) + r(x_{<t}) \cdot (\ell(v, p_t) + \alpha) \right],$$ <p>其中成本惩罚为 $\alpha \geq 0$ ,损失函数为 $\ell: V \times \Delta_V \rightarrow \mathbb{R}_+$ 。常见的损失函数选择包括0-1损失 $\ell_{0-1}(v, q_t) = \mathbf{1}(v \neq \arg\max_{v'} q_t(v'))$ 和对数损失 $\ell_{\log}(v, q_t) = -\log(q_t(v))$ 。

引理1(Token级级联的最优延迟规则)。上述目标的最小化器形式如下:

$$r^*(x_{<t}) = 1 \iff \mathbb{E}_{v \sim \mathbb{P}(\cdot|x_{<t})} [\ell(v, q_t)] > \mathbb{E}_{v \sim \mathbb{P}(\cdot|x_{<t})} [\ell(v, p_t)] + \alpha.$$ <p>直观上,我们比较 $q$ 的预期损失和调用 $p$ 的预期成本,当后者更小时决定延迟。

最优规则的插件估计器。最优规则需要计算对数据生成分布 $P(·|x_{<t})$ 的期望,这在推理时是不可用的。一个常见的方法是用模型的置信度估计来代替预期损失。例如,当损失函数为0-1损失时,可以使用 $1 - \max_v q_t(v)$ 作为 $q_t$ 的预期0-1损失的估计,用 $1 - \max_v p_t(v)$ 作为 $p_t$ 的预期0-1损失的估计。由此产生的插件估计器(称为Diff规则)对来自两个分布的置信度估计的差异进行阈值判断:</p>

$$\hat{r}_{\text {Diff }}(x_{<t})=1 \iff \max _{v} q_{t}(v)<\max _{v} p_{t}(v)-\alpha.$$ <p>备注1(Diff规则在token级级联中不可实现)。我们不能直接在token级级联中使用 $\hat{r}_{\text{Diff}}$,因为它需要在每一步 $t$ 都调用大模型,这违背了级联的目的。然而,它可以作为一个预言机(oracle),用于分析改进Chow规则的潜力空间。

3.3 对比Token级级联和推测解码的权衡

分布差异。Token级级联和推测解码旨在模仿的token分布不同。推测解码旨在模仿大模型的输出分布,而token级级联旨在输出能够紧密近似标签分布的分布,并可能提供良好的成本-质量权衡,有时甚至能产生比大模型更好的质量。当草稿模型在某些输入上表现优于验证器时,级联非常有用。

实验观察。我们考虑使用在WMT EN→DE翻译任务和XSum摘要任务上微调的T5模型构建token级级联。我们使用Chow规则和作为预言机的Diff规则来构建这些级联。图2中,我们绘制了质量与延迟到大模型的样本比例的函数关系图。推测解码保证匹配大模型的质量,而基于Diff规则的级联产生了一系列的成本-质量权衡,通常优于大模型。即使是次优的Chow规则,在小范围内也优于大模型,这可能归因于级联中的集成效应。然而,与推测解码相比,token级级联可能需要更多次地延迟到大模型才能达到相同的质量,因为它们是顺序执行的。推测解码则并行地验证草稿token,并且其随机验证算法通常导致更少的token被拒绝。这些观察引出了一个问题:我们如何才能利用这两种技术的最佳之处?

图2:使用T5模型构建的级联中,质量与延迟到大模型token数占总生成token数的比例的函数关系图(温度采样T=1)。最左边的点代表小模型,最右边的点代表大模型。我们比较了使用Chow规则(Chow)和预言机延迟规则(Diff)构建的token级级联。虽然推测解码将匹配大模型的质量(见虚线水平线),但预言机延迟规则在一系列延迟率上产生了明显更好的质量。
图2:使用T5模型构建的级联中,质量与延迟到大模型token数占总生成token数的比例的函数关系图(温度采样T=1)。最左边的点代表小模型,最右边的点代表大模型。我们比较了使用Chow规则(Chow)和预言机延迟规则(Diff)构建的token级级联。虽然推测解码将匹配大模型的质量(见虚线水平线),但预言机延迟规则在一系列延迟率上产生了明显更好的质量。

A2 方法细节

4. 推测级联:结合两者的优点

本节介绍了我们的主要贡献:推测级联(speculative cascades),这是一种原则性的方法,将token级级联提供的更好权衡与推测解码的更快执行相结合。与token级级联不同,推测级联在每 $\gamma$ 个草稿token后都以并行评分模式调用大模型。因此,它们能够实现顺序级联无法实现的延迟规则,从而可能实现更低的延迟。

4.1 具有通用目标分布的推测解码

通用推测采样。我们首先考虑一个通用版本的推测解码,它旨在模仿一个从草稿模型和验证模型的分布中派生出的通用目标分布。在算法4中概述的采样过程中,我们像以前一样从草稿模型的分布中自回归地采样token。然而,在验证步骤中,我们不将草稿模型的token概率与验证模型的分布进行比较。相反,我们使用一个用户指定的目标分布 $\pi = T(q, p) \in \Delta_V$,该分布是从草稿模型和验证模型在位置 $t$ 的分布派生出来的,其中 $T(·, ·)$ 是一个计算成本不高的函数。当 $q(x_t) \leq \pi(x_t)$ 时,我们接受草稿token $x_t$,否则以 $1 - \frac{\pi(x_t)}{q(x_t)}$ 的概率拒绝它。拒绝后,我们从残差分布 $norm(\max\{0, \pi(·) - q(·)\})$ 中重新采样。这个通用过程不仅包含了标准推测解码($T(q, p) = p$),还包括了有损推测解码作为一个特例。

引理2。当 $T(q, p)(v) = \max\{\min\{q(v), \frac{p(v)}{1-\alpha}\}, \frac{p(v)}{\beta}\}$ 时,算法4简化为带有参数 $\alpha$ 和 $\beta$ 的有损推测采样过程。

4.2 从顺序级联到推测级联

将延迟规则转化为目标分布。我们现在提出新的级联技术,以推测方式实现其延迟规则。一个token级级联的下一个token分布可以写成:

$$\pi(v)=(1-r(x_{<t})) \cdot q_{t}(v)+r(x_{<t}) \cdot p_{t}(v).$$ <p>对于所有描述的延迟规则,结果分布可以由形式为 $T_\delta$ 的目标分布函数描述:

$$\mathbb{T}_{\delta}(q, p)(v)=(1-\delta(q, p)) \cdot q(v)+\delta(q, p) \cdot p(v),$$

其中函数 $\delta: \Delta_V \times \Delta_V \rightarrow \{0, 1\}$ 将分布 $(q, p)$ 映射到一个二元决策。例如,要实现Chow规则,我们可以选择 $\delta(q, p) = \mathbf{1}(\max_v q(v) < 1 - \alpha)$。

推测级联的提出。我们的提议是使用 $T_\delta$ 作为目标分布函数来调用算法4中的推测采样过程。我们在算法5中概述了这种通用的推测级联方法,并将其与算法2中延迟规则的顺序执行进行对比。有趣的是,之前无法实现的Diff规则,现在可以通过推测级联高效实现,使用 $\delta(q, p) = \mathbf{1}(\max_v q(v) < \max_v p(v) - \alpha)$ 。

备注2(Diff规则在推测级联中可实现)。在token级级联中,大模型的分布 $p$ 在做出延迟决定时不能使用。然而,通过推测级联,我们可以使用像Diff这样依赖于 $q$ 和 $p$ 的规则。这是因为我们并行地在小模型 $q$ 生成的草稿上运行大模型 $p$ ,从而可以在每个前缀上计算 $p(·)$ 和 $q(·)$ 。

4.3 最优的推测级联延迟规则

目标。我们寻求一个延迟规则 $r: V^{t-1} \rightarrow \{0, 1\}$,它能在质量和推理成本之间进行权衡。推理成本主要取决于验证阶段草稿token被拒绝的频率。因此,我们旨在最小化延迟规则的预期损失,同时对由此产生的拒绝率施加约束。

引理3。对于给定的前缀 $x_{<t}$ 和目标分布 $\pi = (1 - r(x_{<t})) \cdot q_t + r(x_{<t}) \cdot p_t$,从草稿分布 $q_t$ 中抽取的token被拒绝的概率等于 $r(x_{<t}) \cdot D_{TV}(p_t, q_t)$,其中 $D_{TV}(p, q)$ 是 $p$ 和 $q$ 之间的总变差距离。</p>

优化问题。对于固定的前缀 $x_{<t}$,我们寻求解决以下问题:<br /> $\min \mathbb{E}_{v \sim P(\cdot|x_{<t})} [(1 - r(x_{<t})) \cdot \ell(v, q_t) + r(x_{<t}) \cdot \ell(v, p_t)]$<br /> 约束条件为 $r(x_{<t}) \cdot D_{TV}(p_t, q_t) \leq B$,对于某个预算 $B > 0$。这等价于最小化一个无约束的目标(拉格朗日形式):<br /> $L_{\text{spec}}(r; x_{<t}) = \mathbb{E}_{v \sim P(\cdot|x_{<t})} [(1 - r(x_{<t})) \cdot \ell(v, q_t) + r(x_{<t}) \cdot (\ell(v, p_t) + \alpha \cdot D_{TV}(p_t, q_t))]$<br /> 与token级级联的延迟风险相比,延迟到大模型的成本不再是一个常数,而是取决于 $q_t$ 和 $p_t$ 之间的相似性。

引理4(推测级联的最优延迟规则)。上述目标的最小化器形式如下:

$$r^*(x_{<t}) = 1 \iff \mathbb{E}_{v \sim \mathbb{P}(\cdot|x_{<t})} [\ell(v, q_t)] > \mathbb{E}_{v \sim \mathbb{P}(\cdot|x_{<t})} [\ell(v, p_t)] + \alpha \cdot D_{\text{TV}}(p_t, q_t)$$ <p>当 $p_t$ 和 $q_t$ 相似时,拒绝率低,延迟决策主要取决于哪个模型产生更低的预期损失。当 $p_t$ 和 $q_t$ 非常不同时,只有当 $p_t$ 产生的损失显著低于 $q_t$ 时,才延迟到 $p_t$ 。

插件估计器。最优规则需要估计对数据生成分布 $P(·|x_{<t})$ 的期望。我们采用与token级级联类似的插件估计器。当损失为0-1损失时,我们将预期0-1损失替换为(1减去)模型的最大概率,得到:</p>

$$\hat{r}_{\text{OPT}}(x_{<t}) = 1 \iff \max\nolimits_v q_t(v) < \max\nolimits_v p_t(v) - \alpha \cdot D_{\text{TV}}(p_t, q_t).$$ <p>引理5($\hat{r}_{OPT}$的遗憾界)。该插件估计器的有效性由以下遗憾界形式化:

$$L_{\text{spec}}(\hat{r}_{\text{OPT}}; x_{<t}) - \min_{r} L_{\text{spec}}(r; x_{<t}) \leq \max_{v \in \mathcal{V}} |\mathbb{P}_t(v) - q_t(v)| + \max_{v \in \mathcal{V}} |\mathbb{P}_t(v) - p_t(v)|.$$ <p>现在可以使用该规则在算法5中运行推测级联过程。

算法 1 SpecDecode
输入: 模型 q, p, 前缀 x<t, 块大小 γ
T(q, p) := p
输出: GenSpecSample(q, p, T, x<t, γ)

算法 2 TokenCascade
输入: 模型 q, p, 延迟逻辑 δ, 前缀 x<t
qt(·) := q(·|x<t)
if δ(qt, ∅) = 0 then
    采样 xt ∼ qt(·)
else
    pt(·) = p(·|x<t); 采样 xt ∼ pt(·)
end if
输出: xt

算法 3 OracleCascade
输入: 模型 q, p, 延迟逻辑 δ, 前缀 x<t
qt(·) = q(·|x<t); pt(·) = p(·|x<t)
if δ(qt, pt) = 0 then
    采样 xt ∼ qt(·)
else
    采样 xt ∼ pt(·)
end if
输出: xt

算法 4 GenSpecSample
输入: 模型 q, p, 目标分布 T, 前缀 x<t, 块大小 γ
// 从 q 自回归采样 γ 个 token
for j = 0 to γ − 1 do
    qt+j(·) := q(·|x<t+j); xt+j ∼ qt+j(·)
end for
// 并行运行 p 对 γ 个草稿 token 进行评分
pt+j(·) = p(·|x<t+j), ∀j ∈ [γ]
πt+j = T(qt+j, pt+j)
// 找到最早被拒绝的草稿 token
aj ∼ Ber(min{1, πt+j(xt+j) / qt+j(xt+j)}), ∀j ∈ [γ − 1]; aγ = 0
j∗ = min{j ∈ [γ] : aj = 0}
// 从残差分布中采样一个新 token
if j∗ < γ then
    pres(·) = norm(max{0, πt+j∗(·) − qt+j∗(·)})
else
    pres(·) = πt+γ(·)
end if
采样 xt+j∗ ∼ pres(·)
输出: xt, . . . , xt+j∗

算法 5 SpecCascade
输入: 模型 q, p, 延迟逻辑 δ, 前缀 x<t, 块大小 γ
Tδ(q, p) := (1 − δ(q, p)) · q + δ(q, p) · p
输出: GenSpecSample(q, p, Tδ, x<t, γ)

5. 超越级联延迟:特定于Token的交错

问题。到目前为止我们看到的延迟规则是通过比较 $q_t(·)$ 和 $p_t(·)$ 的最大token概率来决定使用哪个分布。这种方法的缺点是,采样的草稿token $x_t \sim q_t(·)$ 可能与使 $q_t(·)$ 最大化的token不同。因此,即使 $x_t$ 质量很差,我们最终也可能接受它,因为 $q_t$ 恰好比 $p_t$ 更尖锐。

特定于Token的交错。为了解决上述问题,我们提出了使用特定于token的延迟规则 $r: V^{t-1} \times V \rightarrow \{0, 1\}$,它同时使用前缀 $x_{<t}$ 和候选token $v$ 来提供一个二元决策 $r(x_{<t}, v) \in \{0, 1\}$,其中0表示该token质量可接受。然后我们可以构建以下形式的目标分布:</p>

$$\pi_{\text{Token}}(v) = q_t(v) \cdot (1 - r(x_{<t}, v)) + p_t(v) \cdot \eta,$$ <p>其中 $\eta = \sum_{v' \in V} r(x_{<t}, v') \cdot q_t(v')$ 是一个归一化项。该目标分布在延迟规则 $r$ 认为质量可接受的token上紧密模仿 $q_t(·)$,否则延迟到 $p_t(·)$。</p>

延迟规则设计。我们提出了Diff规则的一个启发式变体。具体来说,我们比较草稿token $v$ 是 $x_{<t}$ 的错误下一个token的概率(根据数据生成分布 $P$)与如果我们延迟到验证器 $p_t$ 将产生的预期0-1损失:</p>

$$r(x_{<t}, v) = 1 \iff 1 - \mathbb{P}(v | x_{<t}) > \mathbb{E}_{v' \sim \mathbb{P}(\cdot | x_{<t})} [\ell_{0\text{-}1}(v', p_t)] + \alpha,$$ <p>以下是一些简单的插件近似:

$$\begin{aligned} \begin{aligned} \hat{r}_{\text{TokenV1}}(x_{<t}, v) = 1 & \iff q_t(v) < \max_{v'} p_t(v') - \alpha \\ \hat{r}_{\text{TokenV2}}(x_{<t}, v) = 1 & \iff p_t(v) < \max_{v'} p_t(v') - \alpha \\ \hat{r}_{\text{TokenV3}}(x_{<t}, v) = 1 & \iff p_t(v) < \max_{v'} p_t(v') \cdot (1 - \alpha). \end{aligned} \end{aligned}$$ <p>例如,使用第一个近似规则,目标分布具有直观的形式:

$$\pi_{\text{TokenV1}}(v) = q_t(v) \cdot \mathbf{1}(v \in \mathcal{T}_\alpha) + p_t(v) \cdot \textstyle\sum_{v' \notin \mathcal{T}_\alpha} q_t(v'),$$

其中 $T_\alpha = \{v \in V: q_t(v) \geq \max_{v'} p_t(v') - \alpha\}$ 是被认为重要的token集合。

与有损推测采样的对比。有损推测采样也旨在模仿 $q_t$ 和 $p_t$ 的特定于token的交错,但其目标分布的选择可能严重限制可实现的成本-质量权衡范围,特别是在使用小温度或top-P采样时,因为 $q_t$ 和 $p_t$ 的支持集可能不重叠。相比之下,我们提出的方法通过在未缩放的 $q_t$ 和 $p_t$ 上计算延迟规则 $r$ ,同时在缩放后的版本之间进行交错,从而在各种采样方案下都能实现更广泛的权衡。有趣的是,我们提出的TokenV3方法在温度趋于0时,会简化为【32, Yaniv Leviathan, Matan Kalman, and Yossi Matias. Fast inference from transformers via speculative decoding. 2023. ICML】为贪婪解码提出的一个有损推测解码变体,因此它可以被看作是适用于贪婪和非贪婪解码的通用延迟规则。


A4 实验环境

数据集
* T5模型
* WMT EN→DE 翻译 【5, Ondˇrej Bojar, Christian Buck, Christian Federmann, et al. Findings of the 2014 workshop on statistical machine translation. 2014. WMT】
* CNN/DM 摘要 【24, Karl Moritz Hermann, Tomas Kocisky, Edward Grefenstette, et al. Teaching machines to read and comprehend. 2015. NeurIPS】
* XSum 抽象摘要 【39, Shashi Narayan, Shay B Cohen, and Mirella Lapata. Don’t give me the details, just the summary! topic-aware convolutional neural networks for extreme summarization. 2018. EMNLP】

  • Gemma模型:除了WMT和CNN/DM,还使用了以下6个基准测试,其中许多来自SpecBench套件【62, Heming Xia, Zhe Yang, Qingxiu Dong, et al. Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding. 2024. arXiv】:
    • GSM8K (数学推理) 【14, Karl Cobbe, Vineet Kosaraju, Mohammad Bavarian, et al. Training verifiers to solve math word problems. 2021. arXiv】
    • MBPP (Python编程) 【2, Jacob Austin, Augustus Odena, Maxwell Nye, et al. Program synthesis with large language models. 2021. arXiv】
    • SQuAD 2.0 (问答) 【44, Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, and Percy Liang. SQuAD: 100,000+ questions for machine comprehension of text. 2016. EMNLP】
    • WebQuestions (问答) 【3, Jonathan Berant, Andrew Chou, Roy Frostig, and Percy Liang. Semantic parsing on freebase from question-answer pairs. 2013. EMNLP】
    • NaturalQA (问答) 【31, Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, et al. Natural questions: A benchmark for question answering research. 2019. TACL】
    • TriviaQA (问答) 【27, Mandar Joshi, Eunsol Choi, Daniel Weld, and Luke Zettlemoyer. TriviaQA: A large scale distantly supervised challenge dataset for reading comprehension. 2017. ACL】

模型架构

  • T5:使用T5 v1.1系列编码器-解码器模型,包括T5-small (77M)、T5-large (800M)和T5-XL (3B)。
  • Gemma:使用Gemma v2指令调优的仅解码器模型。使用2B参数模型作为草稿模型,9B或27B模型作为验证模型。

硬件配置
* T5模型的实验在TPUv4设备上运行。

软件配置
* T5模型:使用公开检查点初始化,进一步预训练100K步,然后在各个任务上进行监督微调。
* Gemma模型:采用少样本(few-shot)推理方式。
* 通用参数
* 延迟测量:使用批大小为1的壁钟解码时间。
* 块大小 $\gamma$:T5实验中通常为5,Gemma实验中为1。
* 温度 T:0, 0.1, 0.5, 1.0。
* Top-P采样也进行了评估。


A5 实验结果

评估协议:本文的评估协议与标准推测解码算法的评估不同。标准评估的目标是在保持大模型输出分布的同时加速推理。而本文的重点是通过交错使用不同大小的两个模型来权衡质量与推理成本。

方法与基线

  • 提出的方法:评估了四种延迟规则的推测级联:(i) Chow,(ii) Diff,(iii) OPT,以及 (iv) Token-specific(V3变体)。
  • 基线方法
    • 序列级级联 (SeqCascade [Chow])
    • Token级级联 (TokenCascade [Chow])
    • 有损推测解码 (SpecDecode [Lossy]),包括 $\beta=1$ 和调优 $\beta$ 的版本
    • Big-Little Decoder (BiLD),包括原始确定性版本 (BiLD) 和随机拒绝采样变体 (BiLD*)

微调T5级联实验

  • 结果总结 (表2)
    • 在匹配大模型质量时,SpecCascade [OPT]SpecCascade [Token] 提供了最大的延迟降低(速度提升)。例如,在T5-small→T5-large的WMT任务上,SpecCascade [Token] 达到了1.85倍的加速。
    • 在不超过大模型延迟的前提下,SpecCascade [Token] 能够达到最高的质量。例如,在WMT任务上,其BLEU分数为22.50,显著高于大模型的17.26。
    • 顺序级联和token级级联在质量和延迟指标上通常表现较差。
    • SpecCascade [Token] 通常比SpecCascade [OPT]Diff 表现更好,因为后两者基于分布的峰值做决策,这在温度较高、采样token偏离众数时可能不利。
表2:在匹配大模型质量时的延迟降低(第2-7列),以及在不超过大模型延迟时的最佳质量指标(第8-13列)。质量通过WMT的BLEU和XSum及CNNDM的ROUGE-2来衡量。第1-4行是基线;第5-6行是使用旧延迟规则的提议方法(§3);第7-8行是使用新延迟规则的提议方法(§4)。
表2:在匹配大模型质量时的延迟降低(第2-7列),以及在不超过大模型延迟时的最佳质量指标(第8-13列)。质量通过WMT的BLEU和XSum及CNNDM的ROUGE-2来衡量。第1-4行是基线;第5-6行是使用旧延迟规则的提议方法(§3);第7-8行是使用新延迟规则的提议方法(§4)。

少样本Gemma级联实验

  • 结果总结 (图3)
    • 实验比较了在Gemma 2B(草稿)和27B(验证)模型上的性能,块大小 $\gamma=1$。
    • 除了TriviaQA任务外,SpecCascade [Token] 在以更低的拒绝率匹配27B模型质量,以及在实现最佳整体质量(通常优于27B)方面都表现出色。
    • 由于所有方法使用完全相同的实现,仅拒绝标准不同,因此直接比较它们的拒绝率是公平的。
    • 与T5模型不同,OPT规则在Gemma上效果不佳。这归因于Gemma 2B和27B之间巨大的参数差异,使得大模型的概率分布成为衡量置信度的更好指标,这正是TokenV3规则所做的。
    • 当小模型在部分数据上性能与大模型相当或更好时,推测级联相比于推测解码的优势更大。
图3:Gemma 2B与Gemma 27B交错使用时,质量与拒绝率的关系图(γ=1)。我们使用指令调优模型;对于MBPP,我们报告了使用预训练模型的额外结果。其余图表、与(13-14)的比较以及2B→9B级联的结果见§F.7。
图3:Gemma 2B与Gemma 27B交错使用时,质量与拒绝率的关系图(γ=1)。我们使用指令调优模型;对于MBPP,我们报告了使用预训练模型的额外结果。其余图表、与(13-14)的比较以及2B→9B级联的结果见§F.7。

A6 结论

本文提出了新的推测级联技术,它结合了自回归草稿生成和并行验证来实现其延迟规则。实验证明,这些技术比标准的顺序级联和推测解码能产生更好的成本-质量权衡。

局限性:本文方法虽然通过并行执行提供了更低的延迟,但与顺序级联相比,它也产生了更高的总计算成本。

未来工作

  1. 用一个在真实样本上训练的路由器模型(router model)来替代插件估计器,以更精确地近似最优规则。
  2. 将该提议扩展到两个以上模型的场景。

A7 附录

B 更多相关工作

改进推测解码的草稿过程。一些工作通过让草稿模型和验证模型共享骨干网络、使用多个小型草稿模型、使用树状结构的草稿批次、用验证模型蒸馏草稿模型以及利用多个采样草稿等方法来改进推测解码。

与BiLD的比较。与本文最相关的工作是Big Little Decoder (BiLD) 【29, Sehoon Kim, Karttikeya Mangalam, Suhong Moon, et al. Speculative decoding with big little decoder. 2023. NeurIPS】,它可以被看作是另一种有损推测解码。BiLD有两个阶段:一个回退阶段(fallback phase),在此期间草稿模型 $q$ 自回归运行,直到其最大预测概率足够低;以及一个回滚阶段(rollback phase),在此期间验证模型 $p$ 在 $q$ 生成的前缀上并行运行,并在差异度量 $D(q, p) > \alpha$ 的地方回滚。BiLD的目标分布旨在近似 $p$ 。因此,在 $q$ 与 $p$ 显著偏离的情况下,BiLD会选择 $p$ 作为目标分布,即使 $q$ 在某个前缀上提供更好的质量。相比之下,本文提出的方法使用推测解码来近似那些旨在在 $q$ 和 $p$ 之间实现最优级联的目标分布。

C 在不同采样方案下对比推测级联和有损推测采样

C.1 温度采样和Top-P采样下的推测级联

在实现带有温度采样和top-P采样的推测级联时,我们在原始分布 $p$ 和 $q$ 上计算延迟规则,但使用延迟决策来交错 $p$ 和 $q$ 的温度缩放(或top-P截断)版本。对于级联延迟规则,目标分布构建如下:

$$\tilde{\pi}_{t}(v)=\left(1-\delta\left(q_{t}, p_{t}\right)\right) \cdot \mathbb{S}\left(q_{t}\right)(v)+\delta\left(q_{t}, p_{t}\right) \cdot \mathbb{S}(p)(v),$$

其中 $S: \Delta_V \rightarrow \Delta_V$ 表示分布的变换。对于OPT规则,我们会在约束问题中使用变换后分布之间的TV距离。对于特定于token的延迟规则,我们也在未缩放的分布上计算规则,但在变换后的分布上构建目标分布。

C.2 温度采样下与有损推测采样的对比

在温度采样下实现有损推测解码时,接受标准和残差分布都是使用温度缩放后的分布计算的。由此产生的目标分布 $\pi_{\text{Lossy}}(v)$ 具有以下形式:

$$\pi_{\text {Lossy }}(v)=\max \left\{\min \left\{\mathbb{S}\left(q_{t}\right)(v), \frac{\mathbb{S}\left(p_{t}\right)(v)}{1-\alpha}\right\}, \frac{\mathbb{S}\left(p_{t}\right)(v)}{\beta}\right\}.$$

这种选择可能会严重限制可实现的成本-质量权衡范围,因为当 $S(p_t)(v) = 0$ 时,$\pi_{\text{Lossy}}(v)$ 也为0,这使得权衡参数 $\alpha$ 和 $\beta$ 对于这些token无效。当温度为0(即贪婪解码)时,问题更加严重,此时 $\pi_{\text{Lossy}}(v)$ 与 $p_t$ 的one-hot版本相同,使得 $\alpha$ 和 $\beta$ 完全无关。相比之下,由于我们的级联延迟规则和特定于token的延迟规则都使用未缩放的分布,它们即使在温度为0时也能提供有意义的权衡。

C.3 Top-P采样下与有损推测采样的对比

在top-P采样下实现有损推测采样时,当 $P$ 变小时,大多数token的概率变为0,导致权衡参数 $\alpha$ 在大多数候选token上无效。当 $P \rightarrow 0$ 时,有损推测解码变得与标准无损推测解码相同。推测级联不存在此问题,因为它使用权衡参数 $\alpha$ 来构建一个新的目标分布,该分布即使在top-P采样下也适合更高的接受率。

表4:不同推测推理策略在非贪婪解码和贪婪解码下的接受标准。第2行和第3行表明,在温度为0时,带有TokenV3延迟规则的推测级联与SpecDecode [Lossy, Greedy]具有相同的接受标准;见引理7。
表4:不同推测推理策略在非贪婪解码和贪婪解码下的接受标准。第2行和第3行表明,在温度为0时,带有TokenV3延迟规则的推测级联与SpecDecode [Lossy, Greedy]具有相同的接受标准;见引理7。
C.4 Leviathan等人(2023)的有损推测贪婪解码变体

对于贪婪解码的特殊情况,【32, Yaniv Leviathan, Matan Kalman, and Yossi Matias. Fast inference from transformers via speculative decoding. 2023. ICML】提出了另一种有损变体,其中当 $p_t(v) \geq (1-\alpha) \cdot \max_{v'} p_t(v')$ 时,草稿token $v$ 被确定性地接受。我们证明了带有TokenV3延迟规则的推测级联在温度为0时与此变体是相同的(引理7)。

D 最优延迟:额外讨论

Chow规则的推导。Chow规则是对一个优化问题的插件估计,该问题对延迟到大模型的选择施加一个恒定的惩罚 $\alpha$。

当损失为对数损失时的最优顺序/推测延迟。当损失函数 $\ell = \ell_{\log}$ 时,可以使用分布的熵作为其预期对数损失的估计,从而推导出相应的Diff和OPT规则的插件估计器。

贪婪解码下的最优推测延迟。在贪婪解码下,最优延迟规则OPT等价于Diff延迟规则(引理10)。

优化问题的等价性。证明了带约束的优化问题(7)和一个无约束的拉格朗日形式问题(8)是等价的。

E 特定于Token的推测级联

算法6。提供了算法5的修改版以适应特定于token的延迟规则。

最优的特定于Token的延迟。讨论了推导最优特定于token的延迟规则的复杂性。与之前的级联情况不同,这里的优化问题没有简单的闭式解。因此,在实验中,我们采用了受Diff规则形式启发的三个启发式规则。

F 额外的实验细节

实验设置和超参数。提供了T5和Gemma数据集的详细信息(验证集大小、最大输入/输出长度),模型微调细节,以及运行时评估协议(TPUv4,batch size 1,抽样500个样本)。还说明了BiLD基线的差异度量 $D(q, p)$ 的选择。

图4:T5模型在温度T=0.5和T=1,块大小γ=5时的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。我们包括了使用Chow、Diff、OPT和TokenV3(称为Token)延迟规则的推测级联,并将其与三个基线进行比较:SpecDecode [Lossy]、TokenCascade [Chow]和BiLD*。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。
图4:T5模型在温度T=0.5和T=1,块大小γ=5时的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。我们包括了使用Chow、Diff、OPT和TokenV3(称为Token)延迟规则的推测级联,并将其与三个基线进行比较:SpecDecode [Lossy]、TokenCascade [Chow]和BiLD*。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。

不同温度和贪婪解码下的比较。图5和图6展示了在不同温度(T=0, 0.1, 0.5, 1.0)下的质量与延迟图。贪婪解码下,所有方法的质量都更高,并且SpecCascade [TokenV3]等价于SpecDecode [Lossy, Greedy]。随着温度降低,SpecDecode [Lossy]产生的权衡点变少,而SpecCascade [Token]仍然提供广泛的权衡范围。

图5:T5模型在贪婪解码(T=0)且块大小γ=5时的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。SpecDecode [Lossy, Greedy]是Leviathan等人(2023)提出的有损推测解码的贪婪版本。SpecCascade [Token]使用(15)中的TokenV3延迟规则。如§C.4所述,当T→0时,SpecDecode [TokenV3]与SpecDecode [Lossy, Greedy]相同。
图5:T5模型在贪婪解码(T=0)且块大小γ=5时的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。SpecDecode [Lossy, Greedy]是Leviathan等人(2023)提出的有损推测解码的贪婪版本。SpecCascade [Token]使用(15)中的TokenV3延迟规则。如§C.4所述,当T→0时,SpecDecode [TokenV3]与SpecDecode [Lossy, Greedy]相同。

Top-P采样下的比较。图7和表5显示了在top-P采样下的结果。当P值较小时,SpecCascade [Token]相比SpecDecode [Lossy]提供了更大的加速。这是因为随着P变小,SpecDecode [Lossy]中的权衡参数 $\alpha$ 效果减弱,而SpecCascade [Token]不受此影响。

图7:T5模型在不同P值的Top-P采样下的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。随着P变小,SpecDecode [Lossy]能够产生更少的独特权衡点,因为其接受标准中通过权衡参数α进行的缩放效果较差(§C.3)。使用(15)中TokenV3延迟规则的SpecCascade [Token]则没有同样的问题;这里的权衡参数不作为接受标准中的缩放项出现。
图7:T5模型在不同P值的Top-P采样下的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。随着P变小,SpecDecode [Lossy]能够产生更少的独特权衡点,因为其接受标准中通过权衡参数α进行的缩放效果较差(§C.3)。使用(15)中TokenV3延迟规则的SpecCascade [Token]则没有同样的问题;这里的权衡参数不作为接受标准中的缩放项出现。

不同块大小 $\gamma$ 下的比较。图8显示,在不同块大小(3, 5, 7)下,推测级联都优于有损推测解码。更大的 $\gamma$ 值会产生更宽的权衡范围,但当需要高质量时,可能会因频繁回滚而损害加速效果。

图8:T5模型在不同块大小γ下的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。
图8:T5模型在不同块大小γ下的质量与延迟关系图。每种方法都交错使用T5-small和T5-large。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。

BiLD变体。图9(顶部)比较了本文中使用的BiLD(随机拒绝采样版本)与原始BiLD算法。结果表明,BiLD的性能与原始BiLD的最佳参数选择相当。

有损推测解码变体。图9(底部)比较了 $\beta=1$ 和调优 $\beta$ 的有损推测解码。结果显示,设置 $\beta=1$ 的性能与调优 $\beta$ 相当或更好,可能是因为调优目标(KL散度)与评估指标(BLEU/ROUGE-2)不一致。

图9:顶部:比较BiLD*与Kim等人(2023)的原始BiLD算法在不同最大草稿窗口大小γ和回退置信度阈值αf下的质量与延迟关系图。底部:比较β=1 [Lossy]和根据(Tran-Thien, 2023)中的程序调优的β [Lossy?]的有损推测解码。
图9:顶部:比较BiLD*与Kim等人(2023)的原始BiLD算法在不同最大草稿窗口大小γ和回退置信度阈值αf下的质量与延迟关系图。底部:比较β=1 [Lossy]和根据(Tran-Thien, 2023)中的程序调优的β [Lossy?]的有损推测解码。

特定于Token的延迟规则变体。图10-12比较了三种特定于token的延迟规则(V1, V2, V3)。在T5模型上结果混合,但在Gemma模型上,V3变体通常表现最好,这可能是因为它使用大模型的分布来衡量置信度,这在模型大小差距较大时特别有帮助。

图10:T5模型使用(13)-(15)中所有三种特定于token的推测级联延迟规则的质量与延迟关系图。每种方法都交错使用一个T5 small和一个T5 large模型。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。
图10:T5模型使用(13)-(15)中所有三种特定于token的推测级联延迟规则的质量与延迟关系图。每种方法都交错使用一个T5 small和一个T5 large模型。x轴追踪相对于在所有输入上调用大模型的延迟。水平虚线表示大模型的质量。

G 局限性

插件估计器的依赖性。本文提议的一个局限性是使用插件估计器来近似最优规则,这依赖于单个模型的校准程度。一个替代方案是使用一个明确训练来模仿最优规则的路由器模型。

局部优化目标。另一个局限性是,我们寻求最小化的优化目标是局部的,即在当前位置 $t$ 做出最佳延迟决策。它们忽略了在当前步骤中选择特定模型对下游的后续影响。设计一个考虑下游错误的全局延迟目标将是一个有趣的未来工作方向。

公平性。更广泛地说,本文旨在改善LM推理中的成本-质量权衡。重要的是,这种改进不应以牺牲其他数据切片或部分人群为代价,不公平地偏向某一部分。确保我们的方法提供的权衡增益在不同数据切片上是公平的,是未来的另一个重要方向。