文章标题: 不遗漏任何上下文:基于 Infini-attention 的高效无限上下文 Transformer
作者/机构: Tsendsuren Munkhdalai, Manaal Faruqui and Siddharth Gopal, Google

A1 主要贡献

本文旨在解决 Transformer 及基于 Transformer 的大型语言模型(LLMs)因注意力机制的二次复杂度而在处理长序列时面临的内存和计算瓶G颈问题。现有模型在扩展到如 1M 词元这类长序列时,其注意力键值(KV)状态的内存占用变得极其巨大,导致训练和服务的成本高昂。

为应对此挑战,本文提出了一种名为 Infini-attention 的新颖注意力机制,其核心目标是让 Transformer 模型能够以有界的内存和计算资源高效处理无限长的输入。

主要贡献如下:

  1. 提出 Infini-attention 机制:这是一种实用且强大的注意力机制,它将一个长期压缩记忆模块(compressive memory)与标准的局部因果注意力(local causal attention)相结合。这种设计使得模型能在一个 Transformer 块内同时高效地建模长期和短期的上下文依赖。
  2. 支持即插即用的长上下文适应:Infini-attention 对标准缩放点积注意力的改动极小。它通过复用标准的查询(Q)、键(K)、值(V)状态来进行长期记忆的整合与检索,从而天然地支持对现有 LLMs 进行即插即用的持续预训练和长上下文微调。
  3. 实现流式处理无限长上下文:该方法通过分段处理极长的输入,并利用压缩记忆来维持整个上下文历史,使得 Transformer LLMs 能够以有界的内存和计算资源扩展至无限长的上下文,实现了高效的流式推理。

A3 背景知识

循环神经网络(RNNs)的局限性:RNNs 在每个时间步 $t$ 处理单个词元 $x_t$,并计算一个循环隐藏状态 $h_t$ 来表示整个输入序列,其计算公式为:

公式
公式

尽管 RNN 的计算效率很高,因为它只需为输入序列维持一个固定大小的向量 $h_t$,但将全部上下文信息压缩进单个定长向量的限制,影响了其在处理长序列任务中的效用。

带外部记忆的RNNs:为解决上述局限,研究者们为标准 RNN 增加了可读写的外部记忆组件。一个例子是元学习神经记忆(Metalearned Neural Memory, MNM)【【42】Munkhdalai et al., Metalearned neural memory, NeurIPS 2019】,其计算过程如下:

公式
公式

MNM 学习一个由前馈神经网络(FFN)参数化的额外记忆状态 $\theta$,并使用查询、键和值向量(QKV)与记忆交互。它通过将键向量作为输入、值向量作为目标来修改 FFN 的参数以存储信息,并通过将查询向量前向传播通过记忆 FFN 来检索对应的值。然而,与 RNN 类似,MNM 中的记忆状态仍然是有界的。

注意力机制的无状态特性:与 RNN 不同,注意力机制不维持循环状态,仅对输入序列段 $X_s$ 执行前馈计算:

公式
公式

注意力输出 $O_s$ 直接传递给下一层,同一注意力层中没有任何状态会传递给下一个输入序列 $X_{s+1}$。为了捕捉连续片段 $X_s$ 和 $X_{s+1}$ 之间的依赖关系,必须将它们同时处理,这在输入序列增长时成为计算资源的瓶颈。为了在保持注意力机制表达能力的同时提高效率,本文引入了一种循环注意力层。

A2 方法细节

Infini-Transformer 架构概述

与 Transformer-XL 的对比:本文提出的模型 Infini-Transformer 与 Transformer-XL【【14】Dai et al., Transformer-xl: Attentive language models beyond a fixed-length context, arXiv 2019】类似,都对分段序列进行操作。在每个段内,模型计算标准的因果点积注意力上下文,因此点积注意力计算是局部的,仅覆盖当前段索引为 $S$ 的 $N$ 个词元($N$ 为段长度)。然而,局部注意力机制在处理下一个段时会丢弃前一个段的注意力状态。Infini-Transformer 则不同,它不抛弃旧的 KV 注意力状态,而是利用它们通过一个压缩记忆来维持整个上下文历史。因此,Infini-Transformer 的每个注意力层都同时拥有全局压缩状态和局部细粒度状态。这种高效的注意力机制被称为 Infini-attention。

图2:Infini-Transformer(上)拥有完整的上下文历史,而 Transformer-XL(下)则丢弃旧的上下文,因为它只缓存最后一个段的KV状态。
图2:Infini-Transformer(上)拥有完整的上下文历史,而 Transformer-XL(下)则丢弃旧的上下文,因为它只缓存最后一个段的KV状态。

3.1 Infini-attention

Infini-attention 的工作机制:如图1所示,Infini-attention 是一种循环注意力机制,它计算局部和全局上下文状态,并将它们组合成最终输出。类似于多头注意力(MHA),除了点积注意力外,它在每个注意力层为每个头(共 $H$ 个)维持一个并行的压缩记忆。与 RNN 和 MNM 类似,它维持一个循环记忆状态,以高效地追踪长序列上下文:

公式
公式

图1:Infini-attention 拥有一个额外的带线性注意力的压缩记忆,用于处理无限长的上下文。{KV}s−1 和 {KV}s 分别是当前和前一个输入段的注意力键和值,Qs 是注意力查询。PE 表示位置嵌入。
图1:Infini-attention 拥有一个额外的带线性注意力的压缩记忆,用于处理无限长的上下文。{KV}s−1 和 {KV}s 分别是当前和前一个输入段的注意力键和值,Qs 是注意力查询。PE 表示位置嵌入。

3.1.1 缩放点积注意力

标准多头注意力(MHA)回顾:多头缩放点积注意力【【61】Vaswani et al., Attention is all you need, NeurIPS 2017】,特别是其自注意力变体,是 LLMs 的主要构建模块。MHA 在自回归生成模型中被广泛利用,因其强大的上下文相关动态计算能力和便捷的时间掩码机制。

单头注意力计算过程:在标准的 MHA 中,单个头从输入序列段 $X \in \mathbb{R}^{N \times d_{\text{model}}}$ 计算其注意力上下文 $A_{\text{dot}} \in \mathbb{R}^{N \times d_{\text{value}}}$。首先,它计算注意力查询、键和值状态:

公式
公式

其中,$W_K \in \mathbb{R}^{d_{\text{model}} \times d_{\text{key}}}$, $W_V \in \mathbb{R}^{d_{\text{model}} \times d_{\text{value}}}$ 和 $W_Q \in \mathbb{R}^{d_{\text{model}} \times d_{\text{key}}}$ 是可训练的投影矩阵。然后,注意力上下文通过所有其他值的加权平均计算得出:
公式
公式

对于 MHA,模型为每个序列元素并行计算 $H$ 个注意力上下文向量,将它们沿第二个维度拼接,最后将拼接后的向量投影到模型空间以获得最终的注意力输出。

3.1.2 压缩记忆

QKV 状态的复用:在 Infini-attention 中,模型不为压缩记忆计算新的记忆条目,而是复用点积注意力计算中的查询、键和值状态(Q、K 和 V)。点积注意力和压缩记忆之间的状态共享与复用,不仅实现了高效的即插即用长上下文适应,还加速了训练和推理。其目标是将键和值的绑定关系存储在压缩记忆中,并使用查询向量进行检索,这与先前的工作【【42】Munkhdalai et al., Metalearned neural memory, NeurIPS 2019】类似。

记忆参数化与线性注意力:尽管文献中提出了不同形式的压缩记忆,为简化和提高计算效率,本文采用关联矩阵(associative matrix)【【52】Schlag et al., Learning associative inference using fast weight memory, arXiv 2020】来参数化记忆。这种方法使得记忆的更新和检索过程可以被视为线性注意力机制【【55】Shen et al., Efficient attention: Attention with linear complexities, arXiv 2018】,并能借鉴相关方法的稳定训练技术。具体来说,本文采用了 Katharopoulos 等人【【28】Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, ICLR 2020】提出的更新和检索机制,主要因为其简单性和有竞争力的性能。

记忆检索:在 Infini-attention 中,模型使用查询 $Q \in \mathbb{R}^{N \times d_{\text{key}}}$ 从记忆 $M_{s-1} \in \mathbb{R}^{d_{\text{key}} \times d_{\text{value}}}$ 中检索新内容 $A_{\text{mem}} \in \mathbb{R}^{N \times d_{\text{value}}}$:

公式
公式

这里,$\sigma$ 是一个非线性激活函数,$z_{s-1} \in \mathbb{R}^{d_{\text{key}}}$ 是一个归一化项。非线性和归一化方法的选择对训练稳定性至关重要。遵循 Katharopoulos 等人【【28】Katharopoulos et al., Transformers are rnns: Fast autoregressive transformers with linear attention, ICLR 2020】的做法,模型记录所有键的总和作为归一化项 $z_{s-1}$,并使用逐元素的 ELU + 1 作为激活函数【【13】Clevert et al., Fast and accurate deep network learning by exponential linear units (elus), arXiv 2015】。
表格1:对采用段级记忆的 Transformer 模型进行比较。对于每个模型,记忆大小和有效上下文长度根据其模型参数定义(N:输入段长度,S:段数,l:层数,H:注意力头数,c:Compressive Transformer 记忆大小,r:压缩率,p:软提示摘要向量数,m:摘要向量累积步数)。
表格1:对采用段级记忆的 Transformer 模型进行比较。对于每个模型,记忆大小和有效上下文长度根据其模型参数定义(N:输入段长度,S:段数,l:层数,H:注意力头数,c:Compressive Transformer 记忆大小,r:压缩率,p:软提示摘要向量数,m:摘要向量累积步数)。

记忆更新:检索完成后,模型用新的 KV 条目更新记忆和归一化项,得到下一状态:

公式
公式

新的记忆状态 $M_s$ 和 $z_s$ 随后被传递到下一个段 $S+1$,从而在每个注意力层中构建了一个循环。公式(8)中的右侧项 $\sigma(K)^T V$ 被称为关联绑定算子(associative binding operator)【【56】Smolensky, Tensor product variable binding and the representation of symbolic structures in connectionist systems, Artificial intelligence 1990;【21】Hebb, The organization of behavior: A neuropsychological theory, Psychology press 2005;【52】Schlag et al., Learning associative inference using fast weight memory, arXiv 2020】。

Delta 规则的引入:受 Delta 规则成功的启发【【42】Munkhdalai et al., Metalearned neural memory, NeurIPS 2019;【52】Schlag et al., Learning associative inference using fast weight memory, arXiv 2020;【53】Schlag et al., Linear transformers are secretly fast weight programmers, ICML 2021】,本文也将其整合到 Infini-attention 中。Delta 规则试图通过一种稍加改进的记忆更新方式:在应用关联绑定作为新更新之前,首先检索现有的值条目,并从新值中减去它们。

公式
公式

这种更新规则(称为 Linear + Delta)在 KV 绑定已存在于记忆中时,保持关联矩阵不变,同时为保证数值稳定性,仍追踪与前者(Linear)相同的归一化项。

长期上下文注入:模型通过一个学习到的门控标量 $\beta$ 来聚合局部注意力状态 $A_{\text{dot}}$ 和从记忆中检索的内容 $A_{\text{mem}}$:

公式
公式

这个设计每头仅增加一个标量值作为训练参数,同时允许模型在长期和局部信息流之间进行可学习的权衡【【63】Wu et al., Memorizing transformers, arXiv 2022】。

多头 Infini-attention 输出:与标准 MHA 类似,对于多头 Infini-attention,模型并行计算 $H$ 个上下文状态,然后将它们拼接并投影以获得最终的注意力输出 $O \in \mathbb{R}^{N \times d_{\text{model}}}$:

公式
公式

其中 $W_O \in \mathbb{R}^{H \times d_{\text{value}} \times d_{\text{model}}}$ 是可训练的权重。

3.2 记忆与有效上下文窗口

Infini-Transformer 的内存优势:Infini-Transformer 以有界的内存占用实现了无界的上下文窗口。表1列出了先前的段级记忆模型及其上下文内存占用和有效上下文长度。Infini-Transformer 在单层每头中存储压缩上下文($M_s$ 和 $z_s$)的内存复杂度为常数 $d_{\text{key}} \times d_{\text{value}} + d_{\text{key}}$。相比之下,其他模型的内存复杂度随序列维度增长:Transformer-XL【【14】Dai et al., Transformer-xl: Attentive language models beyond a fixed-length context, arXiv 2019】、Compressive Transformer【【47】Rae et al., Compressive transformers for long-range sequence modelling, arXiv 2019】和 Memorizing Transformers【【63】Wu et al., Memorizing transformers, arXiv 2022】的内存复杂度依赖于缓存大小;而 RMT【【7】Bulatov et al., Recurrent memory transformer, NeurIPS 2022】和 AutoCompressors【【18】Ge et al., In-context autoencoder for context compression in a large language model, arXiv 2023】则依赖于软提示的大小。

与其他长上下文模型的比较
- Transformer-XL:通过缓存上一段的 KV 状态,将上下文窗口从 $N$ 扩展到 $N \times l$ 个词元,但带来了 $(d_{\text{key}} + d_{\text{value}}) \times H \times N \times l$ 的额外内存占用。
- Compressive Transformer:在 Transformer-XL 的基础上增加第二个缓存,存储过去段激活的压缩表示,将上下文窗口扩展了 $c \times r \times l$,但上下文内存复杂度依然很高。
- Memorizing Transformers:选择存储整个 KV 状态作为上下文,由于存储开销巨大,他们将上下文计算限制在单层内。通过使用快速 kNN 检索器,它构建了一个覆盖整个序列历史(长度 $N \times S$)的上下文窗口,但存储成本更高。实验表明,Infini-Transformer LM 在 Memorizing Transformers 的基础上实现了超过100倍的压缩率,同时进一步改善了困惑度分数。
- RMT 和 AutoCompressors:理论上允许无限的上下文长度,因为它们将输入压缩成摘要向量,并作为额外的软提示输入到后续段。然而,在实践中,这些技术的成功高度依赖于软提示向量的大小。增加软提示向量数量以提升性能会迅速增加内存和计算复杂度,从而降低效率。AutoCompressors 的研究还发现,需要一个高效的压缩目标来训练这类提示压缩技术。

A4 实验环境与结果

A4.1 实验环境

  • 实现细节:
    • 分段处理 (Segment chunking): 为最小化对现有 Transformer 实现的修改,模型首先对整个输入文本进行一次前向传播,然后在每个 Infini-attention 层内部执行分段。该层将输入分段,逐段处理,最后将各段结果拼接回原始长度,传递给下一层。
    • 随时间反向传播 (BPTT): 每个 Infini-attention 层使用 BPTT 进行训练,通过计算关于压缩记忆状态的梯度,类似于 RNN 的训练方式。为节省内存,在逐段处理序列时采用了梯度检查点技术。
    • 位置嵌入 (PE): 为了在长期记忆中仅存储全局上下文信息,压缩记忆的键(K)和查询(Q)向量不使用位置嵌入。PE 仅在压缩记忆读取和更新之后应用于 QK 向量。
  • 模型与硬件:
    • 长上下文语言建模: 模型包含12层,8个注意力头,每个头维度为128,FFN隐藏层维度为4096。
    • LLM 持续预训练:
      • Passkey 任务: 使用了一个 1B 参数的 LLM。
      • BookSum 任务: 使用了一个 8B 参数的 LLM。
      • 硬件配置未明确说明。
  • 数据集:
    • 长上下文语言建模: PG19【【47】Rae et al., Compressive transformers for long-range sequence modelling, arXiv 2019】和 Arxiv-math【【63】Wu et al., Memorizing transformers, arXiv 2022】。
    • LLM 持续预训练: PG19, Arxiv-math, 以及长度超过4K词元的C4文本【【48】Raffel et al., Exploring the limits of transfer learning with a unified text-to-text transformer, The Journal of Machine Learning Research 2020】。
    • Passkey 检索: 一项合成任务,用于测试模型在长文本中检索信息的能力。
    • 书籍摘要: BookSum 数据集【【30】Krysci ´ nski et al., Booksum: A collection of datasets for long-form narrative summarization, arXiv 2021】。

A4.2 实验结果

长上下文语言建模

  • 实验设置: Infini-attention 的段长度 N 设为 2048,训练序列长度为 32768。这使得 Infini-attention 能够在压缩记忆状态上展开16个步骤。RMT 基线模型在序列长度为8196、摘要向量数为100时表现最佳。
  • 实验结果: 如表2所示,Infini-Transformer 的性能优于 Transformer-XL 和 Memorizing Transformers 基线。与在第9层使用长度为 65K 的基于向量检索的 KV 记忆的 Memorizing Transformer 模型相比,Infini-Transformer 维持的记忆参数少了114倍。
    • 100K 长度训练: 将 Arxiv-math 数据集的训练序列长度从32K增加到100K,Linear 和 Linear+Delta 模型的困惑度进一步降低到 2.21 和 2.20。
  • 分析:
    • 门控分数可视化 (图3): 训练后,Infini-attention 中出现了两种类型的头:专业头(门控分数接近0或1)和混合头(分数接近0.5)。专业头要么专注于局部注意力计算,要么专注于从压缩记忆中检索;混合头则将当前上下文信息和长期记忆内容结合起来。有趣的是,每层至少有一个短程头,保证了输入信号能一直传播到输出层。实验也观察到长短期内容检索在整个前向计算中是交错进行的。

表格2:长上下文语言建模结果在平均词元级困惑度上进行比较。Comp.表示压缩率。Infini-Transformer 优于记忆长度为65K的Memorizing Transformers,并实现了114倍的压缩率。
表格2:长上下文语言建模结果在平均词元级困惑度上进行比较。Comp.表示压缩率。Infini-Transformer 优于记忆长度为65K的Memorizing Transformers,并实现了114倍的压缩率。

图3:训练后,Infini-attention 中出现了两种类型的头:门控分数接近0或1的专业头,以及分数接近0.5的混合头。专业头要么通过局部注意力机制处理上下文信息,要么从压缩记忆中检索;而混合头则将当前上下文信息和长期记忆内容聚合到单个输出中。
图3:训练后,Infini-attention 中出现了两种类型的头:门控分数接近0或1的专业头,以及分数接近0.5的混合头。专业头要么通过局部注意力机制处理上下文信息,要么从压缩记忆中检索;而混合头则将当前上下文信息和长期记忆内容聚合到单个输出中。

LLM 持续预训练

  • 实验设置: 在 PG19、Arxiv-math 和 C4 文本上进行轻量级持续预训练,以适应长上下文。所有实验中段长度 N 设为 2K。

  • 1M Passkey 检索基准测试:

    • 设置: 将一个 1B LLM 中的标准 MHA 替换为 Infini-attention,在 4K 长度的输入上持续预训练 30K 步,然后在 5K 长度的输入上微调 400 步以完成 Passkey 检索任务。
    • 结果 (表3): Infini-Transformer 在仅用 5K 长度输入进行微调后,成功解决了上下文长度高达 1M 的 Passkey 任务,展示了卓越的长度外推能力。结果分为零样本和微调后的词元级准确率,涵盖了密钥在长文本的开头、中间和结尾的不同位置。

表格3:Infini-Transformers 在5K长度输入上微调后,解决了高达1M上下文长度的passkey任务。我们报告了在32K到1M长度的长输入中,隐藏在不同部分(开始/中间/结束)的passkey的词元级检索准确率。
表格3:Infini-Transformers 在5K长度输入上微调后,解决了高达1M上下文长度的passkey任务。我们报告了在32K到1M长度的长输入中,隐藏在不同部分(开始/中间/结束)的passkey的词元级检索准确率。

  • 500K 长度书籍摘要 (BookSum):
    • 设置: 对一个 8B LLM 进行持续预训练,输入长度为 8K,训练 30K 步。然后在 BookSum 任务上进行微调,微调输入长度为 32K,评估时长度增加到 500K。
    • 结果 (表4): 模型通过处理书籍的全部文本,在 BookSum 任务上取得了新的 SOTA 成绩,优于之前专为摘要任务构建的编码器-解码器模型及其基于检索的长上下文扩展版本。
    • 性能趋势 (图4): 在 BookSum 验证集上的 Rouge 总体分数显示,随着提供更多书籍文本作为输入,Infini-Transformers 的摘要性能指标也随之提升。

表格4:500K长度书籍摘要(BookSum)结果。BART、PRIMERA和Unlimiformer的结果来自Bertsch等人(2024)。
表格4:500K长度书籍摘要(BookSum)结果。BART、PRIMERA和Unlimiformer的结果来自Bertsch等人(2024)。

图4:Infini-Transformers在提供更多书籍文本作为输入时,获得了更好的Rouge总体分数。
图4:Infini-Transformers在提供更多书籍文本作为输入时,获得了更好的Rouge总体分数。

A5 结论

本文引入了一种将压缩记忆模块与标准点积注意力层紧密集成的方案——Infini-attention。这一对注意力层的精妙而关键的修改,使得 LLMs 能够利用有界的内存和计算资源处理无限长的上下文。实验证明,该方法可以自然地扩展到百万词元级别的输入序列,在长上下文语言建模和书籍摘要任务上均优于基线模型。此外,该方法还展示了强大的长度泛化能力:一个在最多 5K 序列长度的 Passkey 实例上微调的 1B 模型,成功解决了 1M 长度的问题。一个有效的记忆系统不仅对于 LLMs 理解长上下文至关重要,也对推理、规划、持续学习新知识,甚至学会如何学习都至关重要。

A6 附录

A. 额外的训练细节

长上下文语言建模任务:
* 学习率: 通过在 {0.003, 0.005, 0.01, 0.03} 中搜索,最终设定为 0.01。
* 优化器: 使用 Adafactor 优化器【【54】Shazeer & Stern, Adafactor: Adaptive learning rates with sublinear memory cost, ICML 2018】,配合 1000 步的线性预热(warmup),之后进行余弦衰减(cosine decay)。
* 内存优化: 在每个段之后应用梯度检查点技术以节省内存。
* 批量大小: 设置为 64。

LLM 实验:
* 在持续预训练和任务微调期间,学习率设定为 0.0001。

B. Passkey 检索任务

输入格式示例:
下面展示了 Passkey 任务的输入格式。

There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. (重复 x )
The pass key is 9054. Remember it. 9054 is the pass key.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and ack again. (重复 y )
What is the pass key?
The pass key is