Mixture of Attentions for Speculative Decoding
Mixture of Attentions for Speculative Decoding
作者/机构: Matthieu Zimmer, Milan Gritta, Gerasimos Lampouras (Huawei Noah’s Ark Lab); Haitham Bou Ammar‡ (Huawei Noah’s Ark Lab, UCL Centre for Artificial Intelligence); Jun Wang‡ (UCL Centre for Artificial Intelligence)
A1 主要贡献
本文旨在解决大型语言模型(LLM)因参数量巨大而导致的计算需求激增和部署成本高昂的问题。研究聚焦于一种名为“推测解码”(Speculative Decoding, SD)的加速技术,该技术利用一个更小、更高效的模型(草稿模型)来起草未来词元(tokens),然后由大型语言模型(LLM)并行验证。
核心问题与现有方法局限性:
当前最先进的推测解码方法,如 EAGLE 和 MEDUSA,虽然利用了 LLM 的激活值来提升解码速度,但存在几个架构上的局 ઉ限性:
1. 部分可观测性(Partial Observability):草稿模型无法获取 LLM 的完整状态信息,导致其预测并非最优。
2. 缺乏在策略性(On-policyness):草稿模型通常在理想条件下(假设输入完美)进行训练,这与实际推理中部分输入由其自身生成的真实场景不符。随着草稿模型生成的词元增多,其与训练分布的偏差会越来越大。
这些局限性会降低推测解码的性能和可靠性。
研究目标与创新点:
为应对上述挑战,本文提出了一种新颖的推测解码架构,旨在增强草稿模型预测未来词元的准确性,并使其训练过程更贴近实际推理过程。该架构引入了多项关键改进:
1. 注意力混合架构(Mixture of Attentions):这是本文的核心贡献,通过结合不同的注意力机制来解决现有方法的局限性。
* 层自注意力(Layer Self-Attention, LSA):旨在缓解部分可观测性问题,让草稿模型能获取 LLM 所有层的综合信息。
* 交叉注意力(Cross-Attention, CA):用于提升训练的在策略性和效率,使模型能更好地处理由自身生成的输入。
* 目标层推理(Target Layer Inference, TLI):提供一种灵活的机制,允许在计算效率和预测准确性之间进行权衡。
- 重用 LLM 激活值:新架构在草稿模型中重用更多来自 LLM 的激活值,从而在草稿生成速度和响应质量之间实现权衡。
- 两种应用场景:
- 单设备部署:在传统的单个设备上部署 LLM 和草稿模型。
- 客户端-服务器部署:一种创新的部署模式,将草稿模型部署在消费级设备(客户端),而 LLM 部署在服务器上。此模式旨在最小化服务器调用,并在网络断开时仍能保持高准确性。
主要实验成果:
本文通过广泛的实验证明了所提方法的有效性:
- 在单设备场景下,与 EAGLE-2 相比,解码速度提升了 9.5%,接受长度(acceptance length)增加了 25%,达到了业界领先水平。
- 在客户端-服务器场景下,与 EAGLE-2 相比,速度提升了 84%,接受长度增加了 53%。该方法在不同网络条件下(4G、5G)实现了业界领先的延迟和最少的服务器调用。
- 在完全网络断连的情况下,本文方法相比其他 SD 方法能保持更高的准确性,并且优于依赖 API 调用的方式(后者将无法继续生成)。
最终,本文不仅提出了一种新的推测解码架构,还为在客户端-服务器模式下提供 LLM 服务提出了一个有效的新框架。
A3 背景知识
2.1 LLM 解码
LLM解码的动态系统视角。解码是 LLM 响应输入查询生成词元的过程,通常以自回归方式进行,即每个新词元 $y_t$ 的采样都依赖于查询和之前已生成的词元 $y_{<t}$。本文从动态系统的角度来审视解码过程,为结合大小模型的新解码机制奠定基础【索引18,Aligning large language models with representation editing: A control perspective,2024】。LLM 的内部工作可以被看作一个随着词元生成而演化的动态系统。对于一个大模型 $M_{Large}$,其标准解码的状态转移模型可以描述为:</p>
$$\boldsymbol{h}_{\leq t+1}, \boldsymbol{o}_{t+1}=f_{\text {Large }}\left(\boldsymbol{h}_{\leq t}, \text { token\_embed }\left(y_{t}\right)\right), \quad y_{t+1} \sim \operatorname{Softmax}\left(L M\_h e a d\left(\boldsymbol{o}_{t+1}\right)\right)$$在此公式中,$h_{\le t}$ 代表截至当前时间步 $t$ 的每一层的键(key)和值(value)张量,$y_t$ 是最新生成的词元,$y_{t+1}$ 是从一个 softmax 分布中采样得到的下一个词元。token embed 是一个查找表,为词汇表 $V$ 中的特定词元分配一个嵌入向量。LM head 是一个从嵌入大小到词汇表大小 $|V|$ 的投影。$f_{Large}(\cdot)$ 是聚合 $M_{Large}$ 所有解码器层的函数,$o_{t+1}$ 是最后一个解码器层的激活值。因此,该动态系统的状态由 $(h_{\le t}, y_t)$ 构成,这是从 $M_{Large}$ 采样下一个词元所需的最小信息。
2.2 推测解码
推测解码的基本原理。最早的推测解码工作由【索引33,Blockwise parallel decoding for deep autoregressive models,2018】提出,后被【索引23,Fast inference from transformers via speculative decoding,2023】扩展到非贪婪采样。这些方法的动机在于,$M_{Large}$ 的自回归生成成本高昂,可以通过使用一个更高效的草稿模型 $M_{Small}$ 来缓解,该模型能高效生成那些不需要 $M_{Large}$ 全部能力的词元。这些被称为“草稿”的假设随后可以由 $M_{Large}$ 使用拒绝采样进行并行验证,即在第一个不匹配的词元之后丢弃所有后续词元【索引23,Fast inference from transformers via speculative decoding,2023】。本文遵循标准的“草稿与验证”循环,该循环迭代以下两个步骤,直到生成“序列结束”词元或达到最大序列长度:
1. $M_{Small}$ 生成新的词元 $y_{t+1}, \dots, y_{t+K}$,其中 $K$ 是草稿序列的长度。生成方式可以是自回归的【索引36,Speculative decoding: Exploiting speculative execution for accelerating seq2seq generation,2023;索引25,EAGLE-2: Faster inference of language models with dynamic draft trees,2024】或并行的【索引1,Hydra: Sequentially-dependent draft heads for medusa decoding,2024】。词元数量 $K$ 通常是固定的,但近期也提出了动态 $K$ 的方法【索引30,Tandem transformers for inference efficient llms,2024;索引27,Accelerating speculative decoding using dynamic speculation length,2024;索引14,Specdec++: Boosting speculative decoding via adaptive candidate lengths,2024】。
2. $M_{Large}$ 在一次前向传播中验证 $K$ 个草稿词元。验证方法包括贪婪(精确)匹配【索引36,Speculative decoding: Exploiting speculative execution for accelerating seq2seq generation,2023】、推测采样【索引43,Distillspec: Improving speculative decoding via knowledge distillation,2023;索引23,Fast inference from transformers via speculative decoding,2023】或“近似”验证【索引33,Blockwise parallel decoding for deep autoregressive models,2018】,后者放宽了接受标准但不能保证输出与 $M_{Large}$ 的分布一致。在所有情况下,一旦出现第一个被拒绝的词元,其后的所有词元通常都会被丢弃。
草稿生成策略:链式与树状。当草稿 $y_{t+1}, \dots, y_{t+K}$ 是逐个自回归生成时,这种策略被称为链式起草(chain drafting)。一些工作【索引3,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024;索引25,EAGLE-2: Faster inference of language models with dynamic draft trees,2024】将此方法扩展为树状起草(tree drafting)以进行额外优化。在这种情况下,对于未来的每个位置 $i$,可以提出多个候选词元 $y_{t+i}$。验证过程使用树状注意力(Tree Attention)【索引28,Specinfer: Accelerating large language model serving with tree-based speculative inference and verification,2024】来高效处理树状起草提出的多个分支路径。因此,与链式起草相比,这能增加接受长度并减少对 $M_{Large}$ 的调用次数。对于小批量大小的场景,LLM 的生成是内存受限的,推测解码尤其在配合树状注意力时能更好地利用闲置的计算资源。
本文采用的树状起草方法。本文采用了 EAGLE-2 【索引25,EAGLE-2: Faster inference of language models with dynamic draft trees,2024】中的树状起草方法来构建可变结构的树。从根节点开始,我们从模型 $M_{Small}$ 中扩展出 $B$ 个最可能的词元 $(y_{t+1}|\cdot)$。然后,对于一个固定的深度 $D$,我们递归地执行以下步骤:对于每个已有的分支,我们计算其 $B$ 个子词元的联合概率 $Q_{t \in D} M_{Small}(y_t|y_{t-1}, \dots)$,这会导致 $B^2$ 次扩展,因为我们有 $B$ 个分支,每个分支有 $B$ 个子节点。从这 $B^2$ 次扩展中,我们根据联合概率选择前 $B$ 个分支用于下一层树的扩展。当达到最大深度时,我们从总共 $B + (D-1)B^2$ 个节点中保留最多 $m$ 个具有最高联合概率的词元进行验证。
2.3 $M_{Small}$ 的架构
$M_{Small}$ 的架构分类。推测解码的架构大致分为两类:独立型(independent)和自起草型(self-drafting)。独立型草稿模型通常是与 $M_{Large}$ 同一模型家族的较小版本【索引24,Nearest neighbor speculative decoding for llm generation and attribution,2024;索引41,Ouroboros: Speculative decoding with large model enhanced drafting,2024;索引13,Rest: Retrieval-based speculative decoding,2023】,而自起草方法则利用 $M_{Large}$ 的一部分和/或新初始化的参数【索引1,Hydra: Sequentially-dependent draft heads for medusa decoding,2024;索引2,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024】。
本文基于 EAGLE 架构进行构建。本文的贡献建立在 EAGLE【索引25,EAGLE-2: Faster inference of language models with dynamic draft trees,2024】之上,这是一种自起草架构,在 Spec-Bench【索引37,Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding,2024】排行榜上取得了目前最好的成绩。该草稿模型重用了 $M_{Large}$ 的 token embed 和 LM head 参数。它接收 $M_{Large}$ 最后一个解码器层的真实激活值 $o_1, \dots, o_t$ 以及序列词元 $y_1, \dots, y_t$ 作为输入,来预测下一个激活值 $\hat{o}{t+1}$,该激活值随后被传递给 LM head 来预测下一个词元的分布:
$\hat{o}} = M_{EAGLESmall}((o_1, \dots, o_t), \text{token embed}(y_1, \dots, y_t)), \quad \hat{y{t+1} \sim \text{Softmax}(\text{LM head}(\hat{o}))$
这个过程通过将 $\hat{o}{t+1}, \hat{y}$ 附加到输入中来重复进行,从而自回归地起草后续词元 $\hat{y}_{t+2}$。
A2 方法细节
3.1 注意力混合
我们首先定义 $M_{Small}$ 的重要属性,然后详细阐述架构选择。
3.1.1 部分可观测性
部分可观测性问题。在马尔可夫决策过程(MDPs)【索引16,Planning and acting in partially observable stochastic domains,1998】中,部分可观测性是一个普遍挑战,即智能体没有足够的信息来了解真实底层状态以做出最优决策。这一限制会严重降低智能体的性能。为缓解此问题,已提出多种方法,例如增加额外的历史观测【索引29,Human-level control through deep reinforcement learning,2015】。在起草过程中,为了更准确地起草未来词元,避免部分可观测性至关重要。我们在此背景下扩展此概念:
属性 3.1 (部分可观测性). 给定一个真实函数 $F: X \to Z$,一个草稿函数 $f: Y \to Z$ 和一个观测函数 $g: X \to Y$,使得对于任何 $x \in X$,$f(g(x))$ 都在建模 $F(x)$。如果 $g$ 是非单射的:$\exists (x, x') \in X^2, x \neq x', g(x) = g(x')$,则称 $f$ 存在部分可观测性。
我们可以观察到,当 $F$ 是 $f_{Large}$,$f$ 是 $M_{EAGLESmall}$ 且 $o_1, \dots, o_t = g(h_{\le t}, (y_1, \dots, y_t))$ 时,EAGLE 草稿模型存在部分可观测性。换言之,$o_1, \dots, o_t$ 只是动态系统(公式1)真实状态 $(h, y)$ 的一个部分观测,这妨碍了 $M_{Small}$ 预测 $M_{Large}$ 正确词元的能力。
层自注意力(Layer Self-Attention, LSA)。为缓解这一问题,我们的新架构将动态系统(公式1)的状态作为输入。然而,$h_{\le t}$ 是一个形状为 $(T, L, 2E_{kv})$ 的大张量,其中 $T$ 是序列长度,$L$ 是 $M_{Large}$ 的层数,$E_{kv}$ 是键和值的嵌入大小。因此,我们引入了层自注意力(LSA),后跟一个均值聚合操作,将其维度降至 $(T, 2E_{kv})$,并从每一层中提取最相关的词元信息(图1)。自注意力是在层的维度上执行的,每个词元在该层中被独立处理。在起草过程中,我们可以访问所有层的过去键值,因此 LSA 的注意力掩码是双向/完全的(见图2)。我们只在每个起草阶段开始时执行一次 LSA 计算,关于信息流的详细算法见附录 A.3。
3.1.2 缺乏在策略性(ON-POLICYNESS)
训练与测试的不一致性。训练和测试场景之间的差异源于:在训练期间,Transformer 模型通常以真实的序列为条件,假设所有先前的输入都是正确的。如果说这个假设对于 Transformer 的标准训练来说不成问题,那么在推测解码场景中为 $M_{Small}$ 的训练做此假设则要微妙得多。众所周知,一些先前的输入是直接由 $M_{Small}$ 生成的,因此准确性要低得多。我们仅用 $M_{Small}$ 预测的词元越多,就越可能累积误差。为了缓解这个问题,EAGLE 在训练时向其观测值 $(o_1, \dots, o_t)$ 中添加了均匀噪声,但这并非理想之举。
在策略训练的必要性与挑战。为了优化地训练 $M_{Small}$,我们需要确保其训练和推理条件紧密匹配。具体来说,这意味着训练 $M_{Small}$ 时,要模拟部分先前词元是由其自身生成的情况。此外,我们还应考虑在起草周期中无法获得 $M_{Large}$ 激活值的情况。这种方法被称为在策略(on-policy)训练。在在策略训练中,用于训练的数据是由当前正在训练的同一策略(或模型)生成的。例如,当我们在静态数据集上使用下一个词元预测来训练 Transformer 时,这被认为是离策略(off-policy)的,因为数据不会根据模型的决策而改变。然而,如果我们将这个静态数据集与模型在训练过程中自身生成的数据混合,我们就向更在策略的方法迈进。同样,如果模型在生成过程中无法访问某些信息(例如 $M_{Large}$ 的激活值),那么总是在训练 $M_{Small}$ 时提供这些信息也被认为是离策略的。然而,在策略训练的成本非常高,因为我们需要在训练过程中从模型生成数据。为了形式化这一局限性,我们引入了 T 步有界(T-step boundedness)的概念:
属性 3.2 (T 步有界). 如果一个草稿模型 $f$ 在一次前向传播中,可以预测最多 $T$ 个未来词元而无需 $M_{Large}$ 的额外输入,即 $f(y_1, y_2, \dots, y_t) \to (\hat{y}_{t+1}, \hat{y}_{t+2}, \dots, \hat{y}_{t+T})$,则称其为 T 步有界。
这个属性对于高效训练草稿模型很重要。例如,EAGLE 草稿模型是 1 步有界的。如果想在时间 $t+2$ 进行预测,由于其自回归层需要前一个 $\hat{o}_{t+1}$ 作为输入,将需要两次前向传播,这在策略训练中成本极高。相比之下,一个 T 步有界且 $T>1$ 的草稿模型可以在单次前向传播中预测多个未来词元。
交叉注意力(Cross-Attention)。为了使我们的草稿模型部分地 T 步有界($T>1$),我们架构的主要组成部分是一个交叉注意力(CA)层,其中查询(query)来自词元,而键(key)和值(value)来自 $M_{Large}$ 的激活值。更准确地说,键和值来自 LSA 的输出。让时间 $t+1$ 到 $t+K$ 的输入查询进入 CA 层,而键值仅来自 $M_{Large}$ 直到时间 $t$ 的信息,这实际上意味着 CA 层是 K 步有界的。这使我们能够更高效地进行在策略训练,因为它模拟了生成过程中会发生的情况:我们只能访问 $M_{Large}$ 直到时间 $t$ 的激活值,但仍需对直到时间 $t+K$ 的情况进行预测。请注意,这仍然不是完全在策略的,因为时间 $t+1$ 到 $t+K$ 的输入查询并未假设是由 $M_{Small}$ 生成的。在训练过程中,通过改变 CA 层的掩码来采样多个 $K$ 值,以模拟不同长度的已接受草稿。例如,在图 2 中,我们先有 $K=4$,后有 $K=3$。相反,在生成过程中,我们不应用掩码,因为我们希望让 $M_{Small}$ 关注所有当前可用的 $M_{Large}$ 激活值。
自注意力(Self-Attention)。为了说明引入自注意力(SA)层的动机,我们首先观察到交叉注意力层相对于输入查询是输入无关的(属性3.3),即一个输入查询不影响另一个查询的结果。
属性 3.3 (输入无关性). 如果对于任意 n 个输入 $x = (x_1, \dots, x_n)$,我们有 $f(x) = (f(x_0), \dots, f(x_n))$,则称层 $f$ 是输入无关的。
因此,如果 CA 层的查询直接来自嵌入的词元 $y_1, \dots, y_t$,$M_{Small}$ 将不会意识到先前已起草的词元。它只会知道 $M_{Large}$ 处理的最后一个词元和最新的 $y_t$。但是,为了做出准确的预测,$M_{Small}$ 需要知道先前已起草的词元。因此,我们引入一个因果自注意力层作用于查询上以缓解此问题,如图 1 所示,并在表 1 中总结。
表1:我们新架构的属性比较。
3.2 目标层推理
挑战最终层假设。之前的工作假设语言模型头部(LM head)之前的最后一个隐藏层是 $M_{Small}$ 应该预测的最合适目标(激活值)。然而,我们挑战了这一假设,并假设针对 $M_{Large}$ 更深层的层可能在草稿质量方面更有优势。因此,我们通过引入层索引 $l$(作为上标)逐层分解动态系统(公式1):
$$\begin{aligned} \begin{aligned} \boldsymbol{o}_{t+1}^1 &= token\_embed(y_t), & \boldsymbol{h}_{\leq t+1}^l, \boldsymbol{o}_{t+1}^{l+1} &= f_{\text{decoder}}^l(\boldsymbol{h}_{\leq t}^l, \boldsymbol{o}_{t+1}^l), \\ y_{t+1} &\sim \text{Softmax}(LM\_head(\boldsymbol{o}_{t+1}^{L+1})) & l &= 1, \dots, L \end{aligned} \end{aligned}$$其中 $f_{decoder}^l$ 是 $M_{Large}$ 在第 $l$ 层的解码器层。这个新动态系统的状态由 $(o_{t+1}^l, h_{<t+1}^{\le l}, h_{\le t}^{\ge l})$ 组成。我们观察到,要完美预测 $o_{t+1}^{L+1}$,只需完美预测 $o_{t+1}^L$ 并重用 $M_{Large}$ 的 $f_{decoder}^L$ 以及在时间 $t$ 已经计算好的第 $L$ 层的 KV 缓存 $h_{\le t}^L$。同样的递归推理可以用于从 $o_{t+1}^{L-1}$ 预测 $o_{t+1}^L$,以此类推。我们假设(并将在后文证明)预测 $o_{t+1}^l$ 总是比预测 $o_{t+1}^k$ (对于 $l < k$) 更容易,因为 $o_{t+1}^l$ 经历的层变换更少。</p>
目标层推理(TLI)超参数。因此,我们引入一个新的超参数 TLI 来指代 $M_{Small}$ 应该预测的目标层 $o^{L+1-TLI}$。当 TLI > 0 时,$M_{Large}$ 的最后 TLI 层(在训练期间保持冻结)及其 KV 缓存被用来输出 $o^{L+1}$。此后,我们使用符号 (TLI = l) 来表示用于推理的目标层,其中 l 是一个整数。现在我们可以给出描述我们的 $M_{OursSmall}$ 在给定 TLI 下的方程,假设 $t$ 是我们最后一次用 $M_{Large}$ 验证的时间点:
$$\begin{aligned} \begin{aligned} \boldsymbol{\hat{o}}_{T+1}^{L+1-\mathrm{TLI}} & =\mathcal{M}_{\mathrm{Small}}^{\mathrm{Ours}}\left(\boldsymbol{h}_{\leq t}, \text {token\_embed}\left(y_{1}, \cdots, y_{t}, \hat{y}_{t+1}, \cdots, \hat{y}_{T}\right)\right), \\ \hat{\boldsymbol{h}}_{T+1}^{l}, \boldsymbol{\hat{o}}_{T+1}^{l+1} & =f_{\text {decoder }}^{l}\left(\left(\boldsymbol{h}_{\leq t}^{l}, \hat{\boldsymbol{h}}_{>t, \leq T}^{l}\right), \boldsymbol{\hat{o}}_{T+1}^{l}\right), \quad l=L-\mathrm{TLI}, \ldots, L, \\ \hat{y}_{T+1} & \sim \operatorname{Softmax}\left(L M\_head\left(\boldsymbol{\hat{o}}_{T+1}^{L+1}\right)\right) . \end{aligned} \end{aligned}$$3.3 损失函数
损失函数定义。设 $M_{Small}$ 由参数 $\theta$ 参数化,我们使用与 EAGLE 类似的训练损失,即前向 KL 散度损失,并结合 $M_{Small}$ 预测的激活值 $\hat{o}^{L+1-TLI}$ 与从 $M_{Large}$ 获得的目标激活值之间的 Smooth-L1 损失 $\mathcal{L}$:
$$\arg \min _{\boldsymbol{\theta}} \lambda_{0} \mathrm{KL}\left[\mathcal{M}_{\text {Large }} \| \mathcal{M}_{\text {Small }}(\boldsymbol{\theta})\right]+\lambda_{1} \mathcal{L}\left(\hat{\boldsymbol{o}}^{L+1-\mathrm{TLI}}, \boldsymbol{o}^{L+1-\mathrm{TLI}}\right) .$$为了保持训练的轻量级,我们在训练期间不从 $M_{Large}$ 或 $M_{Small}$ 生成数据。此损失仅在固定训练数据集的提示响应部分上定义。
A4 实验环境与结果
实验环境
- 大模型 ($M_{Large}$): Llama3-8B-Instruct【索引9,The llama 3 herd of models,2024】。
- 训练数据集: 所有小模型 ($M_{Small}$) 均在 Ultrachat 数据集【索引7,Enhancing chat language models by scaling high-quality instructional conversations,2023】上进行训练,不使用系统提示。Ultrachat 包含约 20 万个提示和约 2.4 亿个词元。
- 测试数据集: 使用多个测试集进行生成评估,涵盖推理、代码生成、多轮对话和摘要等任务。主要依赖 SpecBench 基准【索引37,Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding,2024】,并包括 MT-Bench【索引42,Judging llm-as-a-judge with mt-bench and chatbot arena,2023】、HumanEval【索引4,Evaluating large language models trained on code,2021】、GSM8K【索引5,Training verifiers to solve math word problems,2021】、Alpaca【索引34,Stanford alpaca: An instruction-following llama model,2023】、CNN/Daily Mail【索引31,Abstractive text summarization using sequence-to-sequence rnns and beyond,2016】和 Natural Questions【索引21,Natural questions: a benchmark for question answering research,2019】。
- 硬件与网络配置 (客户端-服务器): 客户端和服务器位于相距约 300 公里的两个不同城市。服务器的 float16 tflops 是客户端的 3 倍。设备间的 ping 约为 9ms,带宽约 50 Mbits/sec。使用 Linux 流量控制子系统模拟 4G(20 Mbits/sec,延迟 21±19ms,0.1% 丢包率)和 5G(延迟 10±10ms,0.1% 丢包率)网络。
- 软件与参数配置: 使用 EAGLE-2 的树状解码策略,最大宽度为 8,深度为 6,最多验证 62 个词元。使用 float16 精度,但注意力 softmax 权重升至 float32。
- 对比模型:
- EAGLE-2: 我们自己训练的版本。
- EAGLE-2 off.: 【索引25,EAGLE-2: Faster inference of language models with dynamic draft trees,2024】发布的官方权重。
- Independent: 独立蒸馏的、规模相似的 $M_{Small}$。
- 为保证公平比较,"Ours (TLI=0)"、"EAGLE-2"、"EAGLE-2 off."、"Independent 1.3B" 和 "Glide" 模型均包含 1.3B 参数(2.5亿可训练参数和 10.5亿冻结参数)。
实验结果
4.1 单设备场景
实验内容: 在 SpecBench 基准【索引37,Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding,2024】上进行单设备性能测试,不使用系统提示以确保公平比较。
实验结果: 如表 2 所示,本文提出的“注意力混合”架构在 TLI=1 和 TLI=0 时均取得了当前最佳(SOTA)的加速效果。与 EAGLE-2 相比,我们的方法平均生成速度(tokens-per-second)快 9.5%。当 TLI=1 时,平均接受长度($\tau$)提高了 25%。
表2:在 SpecBench 上使用来自 MT-Bench、HumanEval、GSM8K、Alpaca、Sum 和 QA 数据集的提示得到的加速比和接受长度 τ。每个模型都经过 30 个 epoch 的微调,并使用 EAGLE-2 树状解码。
结论: 在标准的单设备部署中,我们的方法在解码速度和接受长度上均优于现有的 SOTA 方法 EAGLE-2。
4.2 客户端-服务器场景
实验内容: 研究自起草方法在客户端-服务器场景下的性能。$M_{Small}$ 部署在客户端,$M_{Large}$ 部署在服务器。分析了在模拟 4G 和 5G 网络条件下的解码速度、接受长度和服务器调用次数,并考虑了消息大小对性能的影响。
实验结果:
- 解码速度: 如表 3 所示,"Ours (TLI=0)" 实现了最快的解码速度(5G下 34.1 tokens/sec,4G下 25.1 tokens/sec),甚至超过了那些只交换文本ID、不传输激活值的独立小模型。虽然速度相较于单设备场景有所下降,但仍能恢复单设备上标准解码的速度(约 33 tokens/sec)。
- 网络断连鲁棒性: 模拟了与服务器完全断开连接的情况。如表 4 所示,与依赖 API 调用的方式不同(后者会完全中断),我们的方法可以在客户端设备上继续生成响应,并在 HumanEval 数据集上完成额外的编程问题。例如,在生成 50 个新词元后断开连接,"Ours (TLI=3)" 的成功率达到 31.67%,远高于其他方法。
表3:在 5G 和 4G 网络环境下,使用 EAGLE-2 树状解码在 HumanEval 上的性能。
表4:在客户端与服务器连接中断情况下,在 HumanEval 上的成功率(pass@1,贪婪解码)。使用了 EAGLE-2 树状解码。
结论: 本文提出的客户端-服务器框架是可行的。它不仅提供了可接受的解码速度,还能在网络断连时继续生成有用的响应,为在小型设备上提供 LLM 服务开辟了新途径。
4.3 消融研究
实验内容: 在 HumanEval 数据集上,对“注意力混合”架构的各个组件(CA 层、LSA、TLI)进行消融实验,以验证其各自的贡献。为控制变量,该实验使用链式起草(最大长度4)并训练10个周期。
实验结果:
* CA 层的效果(在策略性): 将 EAGLE 与 "Ours (TLI=0, -LSA)"(即仅包含 CA 和 SA)进行比较。结果显示,CA 层带来的在策略性和 T 步有界属性使 tokens-per-second 提升了 26%,接受长度提升了 33%。这证明了该组件的巨大积极影响。
* LSA 的效果(部分可观测性): 将 "Ours (TLI=0, -LSA)" 与 "Ours (TLI=0)" 进行比较。引入 LSA 后,tokens-per-second 性能提升了 6%。这表明减少部分可观测性是有益的,但其影响不如 CA 层带来的在策略性关键。
* TLI 的效果: 观察表 2、3 和 5 中不同 TLI 值的变化,可以发现增加 TLI 会提高接受长度,正如假设的那样。然而,这并不总能转化为更高的 tokens-per-second,因为它也增加了起草的计算时间。但在客户端-服务器场景中,更高的 TLI 在网络断连时能提高响应质量。
表5:我们提出的架构在 HumanEval 上的消融研究。每个模型在约 2.4B 词元上训练。本研究使用链式(非树状)起草,最大长度为 4。平均值是在大约 8500 个起草-验证周期上计算的。
结论: 消融研究证实了所提出的架构中每个组件的有效性:CA 层对于提升在策略性至关重要,LSA 有效缓解了部分可观测性问题,而 TLI 则提供了一个在速度和质量之间权衡的有效手段。
A7 补充细节
5 相关工作
利用 $M_{Large}$ 激活值的自起草方法。Medusa【索引2,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024】是利用 $M_{Large}$ 激活值作为 $M_{Small}$ 输入的最早工作之一,它通过训练 K 个新的 LM head 来并行生成 K 个未来词元。后来被 Kim 等人【索引17,Accelerating blockwise parallel language models with draft refinement,2025】扩展,通过使用轻量级重排序器来优化草稿。EAGLE【索引26,EAGLE: Speculative sampling requires rethinking feature uncertainty,2024】和 Hydra【索引1,Hydra: Sequentially-dependent draft heads for medusa decoding,2024】是 Medusa 的自回归扩展,它们观察到非自回归生成会限制接受长度,因为 $M_{Small}$ 无法感知到先前生成的词元。由于 EAGLE 在 SpecBench 排行榜上排名更高,本文未与 Medusa 或 Hydra 进行比较。
与 $M_{Large}$ 隐藏状态对齐的方法。Tandem Transformers【索引30,Tandem transformers for inference efficient llms,2024】提出让 $M_{Small}$ 关注 $M_{Large}$ 降维后的隐藏状态,通过共享隐藏状态表示来对齐两个模型,从而提高草稿接受率。由于缺乏开源实现、使用闭源 LLM 以及未公开的复现所需数据/计算量,我们未能与其进行比较。此外,该方法似乎在大小模型之间有很高的通信开销,不适用于客户端/服务器设置。
免训练的推测解码方法。与本文工作正交,研究人员最近提出了免训练的 SD 方法。Lookahead Decoding【索引10,Break the sequential dependency of llm inference using lookahead decoding,2024】使用 Jacobi 迭代通过单个 $M_{Large}$ 生成新词元,后被 CLLM【索引19,Cllms: Consistency large language models,2024】和 Ouroborous【索引41,Ouroboros: Speculative decoding with large model enhanced drafting,2024】扩展。我们在我们的设置中评估了后者,但发现其效率低于 EAGLE-2 的树状解码策略(见附录 A.4)。更多相关工作可参考【索引37,Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding,2024】中的详细综述。
利用 KV 缓存的方法。Du 等人【索引8,Glide with a cape: A low-hassle method to accelerate speculative decoding,2024】先前提出利用 $M_{Large}$ 某些层的 KV 缓存。他们没有从理论上解释为何使用 KV 缓存而不是每层的输出,也没有说明如何精确选择要包含哪些层作为 $M_{Small}$ 的输入。然而,通过我们的动态系统视角,我们证明了所有层的 KV 缓存都是状态的一部分。LSA 的引入使得可以在有限层数内完整利用它,而 Du 等人的方法则需要在 $M_{Small}$ 和 $M_{Large}$ 中有相同数量的层才能完全捕捉,导致起草速度变慢。
本文方法的普适性。尽管我们专注于改进当前的 SOTA 方法(EAGLE-2),但我们的观察(部分可观测性、在策略性和目标推理层)对许多自起草方法都成立。例如,它也可以应用于 Medusa【索引2,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024】、MLP Speculator【索引35,Accelerating production llms with combined token/embedding speculators,2024】或 Gloeckle 等人【索引11,Better & faster large language models via multi-token prediction,2024】和 Kim 等人【索引17,Accelerating blockwise parallel language models with draft refinement,2025】的工作。对于非自起草的 SD,需要逐案研究。例如,目标推理层可能可以应用于独立的的小模型。许多学生-教师蒸馏框架【索引12,Miniplm: Knowledge distillation for pre-training language models,2024;索引43,Distillspec: Improving speculative decoding via knowledge distillation,2023】已经利用了在策略属性,但大多是 1 步有界的(因此训练成本高)。对于基于前瞻解码的 SD 方法,通常不适用。一个例外是 Ouroboros【索引41,Ouroboros: Speculative decoding with large model enhanced drafting,2024】,它利用了一个小模型进行前瞻解码,其小模型也可能从我们的解决方案中受益。
A5 结论
本文介绍了一种用于推测解码的注意力混合架构,有效解决了现有先进方法的多项局限性。为了提高草稿模型 $M_{Small}$ 的起草准确性,我们提出了一种注意力层组合:
1. 层自注意力(Layer Self-Attention, LSA),以缓解部分可观测性问题。
2. 自注意力(Self-Attention) 后接 交叉注意力(Cross-Attention),以实现更具在策略性的训练。
接着,我们引入了目标层推理(Target Layer Inference, TLI),这是一种新颖的方法,允许 $M_{Small}$ 重用 $M_{Large}$ 的最后 N 层,从而在起草速度和准确性之间进行权衡。
实验结果表明,在标准的单设备设置中,我们实现了业界领先的解码加速,比 EAGLE-2 提升了 9.5%,并将接受长度延长了高达 25%。我们还引入了一个客户端-服务器范式,并证明了我们的自起草推测解码方法是调用 $M_{Large}$ API 的一种可行替代方案。在此范式下,即使在网络完全断开后,客户端仍能以最高的准确性和速度继续生成响应。
作为未来的研究方向,探索是否可以由 $M_{Small}$ 来预测 TLI (N) 的值,以便根据当前的网络状况自动平衡速度和准确性,将是一个有趣的研究课题。
A6 附录
A.1 超参数
表6:我们的超参数列表。
A.2 客户端-服务器部署
表7:消息大小(量化前),单位为字节。M = 草稿树中的节点数,A = 接受的词元数,E = 隐藏层大小,Ekv = 键和查询向量的隐藏层大小。
A.3 算法
$M_{OursSmall}$ 的生成算法。以下算法展示了假设使用链式解码的 $M_{OursSmall}$ 的生成过程。
Algorithm 1: MOursSmall 的生成算法(假设链式解码)
Require: 输入序列 y = (y1, y2, . . . , yt),草稿长度 K,目标层推理 TLI
1: 通过在 MLarge 中对输入 y 进行一次前向传播,获得 h≤t 激活值和 yt+1
2: y ← (y, yt+1)
3: kv ← LSA layer with mean (h≤t)
4: while y 未满足停止条件 do
5: for i = 1 to K do
6: q ← SA layer(token embed(y))
7: oˆL+1−N ← CA layer(q, kv)
8: if N > 0 then
9: for l = L − N to L do
10: [hˆ l, oˆl+1] ← f ldecoder((hl≤t, hˆ l>t,≤t+i), oˆl)
11: end for
12: end if
13: yˆ ∼ Softmax(LM head(oˆL+1))
14: y ← (y, yˆ)
15: end for
16: 从 y 的 K 个最新词元中识别出 K′ 个已验证的词元,获得相关的 h′ 和 y′,这是通过在 MLarge 中对输入 y|y|−K,··· ,|y| 和 h≤t 进行一次前向传播得到的
17: kv′ ← LSA layer with mean (h′)
18: kv ← (kv, kv′)
19: 通过附加新的 h′ 组件来更新 h
20: 丢弃之前的 hˆ
21: y ← y1,··· ,|y|−K+K′ (仅保留已验证的词元)
22: t ← |y|
23: y ← (y, y′)
24: end while
25: return y
A.4 附加实验结果
生成文本的准确性。我们运行了多项实验以评估使用贪婪解码生成响应的质量。我们专注于 SpecBench 中的 3 个数据集(HumanEval、GSM8K 和 CNN/DM),这些数据集的评估不需要访问专有模型/API(例如 llm-as-a-judge)。结果如表 8 所示。HumanEval 上的 pass@1 在所有方法中都相同。GSM8K 上的准确率对于 Ours (TLI=1) 和 EAGLE-2 在一个问题上相对于基础模型有所提高。最后,ROUGE 分数也非常相似,我们得出结论,与基础模型的任何差异都可以忽略不计,几乎可以肯定是由于使用 float16 引起的。
表8:生成文本的质量。
在 Qwen2.5 3B 上的实验。我们额外训练了 3 个小模型在 Ultrachat 数据集上,以加速 Qwen2.5 3B。EAGLE 建议使用大 LLM 的一个解码器层来定义小 LM 的大小,这导致可训练参数大小为 80M。我们保持共享的 "embed tokens/LM head" 层冻结。结果如表 9 所示,我们的方法同样取得了最好的加速效果和接受长度。
表9:在使用 Qwen2.5-3B Instruct 的情况下,在 SpecBench 上使用 MT-Bench、HumanEval、GSM8K、Alpaca、Sum 和 QA 数据集提示的加速比和接受长度 τ。
在 vLLM 中支持更高批量。我们在 vLLM 【索引22,Efficient memory management for large language model serving with pagedattention,2023】中实现了我们的方法(不使用树状解码),以支持更高的批量大小和连续批处理。如图 4 所示,我们的方法 (MOA) 在不同请求率下都比 EAGLE-2 表现出更好的延迟加速比。
在单设备上对 HumanEval 的完整测试。为了进行这个实验,我们重用了在消融研究中使用的完整 HumanEval 数据集,并采用了严格的停止标准,在单设备设置下进行测试。如表 10 所示,我们的方法比 EAGLE/EAGLE-2 快 26%。我们也比独立的小模型和 Ouroboros 【索引41,Ouroboros: Speculative decoding with large model enhanced drafting,2024】更快。
表10:在 Human Eval 上的测试,每个模型训练 30 个 epoch。
A.5 复杂度分析
标准解码器复杂度。标准解码器-only Transformer 的标准解码:
* 在第一个预填充(prefill)阶段,复杂度为 $O(LKE(E+K))$,其中有 $L$ 个自注意力层, $K$ 个输入词元,嵌入大小为 $E$。
* 对于 $K'$ 个新解码的词元,复杂度为 $O(\sum_{i=1}^{K'} L(E^2 + E(K+i))) = O(LE(EK' + KK' + K'^2))$。
如果假设 $E$ 和 $L$ 固定,则总体复杂度为 $O((K+K')^2)$。
推测解码复杂度。对于推测解码,预填充阶段是相同的。假设一次验证 $S$ 个词元,验证的复杂度将增长为 $O(\sum_{i=S}^{K'} L(SE^2 + SE(K+i))) = O(LE(EK' + KK' + K'^2))$,导致与标准解码相同的复杂度。这主导了自起草的复杂度,但我们仍然可以分析它。
* EAGLE: 解码一个新词元增长为 $O(E^2+EK)$,因为它是一个单一的自注意力层。
* 本文方法: 我们的自注意力和交叉注意力层也增长为 $O(E^2+EK)$。层自注意力仅在每个验证阶段后调用一次,而不是在每个解码步骤,其复杂度为 $O(ALE_{kv}^2 + AE_{kv}L^2)$,其中 $A$ 是上一阶段接受的词元数。在我们的实验中,$ALE_{kv}^2$ 小于解码词元数 $\times E^2$,因为 $E_{kv}$ 比 $E$ 小 4 倍,$L$ 为 32,$A$ 平均为 4.5,解码词元数为 48。同样,$AE_{kv}L^2$ 通常也小于解码词元数 $\times EK$,只要请求包含超过 24 个词元。因此,总体时间复杂度与 EAGLE 相同。
A.6 隐私应用
客户端-服务器设置的隐私优势。客户端-服务器设置的另一个优点是,我们可以通过仅将提示的非敏感部分发送到服务器来选择性地确保客户端的隐私。本质上,客户端可以将其输入分为连续的“安全”文本和“私有”文本。服务器只处理“安全”文本,这可能是通用上下文或非敏感信息。客户端保留“私有”文本,例如机密数据或敏感指令,并在本地用 $M_{Small}$ 处理这部分。
应用实例。例如,客户端可能会向服务器发送一些 Python 代码以及一般性描述。但是,任何敏感信息,例如要注入代码的登录名和密码,都保留在客户端,不传输到服务器,只传递给 $M_{Small}$。这种方法利用 $M_{Large}$ 的激活值来提高 $M_{Small}$ 在任务某些部分的准确性,同时确保敏感信息永远不会暴露在客户端环境之外。
💬 评论讨论
欢迎在这里分享您的想法和见解!