通过多令牌预测实现更好、更快的大型语言模型

Fabian Gloeckle * 1 2, Badr Youbi Idrissi * 1 3, Baptiste Rozière 1, David Lopez-Paz + 1, Gabriel Synnaeve + 1


A1 主要贡献

本文研究的核心问题是,当前大型语言模型(LLMs)普遍采用的“下一令牌预测”(next-token prediction)训练范式在样本效率上存在不足,模型需要比人类儿童多出几个数量级的数据才能达到同等流利程度。研究认为,这种方法容易固守于局部模式而忽略“困难”的决策点。

为了解决这一问题,本文提出的研究目标是探索一种更高效的训练方法:多令牌预测(multi-token prediction)。具体而言,该方法在训练语料库的每个位置,要求模型同时并行地预测未来的n个令牌。如下图所示,该架构包含一个共享的主干网络(shared trunk)和n个独立的输出头(output heads),每个头负责预测一个未来令牌。


图 1: 多令牌预测概览。(上图) 训练期间,模型通过一个共享主干和4个专用输出头,一次性预测未来4个令牌。推理期间,我们仅使用下一个令牌的输出头。可选地,其他三个头可用于加速推理时间。(下图) 多令牌预测提升了MBPP代码任务的pass@1,且模型尺寸越大,提升越显著。误差棒是通过对数据集样本进行自举法计算的90%置信区间。

本文的主要贡献如下:
1. 提出了一种简单的多令牌预测架构:该架构在训练时间和内存开销上与传统方法相比没有额外负担(详见第2节)。
2. 提供了大规模实验证据:实验证明该训练范式在模型规模扩大时(最高达13B参数)尤其有效,在代码生成任务上,性能平均提升约15%(详见第3节)。
3. 实现了更快的推理速度:多令牌预测使得自推测解码(self-speculative decoding)成为可能,在各种批量大小下,模型推理速度可提升高达3倍(详见第3.2节)。

作者希望这项工作能激发对超越下一令牌预测的新型辅助损失函数的研究兴趣,以提升这些模型的性能、连贯性和推理能力。

A2 方法细节

标准语言模型的目标函数。标准语言模型通过实现下一令牌预测任务来学习大型文本语料库 $x_1, \dots, x_T$。其学习目标是最小化以下交叉熵损失:


其中,$P_{\theta}$ 是我们正在训练的大型语言模型,目标是在给定历史令牌 $x_{t:1} = x_t, \dots, x_1$ 的情况下,最大化下一个未来令牌 $x_{t+1}$ 的概率。

多令牌预测的目标函数泛化。本文将上述目标泛化为多令牌预测任务,即在训练语料库的每个位置,模型被指示一次性预测n个未来的令牌。这对应于以下交叉熵损失:

多令牌预测损失的分解。为了使问题易于处理,我们假设大型语言模型 $P_{\theta}$ 使用一个共享的主干网络(shared trunk)来为观察到的上下文 $x_{t:1}$ 生成一个潜在表示 $z_{t:1}$,然后将该表示输入到n个独立的头(independent heads)中,以并行预测n个未来的令牌(如图1所示)。这使得多令牌预测的交叉熵损失可以分解为以下形式:

具体架构实现。在实践中,我们的架构由一个共享的Transformer主干网络 $f_s$ 构成,它从观察到的上下文 $x_{t:1}$ 生成隐藏表示 $z_{t:1}$;然后是n个独立的输出头,这些头由Transformer层 $f_h$ 实现;最后是一个共享的unembedding矩阵 $f_u$。因此,为了预测n个未来的令牌,我们计算:


对于 $i = 1, \dots, n$,其中,特别地,$P_{\theta}(x_{t+1} | x_{t:1})$ 是我们的下一令牌预测头。关于多令牌预测架构的其他变体,请参见附录B。

内存高效的实现策略。训练多令牌预测器的一大挑战是降低其GPU内存利用率。这是因为在当前的LLM中,词汇表大小V远大于潜在表示的维度d,因此logit向量成为GPU内存使用的瓶颈。如果朴素地实现多令牌预测器,将所有logits及其梯度(形状均为(n, V))都实例化,会严重限制可用的批量大小和平均GPU内存利用率。因此,在我们的架构中,我们提议仔细调整前向和后向传播操作的顺序,如图2所示。具体来说,在通过共享主干 $f_s$ 的前向传播之后,我们顺序地计算每个独立输出头 $f_i$ 的前向和后向传播,并将梯度累积到主干上。虽然这会为输出头 $f_i$ 创建logits(及其梯度),但这些张量在处理下一个输出头 $f_{i+1}$ 之前就被释放了,从而只需要长期存储d维的主干梯度 $\partial L_n / \partial f_s$。总而言之,我们将峰值GPU内存利用率从 $O(nV + d)$ 降低到了 $O(V + d)$,而没有增加运行时间(见附录表S5)。


图 2: n=2头的n-令牌预测模型中前向/后向传播的顺序。通过按顺序执行头的前向/后向传播,我们避免了同时在内存中实例化所有unembedding层的梯度,从而降低了峰值GPU内存使用量。

推理阶段的策略。在推理时,该架构最基本的用法是使用下一令牌预测头 $P_{\theta}(x_{t+1} | x_{t:1})$ 进行标准的自回归预测,同时丢弃所有其他头。然而,额外的输出头可以被用来加速下一令牌预测头的解码过程,通过使用自推测解码(self-speculative decoding)方法,例如块式并行解码(blockwise parallel decoding)【索引52,Blockwise parallel decoding for deep autoregressive models,2018,Mitchell Stern et al.】——这是推测解码【索引30,Fast inference from transformers via speculative decoding,2023,Yaniv Leviathan et al.】的一种变体,无需额外的草稿模型——以及使用类似Medusa的树状注意力进行推测解码【索引5,Medusa: Simple llm inference acceleration framework with multiple decoding heads,2024,Tianle Cai et al.】。

A3 关键洞察与设计原则

第5节 为何有效?一些推测

核心直觉。多令牌预测之所以在代码评估基准和小型算法推理任务上表现优越,其核心直觉在于,它缓解了训练时使用的“教师强制”(teacher forcing)与推理时使用的自回归生成之间的分布差异。本节通过一个关于多令牌预测如何根据令牌对文本后续内容的相关性来赋予其隐式权重的说明性论证,以及一个多令牌预测损失的信息论分解来支持这一观点。

5.1 前瞻性加强了决策点(Choice Points)

并非所有令牌都同等重要。并非所有令牌决策对于从语言模型生成有用的文本都同等重要【索引2,The pitfalls of next-token prediction,2024,Gregor Bachmann and Vaishnavh Nagarajan;索引33,Rho-1: Not all tokens are what you need,2024,Zhenghao Lin et al.】。一些令牌允许文体上的变化,而不会限制文本的其余部分,而另一些令牌则代表了与文本更高层次语义属性相关的“决策点”,它们可能决定一个答案被认为是有效的还是跑偏的。

多令牌预测的隐式加权机制。多令牌预测会根据训练令牌与其后续令牌的关联紧密程度,隐式地为它们分配权重。以图9中描绘的序列为例,其中一个转换是难以预测的决策点,而其他转换则被认为是“不重要的”。紧跟在决策点之后的不重要转换同样难以提前预测。通过标记和计算损失项,我们发现n-令牌预测通过其相关项,为决策点赋予了 $n(n+1)/2$ 的权重,而为不重要的点赋予了较小的权重n。更多细节请参见附录L.3。总的来说,我们认为文本生成的质量取决于在决策点做出正确的选择,而n-令牌预测损失则促进了这一点。


图 9: 多令牌预测损失为关键令牌分配更高的隐式权重。图中显示了一个序列,其中除了“5 → A”之外的所有转换都很容易预测,旁边是3-令牌预测对应的预测目标。由于困难转换“5 → A”的后果同样难以预测,这个转换通过其相关项“3 → A”, ..., “5 → C”在总损失中获得了更高的隐式权重。

5.2 信息论论证

教师强制与自回归的矛盾。语言模型通常通过教师强制进行训练,即模型在训练期间接收每个未来令牌的真实值。然而,在测试时,生成是无引导的自回归过程,错误会不断累积。我们认为,教师强制鼓励模型专注于在极短期内进行良好预测,这可能会以忽略生成序列整体结构中的长期依赖为代价。

多令牌预测的信息论分解。为了说明多令牌预测的影响,考虑以下信息论论证。这里,X表示下一个未来令牌,Y表示下下个未来令牌。这两个令牌的生成都以某个观察到的输入上下文C为条件,为简化起见,我们在公式中省略了C。在令牌X之前,标准的下一令牌预测关注的是量H(X),而n=2的多令牌预测目标是H(X) + H(Y)。我们将这两个量分解如下:


通过舍弃项H(Y|X)——该项在下一个位置进行预测时会再次出现——我们观察到2-令牌预测将I(X; Y)的重要性提高了2倍。因此,多令牌预测器在预测那些对文本后续内容有重要影响的令牌X时更为准确。在附录L.2中,我们给出了上述公式的相对版本,该版本显示了在2-令牌预测损失的分解中,相对互信息(relative mutual information)的权重增加了。

A4 实验环境

  • 数据集:

    • 代码: 在至少91B token的代码数据上从头训练。部分7B模型在200B、314B(字节级)或1T token的代码数据上训练。
    • 自然语言: 在200B或500B token的自然语言数据上训练7B模型。
    • 合成数据: 用于消融研究的数据集包括:儿童故事集、书籍与儿童故事的9:1混合数据集、一个多项式算术任务数据集。
    • 评估基准:
      • 代码: MBPP, HumanEval, APPS, CodeContests。
      • 自然语言:
        • 选择题/似然任务: ARC Challenge, COPA, Hellaswag, Natural Questions, PIQA, SIQA, TriviaQA。
        • 生成任务: 八个摘要数据集 (如Multi-news, samsum等)。
        • 数学推理: GSM8K。
  • 模型架构:

    • 基于Transformer架构,模型参数规模从300M到13B不等。
    • 为保证公平比较,当n-令牌预测模型增加n-1个预测头层时,会从共享主干网络中移除n-1层,以保持总参数量不变。
    • 具体模型参数(如层数、头数、维度)见附录表S14。
  • 硬件配置:

    • GPU: A100-80GB 和 H100。
    • 总计算量约为500K GPU小时。
  • 软件配置:

    • 实现库: 使用xFormers【索引29,xformers: A modular and hackable transformer modelling library,2022,Benjamin Lefaudeux et al.】实现贪婪自推测解码。
    • 优化器: Adam【索引24,Adam: A method for stochastic optimization,2015,Diederik Kingma and Jimmy Ba】,参数 $\beta_1=0.9, \beta_2=0.95$,解耦权重衰减系数为0.1。
    • 学习率调度: 线性预热(warmup)和余弦衰减(cosine decay)【索引34,Sgdr: Stochastic gradient descent with warm restarts,2017,Ilya Loshchilov and Frank Hutter】。
    • 具体超参数见附录表S13。

A4 实验结果

真实数据实验 (第3节)

  • 3.1 模型规模效应: 在代码任务(MBPP, HumanEval)上,多令牌预测的优势随模型规模的增大而愈发明显。对于小模型,其性能可能不如基线,但在大规模模型(如13B)上显著超越基线。这可能是该方法过去被忽视的原因。(图3)


    图 3: 不同模型大小的n-令牌预测模型在MBPP上的结果。

  • 3.2 推理加速: 通过自推测解码,7B参数的4-令牌预测模型在代码生成上实现了3.0倍的加速,在文本生成上实现了2.7倍的加速。一个8字节预测模型甚至达到了6.4倍的推理加速。(表S2, S3)

  • 3.3 字节级预测: 在字节级tokenization的极端情况下,8字节预测模型相比单字节预测模型在MBPP pass@1上解决了多67%的问题,在HumanEval pass@1上多解决了20%的问题,证明多令牌预测有助于学习更全局的模式。(表1)


    表 1: 多令牌预测提升性能并解锁高效的字节级训练。

  • 3.4 最优n值探索: 对于7B模型和200B代码数据,n=4在HumanEval和MBPP上表现最佳。而在APPS/Intro任务上,n=6表现更好,表明最优预测窗口大小可能与数据分布有关。(表1)

  • 3.5 多轮训练: 即使在同一数据上进行多轮(epoch)训练,多令牌预测模型依然保持优势,尽管优势有所减小。

  • 3.6 微调性能: 在CodeContests数据集上进行微调时,使用多令牌预测预训练的模型优于使用下一令牌预测预训练的模型。有趣的是,在4-令牌预测模型上进行下一令牌预测微调(即只使用一个头)取得了最佳效果。(图4)


    图 4: 在CodeContests上的微调性能对比。

  • 3.7 自然语言任务:

    • 在标准的多项选择NLP基准上,2-令牌预测模型与基线持平,而4-令牌模型性能有所下降。(图5)
    • 在生成式任务(如摘要)上,2-令牌和4-令牌模型均优于基线,且数据量越大,性能差距越小。(图6)
    • 在数学推理任务(GSM8K)上,2-令牌模型在200B数据量时占优,但当数据量增至500B时,基线模型反超。(图S13)

      图 5: 7B模型的多令牌训练在选择任务上未提升性能。

      图 6: 在抽象文本摘要任务上的性能。

合成数据消融实验 (第4节)

  • 4.1 归纳能力 (Induction Capability): 在小型模型(<30M参数)上,多令牌预测显著促进了归纳头(induction heads)的形成。这种优势在较大模型(>100M参数)中消失,因为大模型本身就能学习到这种能力。(图7)


    图 7: n-令牌预测模型的归纳能力。

  • 4.2 算法推理: 在多项式算术任务中,多令牌预测模型在所有任务难度上都表现更优,尤其是在域外泛化方面。其性能提升效果甚至超过了将模型大小增加三倍。(图8)


    图 8: 在多项式算术任务上不同操作次数下的准确率。

A5 结论

本文提出将多令牌预测作为下一令牌预测的改进方案,用于训练语言模型的生成或推理任务。实验(模型参数最高达7B,训练数据最高达1T tokens)表明,该方法对于更大的模型尤其有效,并在代码任务上展现出强大的改进效果。作者推测,该方法减少了教师强制训练与自回归生成之间的分布失配。当与推测解码结合使用时,精确推理速度可提高3倍。

未来的工作方向包括:
1. 更好地理解如何自动选择多令牌预测损失中的n值,例如使用损失缩放和平衡技术【索引11,High fidelity neural audio compression,2022,Alexandre Défossez et al.】。
2. 探索适用于多令牌预测的最优词汇表大小,这可能与下一令牌预测不同,调整词汇表大小可能带来更好的结果和效率权衡。
3. 开发在嵌入空间中操作的更优辅助预测损失【索引28,A path towards autonomous machine intelligence version 0.9. 2,2022,Yann LeCun】。

A6 附录

A. 自推测解码的额外结果

解码速度和延迟分析。图S10展示了使用4-令牌预测模型的k个头进行自推测解码时,相对于标准自回归解码的速度和延迟。所有数值均以相同批量大小的自回归(k=1)基线为参照。


图 S10: 使用自推测解码相对于标准自回归解码的速度和延迟。

不同数据集上的加速效果。表S2和S3分别展示了在自然语言/代码数据集和字节级代码数据集上的相对加速比。在7B模型上,使用3个额外头(共4个)在代码上可实现3.0倍加速,平均接受2.5个建议令牌。在字节级模型上,8字节预测模型可实现6.4倍加速。加速比在不同批量大小下保持稳定。


表 S2: 使用自推测解码的相对加速比。

表 S3: 在字节级模型上使用自推测解码的相对加速比。

B. 替代架构

架构探索。第2节中描述的并行头架构并非唯一选择,但经实验证明技术上可行且性能良好。本节描述并比较了一些替代架构,如表S4所示,这些替代方案虽有改进但不如并行架构稳定。


表 S4: 替代架构与基线相比有所改进,但一致性不高。

复制Unembeddings矩阵。一种简单实现是复制n次unembeddings矩阵,但这会产生形状为(d, nV)的矩阵,对于大规模训练来说是不可行的。

线性头。除了使用单层Transformer作为头,也可以使用无非线性激活的单层线性层,这相当于对模型的残差表示z进行线性探测。

因果与反因果变体。不同于使预测头 $P_i(x_{t+i} | z_{t:1})$ 相互独立,也可以让它们依赖于其他头的(unembedding前的)输出。
* 因果变体:后一个预测头建立在前一个之上,即第i个预测头 $P_i$ 的计算公式为:


* 反因果变体:网络从预测最远的令牌开始,逐步精化到预测下一个令牌:

这些架构同样允许像并行架构一样的顺序前向/后向传播,如图S11所示,从而实现内存高效的训练。

图 S11: n=2头的因果n-令牌预测模型中前向/后向传播的顺序。

C. 训练速度

训练时间开销。如表S5所示,使用多令牌预测的训练时间相对于下一令牌预测有轻微开销。这主要是由于在我们的实现中,对每个头进行独立的后向传播时,失去了层权重通信与计算的重叠,这是由于对完全分片数据并行(Fully Sharded Data Parallel)的次优使用造成的。如果重新正确实现,这个轻微的开销可以被消除。


表 S5: 相对于下一令牌预测训练的训练时间。

D. 微调

微调Llama 2的探索。我们尝试使用4-令牌预测来微调Llama 2,但与基线相比并未产生显著的性能提升。我们推测,这种新的损失函数可能过于剧烈地改变了模型的初始化状态,导致模型未能真正恢复。尽管如此,在某些指标上仍观察到一些改进,例如MBPP Pass@1。所有实验均使用200B token的代码数据。


表 S6: 使用多令牌预测微调Llama 2并未显著提升性能。

F. 关于CodeContests微调的细节

实验设置。我们使用了CodeContests【索引31,Competition-level code generation with alphacode,2022,Yujia Li et al.】训练集中带有奖励注释(“正确”/“不正确”)的Python子集,并在评估时以正确解决方案为条件。对于评估,我们为测试集中的每个问题在每个温度T ∈ {0.5, 0.6, 0.7, 0.8, 0.9}下生成1000个样本,并计算了pass@k的无偏估计器【索引7,Evaluating large language models trained on code,2021,Mark Chen et al.】。由于不同预训练损失的模型可能有不同的最优pass@k温度,我们在图4中展示了 $k \rightarrow \max_T \text{pass\_at}(k, T)$,即给予模型一个温度预言机(temperature oracle)。小k值衡量任务理解和解决能力,大k值额外青睐输出的多样性。根据图4的结果,多令牌预测预训练在这两个方面都带来了更好的微调模型。

K. 算法推理的额外结果

计算共享假说。我们研究了计算共享假说来解释多令牌预测的有效性。该假说认为,由于不同令牌的预测难度不同,多令牌预测损失通过鼓励相邻令牌位置之间的信息共享,可以被看作是一种学习更有效地将计算资源分配给最受益的令牌的方法。为了验证这一点,我们通过在问题和答案之间插入“暂停令牌”【索引18,Think before you speak: Training language models with pause tokens,2023,Sachin Goyal et al.】来增强多项式算术任务。这些暂停令牌提供了额外的计算资源,模型可以利用它们来思考答案。根据计算共享假说,多令牌预测模型应该能更好地利用这些资源。然而,图S15的结果显示,尽管多令牌模型在带暂停令牌的任务上仍然优于基线模型,但性能差距并未明显扩大或缩小,因此实验无法证实该假说。表S11通过在代码提示中添加空格和换行符进行了类似实验,结果显示多令牌模型的优势微乎其微。


图 S15: 在带暂停令牌的多项式算术任务上的准确率。

表 S11: 在代码基准中利用额外空白令牌的情况。

模型大小与损失函数的比较。图S16显示,在多项式算术任务上,将模型大小从30M参数增加到100M参数,其性能提升效果不如将下一令牌预测损失替换为多令牌预测损失。这表明改变训练目标比增加模型容量更能有效提升算法推理能力。


图 S16: 两种模型大小在多项式算术任务上的准确率。

L. 关于多令牌预测的额外直觉

L.1 与计划采样(Scheduled Sampling)的比较。计划采样【索引3,Scheduled sampling for sequence prediction with recurrent neural networks,2015,Samy Bengio et al.】也旨在弥合教师强制训练与自回归评估之间的差距,它通过在训练中逐渐用模型生成的令牌替换真实令牌。然而,我们认为该方法不适用于语言建模,因为文本是离散的,混合真实和生成的令牌容易产生不连贯的文本。此外,该技术难以适应Transformer等并行训练框架,而多令牌预测则没有这个问题。

L.2 信息论论证的细节。我们对第5.2节中的信息论分解进行更严谨的推导。我们定义了条件交叉熵 $H(p_{X|Y}, q_{X|Y})$ 和相对互信息 $I_{p\|q}(X; Y)$。通过这些定义,我们得到了一个相对版本的分解公式:


这个公式表明,2-令牌预测模型的交叉熵损失可以分解为一个局部交叉熵项、一个权重为2的互信息项和一个移位的下一令牌交叉熵项。这说明,通过在损失中加入 $H(p_Y, q_Y)$,2-令牌预测激励模型预先计算对预测Y有用的特征,并增加了相对互信息项在损失中的权重。最小化相对互信息意味着模型学会了利用X和Y在真实分布p下的互信息。

L.3 前瞻性加强了决策点。我们通过一个简化的序贯决策任务模型来说明多令牌预测如何增强“决策点”的重要性。如图S17所示,在一个需要做出正确选择才能到达目标的任务中,错误的决策会导致“脱轨”(derailing)。在教师强制训练下,一个n步预测模型会因为一个决策点上的错误及其后续影响,而收到 $1 + \dots + n = n(n+1)/2$ 个损失项的惩罚。相比之下,每个“不重要的”转换只被加强了n次。这意味着,n步预测模型中,决策点在损失中的重要性平均是不重要转换的 $(n+1)/2$ 倍。此外,我们假设n步预测能促进形成一种电路,专门发现早期和晚期预测之间的不一致性,从而避免在自回归生成时做出不可逆的次优决策。


图 S17: 一个带有脱轨风险的序贯预测任务示例。

L.4 分解顺序(Factorization Orders)。因果语言建模将文本序列概率分解为 $P(x_t, \dots, x_1) = \prod_{i=1}^t P(x_i | x_{i-1}, \dots, x_1)$。虽然按时间前进是最自然的选择,但在某些情况下(如处理屈折语中的语法一致性),其他分解顺序可能更优。例如,在德语句子 "Wie konnten auch Worte meiner durstenden Seele genügen?" 中,动词 "genügen" 决定了其宾语 "Seele" 的格,进而影响了代词 "meiner" 和形容词 "durstenden" 的形式。一个更简单的生成顺序可能是先确定动词,再确定名词,最后确定修饰词。人类和语言模型必须在其潜在激活中执行这种偏离因果顺序的分解。4-令牌预测损失通过明确鼓励模型在其潜在表示中拥有所有关于后续4个令牌的信息,从而简化了这一过程。

M. 训练超参数

超参数概览。表S13和S14提供了本文所有实验的训练超参数和模型架构的详细信息。


表 S13: 使用的所有训练超参数概览。

表 S14: 用于规模分析的模型架构概览。