GRIFFIN: Effective Token Alignment for Faster Speculative Decoding

作者/机构: Shijing Hu (复旦大学), Jingyang Li (新加坡国立大学), Xingyu Xie (新加坡国立大学), Zhihui Lu (复旦大学), Kim-Chuan Toh (新加坡国立大学), Pan Zhou (新加坡管理大学)

A1 主要贡献

大型语言模型(LLMs)的标准自回归解码逐个生成 token,计算成本高且耗时,对延迟敏感的应用构成挑战。为了加速生成,投机解码(speculative decoding)通过使用一个轻量级的草稿模型提出多个 tokens,然后由目标 LLM 并行验证,从而在目标 LLM 的单次前向传播中生成多个 token,显著降低了延迟。然而,投机解码的效率严重依赖于草稿 tokens 的高接受率以及生成它们的低计算成本。

核心问题: 现有方法(如 EAGLE 和 Medusa)在训练和解码过程之间存在错位问题。训练时,草稿模型使用来自目标模型的特征和训练数据中的真实 tokens;而在解码时,它依赖于自身生成的特征和先前生成的草稿 tokens。这种差异导致了两个关键问题:(1) 特征错位(feature misalignment),即解码时草稿模型生成的特征与训练时使用的特征不一致;(2) token 错位(token misalignment),即解码时使用的草稿 tokens 替代了训练时使用的真实 tokens,导致错误在多步中累积。这些错位问题(类似于暴露偏差【索引11, Samy Bengio et al., Scheduled sampling for sequence prediction with recurrent neural networks, 2015, NeurIPS】, 【索引12, Florian Schmidt, Generalization in generation: A closer look at exposure bias, 2019, arXiv】)显著降低了草稿 tokens 的接受率,从而影响了整体的加速性能。

现有工作的局限性: 尽管像 HASS【索引13, Lefan Zhang et al., Learning harmonized representations for speculative sampling, 2024, arXiv】这样的方法尝试通过在训练中使用草稿模型的特征来解决特征错位问题,但它们忽略了 token 错位问题。Token 错位问题在解码过程中尤为严重,因为早期解码步骤的错误会传播并放大。如图 1(c) 所示,EAGLE2 在训练中遭受 48% 的 token 错位率,导致接受长度不理想。HASS 虽然缓解了特征错位,但其 token 错位率在训练后期仍升至 37%,使其在更深层次的多前向传播协调训练中(例如,训练步数 n ≥ 3)效率低下,如图 1(b) 所示。


图 1: GRIFFIN、EAGLE2 和 HASS 的比较。(a) 加速比比较。(b) 不同训练步骤下的接受长度,其中 "Step n" 表示在训练中对齐草稿模型 n 次。(c) 在每个草稿-验证周期中不同前向传播次数下的错位 token 率,其中 "Forward n" 表示进行 n 次前向传播以生成 n 个草稿 token。

本文的主要贡献: 本文提出了 GRIFFIN,一个新颖的投机解码框架,该框架首次明确识别并解决了先前未被观察到的 token 错位问题,同时也处理了特征错位问题。GRIFFIN 引入了两项核心创新:一个可进行 token 对齐的训练策略和一个可进行 token 对齐的草稿模型,两者协同工作以显著提高解码效率。

  1. 可进行 token 对齐的训练策略: 为了缓解训练期间的 token 错位,GRIFFIN 采用了一种动态损失掩码机制。该机制只选择性地对齐的 tokens 进行反向传播——这些对齐的 tokens 被定义为其真实 token 出现在草稿模型 top-k 预测中的 tokens。这不仅最小化了高度错位 tokens 的破坏性影响,而且协调了训练和解码过程,因为解码中的草稿树是建立在 top-k 预测之上,而不是精确匹配最高概率的 token。
  2. 可进行 token 对齐的草稿模型: 为了进一步减少 token 错位,GRIFFIN 设计了一个可进行 token 对齐的草稿模型,通过将一种名为Token 引导融合(Token-Guided Fusion, TGF)的架构创新融入到 EAGLE【索引8, Yuhui Li et al., EAGLE: Speculative sampling requires rethinking feature uncertainty, 2024, arXiv】的草稿模型中。TGF 执行一个两步融合过程来优化特征表示,并缓解草稿模型和目标模型之间的不一致性。通过两次引入输入 tokens——初次与特征融合,随后用于优化特征——TGF 模块确保草稿模型产生与目标模型更紧密对齐的特征,从而减少特征和 token 的错位。

这两个组件是相辅相成的。可进行 token 对齐的草稿模型减少了错位,增加了可用于有效训练的对齐 tokens 的数量。反过来,训练策略确保这些对齐的 tokens 对模型优化做出有意义的贡献。如图 1(c) 所示,GRIFFIN 在多个前向传播步骤中始终保持比 EAGLE2 和 HASS 低得多的 token 错位率,从而产生更长的接受 token 序列和更大的加速比。实验结果表明,在 LLaMA2-7/13B、LLaMA3-8B 和 Vicuna-7B 等模型上,GRIFFIN 的平均接受长度比 EAGLE2 提高了 20%,比 HASS 提高了 8%;加速比则分别提高了 18% 和 7%。

A3 动机:Token 错位

投机解码的工作流程。投机解码【索引6, Yaniv Leviathan et al., Fast inference from transformers via speculative decoding, 2023, ICML】, 【索引7, Charlie Chen et al., Accelerating large language model decoding with speculative sampling, 2023, arXiv】通过采用“草稿-验证”策略来加速文本生成。在每个周期中,一个轻量级的草稿模型 M 首先通过多次前向传播生成多个 tokens,然后一个更强的目标模型 T 在单次前向传播中验证并接受其中的一个子集。


图 2: EAGLE 中的 Token 和特征错位。

EAGLE范式及其固有的错位问题。EAGLE【索引8, Yuhui Li et al., EAGLE: Speculative sampling requires rethinking feature uncertainty, 2024, arXiv】通过将自回归从 token 级别转移到特征级别来扩展这一范式。如图 2 所示,草稿模型不是直接预测 tokens,而是生成中间隐藏状态特征,这些特征近似于目标模型 T 最后一层(语言建模头 H 之前)的特征。在时间步 t,令 $x_t$ 和 $\bar{x}_t$ 分别表示第 t 个真实 token 和草稿 token,而 $F_t$ 和 $\bar{F}_t$ 分别是它们来自 T 和 M 的隐藏特征。在训练期间,草稿模型使用 $x_t$ 和 $F_t$ 来预测 $\bar{x}_{t+1}$ 和 $\bar{F}_{t+1}$。然而,在解码期间,草稿模型必须完全依赖于先前生成的 tokens $\bar{x}_t$ 和特征 $\bar{F}_t$——无法访问 $x_t$ 或 $F_t$——因为目标模型在所有草稿 tokens 生成后每个周期只调用一次。这种差异引入了两个基本问题:1) 特征错位,即在解码时,草稿模型使用 $\bar{F}_t$ 而不是训练时的 $F_t$ 进行预测;2) token 错位,即解码中使用的 tokens $\bar{x}_t$ 与训练期间看到的真实 tokens $x_t$ 不同。

Token错位的严重性被低估。在这两个问题中,token 错位问题尤为严重但未得到充分研究。图 1(c) 显示,对于 EAGLE2 和 HASS,$\bar{x}_t \ne x_t$ 的比率——即 token 错位率——随着前向传播次数的增加而急剧上升。例如,当每个周期生成五个草稿 tokens 时,EAGLE2 的错位率达到 48%。即使是部分缓解了特征错位的 HASS,其 token 错位率仍高达 37%。这种性能下降源于错误在多次传播中的累积,其中 token 生成的早期错误在后续步骤中传播并复合。

Token错位对训练效率的破坏。严重的是,高的 token 错位率会破坏训练的有效性。如图 1(b) 所示,当 HASS 中的前向传播次数超过三次时,即使继续训练,接受长度也趋于平稳。这表明草稿模型只能生成较少的可接受 tokens,直接限制了解码效率。因此,解决 token 错位问题不仅重要,而且是解锁更深层次的多步投机解码和实现更大加速的必要条件。

简单的修复方案无效。一个看似简单的修复方法——在训练期间用草稿模型生成的 $\bar{x}_t$ 替换 $x_t$(真实训练 tokens)——在实践中是失败的。这是因为:1) 像 EAGLE 和 HASS 这样的框架在训练前会为所有 $x_t$ 预计算并存储 $F_t$,以避免重新生成训练数据的计算负担;2) 换入 $\bar{x}_t$ 会导致不一致的输入-特征对,这破坏了损失计算所需的对齐,并降低了性能,正如 HASS 的附录 B 所证实的。事实上,简单的替换会显著降低接受长度。鉴于这些挑战,我们提出了一个解决 token 错位问题的有效方案,该方案与现有训练工作流兼容,并能在训练和解码之间实现更好的对齐。

A2 方法细节

为了应对第 3 节中描述的 token 错位挑战,我们提出了 GRIFFIN,这是一个新颖的框架,通过两个关键组件来缓解 token 错位:1) 第 4.1 节介绍的可进行 token 对齐的训练策略,以及 2) 第 4.2 节阐述的可进行 token 对齐的草稿模型。

4.1 可进行 Token 对齐的训练

渐进式对齐训练策略。GRIFFIN 的核心是一种渐进式训练策略,它模仿了草稿模型在解码期间的操作方式。我们不是在每一步都依赖于真实 tokens 和特征——这个假设在推理时会失效——而是在训练期间逐渐将模型转向使用其自身的输出。这种对齐对于缓解 token 错位至关重要。

多步前向传播训练。具体来说,GRIFFIN 将训练组织成多个步骤,其中每个训练步骤 n 涉及草稿模型执行 n 次前向传播来预测 n 个未来的 tokens 及其对应的特征。随着每次额外的传播,模型越来越依赖于其自身在前几步生成的 tokens 和特征,而不是目标模型的真实 tokens 和特征。这有效地将训练过程与解码阶段对齐,因为在训练阶段,草稿模型模拟了在解码期间遇到的类似输入条件。接下来我们详细介绍第一个训练步骤及其后续步骤。

首次前向传播 (n = 1)。与标准的自回归生成类似,草稿模型 M 预测草稿 tokens,然后将其送入目标模型 T 进行验证和接受。具体来说,在时间步 t,草稿模型 M 和目标模型中的 LM 头 H 预测第 t 个特征嵌入 $\bar{F}_t$ 和草稿 token $\bar{x}_t$:

$$\bar{\mathbf{F}}_t = \mathcal{M}(\mathbf{x}_{1:t-1}, \mathbf{F}_{1:t-1}), \quad \bar{\mathbf{x}}_t = \mathcal{H}(\bar{\mathbf{F}}_t),$$

其中,$x_{1:t-1}$ 表示来自训练数据集的 token 序列 {$x_i$}$_{i=1}^{t-1}$,而 $F_{1:t-1}$ 是由目标模型 T 生成的特征嵌入序列 {$F_i$}$_{i=1}^{t-1}$。Token 错位仅在第二次前向传播后才会出现。在第一次传播中,训练和解码共享相同的前缀 $x_{1:t-1}$,因此所有预测的 tokens $\bar{x}_t$ 都是完美对齐的。因此,在第一次传播中不需要掩码。在时间步 t 的第一次传播损失为:

$$\mathcal{L}_{\mathcal{M}}^{(1)}=\sum_{t=1}^{l} \ell\left(\bar{\mathbf{x}}_{t}, \mathbf{x}_{t}, \bar{\mathbf{F}}_{t}, \mathbf{F}_{t}\right),$$

其中,$l$ 结合了特征级损失(即 $\bar{F}_t$ 和 $F_t$ 之间的 $l_1$ 距离)和 token 级损失(即 $\bar{x}_t$ 和 $x_t$ 之间的交叉熵)。$l$ 的详细实现在附录 C 中总结。

第n次前向传播 (n ≥ 2) 及 token 对齐。草稿模型 M 将在第 n 次前向传播时预测 n 个草稿 tokens。从第二次前向传播开始,在解码期间,投机解码可能会拒绝一个草稿 token $\bar{x}_t$,在这种情况下,该草稿序列中所有后续的 tokens $\bar{x}_{t+1}, \bar{x}_{t+2}, \dots$ 也会被丢弃。因此,在训练期间,如果 $\bar{x}_t$ 是不可预测的(被拒绝),那么该草稿序列中的后续草稿 tokens $\bar{x}_{t+1}, \bar{x}_{t+2}, \dots$ 就是错位的 tokens。对这些错位的 tokens 进行训练无法提供有用的信号。

基于Top-k预测的可预测性掩码。为了应对 token 错位挑战,我们引入了一种新颖的可进行 token 对齐的训练策略,该策略将草稿模型的训练与其在解码期间的多步行为对齐。与 EAGLE 等仅在训练中使用 top-1 预测的先前方法不同,我们的方法通过监督 top-k 预测将树状解码过程直接融入学习中。在 EAGLE 的解码中,草稿模型的每次前向传播在每个时间步生成一个候选 tokens 的 top-k 列表,形成一棵树,如果 top-1 token 被拒绝,可以探索替代分支。为了与此匹配,GRIFFIN 认为一个草稿 token $\bar{x}_t$ 是可预测的,如果真实 token $x_t$ 出现在其 top-k 预测 Top-k($\bar{x}_t$) 中。这确保了训练能够反映解码阶段,其中任何 top-k token 都可能是有效的。因此,我们引入一个二元可预测掩码 $\bar{m}_t \in \{0, 1\}$,其中如果 $x_t \in \text{Top-k}(\bar{x}_t)$,则 $\bar{m}_t = 1$,否则 $\bar{m}_t = 0$。由于当前的草稿 token $\bar{x}_t$ 是由之前 (n-1) 次前向传播中预测的草稿 tokens $\bar{x}_{t-n+1:t-1}$ 决定的,那么如果 $\bar{x}_{t-n+1:t-1}$ 中的任何一个草稿 token 是不可预测的,那么草稿 token $\bar{x}_t$ 很可能是错位的。为了防止模型因这些不可避免的被拒绝位置而受到惩罚,我们引入一个累积的二元对齐掩码 $m_t$,该掩码由草稿 tokens $x_{t-n+1:t-1}$ 的可预测掩码 $\bar{m}_{t-n+1:t-1}$ 调整:

$$\mathbf{m}_t = \prod\nolimits_{i=t-n+1}^{t-1} \bar{\mathbf{m}}_i.$$

自洽的前向传播和损失计算。这些掩码指示一个 token 是否应该对训练损失做出贡献,确保了训练和推理之间的一致性。接下来,为了进一步确保训练和解码之间的对齐,我们用先前传播中草稿模型生成的特征 $\bar{F}_{t-n+1:t-1}$ 替换目标模型的特征 $F_{t-n+1:t-1}$。然后,草稿模型 M 和 LM 头 H 用于生成特征 $\bar{F}_t$ 和草稿 token $\bar{x}_t$:

$$\bar{\mathbf{F}}_t = \mathcal{M}(\mathbf{x}_{1:t-1}, \mathbf{F}_{1:t-n}, \bar{\mathbf{F}}_{t-n+1:t-1}), \quad \bar{\mathbf{x}}_t = \mathcal{H}(\bar{\mathbf{F}}_t).$$

然后,我们定义以下训练损失来训练草稿模型 M:

$$\mathcal{L}_{\mathcal{M}}^{(n)} = \frac{1}{\sum_{t=1}^{l} \mathbf{m}_{t}} \sum_{t=1}^{l} \mathbf{m}_{t} \ell(\bar{\mathbf{x}}_{t}, \mathbf{x}_{t}, \bar{\mathbf{F}}_{t}, \mathbf{F}_{t}),$$

与先前工作的对比。GRIFFIN 的训练策略与 EAGLE 和 HASS 等依赖于真实 tokens 的先前方法不同。通过逐步调整草稿模型以在其自身的预测下运行,并将其训练与解码对齐,GRIFFIN 通过引入 top-k 对齐掩码、通过生成的 tokens 进行自条件化以及掩码传播来解决 token 错位问题。

4.2 可进行 Token 对齐的草稿模型

模型架构概述。为了提高草稿 token 的准确性并有效解决 token 错位问题,我们提出了一个可进行 token 对齐的草稿模型,该模型系统地解决了先前草稿模型忽略的特征不一致问题。虽然我们的架构建立在 EAGLE 的草稿模型之上,但它引入了两个关键的额外模块:Token 引导融合(Token-Guided Fusion, TGF)Token 增强头(Token-Enhanced Head, TEH)。如图 3(a) 所示,我们在自回归层之前插入 TGF 模块,以融合输入特征 $F_t$ 和 tokens $x_t$。在自回归之后,我们使用 TEH 模块,这是一个受先前工作【索引31, Lujun Gui et al., Boosting lossless speculative decoding via feature sampling and partial alignment distillation, 2024, arXiv】启发的双头设计,用于输出 1) 一个用于 token 预测的预测特征 $\bar{F}^P_{t+1}$ 和 2) 一个用于后续前向传播的回归特征 $\bar{F}^R_{t+1}$。相应地,TEH 可以分离和解耦草稿模型内 token 预测和特征生成的冲突目标,从而提高草稿 token 的准确性。我们在附录 B 中的消融实验证实了其有效性。

TGF模块的设计动机。TGF 模块旨在解决一个核心挑战:即使经过大量训练,草稿模型中的特征表示也常常无法与目标模型的特征表示相匹配。由于在实践中特征级损失无法最小化到零,这种差距会导致持续的错位,从而导致草稿模型和目标模型之间的特征不一致,损害草稿 token 的准确性。TGF 通过在融合过程中优先考虑 token 嵌入来解决这个问题,引导特征生成更好地与目标模型保持一致。如图 3(b) 所示,TGF 分三步操作:


图 3: GRIFFIN 的草稿模型结构。(a) 可进行 Token 对齐的草稿模型。(b) TGF 模块。该图描绘了训练和解码阶段共享的架构——指示 token 流的箭头对应于两种模式下有效的数据依赖关系。

TGF步骤1:嵌入融合。(1) 在图 3 (b-i) 的嵌入融合中,给定输入特征 F 和 token 嵌入 x(两者维度均为 $R^d$),我们将它们连接起来,并使用一个轻量级的 MLP 将结果投影回 $R^d$:

$$\mathbf{h} = \mathcal{C}(\mathbf{F}, \mathbf{x})\mathbf{W}_{\mathrm{m}} + \mathbf{b}_{\mathrm{m}}$$

这里,$W_m \in R^{2d \times d}$ 和 $b_m \in R^d$ 是 MLP 的权重和偏置,$C(\cdot, \cdot)$ 是连接操作符。这产生了一个融合了 token 和特征信息的统一特征。

TGF步骤2:特征归一化与扩展。(2) 在图 3 (b-ii) 的特征归一化与扩展中,我们对 h 和 x 应用层归一化,然后连接它们,并使用一个上投影器(一个线性层)将维度扩展到 4d:

$$\mathbf{z} = \mathcal{C}(\mathcal{N}(\mathbf{h}), \mathcal{N}(\mathbf{x})) \mathbf{W}_{\mathrm{u}} + \mathbf{b}_{\mathrm{u}},$$

其中 $W_u \in R^{2d \times 4d}$ 和 $b_u \in R^{4d}$。这里,$N(\cdot)$ 表示层归一化。在更高维度的空间中操作使模型能够解开和对齐更复杂的 token-特征关系。这种 4d 扩展与许多 Transformer FFN 中使用的中间大小一致,我们在第 5.2 节、附录 A.2 和 A.3 中的消融实验证实了其有效性。

TGF步骤3:优化与稳定。(3) 在图 3 (b-iii) 的优化与稳定中,我们对 z 应用 SiLU 非线性函数 $\sigma$,并使用一个下投影器(一个线性层)将其投影回 $R^d$。与 h 的残差连接稳定了训练过程:

$$\mathbf{o}=\sigma(\mathbf{z}) \mathbf{W}_{\mathrm{d}}+\mathbf{b}_{\mathrm{d}}+\mathbf{h},$$

其中 $W_d \in R^{4d \times d}$ 和 $b_d \in R^d$。非线性函数丰富了表达能力,而残差加法保留了必要的融合信息。

TGF模块的整体作用。通过将 token 嵌入明确地整合到特征融合中,TGF 确保生成的特征能更好地反映目标模型的 token 分布。结合 TEH,它使草稿模型能够生成更准确的草稿 tokens 和特征,这对于缓解多步解码中的错位至关重要。这种 token 感知和特征感知的设计是 GRIFFIN 与先前工作的关键创新之处。

A4 实验

实验环境

  • 模型: 测试了多种代表性 LLMs,包括 LLaMA2-Chat 7B/13B、LLaMA3-Instruct 8B/70B【索引3, Hugo Touvron et al., Llama 2: Open foundation and fine-tuned chat models, 2023, arXiv】、Vicuna-1.5 7B【索引32, Zhenyi Lu et al., Chinese-vicuna: A chinese instruction-following llama-based model, 2023】、Qwen2-Instruct 7B【索引33, Jinze Bai et al., Qwen technical report, 2023, arXiv】 和 Mixtral-8x7B-Instruct-v0.1【索引34, Albert Q Jiang et al., Mixtral of experts, 2024, arXiv】。
  • 硬件配置: 所有推理实验均在单张 NVIDIA A100 80G GPU 上运行,除了 LLaMA3-70B 和 Mixtral-8x7B 需要两张 GPU。
  • 软件配置与基线: 使用标准的自回归解码作为基线(加速比为 1.00x)。GRIFFIN 与最新的 SOTA 投机解码方法进行了比较,包括 SPS【索引6, Yaniv Leviathan et al., Fast inference from transformers via speculative decoding, 2023, ICML】、PLD【索引20, Apoorv Saxena, Prompt lookup decoding, 2023】、Lookahead【索引21, Yichao Fu et al., Break the sequential dependency of llm inference using lookahead decoding, 2024, arXiv】、Medusa【索引10, Tianle Cai et al., Medusa: Simple llm inference acceleration framework with multiple decoding heads, 2024, arXiv】、EAGLE【索引8, Yuhui Li et al., EAGLE: Speculative sampling requires rethinking feature uncertainty, 2024, arXiv】、EAGLE-2【索引9, Yuhui Li et al., EAGLE-2: Faster inference of language models with dynamic draft trees, 2024, arXiv】、FSPAD【索引31, Lujun Gui et al., Boosting lossless speculative decoding via feature sampling and partial alignment distillation, 2024, arXiv】 和 HASS【索引13, Lefan Zhang et al., Learning harmonized representations for speculative sampling, 2024, arXiv】。
  • 数据集与训练参数: 草稿模型在 ShareGPT 数据集上进行训练。Token 对齐设置为 top-k (k=3),训练步数 N=3。其他超参数(如优化器)与 EAGLE-2 保持一致以进行公平比较。
  • 评估任务: 在三个关键任务上评估性能:多轮对话 (MT-Bench【索引4, Lianmin Zheng et al., Judging llm-as-a-judge with mt-bench and chatbot arena, 2023, NeurIPS】)、代码生成 (HumanEval【索引5, Mark Chen et al., Evaluating large language models trained on code, 2021, arXiv】) 和数学推理 (GSM8K【索引14, Karl Cobbe et al., Training verifiers to solve math word problems, 2021, arXiv】)。
  • 评估指标: 遵循先前工作,主要关注加速指标:1) 加速比 (Speedup Ratio, SR),衡量相对于标准自回归解码的实际测试加速比;2) 接受长度 (Acceptance Length, τ),每个草稿-验证周期接受的平均 token 数。由于 GRIFFIN 是无损的,因此无需进行额外的生成质量评估。

实验结果

5.1 与SOTA方法的比较

表 1 展示了不同方法在三个数据集上的接受长度和加速比。GRIFFIN 在所有测试的数据集和 LLM 上始终取得了最高的接受长度和加速比。每个 GRIFFIN 的草稿-验证周期大约能生成 5-6 个 tokens,显著超过了其他方法。

  • 性能优势:

    • 在多轮对话任务 (MT-Bench) 中,使用 LLaMA3 8B (温度 T=0) 时,GRIFFIN 的加速比比 HASS 高 8.7%。
    • 在代码生成任务 (HumanEval) 中,使用 Vicuna 7B (温度 T=0) 时,GRIFFIN 的加速比比 HASS 提高了 7.3%。
    • 在数学推理任务 (GSM8K) 中,使用 LLaMA2 13B (温度 T=0) 时,GRIFFIN 的加速比比 HASS 提高了 6.8%。
    • 与 EAGLE-3 相比,在 LLaMA3 8B 上,GRIFFIN 在所有基准和温度设置下均表现更优。
  • 模型架构的泛化能力:

    • 在 Qwen2 7B 模型上,GRIFFIN 的加速比比 HASS 提高了 5.8%。其加速比略低于 LLaMA2 7B,这归因于 Qwen2 7B 更大的词汇表,导致 LM Head 更大,从而减慢了草稿模型的解码速度。
    • 对于 Mixtral-8x7B (MoE 架构),所有投机解码方法的加速效果都不如其他 LLM 明显,因为同时验证多个 tokens 会带来额外的计算开销。尽管如此,GRIFFIN 在 Mixtral-8x7B 上的加速比仍比 HASS 高出 6.6% 以上,展示了其强大的泛化能力。
  • 鲁棒性: GRIFFIN 在不同温度(T=0 和 T=1)下均能保持稳定的性能提升,表明其在处理不同程度的 token 预测不确定性方面的鲁棒性。

这些结果表明 GRIFFIN 是一个在真实世界场景中加速 LLM 推理的有力候选者。


表 1: 不同投机解码方法的比较。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的评估结果,包括加速比 SR 和接受长度 τ。数值越高表示性能越好。

5.2 消融研究

  • GRIFFIN 组件的有效性 (表 2):
    • 我们评估了 GRIFFIN 的两个关键组件——可进行 Token 对齐的训练 (TAT)可进行 Token 对齐的草稿模型 (TAD)——的影响。
    • 移除 TAT 或 TAD 都会导致接受长度和加速比显著下降。移除 TAT 导致平均接受长度在 T=0 时减少 0.26,加速比减少 0.18。移除 TAD 导致平均接受长度在 T=0 时减少 0.19,加速比减少 0.12。
    • 同时移除两个组件导致性能下降最严重,这突显了它们的互补效应。TAT 和 TAD 共同确保了草稿 tokens 在训练和解码期间都能对齐。


表 2: 关于可进行 Token 对齐的训练 (TAT) 和可进行 Token 对齐的草稿模型 (TAD) 的消融研究。该表展示了在 LLM 基准上,温度 T ∈ {0, 1} 时的加速比 SR 和接受长度 τ 的评估结果。数值越高表示性能越好。

  • 可进行 Token 对齐训练中的超参数 (表 3 和表 4):
    • Top-k 参数: 对齐 top-k tokens(从 1到 10)始终比不对齐(NA)效果好。当 k=3 时,接受长度和加速比达到峰值,表明对齐少量但足够的 tokens 可以在对齐和泛化之间取得最佳平衡。
    • 训练步数 N: 增加训练步数 N 能稳定提升 GRIFFIN 的接受长度,直到第 5 步。这与 HASS 在第 3 步后性能停滞(图 1 b))形成对比,证明了 GRIFFIN 的 token 对齐机制的有效性。为了与 HASS 公平比较,实验中选择 N=3。


表 3: GRIFFIN 不同 top-k 参数的比较。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的加速比 SR 和接受长度 τ 的评估结果。数值越高表示性能越好。NA 表示不对齐 token。


表 4: GRIFFIN 不同训练步数的比较。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的加速比 SR 和接受长度 τ 的评估结果。数值越高表示性能越好。

  • Token 引导融合 (TGF) 的有效性 (表 5):
    • 为了验证 TGF 的改进是源于其 token 感知设计而非仅仅增加参数量,我们进行了消融实验,改变了公式 (7) 中的二次融合输入。
    • 将 token 嵌入 x 替换为原始特征 F,接受长度减少了 0.83,加速比减少了 0.48。将其替换为初始融合特征 h,性能虽有改善,但仍落后于原始设计。
    • 这些结果证实,TGF 的有效性并非来自参数扩展,而是来自对 token 嵌入的明确使用,这对于纠正不一致的特征和使草稿模型与目标分布对齐至关重要。


表 5: TGF 的消融研究。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的加速比 SR 和接受长度 τ 的评估结果。数值越高表示性能越好。"Feature" 和 "Fused" 分别表示在公式 (7) 中使用 F 和 h 替换 x。

A5 结论

本文提出了 GRIFFIN,一个可进行 token 对齐的投机解码框架。先前的方法很大程度上忽略了训练和解码之间的 token 错位问题。GRIFFIN 通过引入一种可进行 token 对齐的训练策略,将错位的 tokens 从损失计算中排除,从而解决了这一问题。它还进一步集成了一个可进行 token 对齐的草稿模型,该模型显著减少了错位。在各种 LLMs 和数据集上的广泛评估表明,GRIFFIN 始终优于 SOTA 方法,实现了最高的加速比和接受长度。

局限性: GRIFFIN 采用多步训练过程进行 token 对齐训练,与 EAGLE 相比,这会产生额外的训练开销。然而,由于草稿模型只需训练一次,实际应用中更看重解码效率而非训练开销,因为推理是主要瓶颈。GRIFFIN 的加速比相较于 EAGLE2 提升了超过 18%,这使得额外的训练成本对于其带来的显著推理加速是值得的。此外,GRIFFIN 的总体训练开销与 HASS 相当。在相同的训练成本下,GRIFFIN 的加速比比 HASS 提升了超过 7%,进一步凸显了其有效性。

更广泛的影响: GRIFFIN 在不牺牲输出质量的情况下,通过加速解码速度提升了 LLM 推理的效率。这种改进可以通过使实时应用变得更加可行,从而普及对强大 LLM 的访问。在下游应用中,GRIFFIN 可以为教育、医疗助手、无障碍工具和科学研究等领域提供更流畅、更快速的交互式 AI,从而拓宽有益的应用范围,并为全球用户减少延迟障碍。

A6 附录

A. Token引导融合(TGF)架构分析

A.1 TGF背后的动机

标准策略的局限性。Token引导融合(TGF)模块的动机源于标准“连接后MLP”策略(如EAGLE所采用)的局限性,该策略未能充分捕捉token嵌入和草稿模型特征之间的互补信息。实践中,草稿模型生成的特征常常与目标模型的表示保持错位,这种差异仅靠特征级损失最小化无法有效消除。

TGF的解决方案。TGF通过明确利用token嵌入来引导融合过程,从而解决了这一挑战,使特征分布更紧密地对齐到目标模型的分布。正如消融实验结果(表5)所证实的,这种有针对性的架构增强显著减少了特征不一致性,表明TGF引入的适度复杂性带来了强大的经验增益。

TGF中的关键架构增强

  • 特征归一化和维度扩展:对初始融合特征h和token嵌入x分别应用层归一化(图3 (b-ii)),允许独立的统计缩放,提高了训练稳定性。上投影器(Up Projector)扩展了特征维度,有助于解开相关信息,增加融合表示的表达能力。
  • 非线性优化和整合:SiLU激活函数(图3 (b-iii))引入了非线性,增强了模块捕捉线性操作之外复杂特征交互的能力。之后,下投影器(Down Projector)将表示压缩回目标维度,将丰富的信息聚合成一个稳定的融合特征,供下游处理。

总体效果。总体而言,TGF使草稿模型能够更准确地逼近目标模型的输出空间,这在接受长度和加速比的显著提升中得到了证实(如表1所示)。

A.2 TGF架构的消融研究

实验设计。为了系统地评估TGF模块中每个组件的贡献,我们进行了有针对性的消融实验,每个变体通过选择性地移除或修改特定子模块来构建:
* 消融1:同时移除上投影器(图3 (b-ii))和SiLU激活函数(图3 (b-iii))。
* 消融2:在公式(7)的二次融合步骤中排除token嵌入x。
* 消融3:在公式(7)的二次融合步骤中排除初始融合特征h。

结果分析 (表6)
* 消融1:省略上投影器和SiLU激活函数导致性能显著下降,接受长度减少0.23(T=0),加速比减少0.17(T=0)。这凸显了这些组件在实现富有表现力和稳定的特征融合中的关键作用。
* 消融2:从二次融合中移除token嵌入x,对模型注入特定token信息的能力产生了负面影响,导致接受长度和加速比降低。
* 消融3:排除初始融合特征h导致最严重的性能下降,接受长度减少2.63(T=0),加速比下降1.50(T=0)。这强调了融合特征的循环整合对于捕捉高质量表示和实现有效对齐是不可或缺的。

结论。这些消融结果证实了TGF中每个架构组件对于最大化接受长度和加速比的必要性。


表 6: TGF 架构的消融结果。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的评估结果,包括加速比 SR 和接受长度 τ。数值越高表示性能越好。

A.3 TGF扩展维度的消融研究

背景。在TGF模块中,扩展维度指的是上投影器(图3 (b-ii))的输出维度。对于GRIFFIN,我们将其设置为11,008,与目标模型的前馈网络(FFN)的中间大小相匹配。

实验设计与结果 (表7)
* 较小的扩展(4,096):将扩展维度降低到4,096会削弱TGF提取和区分显著特征的能力,导致接受长度和加速比显著降低。
* 较大的扩展(22,016):将扩展维度增加到22,016,接受长度略有改善,表明表示能力略有增强。然而,这被加速比的下降所抵消,主要是由于计算开销增加和额外的0.14B参数。

结论。这些结果验证了我们选择的配置:将TGF扩展维度设置为与目标模型的FFN中间大小相等,在融合性能和计算效率之间取得了有效的平衡。


表 7: GRIFFIN 不同扩展维度的比较。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的评估结果,包括加速比 SR 和接受长度 τ。数值越高表示性能越好。

B. 可进行Token对齐的草稿模型(TAD)组件的有效性

组件评估。我们使用LLaMA2-Chat 7B作为基础模型,评估了可进行Token对齐的草稿模型(TAD)的两个主要组件——Token引导融合(TGF)和Token增强头(TEH)的各自贡献。消融结果总结在表8中。

结果分析。移除任一组件都会导致接受长度和加速比的明显且持续的下降:
* Token增强头(TEH):排除TEH导致所有基准测试的性能持续下降,平均接受长度减少0.06,加速比减少0.04。这突显了TEH在提升草稿模型token预测准确性方面的关键作用。
* Token引导融合(TGF):排除TGF导致更严重的性能下降:接受长度下降0.15,加速比在T=0时下降0.09,在T=1时下降0.11。这些发现加强了TGF在缓解投机解码期间特征错位方面的功效。

结论。总的来说,这些结果强调TGF和TEH对于最大化TAD架构的有效性和效率都是不可或缺的。


表 8: TAD 的消融结果。该表展示了在标准 LLM 基准上,温度 T ∈ {0, 1} 时的评估结果,包括加速比 SR 和接受长度 τ。数值越高表示性能越好。

C. GRIFFIN的实现细节

C.1 损失函数

每步损失构成。公式(5)中的每步损失 $l(\bar{x}_t, x_t, \bar{F}_t, F_t)$ 结合了两个互补的组件,用于监督草稿模型 M 与目标模型 T 的 token 预测和潜在特征对齐:

$$\ell(\bar{\mathbf{x}}_t, \mathbf{x}_t, \bar{\mathbf{F}}_t, \mathbf{F}_t) = \lambda_{\text{tok}} \ell_{\text{tok}}(\bar{\mathbf{x}}_t, \mathbf{x}_t) + \lambda_{\text{feat}} \ell_{\text{feat}}(\bar{\mathbf{F}}_t, \mathbf{F}_t),$$

其中 $\lambda_{\text{tok}}$ 和 $\lambda_{\text{feat}}$ 是标量权重(默认:$\lambda_{\text{tok}} = 1, \lambda_{\text{feat}} = 0.1$,除非另有说明)。

Token级损失。Token级监督使草稿模型的预测 token 分布与真实值对齐:

$$\ell_{\text{tok}}(\bar{\mathbf{x}}_t, \mathbf{x}_t) = -\log P_{\mathcal{M}}(\mathbf{x}_t \mid \mathbf{x}_{1:t-1}),$$

这对应于预测 logits $\bar{x}_t$ 和 one-hot 目标 $x_t$ 之间的标准交叉熵。

特征级损失。为了鼓励草稿模型和目标模型的内部表示保持一致,我们最小化它们隐藏特征之间的 $l_1$ 距离:

$$\ell_{\text {feat }}\left(\overline{\mathbf{F}}_{t}, \mathbf{F}_{t}\right)=\left\|\overline{\mathbf{F}}_{t}-\mathbf{F}_{t}\right\|_{1} .$$

该项将草稿模型正则化到目标模型的潜在空间,促进在多步解码中实现稳定对齐。

C.2 草稿树结构

配置。对于所有实验,我们使用一个动态树结构,总共包含60个草稿 tokens,并将草稿树深度设置为6,这与EAGLE-2和HASS中确立的最优配置紧密一致。

C.3 训练配置

优化器和超参数。草稿模型使用AdamW优化器进行训练,关键设置如下:
* 学习率:3e−5
* 批大小:4(每GPU)
* 周期数:20
* 总训练步数:800,000
* 预热:2,000步线性预热;启用学习率调度器
* 优化器:AdamW,betas为(0.9, 0.95)
* 梯度裁剪:0.5(按值)
* 最大序列长度:2,048 tokens

一致性。除非另有说明,所有报告的实验中所有超参数均保持固定。

D. 基线方法的说明

模型和代码来源。对于EAGLE、EAGLE-2和Medusa,我们直接使用了相应作者公开发布的草稿模型参数。对于不需要训练草稿模型的方法,如PLD、Lookahead和SPS,我们使用其GitHub仓库中的官方代码进行性能评估。

HASS和FSPAD的复现。关于HASS,我们使用了LLaMA2-7B、LLaMA3-8B、LLaMA2-13B和LLaMA3-70B的公开发布草稿模型参数。然而,在提交时,Vicuna-7B、Qwen2-7B和Mixtral-8x7B的官方草稿模型参数不可用。为实现公平比较,我们使用其官方GitHub仓库自己训练了HASS草稿模型,并严格遵循HASS论文中描述的配置。我们获得的实验结果与HASS论文报告的结果非常吻合。同样,FSPAD的草稿模型参数在提交时也未公开发布。因此,我们使用其官方代码并按照FSPAD论文中指定的设置训练了FSPAD的草稿模型。我们的实验结果与原始FSPAD论文发布的结果高度一致。

EAGLE-3的复现。关于EAGLE-3,由于EAGLE-3没有为LLaMA3-8B提供预训练的草稿模型,我们使用其官方代码在ShareGPT数据集上进行训练,所有其他超参数与他们的论文保持一致。仅在ShareGPT上训练EAGLE-3就需要超过300个A100-80G GPU小时。遵循其论文的完整协议(UltraChat-200K + ShareGPT)将需要大约2400个GPU小时,这超出了我们的计算资源。然而,EAGLE-3和GRIFFIN都使用了相同的ShareGPT训练数据,确保了公平比较。

训练设置对齐。为确保我们比较的有效性,我们将所有关键的训练设置,包括数据集、优化器和超参数,与EAGLE-2和HASS使用的设置对齐。例如,我们将训练过程与HASS的三步计划相匹配,确保所有实验的一致性和可靠性。

E. GRIFFIN草稿模型的参数大小

不同模型的参数量。对于7B、8B、13B和70B规模的目标模型,相应的GRIFFIN草稿模型大小分别为0.41B、0.42B、0.65B和2.07B参数。对于Mixtral-8x7B,草稿模型大小为0.45B参数。

与基线的比较。相比之下,EAGLE-2和HASS的草稿模型大小分别为0.24B、0.25B、0.37B和0.99B,而FSPAD的草稿模型大小分别为0.42B、0.43B、0.67B和2.09B。对于Mixtral-8x7B,EAGLE-2和HASS的草稿模型大小为0.28B。因此,GRIFFIN的草稿模型比EAGLE-2和HASS多出0.17B到1.08B的参数,但与FSPAD的大小相似。

性能与开销的权衡。尽管参数略有增加,但如表1所示,GRIFFIN的平均加速比提升超过8%。额外的参数数量仅带来边际的计算开销,这被推理效率和整体性能的显著提升所充分证明是合理的。

F. GRIFFIN的训练开销

训练数据一致性。所有方法(GRIFFIN、EAGLE-2、FSPAD和HASS)都使用ShareGPT数据集进行草稿模型训练,确保了所有方法训练的token数量相等。

计算资源开销。在计算资源方面,GRIFFIN采用了与最先进的HASS方法相同的多阶段训练策略,两者都采用了三步训练方案。对于7B、13B和70B参数模型,HASS通常需要大约130、220和500个NVIDIA A100 80G GPU小时,而GRIFFIN的需求约为150、250和600个NVIDIA A100 80G GPU小时。

训练与推理的权衡。关键在于,草稿模型只需训练一次,但在推理过程中会被广泛利用。因此,在实际场景中,主要的计算成本在于解码阶段。GRIFFIN相比HASS提供了大约8%的投机解码速度提升,这意味着训练开销的轻微增加被推理效率的大幅提升所充分证明是合理的。

G. GRIFFIN的吞吐量

vLLM集成与评估。为了评估GRIFFIN在批处理大小大于1时的性能,我们将其集成到开源的vLLM框架中,遵循与EAGLE相同的投机解码接口。所有实验均在LLaMA3-8B-Instruct模型上使用MT-Bench数据集进行,解码温度为0。我们报告了相对于未使用任何投机方法的基线vLLM解码的吞吐量(每秒token数)。

结果(表9)。GRIFFIN在所有评估的批处理配置中始终实现了比EAGLE和HASS更高的吞吐量。


表 9: 不同批处理大小下的吞吐量比较。数字表示相对于标准 vLLM 解码的相对加速比(×) (1.00× 基线)。所有投机方法均使用最大链长为 2 的顺序投机进行评估。

实现约束。这些评估是在vLLM当前投机解码支持的某些限制下进行的。具体来说,该实现不支持树状草稿,这是我们完整解码算法的一个关键组成部分。因此,所有测量都使用了最大链长为2的顺序投机。因此,表9中的吞吐量值不能与正文结果直接比较,后者是使用我们为树状结构投机配置的本地解码后端获得的。

分析。观察到的相对加速比随批处理大小增加而下降的趋势是预期的,并且与理论预期一致。更大的批处理大小提高了目标模型的GPU利用率,减少了冗余计算,并缩小了投机解码与标准解码之间的效率差距。此外,随着批处理的增长,与额外草稿模型评估相关的内存占用和计算开销变得越来越显著,从而减少了净吞吐量增益。

讨论。尽管存在这些挑战,GRIFFIN仍保持了显著的优势——在测试的批处理大小范围内,其吞吐量比EAGLE高6-11%,比HASS高4-8%。这些改进表明,GRIFFIN的对齐机制即使在生产部署中常见的大批量、高吞吐量推理场景中,也能持续带来好处。

H. 解码延迟的分解

动机。虽然投机解码带来了显著的效率提升,但整体加速受限于草稿模型所需的额外计算。在正文的表1中,观察到的加速比(SR)明显低于相应的接受长度。这种差异主要是由草稿模型推理的不可忽略的延迟开销引起的。

延迟公式。设N为解码期间生成的总token数。对于标准自回归解码,总解码延迟为

$$\mathbf{T}_a = N \cdot \mathbf{t},$$

其中t是目标模型单次前向传播的平均每token延迟。

对于投机解码,在每个周期,目标模型验证由草稿模型生成的τ个候选token,其展开深度为d。相应的总延迟可以近似为

$$\mathbf{T}_{s}=\frac{N}{\tau} \cdot(\mathbf{t}+\mathbf{d} \cdot \bar{\mathbf{t}}),$$

其中$\bar{t}$是草稿模型的平均每步延迟。因此,理论上的加速比如下:

$$\text{Speedup Ratio (SR)} = \frac{\mathbf{T}_a}{\mathbf{T}_s} = \frac{\mathbf{t}}{\mathbf{t} + \mathbf{d} \cdot \bar{\mathbf{t}}} \cdot \tau.$$

经验估计。以LLaMA3-8B-Instruct模型在A100-80G GPU上作为代表性设置,我们测得目标模型的前向传播延迟约为t = 25 ms,草稿模型的延迟为$\bar{t}$ = 1.5 ms。如果我们设置接受长度τ = 5,草稿展开深度d = 6,公式(14)给出:

$$\text{SR} = \frac{25}{25 + 6 \times 1.5} \times 5 = 3.68 \times,$$

这与我们的经验结果非常吻合。这种定量上的一致性证实了草稿模型推理贡献的延迟是限制可实现加速的主要因素。

讨论。尽管草稿模型延迟在总预算中占比较小,但其在多个展开步骤中的累积会显著降低整体效率,特别是对于深度或大d的投机配置。未来的工作将探索进一步降低此成本的技术,例如:
* 草稿模型蒸馏以降低前向传播复杂性;
* 异步草稿以重叠草稿和目标评估;
* 核融合和缓存以最小化内存传输开销。

这些分析证实,接受长度和加速比之间的差距可以由草稿推理延迟来定量解释,并激励了进一步的系统级优化。

I. 与EAGLE-3的讨论

动机。EAGLE-3最近提出了一个简化的投机解码框架,从草稿模型目标中移除了特征预测损失。为了检验这一选择的实际效果及其与我们token对齐机制的相互作用,我们进行了两项互补的研究:(i) 对GRIFFIN进行消融,移除特征级损失项;(ii) 在匹配的训练条件下,对我们实现的EAGLE-3和完整的GRIFFIN进行直接比较。

实验设置。所有实验均在LLaMA3-8B-Instruct上进行,使用三个标准评估套件——MT-Bench、HumanEval和GSM8K——解码温度为t=0和t=1。对于EAGLE-3基线,我们遵循其官方开源仓库和超参数设置训练了一个草稿模型,包括相同的优化器、学习率调度和架构。由于计算限制,训练仅使用ShareGPT数据集(不包括额外的UltraChat-200K语料库),否则完全复现大约需要2400个GPU小时。

结果 (表10)
* 观察1:特征级损失的重要性。从GRIFFIN中移除特征预测项,导致在所有评估设置下,SR和τ都有一致的约10-15%的性能下降。这证实了特征级监督对于通过强制草稿模型和目标模型的隐藏表示之间的一致性来稳定token对齐仍然至关重要。
* 观察2:与EAGLE-3的比较。在相同的训练数据和推理设置下,完整的GRIFFIN在所有指标上都优于EAGLE-3。这表明EAGLE-3移除特征损失在此数据规模下并未带来优势,并且即使在控制模型和数据大小的情况下,GRIFFIN的对齐机制也能带来更高效的投机解码。
* 观察3:GRIFFIN技术的适用性。EAGLE-3没有明确解决训练或解码期间的token错位问题。相比之下,GRIFFIN引入了Token-Alignable Draft(TAD)架构和Token-Alignable Training(TAT)程序,两者都旨在缓解此问题。重要的是,这些技术是模块化的,原则上可以应用于EAGLE-3风格的草稿模型,而无需改变其外部解码接口,从而可能提高对齐和吞吐量性能。

总结。本研究表明,即使草稿模型使用大规模数据进行训练,特征级监督仍然是有益的,并且GRIFFIN的token对齐策略提供了超出EAGLE-3所实现的互补改进。我们希望这些发现能阐明特征预测损失的设计影响,并鼓励未来将token对齐原则整合到其他投机解码框架中。


表 10: GRIFFIN、其特征损失消融版本以及 EAGLE-3 在 LLaMA3-8B-Instruct 上的比较。该表报告了在标准 LLM 基准(MT-Bench、HumanEval、GSM8K)上,温度 T ∈ {0, 1} 时的结果,包括加速比 SR 和接受长度 τ。数值越高表示性能越好。