EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees
- 文章标题: EAGLE-2: 借助动态草稿树实现更快的语言模型推理
- 作者/机构: Yuhui Li♠, Fangyun Wei‡, Chao Zhang♠, Hongyang Zhang♣† (♠北京大学, ‡微软研究院, ♣滑铁卢大学, †Vector Institute)
A1 主要贡献
本文旨在解决现代大型语言模型(LLM)推理成本高昂且耗时的问题。投机采样(speculative sampling)是一种有效的解决方案,其通过快速生成草稿令牌(draft tokens)然后并行验证来加速推理。
-
核心问题:现有的基于树结构的投机采样方法,如EAGLE和Medusa,采用静态的草稿树,这隐含地假设草稿令牌的接受率仅与其在树中的位置有关。然而,本文通过实验发现,草稿令牌的接受率不仅与位置相关,还高度依赖于上下文(context-dependent)。因此,静态草稿树结构存在固有的局限性。
-
研究目标与创新点:
- 提出上下文感知的动态草稿树:为了克服静态树的限制,本文提出了EAGLE-2,引入了一种上下文感知的动态草稿树技术。该技术可以根据不同上下文中草稿令牌的接受率来动态调整草稿树的结构。
- 利用草稿模型的校准特性:动态调整草稿树需要预估接受率,但这通常需要目标LLM的验证,会产生额外开销。本文的一个关键发现是,EAGLE的草稿模型是良好校准的(well-calibrated),其输出的置信度分数能够以很小的误差近似实际的接受率。EAGLE-2利用这一特性,使用草稿模型的置信度分数来低成本地指导动态树的构建。
- 实现更优的加速性能:通过动态调整草稿树,EAGLE-2能够生成更多的被接受令牌,从而实现比现有方法更优的加速效果。在多个LLM系列和六个任务上的广泛评估表明,EAGLE-2实现了3.05x至4.26x的加速比,比其前身EAGLE-1快20%-40%。
-
附加优势:
- 开箱即用(Out-of-the-box usability):与EAGLE相比,EAGLE-2无需训练任何额外的模型来预测树结构,它直接利用投机采样中必不可少的草稿模型的置信度分数,因此无需额外训练。
- 可靠性(Reliability):EAGLE-2不微调或更新原始LLM的参数,也不放宽接受条件,这从理论上保证了其生成的文本分布与原始LLM完全一致,是一种无损加速算法。
A2 预备知识与关键观察
2. 预备知识
2.1 投机采样
投机采样的核心思想是“先草稿,后验证”。它首先快速生成一段可能正确的草稿,然后检查草稿中有哪些令牌可以被接受。我们使用 $t_i$ 表示第 $i$ 个令牌,用 $T_{a:b}$ 表示令牌序列 $t_a, t_{a+1}, \dots, t_b$。投机采样在草稿阶段和验证阶段之间交替进行。
草稿与验证过程。给定一个前缀 $T_{1:j}$,在草稿阶段,投机采样调用一个草稿模型(比原始LLM小的模型)自回归地生成一个草稿 $T_{\hat{j+1}:j+k}$,同时记录每个令牌的概率 $\hat{p}$。在验证阶段,投机采样调用原始LLM来检查草稿 $T_{\hat{j+1}:j+k}$ 并记录其概率 $p$。接着,投机采样从前到后依次决定草稿令牌的接受与否。对于令牌 $\hat{t}_{j+i}$,它被接受的概率是 $min(1, p_{j+i}(\hat{t}_{j+i}) / \hat{p}_{j+i}(\hat{t}_{j+i}))$。如果令牌被接受,则继续检查下一个;否则,从分布 $norm(max(0, p_{j+i} - \hat{p}_{j+i}))$ 中采样一个令牌来替换 $\hat{t}_{j+i}$,并丢弃草稿中剩余的令牌。文献【22, Fast inference from transformers via speculative decoding, 2023, ICML】的附录A.1证明了投机采样与原始自回归解码的分布是一致的。EAGLE和EAGLE-2都应用了这个框架。
2.2 EAGLE
EAGLE是对投机采样的改进。在本文提交时,EAGLE在Spec-Bench【44, Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding, 2024】上排名第一,这是一个为评估各种场景下投机解码方法而设计的综合基准。
- 草稿阶段(Drafting Stage)。与标准的投机采样自回归地预测令牌序列不同,EAGLE在结构性更强的特征(LM head之前)层面进行自回归,然后使用原始LLM的LM Head来获得草稿令牌。由于采样过程给特征序列带来了不确定性,为了解决这个问题,EAGLE还将提前一个时间步的令牌序列输入到草稿模型中,如图3a所示。
- 验证阶段(Verification Stage)。在标准的投机采样中,草稿是链式结构的,如果一个草稿令牌被拒绝,就需要丢弃所有后续的令牌。EAGLE使用树状结构的草稿,如果一个草稿令牌被拒绝,可以尝试其他分支。图3b展示了两者的区别。
EAGLE与EAGLE-2的区别。EAGLE的草稿树形状是固定的,草稿阶段只是填充相应的位置。EAGLE-2旨在通过引入一个可动态调整的草稿树来改进这一点。图4通过一个简单的例子说明了EAGLE和EAGLE-2之间的区别。
3. 观察
3.1 接受率依赖于上下文
首先评估使用动态草稿树的必要性。这取决于草稿令牌的接受率是否仅与它们的位置有关。我们在Alpaca数据集和Vicuna 7B上测试了草稿树中不同位置令牌的接受率。结果如图5所示。总体而言,草稿令牌的接受率与位置相关,位置P1的接受率最高,P6的最低。草稿树左上侧的草稿令牌(如位置P1)接受率较高,而右下侧的(如位置P6)接受率较低。这支持了像EAGLE和Medusa等方法中使用的静态草稿树在左上侧节点更多、右下侧节点更少的合理性。然而,我们同时也观察到在同一位置,接受率存在显著的方差,这表明一个草稿令牌被接受的概率不仅取决于其位置,还取决于上下文。这说明一个上下文感知的动态草稿树比静态草稿树具有更大的潜力。
3.2 良好校准的草稿模型
要应用动态草稿树,需要一种低成本的方法来估计草稿令牌的接受率,而无需调用原始LLM。我们在Alpaca数据集上进行了实验,探索草稿模型的置信度分数(LLM对每个令牌输出的概率)与接受率之间的关系。如图6所示,草稿模型的置信度分数与令牌的接受率之间存在很强的正相关关系。置信度分数低于0.05的草稿令牌,其接受率约为0.04;而置信度分数高于0.95的草稿令牌,其接受率约为0.98。因此,我们可以使用草稿模型的置信度分数来估计接受率,而无需额外开销,从而实现对草稿树的动态调整。在其他方法中,如GLIDE和CAPE【9, Glide with a cape: A low-hassle method to accelerate speculative decoding, 2024】,也观察到了类似的现象。
A3 方法细节
4. 上下文感知的动态草稿树
基于上述观察,我们引入了EAGLE-2。这是一种用于LLM推理的加速算法,它能动态地调整草稿树。EAGLE-2不改变草稿模型的训练和推理,也不影响验证阶段。它的改进集中在两个方面:如何扩展草稿树(第4.1节)以及如何对草稿令牌进行重排序(第4.2节)。在扩展阶段,我们将草稿树最新一层中最有希望的节点输入到草稿模型中以形成下一层。在重排序阶段,我们选择接受概率更高的令牌来构成验证阶段原始LLM的输入。在草稿树中,一个节点代表一个令牌。在下文中,我们交替使用“节点”和“令牌”。
4.1 扩展阶段
利用树注意力机制并行扩展节点。得益于树注意力(tree attention),草稿模型可以同时输入当前层的所有令牌,并在一次前向传播中计算出下一个令牌的概率,从而扩展当前层的所有令牌。然而,一次性输入过多的令牌会减慢草稿模型的前向传播速度,并且草稿树每层令牌的数量呈指数级增长。因此,我们需要有选择地扩展草稿树。
依据全局接受概率选择扩展节点。我们选择当前层中全局接受概率最高的top-k个令牌进行扩展。在投机采样中,拒绝一个草稿令牌会导致丢弃所有后续令牌;一个令牌最终被接受,当且仅当它的所有前缀都被接受。一个令牌 $t_i$ 的全局接受率是其从根节点到 $t_i$ 路径上所有令牌接受率的乘积。我们将其定义为值 $V_i$:
其中,$Path(root, t_i)$ 表示草稿树中从根节点到节点 $t_i$ 的路径,$p_j$ 表示节点 $t_j$ 的接受率,$c_j$ 表示草稿模型对 $t_j$ 的置信度分数。第3.2节的实验表明,置信度分数与接受率高度正相关。我们利用这种关系来近似该值。
扩展流程。从值更高的令牌开始的分支更有可能被接受。因此,我们选择最后一层中值最高的top-k个节点作为草稿模型的输入,并根据输出扩展草稿树。图7的上半部分展示了扩展阶段。
4.2 重排序阶段
重排序以选择全局最优的草稿令牌。扩展阶段的目的是加深草稿树。由于接受率在0到1之间,更深的令牌其值会更低。一些未被扩展的浅层节点可能比更深的已扩展节点具有更高的值。因此,我们不直接使用扩展阶段选择的令牌作为草稿,而是对所有草稿令牌进行重排序,并选择值最高的top-m个令牌。一个节点的值总是小于或等于其父节点的值。对于值相同的节点,我们优先选择更浅的节点。这确保了重排序后选择的top-m个令牌仍然构成一个连通的树。
构造验证阶段的输入与注意力掩码。之后,我们将选定的令牌展平为一维序列,作为验证阶段的输入。为了确保与原始自回归解码的一致性,我们还需要调整注意力掩码。在原始自回归解码中,每个令牌可以看到所有前面的令牌,从而形成一个下三角注意力矩阵。当使用草稿树时,来自不同分支的令牌不应该互相看到。因此,必须根据树结构调整注意力掩码,以确保每个令牌只能看到其祖先节点。图7的下半部分展示了重排序阶段。
A4 实验
实验环境
- 模型 (Models): 实验在 Vicuna 7B, 13B【7, Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, 2023】、LLaMA2-Chat 7B, 13B, 70B【43, Llama: Open and efficient foundation language models, 2023】以及 LLaMA3-Instruct 8B, 70B【28, LLaMA3, 2024】模型上进行。
- 任务与数据集 (Tasks): 在六个生成任务上进行了综合评估,包括:
- 多轮对话: MT-bench【54, Judging llm-as-a-judge with mt-bench and chatbot arena, 2023】
- 代码生成: HumanEval【4, Evaluating large language models trained on code, 2021】
- 数学推理: GSM8K【8, Training verifiers to solve math word problems, 2021】
- 指令遵循: Alpaca【42, Stanford alpaca: An instruction-following llama model, 2023】
- 摘要: CNN/Daily Mail【31, Abstractive text summarization using sequence-to-sequence rnns and beyond, 2016】
- 问答: Natural Questions【21, Natural questions: a benchmark for question answering research, 2019】
- 设置 (Settings): 遵循LLM社区常用的zero-shot/few-shot设置,即同一个原始LLM在所有任务上使用相同的草稿模型权重。
- 软件与实现:
- 基准: 使用Huggingface.transformers库中的模型,后端为PyTorch,并预先分配KV缓存。
- 对比方法:
- 标准投机采样: 使用HuggingFace Transformers库的
assisted generation
功能。 - PLD, Lookahead, Medusa, Hydra: 使用默认设置和官方发布的权重。
- EAGLE: Vicuna和LLaMA2-Chat的草稿模型使用官方权重;LLaMA3-Instruct的草稿模型使用ShareGPT数据集进行训练。
- 标准投机采样: 使用HuggingFace Transformers库的
- EAGLE-2: 对于7B(8B)、13B和70B的原始LLM,草稿令牌总数分别设置为60、50和48,草稿树深度为6,扩展阶段选择10个节点。
- 硬件: 为了公平性,所有方法的测试都在相同的设备上进行,并利用了Spec-Bench【44, Unlocking efficiency in large language model inference: A comprehensive survey of speculative decoding, 2024】进行对比实验。
实验结果
- 评估指标 (Metrics): 由于EAGLE-2是无损加速方法,不评估生成质量,主要使用以下指标评估加速性能:
- 加速比 (Speedup Ratio): 相对于原始自回归解码的实际测试加速比。
- 平均接受长度 (Average Acceptance Length τ): 每个“草稿-验证”周期生成的令牌数,即从草稿中接受的令牌数量。该指标独立于硬件和运行环境。
5.1 有效性
- 总体性能: 在所有测试的数据集和LLM上,EAGLE-2都实现了最高的加速比(如图1、2和表1、2所示)。
- 任务表现:
- 大多数投机采样方法在代码生成任务(HumanEval)上加速效果最好,EAGLE-2在该任务上实现了高达5倍的加速。
- 在使用Vicuna作为原始LLM时,PLD在摘要任务(CNN/DM)上加速比最高,这得益于其基于检索的草稿生成方式。
- 标准投机采样(使用Vicuna-68M作为草稿模型)也取得了显著加速,但其训练开销远高于其他方法。
- 平均接受长度: 在所有测试中,EAGLE-2的平均接受长度τ最长,每个周期约生成4-5.5个令牌,大约是标准投机采样和Medusa的两倍(见表1、2)。PLD和Lookahead的τ较短,但因其草稿阶段开销极低,其加速比接近于τ。
- 数据依赖性分析: Medusa、Hydra、EAGLE和EAGLE-2在问答(Natural Questions)和摘要(CNN/DM)任务上的τ和加速比相较于其他任务有所下降。这可能是因为它们的草稿模型仅使用SFT数据集训练,而这些任务需要通过预训练获取的世界知识或摘要能力。相比之下,标准投机采样的草稿模型使用了预训练和SFT数据集。尽管如此,EAGLE-2在这两个数据集上仍然优于标准投机采样。
5.2 消融研究
- 价值(Value)与置信度分数(Confidence Score): 在选择扩展节点时,使用“价值”(即一个草稿令牌及其所有祖先节点置信度的乘积)作为排序依据,相比仅使用该令牌自身的置信度分数,能够带来更高的加速比和平均接受长度。表3的实验结果证明了EAGLE-2方法的合理性。
- 重排序(Reranking): 在扩展阶段之后,对所有草稿令牌进行全局重排序,能够进一步提升平均接受长度和加速比。如表3所示,重排序操作带来了明显的性能改进。
A5 结论
本文介绍了EAGLE-2,一种高效且无损的投机采样方法。我们发现EAGLE的草稿模型置信度能够很好地近似草稿令牌的接受率。基于此,EAGLE-2采用了一种依赖于上下文的草稿树结构,显著增加了被接受的草稿令牌数量,从而带来了更优的加速比。EAGLE-2确保了生成结果与原始LLM一致,并且不需要额外的训练。我们在多种LLM和多个数据集上进行了广泛的评估,并将EAGLE-2与几种最先进的投机采样方法进行了比较。在所有的实验中,EAGLE-2都取得了最高的加速比。
A6 附录
A. 实现细节
- Vanilla: 我们使用Huggingface.transformers库中的模型,后端为PyTorch,并预先分配了KV缓存。其他方法也使用这些模型作为其基础。
- (Standard) Speculative Sampling: 我们使用HuggingFace Transformers库中的assisted generation功能。
- PLD, Lookahead, Medusa, and Hydra: 我们使用默认设置和官方发布的权重。
- EAGLE: Vicuna和LLaMA2-Chat的草稿模型使用官方发布的权重,而LLaMA3-Instruct则是使用ShareGPT数据集进行训练(与Medusa和Hydra保持一致)。
- EAGLE-2: 对于7B (8B)、13B和70B的原始LLM,我们分别将草稿令牌的总数设置为60、50和48,草稿树深度为6,并在扩展阶段选择10个节点。
💬 评论讨论
欢迎在这里分享您的想法和见解!