Overcoming Long-Context Limitations of State-Space Models via Context-Dependent Sparse Attention

作者/机构: Zhihao Zhan, Jianan Zhao, Zhaocheng Zhu, Jian Tang (Mila - Québec AI Institute, University of Montréal, HEC Montréal, CIFAR AI Chair)

A1 主要贡献

核心问题与研究目标

高效的长上下文建模是自然语言处理(NLP)领域的一个关键挑战,因为主流的Transformer架构的时间复杂度会随着序列长度呈二次方增长。虽然状态空间模型(SSM)提供了亚二次方的替代方案,但它们在有效捕捉长距离依赖方面存在困难。本文旨在深入理解并提升SSM的长上下文建模能力。

对现有研究的批判与新任务的提出

研究首先指出,以往广泛使用的合成任务——关联回忆(associative recall),不足以模拟现实世界中长上下文建模的复杂性。关联回忆任务假设每个“键”(key)都唯一地关联一个“值”(value),而忽略了周围的上下文。然而,在自然语言中,同一个键可能根据上下文对应不同的值。如图1所示,当被问及人们在路的哪一侧开车时,正确的答案取决于所指的国家(美国还是英国)。

为了解决这一局限,本文将关联回忆任务扩展为一个新的合成任务——联合回忆(joint recall)。联合回忆要求模型在给定特定上下文的条件下,回忆与键关联的值。这使得联合回忆成为一个更真实、更严格的合成任务,适用于理论分析和实证基准测试。


图 1: 联合回忆与关联回忆的比较。关联回忆不考虑上下文。联合回忆通过将上下文依赖性纳入键值关联中来扩展关联回忆。例如,关联回忆可能将“pay with”映射到“dollar”或“pound”,而联合回忆允许它根据上下文在美国映射到“dollar”,在英国映射到“pound”。这使得联合回忆成为一个更真实、更严格的合成任务,适用于长上下文建模的理论分析和实证基准。

理论创新与新架构的提出

理论上,本文证明了标准的SSM缺乏在亚二次方时间复杂度内解决多查询联合回忆任务的表达能力。为了克服这一瓶颈,研究提出将SSM与上下文相关的稀疏注意力(Context-Dependent Sparse Attention, CDSA)相结合。CDSA是一类稀疏注意力机制,其稀疏模式根据上下文表示动态调整,例如局部敏感哈希(LSH)注意力。相比之下,上下文无关的稀疏注意力(Context-Independent Sparse Attention, CISA),如滑动窗口注意力,其模式是固定的。本文从理论上证明了,存在一种CDSA,当与SSM集成时,能够在亚二次方时间内解决多查询联合回忆任务,并揭示了集成CDSA的SSM与集成CISA的SSM在解决此任务上存在表达能力差距。

基于这一理论洞见,并为了弥合理论与实践之间的差距,本文提出了一种新颖的架构:带有稀疏密钥选择的局部敏感哈希注意力(locality-sensitive Hashing Attention with sparse Key Selection, HAX)。HAX通过结合本文提出的密钥选择(Key Selection, KS)注意力来增强LSH注意力的表达能力,并与最先进的SSM(Mamba和Mamba2)集成,从而实例化了上述理论上可行的解决方案。

主要贡献总结

  1. 引入联合回忆任务:提出了一个新的合成任务——联合回忆,它将关联回忆扩展到上下文相关的键值关联,为长上下文建模的理论分析和实证基准提供了新视角。
  2. 理论分析:通过对联合回忆的理论分析,证明了将状态空间模型(SSM)与上下文相关的稀疏注意力(CDSA)集成,具有以亚二次方计算复杂度解决多查询联合回忆任务的表达能力。
  3. 提出HAX架构:在理论指导下,提出了一种基于SSM与CDSA集成的新型架构HAX,该架构在合成和真实世界的长上下文基准测试中,其性能始终优于SSM以及集成了上下文无关稀疏注意力(CISA)的SSM。

A3 背景知识

2.1 稀疏注意力

标准注意力与稀疏注意力的定义。对于一个长度为 l 的序列,自回归序列建模的注意力分数可以表示为:

$$\mathbf{A} = \text{Softmax}(\mathbf{M} \odot \mathbf{Q}\mathbf{K}^{\top})$$

其中,$M \in \{0, 1\}^{l \times l}$ 是自回归掩码,$Q, K \in \mathbb{R}^{l \times d}$,$d$ 是隐藏维度。本文中,我们将稀疏注意力的注意力分数定义为:

$$\mathbf{A} = \text{Softmax}(\mathbf{S} \odot \mathbf{M} \odot \mathbf{Q K}^{\top})$$

其中 $S \in \{0, 1\}^{l \times l}$ 代表稀疏注意力模式。我们考虑稀疏性约束:

$$\| \mathbf{S} \|_0 \ll l^2$$

为了确保每一步的计算效率,我们进一步收紧此约束,要求:

$$\forall i,\|\mathbf{S}_i\|_0 \ll l$$

其中 $S_i$ 是 $S$ 的第 $i$ 行。这意味着每个查询最多关注 $k$ 个键,其中 $k \ll l$。

2.2 局部敏感哈希注意力

LSH注意力模式生成算法。鉴于局部敏感哈希(LSH)注意力是自回归建模中最有效的输入依赖稀疏注意力之一【索引20,Reformer: The efficient transformer,2020,ICLR】,我们重新构建了一个简单的算法来生成LSH的稀疏注意力模式。该算法接收查询和键矩阵 $Q$ 和 $K$ 作为输入,并输出一个二元稀疏注意力模式 $S_{LSH}$。在每次前向传播中,首先对 $Q$ 和 $K$ 进行中心化和归一化,得到 $\tilde{Q}$ 和 $\tilde{K}$:

$$\tilde{\mathbf{Q}}_{i}=\operatorname{normalize}\left(\mathbf{Q}_{i}-\overline{\mathbf{Q}}_{i}\right), \quad \tilde{\mathbf{K}}_{i}=\operatorname{normalize}\left(\mathbf{K}_{i}-\overline{\mathbf{K}}_{i}\right)$$

接下来,采样一个随机投影矩阵 $H \stackrel{\text{i.i.d.}}{\sim} \mathcal{N}(0, 1) \in \mathbb{R}^{d \times h}$,将归一化后的向量 $\tilde{Q}$ 和 $\tilde{K}$ 投影到哈希空间。然后,我们考虑两种分桶规则,将每个向量 $\tilde{Q}_i$ 和 $\tilde{K}_i$ 分配到一个哈希桶中:argmax分桶规则【索引20,Reformer: The efficient transformer,2020,ICLR;索引35,Sparse attention with learning to hash,2022,ICLR】和sign-bit分桶规则【索引8,Memoryformer: Minimize transformer computation by removing fully-connected layers,2024,NeurIPS;索引5,Hashattention: Semantic sparsity for faster inference,2025,ICML】。argmax分桶规则将每个向量分配到其在 $H$ 中最对齐的列的索引:

$$\text{bin}_{Q_i} = \text{argmax}(\tilde{\mathbf{Q}}_i \mathbf{H}), \quad \text{bin}_{K_i} = \text{argmax}(\tilde{\mathbf{K}}_i \mathbf{H})$$

sign-bit分桶规则通过计算投影值的符号来构建一个二进制哈希码,并将其解释为一个二进制数:

$$\text{bin}_{Q_i} = \sum_{j=1}^{h} \mathbf{1}[(\tilde{\mathbf{Q}}_i \mathbf{H})_j > 0] \cdot 2^{h-j}, \quad \text{bin}_{K_i} = \sum_{j=1}^{h} \mathbf{1}[(\tilde{\mathbf{K}}_i \mathbf{H})_j > 0] \cdot 2^{h-j}$$

argmax分桶规则将向量分配到 $h$ 个桶,而sign-bit分桶规则将向量分配到 $2^h$ 个桶。我们将在附录A中进一步讨论这两种分桶策略之间的关系。基于分配的桶,通过允许每个查询关注同一桶内所有在它之前的键,来构建一个初步的稀疏模式 $\tilde{S}_{LSH}$:

$$\tilde{\mathbf{S}}_{\mathrm{LSH}_{i j}}=\mathbf{1}\left[\operatorname{bin}_{Q_{i}}=\operatorname{bin}_{K_{j}}\right]$$

最后,为了满足公式4中定义的稀疏性约束,应用一个桶内滑动窗口掩码 $M_{LSH}$,使得每个查询只关注同一桶内最多 $k$ 个最近的键:

$$\mathbf{S}_{\mathrm{LSH}}=\mathbf{M}_{\mathrm{LSH}} \odot \tilde{\mathbf{S}}_{\mathrm{LSH}}$$

2.3 广义状态空间模型

SSM的形式化定义。遵循【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】中引入的定义,我们将广义状态空间模型形式化为由一个更新规则 $u : U \times V \to U$ 和一个输出函数 $r : U \to V$ 定义的序列模型,其中 $V$ 表示词元词汇表,$U$ 表示循环状态。令 $U_0(\emptyset) \in U$ 表示初始状态。给定一个输入序列 $v_1, ..., v_n \in V$,对于 $i \in \{1...n\}$,状态 $U_i(v_1, ..., v_i) \in U$ 及其对应的输出 $R_i(v_1, ..., v_i) \in V$ 递归定义为:

$$\begin{aligned} \begin{aligned} U_i(v_1, ..., v_i) &= u(U_{i-1}(v_1, ..., v_{i-1}), v_i) \\ R_i(v_1, ..., v_i) &= r(U_i(v_1, ..., v_i)) \end{aligned} \end{aligned}$$

2.4 关联回忆

关联回忆任务介绍。关联回忆任务最初由【索引2,Using fast weights to attend to the recent past,2016,NeurIPS】引入。【索引27,In-context learning and induction heads,2022】发现,大语言模型在该任务上的表现与其上下文学习能力密切相关。【索引1,Zoology: Measuring and improving recall in efficient language models,2024,ICLR】将关联回忆扩展到多查询设置:首先给模型一系列关联的键值对,然后要求模型在被查询相应键时回忆出每个值。关联回忆已被广泛用作长上下文建模的合成基准【索引4,Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,arXiv;索引18,Ruler: What’s the real context size of your long-context language models?,2024,CoLM】。

3.1 动机

提出联合回忆的动机。提出联合回忆的动机是为了克服关联回忆设置的一个关键局限:每个键对应一个固定的值。虽然这种设置非常适合研究强调捕捉稳定词汇模式(如子词单元或固定多词表达)的任务,但它无法体现自然语言中意义的上下文敏感性。例如,“美国政府的立法部门被称为国会”,其中的值“国会”由属键“立法部门”和差异上下文“美国政府的”共同确定。另一个例子,“周一早上,爱丽丝学习数学”,其中的值“数学”只有在将上下文元素“周一”和“早上”与键“爱丽丝”结合起来时才能确定。这些例子说明,自然语言中准确的语义解释通常需要整合上下文和键,这表明模型必须超越关联回忆的简单一对一映射,以捕捉意义的组合和上下文依赖性。这一观察促使我们引入了一个新的合成任务,我们称之为联合回忆。

多查询联合回忆:在指定上下文中,回忆与一个键相关联的每个值:
图 2: 合成多查询联合回忆与关联回忆的比较。

3.2 公式化

联合回忆任务的定义。关联回忆要求模型记忆 $n_k$ 个关联的键值对。联合回忆推广了这一任务:模型需要记忆一个 $n_c \times n_k$ 的上下文特定键值关联表,其中 $n_k$ 个键在 $n_c$ 个上下文中的每一个都关联着不同的值。受【索引1,Zoology: Measuring and improving recall in efficient language models,2024,ICLR】的启发,我们也把联合回忆扩展到多查询设置,要求模型恢复整个表,而不是表中的某个特定条目。图2展示了 $n_c = 2$ 和 $n_k = 2$ 的多查询联合回忆。遵循自然语言的结构,序列化的表输入包含 $n_c$ 个上下文块,每个块以一个上下文词元(如图2中的大写字母)开始,后面跟着该上下文特定的 $n_k$ 个键值对(如图2中的小写字母-数字对)。然后,模型被要求在上下文和键序任意排列的情况下,为每个上下文-键对回忆出相关联的值。附录C进一步将联合回忆的公式扩展到多级上下文,并提供了基于此扩展公式的理论分析。

3.3 理论分析

3.3.1 稀疏注意力的分类

上下文相关与无关的稀疏注意力。在2.1节中定义的稀疏注意力模式 $S$ 可以分为上下文相关或上下文无关,这取决于它是预先确定的还是从上下文表示中动态推断的。上下文无关的稀疏注意力(CISA)模式,如滑动窗口注意力、A形注意力和扩张注意力【索引6,Longnet: Scaling transformers to 1,000,000,000 tokens,2023】,无论上下文如何都是固定的。相比之下,上下文相关的稀疏注意力(CDSA)模式,如LSH注意力【索引20,Reformer: The efficient transformer,2020,ICLR】,会根据上下文表示进行调整。附录图8提供了这两种类别的图示。

3.3.2 SSMs的有限表达能力

SSM解决联合回忆任务的局限性。作为【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】中定理2.7的一个推论,我们证明解决多查询联合回忆任务要求SSM的状态维度相对于联合回忆表中的条目数 $n = n_c \times n_k$ 线性增长。设 $|U|$ 为循环状态空间 $U$ 可以编码的不同表示的数量,对于一个容量为 $b$ 比特的状态,有$|U| = 2^b$。我们定义均匀多查询联合回忆分布为所有值都从词元词汇表 $V$ 上的均匀分布中独立同分布采样。在此设置下,我们得到以下推论:

推论 3.1 (SSM的有限表达能力)。在均匀多查询联合回忆分布下,对于任何 $n$,一个在2.3节中定义的广义状态空间模型会产生至少 $1 - \frac{|U|}{|V|^n}$ 的错误率。

备注 3.1。为了保证 $Pr[err] = 0$,状态可表示的数量必须满足 $|U| \geq |V|^n$。对两边取对数,得到条件 $b \geq n \log |V|$。这意味着模型的状态空间维度必须随联合回忆表中的条目数 $n$ 线性增长,突显了SSM表示能力的一个基本限制。

3.3.3 SSM与CDSA集成后的表达能力提升

SSM集成稀疏注意力的理论结果。对于与稀疏注意力集成的SSM,我们建立了以下结果:

命题 3.2 (SSM与CDSA集成后的表达能力提升)。存在一个由SSM层后跟一个LSH注意力层组成的2层自回归混合模型,它可以在 $O(n \log^2 n)$ 的时间复杂度内,以 $O(\log n)$ 的SSM状态维度解决多查询联合回忆任务。

命题 3.3 (SSM与CISA集成后的有限表达能力)。不存在一个由SSM层后跟一个CISA层组成的2层自回归混合模型,它能以 $o(n^2)$ 的时间复杂度解决多查询联合回忆任务,因为它需要至少 $O(\frac{n_k}{k})$ 的SSM状态维度,其中 $k$ 是稀疏注意力模块中每个查询允许关注的最大键数,如公式4所定义。

备注 3.2。比较命题3.2和命题3.3,我们可以看到SSM与CDSA集成和SSM与CISA集成之间存在明显的表示能力差距。

备注 3.3。在实践中,对于一个合适的常数 $k$,将CISA与SSM集成仍然有优势:与只能访问最后一个状态表示的SSM不同,CISA层可以同时关注 $k$ 个不同的状态表示,代价是 $k$ 倍的计算预算。

完整的证明在附录B中提供。

A2 方法细节

根据第3.3节的理论分析,我们提出了一种新的架构,即带有稀疏密钥选择的局部敏感哈希注意力(HAX)。HAX通过结合我们提出的密钥选择(KS)注意力来增强LSH注意力的表达能力,并进一步与最先进的SSM集成,作为上下文相关稀疏注意力(CDSA)与SSM集成的实例,从而受益于第3.3节讨论的理论优势。本节首先在4.1节讨论LSH注意力的表达能力局限,然后在4.2节通过引入我们提出的密钥选择(KS)注意力来解决这些局限。最后,4.3节详细介绍了HAX的架构以及如何将HAX与最先进的SSM架构Mamba和Mamba2【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2024,CoLM;索引4,Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,arXiv】集成。


图 3: HAX架构以及与Mamba [14]集成的示意图。

4.1 LSH注意力的局限性

LSH难以捕捉“垂直条纹”模式。在大型语言模型中,某些键(特别是序列开头的键)通常会受到大多数查询的关注,形成独特的“垂直条纹”注意力模式【索引39,Analyzing the structure of attention in a transformer language model,2019,ACL Workshop BlackboxNLP】,如附录图9所示。这些被全局关注的键在指令遵循等任务中扮演着至关重要的角色,因为模型需要将其注意力集中在指令词元上【索引23,Lost in the middle: How language models use long contexts,2024,TACL】。虽然LSH是CDSA的一种实现,但它存在一个关键局限:难以捕捉“垂直条紋”注意力模式。这是因为在每个哈希轮次中,每个键都被映射到单个桶中,而注意力仅限于在同一桶内的查询和键之间发生。因此,当许多查询被迫关注有限的几个键桶时,这些桶会变得过载,从而降低表示的多样性并最终影响注意力质量。

4.2 密钥选择(KS)注意力

KS注意力的设计目标。为了解决LSH注意力在捕捉“垂直条纹”注意力模式方面的局限性,我们提出通过集成一个新的密钥选择(KS)注意力模块来增强LSH注意力。该模块旨在满足以下理想属性:
1. “垂直条纹”能力:KS注意力能够表达“垂直条纹”注意力模式。
2. 自回归兼容性:当前词元的KS注意力计算不依赖于未来的查询或键。
3. 上下文相关的稀疏性:KS稀疏注意力模式以查询和键的上下文表示为条件,并满足公式4的约束。

KS注意力的工作流程。KS注意力以查询和键矩阵 $Q$ 和 $K$ 为输入,分两个阶段运行。第一阶段是密钥评分,一个评分模块根据键本身和所有先前的查询为每个键计算一个重要性分数:

$$x_i = f_\theta(\mathbf{K}_i, \mathbf{Q}_{1 \dots i})$$

第二阶段是密钥选择:每个查询关注得分最高的 $k$ 个先前的键:

$$\mathbf{S}_{\text{KS}_{ij}} = \mathbf{1}[x_j \in \text{Top-}k\{x_1, \dots, x_i\}]$$

通过一个理想的密钥评分模块,为全局重要的键分配最高的分数,KS注意力能有效地在注意力模式中覆盖 $k$ 个“垂直条纹”。

密钥评分网络的实现。为简单起见,我们使用一个多层感知机(MLP)作为密钥评分网络:

$$f_{\theta}(\mathbf{K}_{i}, \mathbf{Q}_{1..i}) \triangleq \operatorname{MLP}(\mathbf{K}_{i}, \operatorname{normalize}(\sum_{1 \leq p \leq i} \mathbf{Q}_{p}))$$

评分MLP的训练方法。为了训练评分MLP,在每个层中,我们随机采样 $k$ 个候选键,其索引记为 $I$。我们通过一个简单的线性注意力模块计算参考注意力权重,并计算这些参考权重与预测分数之间的成对排序损失。具体来说,我们计算:

$$\mathbf{A}^{\prime}=\mathbf{Q} \mathbf{K}[\mathcal{I}]^{\top}, \quad y=\sigma\left(\mathbf{A}^{\prime}\right) \odot \mathbf{M}[\mathcal{I}],$$

其中 $K[I]$ 是选定的键表示,$M[I]$ 是限制在这些位置上的自回归掩码,$\sigma(\cdot)$ 是sigmoid函数。对于预测分数 $x \in \mathbb{R}^k$ 和目标分数 $y \in \mathbb{R}^k$,我们构建成对的logits和目标:

$$\begin{aligned} \mathbf{P}_{i j}(x)=x_{i}-x_{j}, \quad \mathbf{T}_{i j}(y)=\begin{cases}1 & \text { if } y_{i}>y_{j}, \\ 0.5 & \text { if } y_{i}=y_{j}, \\ 0 & \text { if } y_{i}<y_{j} .\end{cases} \end{aligned}$$ <p>并定义排序损失为:

$$\mathcal{L}_{\mathrm{score}}(x, y)=\frac{1}{k^{2}} \sum_{i, j} \operatorname{BCE}\left(\mathbf{P}_{i j}(x), \mathbf{T}_{i j}(y)\right),$$

其中 BCE(·, ·) 表示二元交叉熵。该目标函数鼓励评分网络为那些获得较高注意力权重的、信息丰富的键分配更高的分数。

最终训练目标。最终的训练目标是自回归语言建模损失 $L_{LM}$ 与跨所有层的排序损失之和:

$$\mathcal{L}=\mathcal{L}_{\mathrm{LM}}+\alpha \sum_{\text {layers }} \mathcal{L}_{\text {score }}$$

其中 $\alpha$ 是一个标量超参数,用于平衡排序损失的贡献。

4.3 混合块设计

HAX的组合方式。最后,我们提出了带有稀疏密钥选择的局部敏感哈希注意力(HAX),它结合了LSH和KS的注意力模式:

$$ \mathbf{S}_{\mathrm{HAX}}=\max \left\{\mathbf{S}_{\mathrm{LSH}}, \mathbf{S}_{\mathrm{KS}}\right\} \in\{0,1\}^{l \times l} $$

当 $\forall i, \|S_{\text{LSH}_i}\|_0 \le k_2, \|S_{\text{KS}_i}\|_0 \le k_2$ 时,它满足

$$\forall i,\|\mathbf{S}_{\text{HAX}_i}\|_0 \leq k$$

LSH与KS的互补性。直观上,LSH和KS注意力是互补的,各自解决了对方的局限性。LSH注意力通过随机哈希将每个查询路由到语义相似的键,提供了灵活的、基于内容的交互,这是KS注意力单独所缺乏的。相比之下,KS注意力引入了到一小组全局重要键(如指令或格式标记)的广播连接,从而恢复了LSH注意力难以表达的“垂直条纹”模式。LSH注意力促进了多样化的上下文表示,减轻了表示崩溃的风险,而KS注意力则通过将注意力权重分配给信息最丰富的位置来提高焦点,从而实现更强的长距离控制。重要的是,这两种机制本质上都是稀疏的,因此它们的组合引入了亚二次方的计算成本。

HAX与SSM的集成。图3展示了基于Mamba和Mamba2的HAX层。提出的混合稀疏注意力层通过与一个并行的稀疏注意力分支耦合,缓解了SSM的表示能力限制。一个参数化的门在融合前对稀疏注意力的输出进行重新缩放,这有助于稳定优化过程。

A4 实验环境

多查询联合回忆实验

  • 数据集:自建的多查询联合回忆数据集。上下文块数和每个上下文的键数从区间 [5, 16] 中独立采样,词汇表大小固定为 |V| = 16。数据集包含140万个训练样本,以及各1.44万个验证和测试样本。
  • 模型架构:基础架构为 Mamba 和 Mamba2,层数为2,隐藏维度为64。在此基础上评估了多种混合稀疏注意力模型,包括扩张注意力(D)、滑动窗口注意力(SW)、SW+D、A形注意力(A)、LSH注意力、本文提出的KS注意力和HAX。所有混合架构的总稀疏度参数k(每个查询关注的最大键数)固定为64。
  • 软件配置:实现基于 Flash-Linear-Attention [42]。优化器为AdamW。
  • 硬件配置:未明确说明。

自然语言持续预训练实验

  • 数据集

    • 持续预训练:使用 The Pile [11] 和 TxT-360 [36] 数据集,并筛选出长度至少为2048个词元的样本。
    • 指令微调:使用 UltraChat [7] 数据集。
  • 模型架构:基于公开发布的 Mamba 130M 和 790M checkpoints。稀疏度参数k固定为128。对于组合稀疏注意力(如HAX=LSH+KS),每个组件的k设为64。HAX仅增加了约1%的额外参数。

  • 软件配置:代码实现基于公开的Mamba代码库。优化器为AdamW。
  • 硬件配置:未明确说明。
  • 训练流程:所有模型使用2K的上下文长度进行50K步的持续预训练,随后在UltraChat上进行10K步的指令微tuning。

A4 实验结果

5.1 多查询联合回忆的实证验证

  • 实验内容:在自建的多查询联合回忆数据集上,评估了Mamba和Mamba2基础架构以及集成了多种稀疏注意力(包括CISA类的D、SW、A和CDSA类的LSH、KS、HAX)的混合模型。
  • 实验结果:如表1所示,与Mamba和Mamba2基础架构相比,大多数混合稀疏注意力模型都提升了性能。特别地,本文提出的HAX模型在两种基础架构上都始终取得了最佳性能,例如在Mamba2上将准确率从36.6%提升至74.3%。
  • 分析结论:这些结果从经验上验证了理论发现,即集成上下文相关的稀疏注意力(如HAX)能有效提升SSM解决上下文依赖的键值检索任务的能力。

表 1: 多查询联合回忆的结果。将Mamba或Mamba2与HAX集成可获得最佳性能,以粗体显示。

5.2 自然语言上的持续预训练

  • 验证损失
    • 实验内容:在对Mamba 130M模型进行持续预训练期间,追踪验证集上的语言模型损失(LLM Loss)。
    • 实验结果:如图4所示,Mamba基础模型和所有基线变体(如Mamba+D, Mamba+SW等)的验证损失要么出现训练不稳定,要么在训练早期就趋于平缓。相比之下,本文提出的Mamba+HAX是唯一一个在整个50K步训练过程中验证损失持续下降的架构。
    • 分析结论:HAX的集成显著提高了模型的训练稳定性和持续学习的能力。


图 4: 在Mamba 130M模型上进行持续预训练期间的验证LLM损失。与HAX集成的Mamba是唯一在整个训练过程中验证损失持续下降的架构。

  • Ruler和LongBench评估
    • 实验内容:将持续预训练并指令微调后的Mamba 790M模型在Ruler和LongBench两个长上下文NLP基准上进行评估。Ruler是合成基准,评估检索、多跳推理等能力;LongBench是真实世界任务基准,包括问答、摘要等。
    • 实验结果:如表2和表3所示,在所有混合稀疏注意力变体中,HAX模型是唯一在平均性能上显著优于Mamba基线的模型。在Ruler上,HAX的平均分达到44.57,高于基线的42.43。在LongBench上,HAX的平均分达到19.95,高于基线的19.79。
    • 分析结论:HAX在合成和真实世界的长上下文任务中都表现出优越的性能。

表 2: 在2K上下文长度下的Ruler基准测试结果。我们比较了与Mamba 790M模型集成的不同稀疏注意力,包括CISA方法:扩张注意力(D)、滑动窗口注意力(SW)及其组合(SW+D),以及A形注意力(A),和CDSA方法:LSH注意力,以及我们提出的密钥选择(KS)注意力和HAX。最佳平均性能以粗体显示。

表 3: LongBench英语任务。我们比较了与Mamba 790M模型集成的不同稀疏注意力,包括CISA和CDSA方法。最佳平均性能以粗体显示。

  • 外推评估
    • 实验内容:将在2K上下文长度上训练的模型,直接在4K上下文长度的Ruler基准上进行评估,以测试其外推能力。
    • 实验结果:如表4所示,HAX模型再次持续优于所有基线,平均分达到28.20,高于Mamba基线的27.04。
    • 分析结论:HAX模型展现出强大的上下文长度外推鲁棒性。

表 4: 在Ruler基准上从2K上下文长度外推到4K的性能。我们比较了与Mamba 790M模型集成的不同稀疏注意力,包括表2中的CISA和CDSA方法。最佳平均性能以粗体显示。

A7 补充细节

6.1 状态空间模型

SSM的发展历程。状态空间模型(SSM)起源于控制理论,如阻尼弹簧系统【索引29,Mamba-360: Survey of state space models as transformer alternative for long sequence modelling: Methods, applications, and challenges,2024】。HiPPO【索引15,Hippo: Recurrent memory with optimal polynomial projections,2020,NeurIPS】是早期将SSM应用于机器学习的尝试之一。LSSL【索引17,Combining recurrent, convolutional, and continuous-time models with linear state-space layers,2021,NeurIPS】在SSM框架下统一了卷积神经网络(CNN)、循环神经网络(RNN)和常微分方程(ODE),使其能在深度神经网络中实现。H3【索引10,Hungry hungry hippos: Towards language modeling with state space models,2023,ICLR】将SSM层与短卷积滤波器结合以增强序列建模能力。Mamba【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2024,CoLM】通过使所有SSM参数都依赖于输入,极大地提升了SSM的表示能力。Mamba2【索引4,Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,arXiv】进一步改进了架构,并建立了SSM与Transformer注意力之间的联系。

6.2 对状态空间模型的分析

SSM的局限性与表达能力分析。近期的实证研究表明,SSM在长上下文建模方面的性能通常落后于Transformer架构【索引40,An empirical study of mamba-based language models,2024】。【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】证明,即使是解决像复制这样的简单任务,SSM的状态维度也需要随序列长度线性增长。此外,【索引25,The illusion of state in state-space models,2024,ICML】指出SSM的表达能力受限于复杂度类TC0。【索引34,The expressive capacity of state space models: A formal language perspective,2024,NeurIPS】进一步表明,SSM和Transformer捕获了TC0中重叠但不同的子集,这为开发结合两者优点的混合模型提供了理论基础。

6.3 上下文相关的稀疏注意力

CDSA的相关工作。局部敏感哈希(LSH)注意力【索引20,Reformer: The efficient transformer,2020,ICLR】是应用最广泛的上下文相关稀疏注意力之一。后续的LSH注意力变体大致可按其分桶规则分类:argmax分桶规则被用于【索引20,Reformer: The efficient transformer,2020,ICLR;索引35,Sparse attention with learning to hash,2022,ICLR】,而sign-bit分桶规则被用于【索引8,Memoryformer: Minimize transformer computation by removing fully-connected layers,2024,NeurIPS;索引5,Hashattention: Semantic sparsity for faster inference,2025,ICML】。【索引28,Faster causal attention over large sequences through sparse flash attention,2023,NeurIPS】引入了一个Triton内核来加速基于哈希的稀疏注意力。近期,【索引43,Native sparse attention: Hardware-aligned and natively trainable sparse attention,2025,ACL】提出了native sparse attention,这是一种新的上下文相关稀疏注意力,在从头开始训练时性能优于Transformer,显示了其在高效长上下文建模中的潜力。

6.4 混合架构

结合SSM与其他机制的混合架构。一些工作探索了将大量SSM层与少量全注意力层混合的架构,并报告了超越标准Transformer的性能【索引40,An empirical study of mamba-based language models,2024】。这类混合架构的有效性已在十亿参数规模上得到进一步验证【索引21,Jamba: Hybrid transformer-mamba language models,2025,ICLR】。与此同时,研究人员也探索了混合稀疏注意力模型的设计【索引33,Sparse modular activation for efficient sequence modeling,2023,NeurIPS;索引32,Samba: Simple hybrid state space models for efficient unlimited context language modeling,2025,ICLR;索引9,Hymba: A hybrid-head architecture for small language models,2025,ICLR;索引26,Expansion span: Combining fading memory and retrieval in hybrid state space models,2024;索引44,B’MOJO: Hybrid state space realizations of foundation models with eidetic and fading memory,2024,NeurIPS】,这些模型提供了亚二次方的计算复杂度,为高效长上下文建模提供了一个有前景的方向。

A5 结论

本文引入了联合回忆,一个将关联回忆推广到上下文相关键值检索的新颖合成任务。理论上,我们证明了SSM和与上下文无关稀疏注意力(CISA)集成的SSM都无法在亚二次方时间内解决多查询联合回忆任务,而与上下文相关稀疏注意力(CDSA)集成的SSM则能够克服这一限制。基于这一理论洞见,我们提出了将最先进的SSM与一种新颖的CDSA——带有稀疏密钥选择的局部敏感哈希注意力(HAX)——相结合。实验结果证实,HAX在合成和真实世界的长上下文NLP基准测试中,均取得了改进的训练稳定性和持续的性能提升。

因此,联合回忆任务为长上下文建模提供了一个统一的理论视角和实证衡量标准,而HAX则展示了理论驱动的架构设计的力量。这些结果突显了将模型设计与表达能力提升相结合的重要性,并表明将高效序列模型与CDSA结合是实现可扩展长上下文建模的一个有前景的方向。

A6 附录


图 5: 自回归多查询联合回忆任务的输入和输出组件。输入序列进一步分为信息组件和查询组件。

A Argmax和Sign-Bit LSH分桶规则之间的关系

证明sign-bit规则是argmax规则的特例。本节展示了sign-bit LSH分桶规则(公式7)如何被解释为应用于一个具有 $2^h$ 列的扩展投影矩阵的argmax LSH分桶规则(公式6)。我们首先构建扩展矩阵,然后证明其等价性。

扩展投影矩阵。设原始随机投影为 $H = [H_1, ..., H_h] \in \mathbb{R}^{d \times h}$,$H_j \stackrel{\text{i.i.d.}}{\sim} \mathcal{N}(0, 1)$。定义一个包含 $2^h$ 个带符号原型的码本:

$$\mathcal{B} = \left\{ \mathbf{B}_{\mathbf{b}} = \sum\nolimits_{j=1}^{h} b_{j} \mathbf{H}_{j} \mid \mathbf{b} = (b_{1}, \ldots, b_{h}) \in \{-1, +1\}^{h} \right\} \subset \mathbb{R}^{d}.$$

将所有 $B_b$ 作为列堆叠起来,得到隐式矩阵 $\tilde{H} \in \mathbb{R}^{d \times 2^h}$。

两种分桶规则的等价性。对于一个归一化的查询向量 $\tilde{Q}_i$,我们定义其符号投影为 $s = \text{sign}(\tilde{Q}_i^\top H) \in \{-1, +1\}^h$。$\tilde{Q}_i$ 与一个原型 $B_b \in B$ 的内积为:

$$\langle\tilde{\mathbf{Q}}_i, \mathbf{B}_{\mathbf{b}}\rangle = \sum_{j=1}^h b_j \langle\tilde{\mathbf{Q}}_i, \mathbf{H}_j\rangle.$$

因为每一项中若 $b_j \neq s_j$,则会翻转正值 $|\langle \tilde{Q}_i, H_j \rangle|$ 的符号,所以对于所有 $b \neq s$,我们有严格不等式 $\langle \tilde{Q}_i, B_s \rangle > \langle \tilde{Q}_i, B_b \rangle$。因此:

$$\operatorname{argmax}_{\mathbf{b} \in\{-1,+1\}^{h}}\left\langle\tilde{\mathbf{Q}}_{i}, \mathbf{B}_{\mathbf{b}}\right\rangle=\mathbf{s}=\operatorname{bin}_{Q_{i}}^{(\operatorname{sign})},$$

sign-bit的分配规则恰好是应用于 $\tilde{H}$ 的argmax规则。同样的论证对键 $\tilde{K}_j$ 也成立。因此,sign-bit方法等价于具有 $2^h$ 个(扩展的)列的argmax方法。

B 理论证明

多查询联合回忆要求模型回忆一个 $n_c \times n_k$ 的上下文特定键值关联表,其中 $n_k$ 个键在 $n_c$ 个上下文中的每一个都关联不同的值,总条目数为 $n = n_c \times n_k$。为了清晰起见,我们为自回归设置下的多查询联合回忆引入一些额外的符号,如附录图5所示。输入序列被分为一个信息组件和一个查询组件。信息组件提供了上下文特定的键值关联。查询组件对信息组件中的上下文和键的顺序进行排列,模型需要根据每个指定上下文下的每个键来预测对应的值。


图 6: SSM、混合扩张注意力模型和混合局部敏感哈希(LSH)注意力模型在联合回忆上的比较。通过选择性地绕过不相关的上下文,稀疏注意力缓解了SSM层的记忆过载,并增强了混合模型检索相关信息的能力。

B.1 推论3.1的证明

推论 3.1 (SSM的有限表达能力)。在均匀多查询联合回忆分布下,对于任何 $n$,一个在2.3节中定义的广义状态空间模型会产生至少 $1 - \frac{|V|^n}{|U|}$ 的错误率。

证明的直觉。推论3.1证明背后的直觉很简单:一个模型能够准确表示的可能联合回忆数据实例的数量,从根本上受限于其循环状态的容量。由于输入序列的所有信息在处理过程中必须被编码到一个固定的循环状态中,因此可区分输出的总数受限于状态空间的表示容量 $|U|$。因此,如果输出空间的大小 $|V|^n$ 超过 $|U|$,模型不可避免地会产生不可忽略的错误。作为【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】中定理2.7的直接推论,我们采用其证明策略。我们将【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】中的引理D.1重新表述为以下引理B.1。设 $m$ 为信息组件中最后一个词元的索引。那么,对于查询组件中上下文和键的任何固定排列 $P$,以下引理B.1成立:

引理 B.1。设 $M$ 是一个固态广义SSM,它在任何固定的上下文和键的排列 $P$ 下,将联合回忆输入空间 $X$ 映射到输出空间 $V^n$。那么存在一个函数 $G: U \to V^n$,使得对于所有输入 $x \in X$,模型输出满足 $M(x) = G(U_m(x))$,$U_m$ 在公式10中定义。

遵循【索引19,Repeat after me: Transformers are better than state space models at copying,2024,ICML】,我们通过比较可能的模型状态数量与不同输入实例的数量来界定模型的错误。

证明

$$1-\operatorname{Pr}[e r r]=\operatorname{Pr}\left[\mathcal{M}(\mathbf{x})=\mathbf{y} \mid \mathbf{y} \in \mathcal{V}^n\right]$$ $$= \frac{1}{|\mathcal{V}|^n} \sum_{\mathbf{y} \in \mathcal{V}^n} \mathbb{1}[\mathcal{M}(\mathbf{x}) = \mathbf{y}]$$ $$ = \frac{1}{|\mathcal{V}|^n} \sum_{\mathbf{y} \in \mathcal{V}^n} \sum_{\mathbf{u} \in \mathcal{U}} \mathbf{1}[G(\mathbf{u})=\mathbf{y}] \cdot \mathbf{1}\left[U_m(\mathbf{x})=\mathbf{u}\right] $$ $$\le \frac{1}{|\mathcal{V}|^n} \sum_{\mathbf{u} \in \mathcal{U}} \mathbf{1}[U_m(\mathbf{x}) = \mathbf{u}]$$ $$\le \frac{|\mathcal{U}|}{|\mathcal{V}|^n}$$

为了保证 $Pr[err] = 0$,状态可表示的数量必须满足 $|U| \geq |V|^n$。对两边取对数得到条件 $b \geq n \log|V|$。这意味着模型的状态空间维度必须随联合回忆表中的条目数 $n$ 线性增长,突显了SSM表示能力的一个基本限制。相比之下,如图6所示,混合稀疏注意力模型通过允许多条并行路径传播信息来缓解这一限制,从而减轻了由顺序状态更新造成的瓶颈。

B.2 命题3.2的证明

命题 3.2 (SSM与CDSA集成后的表达能力提升)。存在一个由SSM层后跟一个LSH注意力层组成的2层自回归混合模型,它可以在 $O(n \log^2 n)$ 的时间复杂度内,以 $O(\log n)$ 的SSM状态维度解决多查询联合回忆任务。

构造性证明。我们通过构造来证明。在第一层,我们期望SSM状态将每个值词元表示与其关联的键词元表示和上下文词元表示连接起来。具体来说,我们期望在每个值词元处的SSM状态表示为:

$$[\mathbf{c}, \mathbf{k}, \mathbf{v}, is\_v] \in \mathcal{U}$$

其中 $c$ 是当前关联上下文词元的表示,$k$ 是当前关联键词元的表示,$v$ 是最近的值词元表示。is_v 是一个符号指示符(-1或1),指定当前词元是否为值词元。

为了实现这一点,我们首先将每个向量 $c, k, v$ 构造为一个不同的 $b$ 维单位范数向量且没有零元素,即 $\|c\|_2 = 1, \|k\|_2 = 1, \|v\|_2 = 1$,且对所有 $j$,$c_j \neq 0, k_j \neq 0, v_j \neq 0$。由于可以从单位球面上抽取的不同向量数量随维度呈指数增长,因此 $O(\log n)$ 的嵌入维度足以确保所有表示都是可区分的。然后我们定义一个嵌入空间,其中上下文和值词元被映射到结构化表示。具体来说,一个上下文词元被嵌入为:

$$\mathbf{e}=[\mathbf{c}, 0,0,-1]$$

其中 $c$ 是这个上下文词元在单位球面上的构造表示,最后一个坐标设为-1,表示当前词元不是一个值。类似地,一个键词元被嵌入为:

$$ \mathbf{e} = [0, \mathrm{k}, 0, -1] $$

一个值词元被嵌入为:

$$\mathbf{e}=[0,0, \mathbf{v}, 1]$$

其中 $k$ 和 $v$ 分别是来自单位球面的键和值表示,最后一个坐标仅在当前词元是值词元时设为1。遵循公式10,我们将广义SSM的更新规则定义如下:

$$\begin{aligned} \begin{aligned} U_i &= u(U_{i-1}, \mathbf{e}) = U_{i-1} \odot \mathbf{1}[\mathbf{e}_j = 0] + \mathbf{e} \odot \mathbf{1}[\mathbf{e}_j \neq 0] \\ R_i &= r(U_i) = U_i \end{aligned} \end{aligned}$$

其中 $e$ 是当前输入嵌入,$e_j$ 指其第 $j$ 维。该更新规则作为条件覆盖操作:如果一个位置不携带信息(即 $e$ 中对应的维度为0),则保留先前的状态;否则,用当前嵌入更新它。遵循此更新规则,信息组件中每个值词元处的SSM状态形式为:

$[\mathbf{c}, \mathbf{k}, \mathbf{v}, 1]$

而每个键词元处的SSM状态形式为:

$[\mathbf{c}, \mathbf{k}, ?, -1]$

在第二层,LSH注意力在SSM状态 $[c, k, v, \text{is_v}] \in U$ 上操作,使用 $[c, k, 0, \text{is_v}]$ 作为LSH注意力的键表示,$[c, k, 0, 1]$ 作为LSH注意力的查询表示,以及 $[0, 0, v, 1]$ 作为LSH注意力的值表示。这个设计确保了信息组件中的值词元和查询组件中的键词元,如果共享相同的上下文和键(即SSM状态中相同的 $c$ 和 $k$ 表示),将总是被分配到同一个哈希桶。通过足够数量的哈希桶(例如 $O(n)$),这可以使用带有 $O(\log n)$ 个随机投影向量的sign-bit分桶规则高效构建,与每个指定上下文中的每个键关联的值可以被LSH注意力可靠地检索。这一步是计算的瓶颈,时间复杂度为 $O(n \log^2 n)$。□

B.3 命题3.3的证明

命题 3.3 (SSM与CISA集成后的有限表达能力)。不存在一个由SSM层后跟一个CISA层组成的2层自回归混合模型,它能以 $o(n^2)$ 的时间复杂度解决多查询联合回忆任务,因为它需要至少 $O(\frac{n_k}{k})$ 的SSM状态维度,其中 $k$ 是稀疏注意力模块中每个查询允许关注的最大键数,如公式4所定义。

证明。考虑自回归联合回忆任务查询组件中给定的一个键。当这个键词元作为输入时,模型需要输出相关联的值词元。从这个键词元获取查询表示后,稀疏注意力最多可以关注来自先前词元的 $k$ 个键表示,其中键表示是基于第一层SSM的状态表示计算的。要解决联合回忆任务,这 $k$ 个被关注的键表示必须共同编码完整的信息组件。由于完整信息组件的长度是 $O(n)$,根据推论3.1,广义SSM在第一层的 $k$ 个状态表示必须使用至少 $O(\frac{n}{k})$ 的维度来共同存储信息组件。因此,第二层稀疏注意力所需的每个键的计算成本是 $O(k \cdot \frac{n}{k}) = O(n)$,所以总时间复杂度是 $O(n^2)$。□

比较命题3.2和命题3.3,我们看到SSM与CDSA集成和SSM与CISA集成之间存在明显的表示能力差距。然而在实践中,对于一个合适的常数 $k$,将CISA与SSM集成仍然有优势:与只能访问最后一个状态表示的SSM不同,CISA层可以同时关注 $k$ 个不同的状态表示,代价是 $k$ 倍的计算预算。


图 7: 自然语言中多级上下文的一个例子。

C 将联合回忆扩展到多级上下文

多级上下文的动机。如图7所示,在许多情况下,自然语言上下文表现出层次依赖性。这促使我们将联合回忆扩展到多级上下文设置,其中我们将键视为最后一级的上下文。

C.1 公式化

多级上下文联合回忆的定义。给定 $w$ 个不同级别的上下文词汇表 $C_1, C_2, ..., C_w$ 和词元词汇表 $V$,多级上下文联合回忆要求模型恢复映射 $C_1 \times C_2 \times ... \times C_w \to V$。多级上下文联合回忆的上下文结构类似于自然语言,具有层次性。它被分为 $|C_1|$ 个一级块,每个一级块以来自一级上下文词汇表 $C_1$ 的词元开始。每个一级块又被进一步划分为以来自二级上下文词汇表 $C_2$ 的二级上下文词元开始的子块,这种递归划分一直持续到第 $w$ 级。最后一级块由一个第 $w$ 级上下文词元后跟一个来自 $V$ 的值词元组成。注意,关联回忆是 $w=1$ 的多级联合回忆的特例,而联合回忆是 $w=2$ 的多级联合回忆的特例。我们类似地定义了多查询多级上下文联合回忆,其中模型需要回忆出完整的上下文-值超表中所有 $n = |C_1| \times |C_2| \times ... \times |C_w|$ 个条目。

C.2 SSM与CDSA在多级上下文联合回忆上的表达能力

SSM+CDSA在多级上下文任务上的能力。在多查询多级上下文联合回忆上,推论3.1和命题3.3在相同假设下仍然成立。我们现在将命题3.2扩展为以下命题C.1,它表明即使存在 $w$ 级的层次上下文,集成了CDSA的SSM仍然具有表达能力。

命题 C.1 (SSM与CDSA在多级上下文联合回忆上的表达能力)。存在一个由SSM层后跟一个LSH注意力层组成的2层自回归混合模型,它可以在 $O(wn \log^2 n)$ 的时间复杂度内,以 $O(w \log n)$ 的SSM状态维度解决多查询多级上下文联合回忆任务。

构造性证明。与命题3.2类似,我们通过构造来证明。在第一层,我们希望SSM状态由上下文和值表示组成:

$$[\mathbf{z}_{1}, \mathbf{z}_{2}, \dots, \mathbf{z}_{w}, \mathbf{v}, i s_{-} v] \in \mathcal{U}$$

其中 $z_i$ 是最近的第 $i$ 级上下文词元的表示,$v$ 表示最近的值词元,is_v 是一个符号指示符(-1或1),指定当前词元是否为值词元。

为实现此目的,我们类似地将每个向量 $z_i$ 和 $v$ 构造为一个不同的 $b$ 维单位范数向量且没有零元素,即 $\|z_i\|_2 = 1, \|v\|_2 = 1$,且对所有 $j$,$z_{ij} \neq 0, v_j \neq 0$。由于可以从单位球面上抽取的不同向量数量随维度呈指数增长,因此 $O(\log n)$ 的嵌入维度足以确保所有表示都是可区分的。因此,SSM状态的总大小为 $O(w \log n)$。

然后我们定义一个嵌入空间,其中上下文和值词元被映射到结构化表示。具体来说,一个第 $i$ 级的上下文词元被嵌入为:

$$ \mathbf{e} = [\mathbf{0}, ..., \mathbf{0}, \mathbf{z}_i, \mathbf{0}, ..., -1] $$

其中 $z_i$ 是上下文词元表示,最后一个坐标设为-1以表示该词元不是值。类似地,一个值词元被嵌入为:

$$\mathbf{e}=[0,0, \ldots, 0, \mathbf{v}, 1]$$

其中 $v$ 是值词元表示,最后一个坐标设为1以标记其为值词元。我们保持广义SSM的更新规则如公式29:

$$\begin{aligned} \begin{aligned} U_i &= u(U_{i-1}, \mathbf{e}) = U_{i-1} \odot \mathbf{1}[\mathbf{e}_j = 0] + \mathbf{e} \odot \mathbf{1}[\mathbf{e}_j \neq 0] \\ R_i &= r(U_i) = U_i \end{aligned} \end{aligned}$$

它作为条件覆盖操作。遵循此更新规则,每个值词元处的SSM状态形式为:

$$[\mathbf{z}_{1}, \mathbf{z}_{2}, \dots, \mathbf{z}_{w}, \mathbf{v}, 1]$$

其中 $z_1, z_2, ..., z_w$ 是当前词元的 $w$ 级上下文表示,$v$ 是值表示,最后一个维度设为1以表示当前词元是值词元。

在第二层,LSH注意力在SSM状态 $[z_1, z_2, ..., z_w, v, \text{is_v}]$ 上操作,使用 $[z_1, z_2, ..., z_w, 0, \text{is_v}]$ 作为LSH注意力的键表示,$[z_1, z_2, ..., z_w, 0, 1]$ 作为LSH注意力的查询表示,以及 $[0, 0, ..., v, 1]$ 作为LSH注意力的值表示。这个设计确保了共享相同上下文(即SSM状态中相同的 $[z_1, z_2, ..., z_w]$)的词元将总是被分配到同一个哈希桶。通过足够数量的哈希桶(例如 $O(n)$),这可以使用带有 $O(\log n)$ 个随机投影向量的sign-bit分桶规则高效构建,与相同键组合关联的值表示可以被LSH注意力可靠地检索。这一步是计算的瓶颈,时间复杂度为 $O(wn \log^2 n)$。

该构造证明了由广义SSM后跟LSH注意力组成的2层混合模型可以高效地解决多查询多级联合回忆任务,其时间复杂度相对于输入序列长度为亚二次方,状态复杂度为亚线性。

D 实验细节

D.1 联合回忆的实证验证

详细配置。对于所有模型,我们将层数固定为2,隐藏大小设为64,并使用 $k=64$。对于集成多个稀疏注意力组件的变体,即SW+D、A(由一个SW组件和一个只关注序列前 $k$ 个词元的sink attention组件组成)和HAX(LSH+KS),我们为每个组件分配 $k=32$,以保持全局 $k=64$。对于LSH和HAX(LSH+KS),我们采用sign-bit分桶策略(公式7),其中 $h=8$,并在每个训练步骤刷新随机哈希矩阵。为了弥补没有稀疏注意力机制的不足,我们额外评估了隐藏大小加倍至128的Mamba和Mamba2基线,以确保公平比较。此外,我们还包括一个Samba基线,由2个Mamba层和2个滑动窗口注意力层组成。对于Samba,隐藏大小和滑动窗口宽度($k$)都设为64。我们使用AdamW优化器。所有模型都使用64的批量大小训练40万步。我们的实现基于Flash-Linear-Attention [42]。

D.2 自然语言上的持续预训练

详细配置。对于所有实验,我们将稀疏性参数 $k$ 固定为128。对于集成多个稀疏注意力组件的变体,即SW+D、A和HAX(LSH+KS),我们为每个组件设置 $k=64$,以保持全局 $k=128$。对于LSH和HAX(LSH+KS),我们采用argmax分桶策略(公式6),其中 $h=k$。我们在每个训练步骤重新采样随机哈希矩阵,并在评估期间固定一个随机哈希矩阵。

训练细节。为了增强长上下文建模能力,我们筛选样本,只保留词元化长度至少为2,048的样本。在持续预训练开始时,K和Q的投影权重分别使用Mamba的B和C投影参数进行初始化,这是基于状态空间对偶性【索引4,Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,arXiv】。在持续预训练期间,我们对130M模型使用余弦学习率调度,最大学习率为3e-4,对790M模型为1.5e-4,最小学习率为1e-5。在余弦调度之前,有一个200步的学习率为0的热身阶段,然后是800步的线性预热。对于指令微调,我们也应用了200步的0学习率阶段,然后是800步的线性预热,之后学习率保持在3e-6不变。对于持续预训练和指令微调,我们都使用AdamW优化器,并以64的全局批量大小和2K的上下文长度进行训练。

E 额外实验

E.1 短上下文建模基准

短上下文性能评估。我们遵循Mamba【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2024,CoLM】的做法,在EleutherAI的LM评估套件基准【索引12,A framework for few-shot language model evaluation,2024】上评估持续预训练模型的零样本短上下文建模性能。我们在表5中的结果显示,在长序列上进行持续预训练不会导致短上下文基准性能的显著下降,其中未进行持续预训练的Mamba结果摘自Mamba论文【索引14,Mamba: Linear-time sequence modeling with selective state spaces,2024,CoLM】。

表 5: 持续预训练Mamba模型的LM评估套件基准。我们比较了与Mamba集成的不同稀疏注意力,包括表2中的CISA和CDSA方法。


图 8: 输入相关和输入无关稀疏注意力模式的示例。


图 9: LLM中“垂直条纹”注意力模式的一个例子。我们将维基百科词条“哈利·波特”【索引41,Harry Potter,2025】的第一段输入到Llama 3.2 1B模型【索引13,The llama 3 herd of models,2024】中,并可视化第一层最后一个头的对数注意力概率。输入文本是:“哈利·波特是由英国作家J.K.罗琳创作的七部奇幻小说系列。小说记录了年轻巫师哈利·波特和他的朋友罗恩·韦斯莱和赫敏·格兰杰在霍格沃茨魔法学校的生活。主要故事情节涉及哈利与伏地魔的冲突,伏地魔是一个意图变得不朽、推翻被称为魔法部的巫师管理机构,并征服所有巫师和麻瓜(非魔法人士)的黑巫师。”