MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
MEDUSA:一个带有多个解码头的简单LLM推理加速框架
作者/机构: Tianle Cai * 1 2, Yuhong Li * 3, Zhengyang Geng 4, Hongwu Peng 5, Jason D. Lee 1, Deming Chen 3, Tri Dao 1 2
A1 主要贡献
本文旨在解决大型语言模型(LLM)自回归解码过程中的一个核心瓶颈:由于每个token的生成都依赖于前一个token的输出,这种顺序计算导致推理延迟较高。从系统角度看,LLM推理主要是内存带宽受限的,因为每一步前向传播都需要将完整的模型参数从高带宽内存(HBM)传输到加速器的缓存中,而这个过程仅生成一个token,严重未充分利用现代加速器的算术计算能力。
为了应对这一挑战,研究目标是提高解码过程的算术强度(总浮点运算次数与总数据移动量的比率)并减少解码步骤的数量。现有的推测解码(speculative decoding)方法虽然遵循此思路,但获取和维护一个独立的、小型的草稿模型(draft model)非常困难,且在分布式系统中的集成也充满挑战。
本文提出了MEDUSA框架,作为一种高效的LLM推理加速方法。其核心创新点如下:
1. 多头并行预测:MEDUSA通过在现有LLM的主干模型之上增加额外的解码头(MEDUSA heads),来并行预测未来的多个token。这些头以参数高效的方式进行微调,无需引入一个独立的草稿模型,从而简化了集成过程,尤其适用于分布式环境。
2. 树状注意力机制(Tree Attention):为了有效利用多头预测,MEDUSA生成多个候选续写序列,并通过一种树状注意力机制来并行验证这些候选项。这种机制通过调整注意力掩码实现,使得在单次解码步骤中可以同时处理多个不同长度的候选序列。
3. 两种微调策略:
* MEDUSA-1:在冻结的主干LLM之上直接微调MEDUSA头。这种方法计算资源需求低,可以实现无损的推理加速,且不影响原始模型的性能。
* MEDUSA-2:将MEDUSA头与主干LLM联合微调。这种方法能提升MEDUSA头的预测准确性,从而获得更高的加速比,但需要特殊的训练策略来保持模型原有的能力。
4. 典型接受方案(Typical Acceptance Scheme):作为传统拒绝采样方案的替代,本文提出了一种典型接受方案。该方案使用温度作为阈值来选择合理的候选序列,而不是严格匹配原始模型的输出分布,从而在保持生成质量的同时进一步提升加速率。
5. 自蒸馏(Self-distillation):为了解决训练数据不可用的情况(例如,模型经过RLHF或训练数据集未公开),本文提出了一种自蒸馏方法。该方法利用模型自身生成训练数据来微调MEDUSA头,有效解决了数据缺失问题。
实验证明,MEDUSA-1可以在不牺牲生成质量的前提下实现超过2.2倍的加速,而MEDUSA-2能将加速比进一步提升至2.3到2.8倍。
A2 方法细节
MEDUSA遵循与推测解码相同的框架,每个解码步骤主要包括三个子步骤:(1)生成候选;(2)处理候选;(3)接受候选。对于MEDUSA,步骤(1)通过MEDUSA头实现,步骤(2)通过树状注意力实现。由于MEDUSA头位于原始模型之上,在步骤(2)中计算的logits可用于下一个解码步骤的子步骤(1)。最后的步骤(3)可以通过拒绝采样【28,Leviathan et al., Fast inference from transformers via speculative decoding, 2022】、【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】或典型接受(章节2.3.1)来实现。整体流程如图1所示。
本节首先介绍MEDUSA的关键组件,包括MEDUSA头和树状注意力。然后,我们提出两种级别的MEDUSA微调程序,以满足不同用例的需求。最后,我们提出MEDUSA的两个扩展,包括自蒸馏和典型接受,分别用于处理MEDUSA没有可用训练数据的情况和提高解码过程的效率。
图 1. MEDUSA在LLM的最后隐藏状态之上引入多个头,从而能够并行预测多个后续token(章节2.1.1)。在推理过程中,每个头为其指定位置生成多个顶部预测。这些预测被组装成候选,使用基于树的注意力机制并行处理(章节2.1.2)。最后一步是验证候选并接受一个续写。除了标准的拒绝采样方案,这里还可以使用典型接受方案(章节2.3.1)来选择合理的续写,最长的被接受的候选前缀将用于下一解码阶段。
2.1. 关键组件
2.1.1. MEDUSA头
用额外的解码头替代草稿模型。在推测解码中,后续的token由一个辅助的草稿模型预测。这个草稿模型必须足够小且有效,以生成能被原始模型接受的续写。满足这些要求是一项具有挑战性的任务,现有方法【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】、【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】通常需要单独预训练一个较小的模型。这个预训练过程需要大量的额外计算资源。例如,在【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】中,据报道使用了275个NVIDIA A100 GPU小时。此外,单独的预训练可能会在草稿模型和原始模型之间产生分布偏移,导致原始模型可能不倾向于接受其生成的续写。【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】也指出了在分布式环境中服务多个模型的复杂性。
MEDUSA头的设计与实现。为了简化和普及LLM推理的加速,我们从【44,Stern et al., Blockwise parallel decoding for deep autoregressive models, 2018】中获得灵感,该工作利用并行解码来完成机器翻译和图像超分辨率等任务。MEDUSA头是附加到原始模型最后隐藏状态的额外解码头。具体来说,给定原始模型在位置t的最后隐藏状态$h_t$,我们在$h_t$上添加K个解码头。第k个头用于预测后续token中第(t + k + 1)个位置的token(原始语言模型头用于预测第(t + 1)个位置)。第k个头的预测表示为$p_t^{(k)}$,代表词汇表上的一个分布,而原始模型的预测表示为$p_t^{(0)}$。我们遵循【44,Stern et al., Blockwise parallel decoding for deep autoregressive models, 2018】的方法,为每个头使用一个带有残差连接的单层前馈网络。我们发现这种简单的设计足以获得令人满意的性能。第k个头的定义如下:
其中d是LLM最后一个隐藏层的输出维度,$V$是词汇表大小。我们将$W_2^{(k)}$初始化为与原始语言模型头相同,而$W_1^{(k)}$初始化为零。这使得MEDUSA头的初始预测与原始模型的预测保持一致。我们采用SiLU激活函数【14,Elfwing et al., Sigmoid-weighted linear units for neural network function approximation in reinforcement learning, 2017】,与Llama模型【45,Touvron et al., Llama 2: Open foundation and fine-tuned chat models, 2023】保持一致。
MEDUSA头的优势。与草稿模型不同,MEDUSA头是与原始主干模型联合训练的,主干模型可以在训练期间保持冻结(MEDUSA-1)或一同训练(MEDUSA-2)。这种方法即使在单个GPU上也能微调大型模型,利用了强大的基础模型学到的表示。此外,它确保了MEDUSA头的分布与原始模型的分布对齐,从而减轻了分布偏移问题。另外,由于新头只包含一个类似于原始语言模型头的单层网络,MEDUSA不会增加服务系统设计的复杂性,并且对分布式设置友好。我们将在2.2节讨论MEDUSA头的训练方案。
2.1.2. 树状注意力
通过树状注意力并行处理多个候选。通过MEDUSA头,我们获得了后续K+1个token的概率预测。这些预测使我们能够创建长度为K+1的续写作为候选。虽然推测解码研究【28,Leviathan et al., Fast inference from transformers via speculative decoding, 2022】、【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】建议采样单个续写作为候选,但在解码过程中利用多个候选可以提高单个解码步骤内的预期接受长度。然而,更多的候选也会增加计算需求。为了取得平衡,我们采用树状结构的注意力机制来并行处理多个候选。这种注意力机制不同于传统的因果注意力范式。在这个框架内,只有来自同一次续写的token被视为历史数据。我们从图神经网络领域【54,Ying et al., Do transformers really perform badly for graph representation?, 2021】中将图结构嵌入注意力的概念中获得灵感,将树结构融入我们的注意力掩码中,如图2所示。值得注意的是,类似的想法也已在独立的工作中被探索,如【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】、【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】,它们采用自下而上的方法,通过合并由草稿模型生成的多个候选来构建树。在我们的方法中,我们则采用自上而下的方法来构建树,这得益于MEDUSA头生成的候选结构。
树的构建与注意力掩码。对于给定的第k个头,其top-$s_k$个预测作为候选形成的基础,其中$s_k$是一个指定的超参数。这些候选是通过取每个头的top-$s_k$个预测的笛卡尔积来建立的。例如,在图2中,$s_1=2$且$s_2=3$,每个第一头的预测都可以由第二头的任何一个预测接续。这导致了一个树结构,其中第k层存在$s_k$个分支(将虚拟根视为第0层,实际上,这个第0层用于原始模型的语言模型头的预测,可以独立采样)。在这棵树中,只有token的前驱被视为历史上下文,我们的注意力掩码确保注意力只作用于token的前驱上。通过使用这个掩码并为位置编码正确设置位置索引,我们可以同时处理大量候选,而无需扩大批处理大小。新增token的总数为 $\prod_{k=1}^K \sum_{i=1}^{k} s_i$。
图 2. 我们展示了如何使用树状注意力来并行处理多个候选。如图所示,第一个MEDUSA头的top-2预测和第二个的top-3预测共产生2×3=6个候选。每个候选对应于树结构中的一个不同分支。为确保每个token只能访问其前驱,我们设计了一个注意力掩码,该掩码只允许注意力从当前token流向其先行token。位置编码的位置索引也根据此结构进行调整。
树结构的灵活性。在本节中,我们展示了通过取笛卡尔积来构建树结构的最简单和最常规的方法。然而,可以用更复杂的方式构建树结构,并利用不同头的不同顶部预测的不平衡准确性。我们将在2.3.3节讨论这一点。
2.2. 训练策略
在最基本的层面上,我们可以通过冻结主干模型并微调MEDUSA头来训练它们。然而,将主干与MEDUSA头一起训练可以显著提高MEDUSA头的准确性。根据计算资源和用例的具体要求,我们提出了两种级别的MEDUSA头训练策略。
在本节中,我们假设有一个与目标模型输出分布一致的训练数据集。这可能是用于目标模型监督微调(SFT)的数据集。我们将在2.3.2节讨论如何使用自蒸馏方法来消除对这类数据集的需求。
2.2.1. MEDUSA-1: 冻结主干
训练方法与损失函数。要使用冻结的主干模型训练MEDUSA头,我们可以使用MEDUSA头的预测与真实值之间的交叉熵损失。具体来说,给定位置$t+k+1$的真实token $y_{t+k+1}$,第k个头的损失为$L_k = -\log p_t^{(k)}(y_{t+k+1})$,其中$p_t^{(k)}(y)$表示第k个头预测token $y$的概率。我们还观察到,当k较大时,$L_k$也较大,这是合理的,因为当k较大时,第k个头的预测更不确定。因此,我们可以为$L_k$添加一个权重$\lambda_k$来平衡不同头的损失。总的MEDUSA损失是:
实践中的参数设置与优化。在实践中,我们将$\lambda_k$设置为一个常数(如0.8)的k次幂。由于我们只使用主干模型提供隐藏状态,我们可以使用主干模型的量化版本来减少内存消耗。这引入了一种更普及的加速LLM推理的方法,因为通过量化,MEDUSA可以在单个消费级GPU上为大型模型进行训练,类似于QLoRA【11,Dettmers et al., Qlora: Efficient finetuning of quantized llms, 2023】。训练只需几个小时(例如,在单个NVIDIA A100 PCIE GPU上,使用60k ShareGPT样本训练Vicuna 7B模型的MEDUSA-1需要5小时)。
2.2.2. MEDUSA-2: 联合训练
联合训练的挑战与策略。为了进一步提高MEDUSA头的准确性,我们可以将MEDUSA头与主干模型一起训练。然而,这需要一个特殊的训练方案来保持主干模型的下一token预测能力和输出质量。为实现这一目标,我们提出了三种策略:
* 组合损失 (Combined loss):为了保持主干模型的下一token预测能力,我们需要将主干模型的交叉熵损失 $L_{LM} = -\log p_t^{(0)}(y_{t+1})$ 添加到MEDUSA损失中。我们还添加一个权重 $\lambda_0$ 来平衡主干模型和MEDUSA头的损失。因此,总损失为:
* 差分学习率 (Differential learning rates):由于主干模型已经训练得很好,而MEDUSA头需要更多的训练,我们可以为它们使用不同的学习率,以实现MEDUSA头的更快收敛,同时保持主干模型的能力。
* 头部预热 (Heads warmup):注意到在训练初期,MEDUSA头的损失很大,这会导致梯度很大,可能会扭曲主干模型的参数。借鉴【26,Kumar et al., Fine-tuning can distort pretrained features and underperform out-of-distribution, 2022】的思想,我们可以采用两阶段训练过程。在第一阶段,我们只像MEDUSA-1那样训练MEDUSA头。在第二阶段,我们使用预热策略将主干模型和MEDUSA头一起训练。具体来说,我们首先训练主干模型几个epoch,然后再将MEDUSA头与主干模型一起训练。除了这种简单的策略,我们还可以使用更复杂的预热策略,通过逐渐增加主干模型损失的权重$\lambda_0$来实现。我们发现在实践中这两种策略都效果很好。
联合训练的优势。综合运用这些策略,我们可以在不损害主干模型能力的情况下,将MEDUSA头与主干模型一起训练。此外,这个方案可以与监督微调(SFT)一起应用,使我们能够获得一个原生支持MEDUSA的模型。
2.2.3. 如何选择头的数量
头的数量选择建议。根据经验,我们发现五个头在大多数情况下是足够的。因此,我们建议使用五个头进行训练,并参考2.3.3节中描述的策略来确定树状注意力的最佳配置。通过优化的树状注意力,有时三个或四个头可能就足够用于推理了。在这种情况下,我们可以忽略多余的头,而不会产生额外开销。
2.3. 扩展
2.3.1. 典型接受
拒绝采样的局限性。在推测解码的论文中【28,Leviathan et al., Fast inference from transformers via speculative decoding, 2022】、【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】,作者采用拒绝采样来产生与原始模型分布一致的多样化输出。然而,后续的实现【23,Joao Gante, Assisted generation: a new direction toward low-latency text generation, 2023】、【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】表明,随着采样温度的升高,这种采样策略的效率会降低。直观上,这可以从草稿模型与原始模型相同的极端情况来理解:使用贪心解码,草稿模型的所有输出都将被接受,从而最大化效率。相反,拒绝采样引入了额外的开销,因为草稿模型和原始模型是独立采样的。即使它们的分布完全一致,草稿模型的输出仍可能被拒绝。
典型接受方案的提出。然而,在现实场景中,从语言模型中采样通常是为了生成多样化的响应,而温度参数仅用于调节响应的“创造性”。因此,更高的温度应该为原始模型接受草稿模型的输出提供更多机会。我们确定,通常没有必要严格匹配原始模型的分布。因此,我们提出采用一种典型接受方案来选择合理的候选,而不是使用拒绝采样。这种方法受到截断采样研究【18,Hewitt et al., Truncation sampling as language model desmoothing, 2022】的启发(详见附录A)。我们的目标是选择那些典型的候选,即它们不会被原始模型认为极不可能产生。我们使用原始模型的预测概率作为这一点的自然衡量标准,并根据预测分布建立一个阈值来决定是否接受。
典型接受的判断标准。具体来说,给定上下文$x_1, x_2, \dots, x_n$,在评估候选序列$(x_{n+1}, x_{n+2}, \dots, x_{n+K+1})$(由原始语言模型头和MEDUSA头的顶部预测组成)时,我们考虑以下条件:
其中$H(\cdot)$表示熵函数,$\epsilon$和$\delta$分别是硬阈值和依赖于熵的阈值。这个标准改编自【18,Hewitt et al., Truncation sampling as language model desmoothing, 2022】,基于两个观察:(1)具有相对高概率的token是有意义的;(2)当分布的熵很高时,各种续写都可能被认为是合理的。在解码期间,每个候选都使用此标准进行评估,如果候选的前缀满足条件,则接受该前缀。为了保证每一步至少生成一个token,我们对第一个token应用贪心解码并无条件接受它,而对后续token采用典型接受。当前步骤的最终预测由所有候选中最长的被接受前缀确定。
典型接受方案的特性。审视这个方案可以得出几个见解。首先,当温度设置为0时,它会退化为贪心解码,因为只有最可能的token才具有非零概率。当温度超过0时,只要$\epsilon, \delta$设置得当,贪心解码的结果将始终被接受,因为这些token具有最大概率,从而产生最大加速。同样,在一般情况下,温度升高将相应导致更长的被接受序列,我们的实验结果也证实了这一点。根据经验,我们验证了典型接受可以在保持与图5所示相似的生成质量的同时,实现更好的加速比。
2.3.2. 自蒸馏
解决训练数据缺失问题。在2.2节中,我们假设存在一个与目标模型输出分布匹配的训练数据集。然而,情况并非总是如此。例如,模型所有者可能只发布模型而不提供训练数据,或者模型可能经过了人类反馈强化学习(RLHF)过程,这使得模型的输出分布与训练数据集不同。为了解决这个问题,我们提出了一个自动化的自蒸馏流程,利用模型自身为MEDUSA头生成训练数据集,该数据集与模型的输出分布相匹配。
数据集生成过程。数据集的生成过程非常直接。我们首先从一个与目标模型领域相似的公共种子数据集中获取数据;例如,对于聊天模型,使用ShareGPT【41,ShareGPT, 2023】数据集。然后,我们简单地从数据集中提取提示,并要求模型对这些提示进行回复。为了获得多轮对话样本,我们可以顺序地将种子数据集中的提示喂给模型。或者,对于像Zephyr 7B【46,Tunstall et al., Zephyr: Direct distillation of lm alignment, 2023】这样的模型,它们在对话的两个角色上都进行了训练,具有自对话的能力,我们可以简单地输入第一个提示,让模型生成多轮对话。
MEDUSA-1与MEDUSA-2的自蒸馏策略。对于MEDUSA-1,这个数据集足以训练MEDUSA头。然而,对于MEDUSA-2,我们观察到仅使用这个数据集来训练主干和MEDUSA头通常会导致生成质量下降。事实上,即使不训练MEDUSA头,仅用这个数据集训练主干模型也会导致性能下降。这表明,我们还需要使用原始模型的概率预测,而不是使用真实token作为主干模型的标签,这与经典的知识蒸馏工作【25,Kim & Rush, Sequence-level knowledge distillation, 2016】类似。具体来说,主干模型的损失是:
其中$p_{\text{original},t}^{(0)}$表示原始模型在位置t的预测概率分布。
内存优化。然而,要获得原始模型的概率预测,通常需要在训练期间维护两个模型,这增加了内存需求。为了进一步缓解这个问题,我们提出了一种简单而有效的方法来利用自蒸馏设置。我们可以使用像LoRA【22,Hu et al., Lora: Low-rank adaptation of large language models, 2021】这样的参数高效适配器来微调主干模型。这样,原始模型就是关闭适配器的模型。因此,蒸馏不需要额外的内存消耗。总的来说,这个自蒸馏流程可以在不损害主干模型能力的情况下训练MEDUSA-2,并且几乎不增加额外的内存消耗。最后,关于使用自蒸馏的一个技巧是,在这种情况下最好使用不带量化的LoRA,否则教师模型将是量化模型,这可能导致生成质量下降。
2.3.3. 搜索优化的树构建
超越笛卡尔积的树构建方法。在2.1.2节中,我们介绍了通过取笛卡尔积来构建树结构的最简单方法。然而,在树中总节点数预算固定的情况下,规则的树结构可能不是最佳选择。直观地,由不同头的顶部预测组成的那些候选序列可能具有不同的准确性。因此,我们可以利用对准确性的估计来构建树结构。
基于准确性贪心构建树。具体来说,我们可以使用一个校准数据集,并计算不同头的顶部预测的准确性。让$a_k(i)$表示第k个头的第i个顶部预测的准确性。假设准确性是独立的,我们可以将由不同头的顶部$[i_1, i_2, \dots, i_k]$预测组成的候选序列的准确性估计为$\prod_{j=1}^k a(i_j)_j$。让$I$表示所有可能的$[i_1, i_2, \dots, i_k]$组合的集合,并且$I$中的每个元素都可以映射到树的一个节点(不仅包括叶节点,还包括所有节点)。那么,候选序列的期望接受长度是:
贪心算法。考虑通过逐个添加节点来构建树,一个新节点对期望值的贡献恰好是与该节点相关的准确性。因此,我们可以通过选择连接到当前树且具有最高准确性的节点来贪心地向树中添加节点。这个过程可以重复进行,直到总节点数达到期望的数量。通过这种方式,我们可以构建一个最大化期望接受长度的树。更多细节可以在附录C中找到。
A3 实验环境
- 数据集:
- MT-Bench【57,Zheng et al., Judging llm-as-a-judge with mt-bench and chatbot arena, 2023】:一个多轮对话格式的基准测试,用于评估模型的整体性能和生成质量。
- ShareGPT【41,ShareGPT, 2023】:一个公开的对话数据集,用于训练Vicuna-7B/13B模型的MEDUSA头,并作为Vicuna-33B和Zephyr-7B自蒸馏的种子数据集。
- UltraChat【12,Ding et al., Enhancing chat language models by scaling high-quality instructional conversations, 2023】:作为自蒸馏的另一个种子数据集。
- Alpaca-eval【13,Dubois et al., Alpacafarm: A simulation framework for methods that learn from human feedback, 2023】:用于校准和优化树状注意力结构。
- 模型架构:
- Vicuna【7,Chiang et al., Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, 2023】:使用了7B、13B和33B版本,这些模型是基于Llama/Llama-2【45,Touvron et al., Llama 2: Open foundation and fine-tuned chat models, 2023】微调的聊天模型。
- Zephyr-7B【46,Tunstall et al., Zephyr: Direct distillation of lm alignment, 2023】:一个经过监督微调和对齐(如RLHF)训练的模型。
- MEDUSA头:所有实验中训练了5个MEDUSA头,每个头包含1个单层前馈网络。
- 硬件配置:
- GPU:实验在NVIDIA A100-80GB-PCIe, A40, 和 A6000 GPU上进行。
- 软件配置:
- 训练框架:使用Axolotl【2,Axolotl, 2023】框架进行所有模型的训练。
- 优化器:使用8-bit AdamW【9,Dettmers et al., 8-bit optimizers via block-wise quantization, 2021】优化器。
- 学习率调度器:余弦学习率调度器(cosine learning rate scheduler)与预热(warmup)。
- 参数高效微调:对于MEDUSA-2,使用LoRA【22,Hu et al., Lora: Low-rank adaptation of large language models, 2021】或QLoRA【11,Dettmers et al., Qlora: Efficient finetuning of quantized llms, 2023】进行微调。LoRA应用于所有线性层,rank=32, alpha=16, dropout=0.05。
- 量化:在MEDUSA-1训练和MEDUSA-2的QLoRA配置中使用了4-bit量化的主干模型。
A4 实验结果
3.1. 案例研究:Vicuna 7B和13B上MEDUSA-1与MEDUSA-2的对比
- 实验内容:在Vicuna-7B和13B模型上,使用公开的ShareGPT数据集训练MEDUSA头,对比MEDUSA-1(冻结主干)和MEDUSA-2(联合微调)相对于基线(默认Huggingface实现)的加速效果。
- 实验结果:
- Vicuna-7B:MEDUSA-1实现了2.18倍的加速,MEDUSA-2进一步提升至2.83倍。
- Vicuna-13B:MEDUSA-1实现了2.33倍的加速,MEDUSA-2同样达到2.83倍。
- 任务类型分析:在MT-Bench的8个类别中,编码(coding)任务获得了3.29倍的加速,而提取(Extraction)任务的加速比最高,达到3.62倍(如图3右图所示)。
- 分析结论:MEDUSA-1和MEDUSA-2都能显著提升推理速度,且MEDUSA-2的加速效果更佳。MEDUSA在特定任务(如编码和信息提取)上表现尤为出色,显示出其在优化特定领域LLM方面的巨大潜力。
图 3. 左图:基线、MEDUSA-1和MEDUSA-2在Vicuna-7B/13B上的速度比较。与基线实现相比,MEDUSA-1实现了超过2倍的壁钟时间加速,而MEDUSA-2进一步显著提升了加速比。右图:Vicuna-7B与MEDUSA-2在MT-Bench的8个类别上的详细加速性能。
3.2. 案例研究:在Vicuna-33B和Zephyr-7B上使用自蒸馏进行训练
- 实验内容:针对训练数据不可用的场景(Vicuna-33B使用私有数据集,Zephyr-7B经过RLHF),采用自蒸馏方法生成约10万样本的数据集来训练MEDUSA-2模型。
- 实验结果:
- 如表1所示,所有MEDUSA-2模型都实现了显著的加速,同时生成质量与原始模型相当(由GPT-4评估的MT-Bench分数变化很小)。
- Vicuna-33B模型的加速率(2.32倍)相对较低,但质量保持得很好。推测原因可能是自蒸馏所用数据集与模型原始训练数据集之间存在不匹配。
- 与开源草稿模型实现的推测解码(SpecDecoding)相比,MEDUSA在所有模型尺寸上都取得了更高的加速比。
- 分析结论:自蒸馏是一种可行的训练策略,即使在没有原始训练数据的情况下,也能有效训练MEDUSA头,实现加速同时保持模型质量。这表明MEDUSA-2的配置在提升处理效率的同时,能谨慎地保持模型输出的质量。
表格 1. 不同MEDUSA-2模型的比较。第一部分报告了MEDUSA-2的细节,包括加速率、开销和质量,质量表示在MT-Bench上与原始模型相比的平均得分。第二部分分别列出了SpecDecoding和MEDUSA的加速比(S)。
3.3. 消融研究
3.3.1. 树状注意力的配置
- 实验内容:在MEDUSA-2 Vicuna-7B上,对比随机采样的密集树配置与优化的稀疏树配置对加速率和实际速度(tokens/s)的影响。
- 实验结果:
- 加速率:如图4a所示,一个包含64个节点的优化稀疏树的加速率优于一个包含256个节点的密集树。这表明优化的树结构效率更高。
- 实际速度:如图4b所示,随着树中候选token数量的增加,虽然加速率可能提高,但由于计算开销(矩阵乘法)的增加,实际速度(tokens/s)会下降。
- 分析结论:树的结构对性能至关重要。虽然更复杂的树可以提高加速率,但必须权衡其带来的计算开销。初始阶段增加树的大小收益显著,但当加速率的增长低于开销的增长时,整体性能会下降。因此,选择一个优化的、大小适中的稀疏树是关键。
图 4. 树引入的用于解码的候选token数量的有效性(使用KV缓存时,默认解码候选token数为1)。左图:随机采样的密集树设置(蓝点)和优化的稀疏树设置(红星)的加速率。右图:两种设置的速度(tokens/s)。趋势线表明,虽然稀疏树的加速率保持相对稳定,但随着候选token的增加,速度有明显下降。
3.3.2. 典型接受的阈值
- 实验内容:在MEDUSA-2 Vicuna 7B上,固定采样温度为0.7,研究典型接受方案中不同阈值$\epsilon$(从0.01到0.25)对加速率和生成质量(MT-Bench得分)的影响。
- 实验结果:如图5所示,存在一个明显的权衡:随着阈值$\epsilon$的增加,生成质量得到提升,但加速率会下降。对于需要创造性的任务,默认的随机采样(RS)性能优于贪心采样,而当$\epsilon$增加时,典型采样的性能与随机采样相当。
- 分析结论:典型接受方案的阈值$\epsilon$提供了一个可调的旋钮,用户可以根据对质量和速度的不同需求进行权衡。
图 5. 使用提出的典型采样时MEDUSA的性能比较。该模型是从Vicuna-7B完全微调而来。图中展示了在固定温度0.7下,3种不同设置的加速率和在写作与角色扮演(MT-Bench)上的平均得分:贪心采样和随机采样(RS)分别以星号和点表示,以及在不同阈值下的典型采样曲线。
3.3.3. 两阶段微调的有效性
- 实验内容:对比Vicuna-7B上三种不同的微调策略:MEDUSA-1(仅微调头)、MEDUSA-2(两阶段微调)和直接联合微调。
- 实验结果:如表2所示,
- MEDUSA-1实现了2.18倍的加速,且不损害生成质量。
- MEDUSA-2(两阶段微调)在保持生成质量的同时,提供了更高的加速比(2.83倍)。
- 直接联合微调(Direct fine-tuning)导致了生成质量的下降。
- 分析结论:两阶段微调策略(MEDUSA-2)是有效的,它能够在不牺牲模型原始能力的前提下,最大化MEDUSA带来的加速效果。
表格 2. Vicuna-7B不同设置的比较。质量是通过在MT-Bench上使用GPT-4作为评判者评估模型获得的(越高越好)。
A5 结论
本文提出的MEDUSA框架通过为LLM配备额外的预测性解码头,实现了并行生成多个token,从而绕过了顺序解码的限制,将LLM的推理速度提升了2.3至2.8倍。MEDUSA的主要优势包括其简单性、参数效率高以及易于集成到现有系统中。它避免了对专门的草稿模型的需求。提出的典型接受方案在提供合理输出的同时,消除了拒绝采样带来的复杂性。本文还提供了两种高效的训练程序,确保了在各种模型和提示类型下都能获得高质量的输出。如表3所示,各项技术的引入逐步提升了最终的加速效果。
尽管本文主要关注批处理大小为1的场景,但我们强调,论文中提出的思想可以推广到更大的批处理大小设置,目前已有TensorRT和Huggingface TGI等库根据本文的研究成果支持了这些设置。
表格 3.各项技术对加速比的影响
A6 附录
A. 相关工作
A.1. LLM推理加速
LLM推理效率低下的根源。大型语言模型(LLM)推理效率低下的主要原因是自回归解码过程的内存带宽受限特性。为了缓解这个问题,已经提出了几种方法来改善推理延迟和吞吐量。传统上,批处理推理被用作一种直接的方法来提高算术强度并摆脱内存带宽的限制。然而,对于LLM,模型参数和键值(KV)缓存都消耗了大量的加速器内存,这阻碍了使用大批量大小。解决这个问题现有方法可以概念上分为两大类:(1)减少内存消耗,从而最小化内存传输开销并启用更大的批处理大小;(2)最小化解码步骤的数量以直接减少延迟。
减少KV缓存。诸如多查询注意力(Multi-query attention)【42,Shazeer, Fast transformer decoding: One write-head is all you need, 2019】和分组查询注意力(Grouped-query attention)【1,Ainslie et al., Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023】等方法采用直接的方式来减少KV缓存。通过在注意力模块中使用比查询头更少的键和值头,这些策略显著减少了KV的内存消耗,从而促进了更大的批处理大小和增强的加速器利用率【40,Pope et al., Efficiently scaling transformer inference, 2022】。此外,【56,Zhang et al., H2o: Heavy-hitter oracle for efficient generative inference of large language models, 2023】提出选择性地保留最关键的KV token,进一步减少KV缓存。从系统角度看,【27,Kwon et al., Efficient memory management for large language model serving with pagedattention, 2023】引入了一种分页内存管理方案,以减少KV缓存的碎片化。
量化。量化技术被广泛用于缩减LLM的内存消耗。【52,Xiao et al., Smoothquant: Accurate and efficient post-training quantization for large language models, 2023a】在激活和参数之间应用重新缩放来消除异常值并简化量化过程。【10,Dettmers et al., Llm. int8 (): 8-bit matrix multiplication for transformers at scale, 2022】将矩阵乘法分解为主要的8位和少数的16位操作。【17,Frantar et al., Gptq: Accurate post-training quantization for generative pretrained transformers, 2022】迭代地将权重列舍入为3/4位,而【30,Lin et al., Awq: Activation-aware weight quantization for llm compression and acceleration, 2023】提出了一种激活感知的量化方案,以保护显著权重并将LLM压缩到3/4位。【24,Kim et al., Squeezellm: Dense-and-sparse quantization, 2023】引入了一种稀疏加低精度的模式来处理一小部分重要权重等技术。
推测解码(Speculative Decoding)。作为与上述方法正交的一种方法,推测解码【28,Leviathan et al., Fast inference from transformers via speculative decoding, 2022】、【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】旨在并行执行多个解码步骤,从而减少所需的总步骤数。这种并行化是通过使用一个较小的草稿模型来推测几个后续词,然后LLM集体评估并接受这些词是否合适来实现的。虽然这与非自回归生成文献【53,Xiao et al., A survey on non-autoregressive generation for neural machine translation and beyond, 2023b】相呼应,但该方法专门为LLM量身定制,以解决前述的效率低下问题。与先前的工作不同,我们建议利用原始模型进行预测,而不是引入额外的草稿模型。这种方法更直接,并且可以无缝集成到现有系统中,而没有管理两个模型的复杂性。独立地,【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】、【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】提出了使用树状结构注意力来并行生成多个候选,其中【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】建议使用模型集成来提出候选,而【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】主张为草稿模型添加另一个层次结构。然而,草稿模型需要专门的预训练并与目标模型对齐。虽然使用多个草稿模型可能很麻烦并涉及管理并行性的复杂性,但我们的方法仅依赖于解码头,提供了一种更简单的替代方案。【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】使用多个草稿模型生成token,并使用树状注意力将它们合并,而【43,Spector & Re, Accelerating llm inference with staged speculative decoding, 2023】利用一个小的草稿模型分批处理树的每一层。相比之下,我们的方法直接使用每个MEDUSA头的顶部预测token来创建一个静态的稀疏树,无需自回归或调整树结构。这种方法简化了过程并提高了效率。此外,我们通过详细的消融研究展示了树的节点如何影响解码速度。
A.2. 采样方案
采样方案对生成质量的影响。从大型语言模型(LLM)中采样文本的方式可以显著影响生成输出的质量。最近的研究表明,直接从语言模型采样可能导致不连贯或无意义的结果【39,Pillutla et al., MAUVE: Measuring the gap between neural text and human text using divergence frontiers, 2021】、【21,Holtzman et al., The curious case of neural text degeneration, 2020】。为了应对这一挑战,引入了截断采样方案【15,Fan et al., Hierarchical neural story generation, 2018】、【3,Basu et al., MIROSTAT: A neural text decoding algorithm that directly controls perplexity, 2021】、【31,Meister et al., On the probability-quality paradox in language generation, 2022】、【18,Hewitt et al., Truncation sampling as language model desmoothing, 2022】、【32,Meister et al., Locally typical sampling, 2023】。这些方法旨在通过在每个解码步骤中对特定允许集上的截断分布进行采样来产生高质量和多样化的样本。
不同截断采样策略。不同的策略以各种方式定义这个允许集。例如,top-k采样【15,Fan et al., Hierarchical neural story generation, 2018】保留k个最可能的词,而top-p采样【21,Holtzman et al., The curious case of neural text degeneration, 2020】包含占p%概率的最小词集。另一种方法,称为典型解码(typical decoding)【32,Meister et al., Locally typical sampling, 2023】,利用预测分布的熵来建立包含的阈值。【18,Hewitt et al., Truncation sampling as language model desmoothing, 2022】提供了一个统一的框架来全面理解截断采样技术。
典型接受方案的启发。受这些方法的启发,我们的典型接受方案与定义一个允许集以从采样过程中排除不可能的候选的概念相一致。然而,我们的不同之处在于我们不坚持输出与语言模型分布之间的精确对应。这种偏差使我们能够促进更多样化但高质量的输出,在不损害生成文本完整性的情况下实现更高的效率。
B. 实验设置
B.1. 常用术语
我们澄清三个常用术语:a) 加速率(Acceleration rate):指每个解码步骤解码的平均token数。在标准的自回归模型中,此速率为1.0。b) 开销(Overhead):用于描述与经典解码相比,每个解码步骤的开销,通过将MEDUSA模型的平均每步延迟除以原始模型的平均每步延迟计算得出。c) 加速比(Speedup):指壁钟时间的加速率。根据这些定义,我们有关系式:加速比 = 加速率 / 开销
。
B.2. 共享设置
对于所有实验,我们使用Axolotl【2,Axolotl, 2023】框架进行训练。我们使用带预热的余弦学习率调度器,并使用8位AdamW【9,Dettmers et al., 8-bit optimizers via block-wise quantization, 2021】优化器。我们训练5个MEDUSA头,每个头有1层,并将公式(1)中的$\lambda_k$设置为$0.8^k$。对于MEDUSA-2,我们使用LoRA【22,Hu et al., Lora: Low-rank adaptation of large language models, 2021】或QLoRA【11,Dettmers et al., Qlora: Efficient finetuning of quantized llms, 2023】进行微调,并将MEDUSA头的学习率设置为主干模型的4倍。LoRA应用于主干模型的所有线性层,包括语言模型头。LoRA适配器的秩设置为32,$\alpha$设置为16。在LoRA适配器上添加了0.05的dropout。
B.3. MEDUSA-1 vs. MEDUSA-2 on Vicuna 7B and 13B
我们使用64的全局批处理大小,主干模型的峰值学习率为5e-4,MEDUSA头的峰值学习率为2e-3,预热40步。我们对两个模型都使用4位量化的主干模型。我们首先用MEDUSA-1训练模型,并使用这些训练好的模型作为初始化来训练MEDUSA-2。我们对MEDUSA-2使用QLoRA,并将公式(2)中的$\lambda_0$设置为0.2。
B.4. Training with Self-Distillation on Vicuna-33B and Zephyr-7B
我们对这两个模型都使用MEDUSA-2,而不是使用两阶段训练过程。我们使用正弦调度来处理$\theta_0$,使其在训练结束时逐渐增加到峰值。我们发现这种方法同样有效。我们将主干LoRA适配器的峰值学习率设置为1e-4,预热步数设置为20,因为自蒸馏损失相对较小。我们将公式(2)中的$\lambda_0$设置为0.01。
C. 优化树状注意力的可视化
图6展示了为MEDUSA-2 Vicuna-7B模型构建的稀疏树的结构。该树结构延伸四层深,表明计算中涉及四个MEDUSA头。该树最初通过笛卡尔积方法形成,随后根据在Alpaca-eval数据集【13,Dubois et al., Alpacafarm: A simulation framework for methods that learn from human feedback, 2023】上测量的每个MEDUSA头的top-k预测的统计期望进行剪枝。树向左倾斜的形态直观地表示了算法对每个头上具有更高概率的节点的偏好。
图 6. MEDUSA-2 Vicuna-7B的稀疏树设置可视化。该树有64个节点,代表候选token,深度为4,表示有4个MEDUSA头参与计算。每个节点表示来自MEDUSA头top-k预测的一个token,边表示它们之间的连接。红线突出显示了正确预测未来token的路径。
D. 推测解码的结果
在本研究中,推测解码被应用于不同大小的Vicuna模型【7,Chiang et al., Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, 2023】,具体为7B、13B和33B。初步框架利用了开源模型,如Llama-68M和160M【33,Miao et al., Specinfer: Accelerating generative llm serving with speculative inference and token tree verification, 2023】,以及Tiny-Llama【55,Zhang et al., Tinyllama: An open-source small language model, 2024】和Tiny-Vicuna【38,Pan, Tiny vicuna 1b, 2023】,后者是从Tiny-Llama使用Vicuna风格的指令调优策略微调而来。由于推测解码方法【5,Chen et al., Accelerating large language model decoding with speculative sampling, 2023】、【28,Leviathan et al., Fast inference from transformers via speculative decoding, 2022】的专有性,我们部署了开源替代方案进行评估。此外,我们利用torch.compile()来加速草稿模型的推理速度。
我们的结果如图7所示,揭示了草稿模型的最佳设置随Vicuna模型大小而变化。具体来说,Llama-68M在草稿token数$\gamma=4$的设置下,对Vicuna-7B表现最佳,而相同的草稿模型在$\gamma=3$时对Vicuna-13B最有效。对于更大的Vicuna-33B,Tiny-Vicuna(Vicuna-1B)在$\gamma=3$时提供了最大的加速。这些结果表明,草稿模型的选择和设置应根据LLM的大小量身定制,这为该领域提供了进一步探索的方向。
图 7. 在MT-Bench上使用推测解码的各种模型的推理速度。基线模型速度以灰色虚线表示以供比较。γ表示草稿token的数量。
E. 所有模型的附加结果
我们在图8中展示了各种模型的加速比。
图 8. 各种模型使用MEDUSA-2的加速比。MEDUSA-2在所有模型上都显示出显著的速度提升,而使用自蒸馏训练的模型(Zephyr-7B, Vicuna-13/33B)由于在保持质量和提升速度之间的权衡,加速比较弱。
F. 在AlpacaEval数据集上的附加结果
我们在AlpacaEval【29,Li et al., Alpacaeval: An automatic evaluator of instruction-following models, 2023】数据集上进行了进一步的实验。MEDUSA-2取得了与MT-Bench上结果相似的一致加速比。
表格 4. 在AlpacaEval【29,Li et al., Alpacaeval: An automatic evaluator of instruction-following models, 2023】数据集上的加速比结果。
G. 硬件约束与MEDUSA的探索与建模
我们通过引入一个简化的Llama系列模型,探讨了硬件约束,特别是内存带宽限制,及其对MEDUSA式并行解码的影响。首先,我们确定了涉及矩阵乘法的算子,如线性层和注意力矩阵乘法,是开销的主要来源。我们在多种GPU上(包括A100-80GB-PCIe、A40和A6000)分析了FLOP/s与操作强度(Operational Intensity,即FLOP/s与带宽(字节/秒)的比率)的性能。接下来,我们研究了在使用MEDUSA时不同算子的FLOP/s与操作强度的变化。最后,我们应用一个简单的分析模型来计算加速率,并将其与硬件基准测试相结合。这为不同模型大小、序列长度和批处理大小下的影响提供了见解。
G.1. 算子的Roofline模型
分析设置。我们对大型语言模型(LLM)中各种算子的roofline模型进行了分析,特别关注Llama-7B、Llama-13B和Llama-33B【45,Touvron et al., Llama 2: Open foundation and fine-tuned chat models, 2023】。这些模型在不同的GPU上进行了基准测试,包括A100-80GB-PCIe、A40和A6000。我们研究了三类矩阵乘法算子,因为它们是这些模型中主要的计算开销来源。我们的研究遵循了报告【6,Chen, Dissecting batching effects in gpt inference, 2023】,该报告调查了批处理大小的效果,但我们的研究更侧重于解码和并行解码。
复杂度分析。表5详细说明了预填充(prefill)、解码(decoding)和MEDUSA解码阶段中每个算子的计算和空间复杂度。这些算子包括用于查询、键、值矩阵的线性层($XW_Q, XW_K, XW_V$),注意力矩阵乘法($QK^T, PV$),以及上/门/下线性层($XW_u, XW_g, XW_d$)。其中b代表批大小,s代表序列长度,h代表隐藏维度,i代表中间维度,n代表注意力头的数量,d代表头维度,q代表MEDUSA的候选长度。
表格 5. 不同阶段主要算子的计算和空间复杂度。该表基于报告【6,Chen, Dissecting batching effects in gpt inference, 2023】中的表2。
基准测试结果。图9-17展示了在不同模型(7/13/33B)和各种设置下三类算子的基准测试结果。为了评估每个算子的性能和吞吐量,我们选择了从1到64(2的幂次方)的批处理大小和从128到8192(2的幂次方)的序列长度的组合(每个算子49个设置)。从所有图中,我们观察到每个算子在预填充和解码阶段的数据点在所有GPU和各种模型大小上都聚集在非常相似的位置。
预填充与解码阶段的性能特征。在预填充阶段,增加批处理大小会改变注意力矩阵乘法的FLOP/s(见‘qk/pv init’),但不会影响操作强度(见图9中的垂直虚线箭头)。相比之下,增加序列长度会影响预填充阶段的FLOP/s和操作强度(见图9中的对角虚线箭头)。在解码阶段,注意力矩阵乘法明显受到内存带宽的限制。尽管随着批处理大小和序列长度的变化,FLOP/s有所增加,但操作强度几乎保持不变(见‘qk/pv ar’)。这表明自注意力机制中的资源利用不佳。
线性层的性能特征。预填充阶段的线性层大多是计算受限的(见‘qkv mlp init’和‘up/gate/down init’)。在解码阶段,线性层的数据点形成一条斜率与GPU内存带宽相同的线(见‘qkv mlp ar’和‘up/gate/down ar’)。这表明解码阶段的线性层也受到内存带宽的限制。增加批处理大小可以通过更好的并行性在内存带宽约束下提高实现的FLOP/s和操作强度。请注意,线性层只处理新的token,与序列长度无关(见表5中的‘解码’部分)。
图 9. 该图显示了在A100-80GB-PCIe上Llama-7B算子的所有基准测试数据点的FLOP/s与操作强度之间的关系。虚线分别代表HBM带宽限制(1,935GB/s)和峰值性能限制(312 TFLOP/s)(NVIDIA)。'qkv mlp'代表将隐藏特征投影到查询/键/值特征的线性层。'up/gate/down'代表注意力块之后的线性层。'qk/pv'代表注意力矩阵乘法的两个步骤。'ar'代表解码(自回归),'init'代表预填充阶段。
图 10. A100-80GB-PCIe上的Llama-13B算子。
Roofline模型 (Llama 33B, A100 80GB PCIe)
图 11. A100-80GB-PCIe上的Llama-33B算子。
Roofline模型 (Llama 7B, A40)
图 12. A40上的Llama-7B算子。
Roofline模型 (Llama 13B, A40)
图 13. A40上的Llama-13B算子。
Roofline模型 (Llama 33B, A40)
图 14. A40上的Llama-33B算子。
Roofline模型 (Llama 7B, A6000)
图 15. A6000上的Llama-7B算子。
Roofline模型 (Llama 13B, A6000)
图 16. A6000上的Llama-13B算子。
图 17. A6000上的Llama-33B算子。
G.2. MEDUSA中FLOP/s与操作强度的变化
MEDUSA对操作强度和FLOP/s的影响。我们研究了MEDUSA如何改变操作强度并提升FLOP/s。我们选择在A100-80GB-PCIe上运行Llama 33B作为实验设置。
注意力矩阵乘法。首先,我们检查注意力矩阵乘法。图18和表6展示了在批处理大小固定为16时MEDUSA的效果。我们观察到,随着添加更多候选token,FLOP/s和操作强度均增加(原始解码结果以灰点绘制)。这表明MEDUSA可以利用额外的候选token来提高计算吞吐量。与常规解码相比,在批处理大小为16、序列长度为1024且有64个候选token的设置下,MEDUSA实现了44倍的FLOP/s和41倍的操作强度。图19和表7展示了在序列长度固定为1024时MEDUSA解码的效果。在这种情况下,增加批处理大小并不能提高操作强度。
线性层。接下来,我们检查线性层,重点关注上/门/下线性层。结果如图20和表8所示。由于解码阶段的线性层只处理未来的token,而过去的token被缓存,因此它们与序列长度无关。我们改变批处理大小来观察效果。随着MEDUSA增加候选token数量以及批处理大小的增加,我们观察到性能从内存带宽受限区域转向计算受限区域。这种转变展示了MEDUSA如何将线性层的性能特征从受内存带宽限制转变为受计算能力限制。
Llama 33B, A100 80GB PCIe
图 18. 批处理大小为16时注意力矩阵乘法的FLOP/s与操作强度。
图 19. 序列长度为1024时注意力矩阵乘法的FLOP/s与操作强度。
Llama 33B, A100 80GB PCIe
图 20. 线性层的FLOP/s与操作强度。
表格 6. 在A100 80GB PCIe上,Llama 33B批处理大小为16时注意力矩阵乘法的TFLOP/s和操作强度。
表格 7. 在A100 80GB PCIe上,Llama 33B序列长度为1024时注意力矩阵乘法的TFLOP/s和操作强度。
表格 8. 在A100 80GB PCIe上,Llama 33B线性层(上/门/下)的TFLOP/s和操作强度。
G.3. 预测MEDUSA性能
建模与仿真。我们进一步采用一个简单的分析模型来预测加速率。3.3.1节中的消融研究结果表明,加速率可以用一个简单的对数函数来近似。利用图4a的结果,我们将曲线建模为 加速率 = 0.477 * log(候选数量)
。我们通过首先将批处理大小固定为1,序列长度固定为1024,来模拟Llama-7B模型一个简化块(顺序处理$XW_Q, XW_K, XW_V, QK^T, PV, XW_u, XW_g, XW_d$)的延迟。候选token通过构建2.1.2节中描述的树状注意力进行并行处理。我们忽略了MEDUSA的后处理步骤(包括验证和接受)的延迟,因为它们引入的开销很小。
不同候选数量的影响。图21展示了在这些设置下,不同候选token数量的模拟加速率和加速比。随着候选token数量的增加,加速率和加速比最初都显示出改善。然而,超过64个后,加速比开始下降,表明进一步增加候选长度的回报递减。这与图4b中的实验结果一致,并表明存在一个最佳的候选token数量范围,在此范围内MEDUSA提供最显著的性能增益。
不同批大小的影响。我们在图22中绘制了在固定序列长度1024下不同批处理大小设置的模拟加速比。结果表明,当批处理大小超过32时,加速比下降,甚至可能产生负面影响。这是因为线性层从内存带宽受限转向计算受限。
不同序列长度的影响。我们进行了另一项实验,使用批处理大小为4和不同的序列长度。如图23所示,最佳候选token数量在不同序列长度上保持相对一致。然而,随着序列长度的增加,整体性能下降。这种性能下降主要是由于注意力矩阵乘法的开销,而线性层计算保持不变,因为线性层的计算与序列长度无关。
仿真结论。我们的仿真表明,最佳候选token数量是使用MEDUSA进行模型扩展的关键,因为超过一定范围后收益会减少。最初,增加批处理大小通过并行性提高了性能,但过大的批处理大小会使线性层从内存带宽受限转为计算受限,从而降低加速比。更长的序列会增加注意力矩阵乘法的开销,降低性能,并强调了优化注意力机制的需求。有效的模型扩展需要平衡候选token的数量,调整批处理大小以避免计算受限的转变,并为更长的序列增强注意力机制。这些策略确保了更好的资源利用和更高的性能,展示了仿真在预测性能和指导加速策略设计中的价值。
Llama 7B, 批大小: 1, 序列长度: 1024
图 21. 在A100 80GB PCIe上,Llama-7B在批大小1和序列长度1024设置下,使用不同数量候选token的模拟加速率、加速比和归一化延迟消融研究。
Llama 7B, 序列长度: 1024
图 22. Llama-7B序列长度为1024时的模拟加速比。
Llama 7B, 批大小: 4
图 23. Llama-7B批大小为4时的模拟加速比。
💬 评论讨论
欢迎在这里分享您的想法和见解!