作者/机构: Jingyang Yuan∗1,2, Huazuo Gao1, Damai Dai1, Junyu Luo2, Liang Zhao1, Zhengyan Zhang1, Zhenda Xie1, Y. X. Wei1, Lean Wang1, Zhiping Xiao3, Yuqing Wang1, Chong Ruan1, Ming Zhang2, Wenfeng Liang1, Wangding Zeng1

1DeepSeek-AI 2北京大学计算机学院多媒体信息处理重点实验室, PKU-Anker LLM实验室 3华盛顿大学 {yuanjy, mzhang_cs}@pku.edu.cn, {zengwangding, wenfeng.liang}@deepseek.com

A1 主要贡献

本文提出了一种名为NSA(Natively trainable Sparse Attention,原生可训练稀疏注意力)的机制,旨在解决长上下文语言模型中标准注意力机制计算成本高昂的挑战。
- 核心问题: 随着序列长度的增加,标准注意力机制的二次复杂度成为延迟的关键瓶颈,尤其在解码阶段,当上下文长度达到64k时,注意力计算可占总延迟的70-80%。现有的稀疏注意力方法存在两个主要问题:1) 硬件对齐的推理加速不足,理论上的计算量减少难以转化为实际的延迟降低;2) 缺乏有效的训练支持,大多在推理时应用稀疏性,导致性能下降或无法降低训练成本。
- 研究目标: 设计一种新型的稀疏注意力架构,该架构需同时满足硬件效率和端到端可训练性,以实现高效的长上下文建模,同时不牺牲模型性能。
- 核心创新点:
1. 动态分层稀疏策略: NSA采用了一种结合粗粒度令牌压缩和细粒度令牌选择的动态分层稀疏策略。它将键(keys)和值(values)组织成时间块,并通过三个并行的注意力路径处理:用于捕获粗粒度模式的压缩注意力、用于保留重要令牌块的选择性注意力,以及用于局部上下文的滑动窗口注意力。这种设计旨在同时保持全局上下文感知和局部精度。
2. 硬件对齐的系统设计: 通过平衡算术强度的算法设计,并为现代硬件(如Tensor Cores)实现专门的优化内核,将理论上的计算减少转化为实际的速度提升。这确保了在训练、预填充和解码等各个阶段都能获得显著的加速。
3. 训练感知的设计: NSA通过高效的算法和反向传播算子实现了稳定的端到端训练。这使得模型能够在预训练阶段就学习稀疏模式,从而降低了预训练的计算成本,而不会牺牲模型在通用基准、长上下文任务和指令推理等方面的性能。

图1 | 全注意力模型与我们的NSA在性能和效率上的比较。左图:尽管是稀疏的,NSA在通用基准、长上下文任务和推理评估中的平均表现超过了全注意力基线。右图:对于64k长度的序列处理,NSA在解码、前向传播和反向传播所有阶段都比全注意力实现了显著的计算加速。
图1 | 全注意力模型与我们的NSA在性能和效率上的比较。左图:尽管是稀疏的,NSA在通用基准、长上下文任务和推理评估中的平均表现超过了全注意力基线。右图:对于64k长度的序列处理,NSA在解码、前向传播和反向传播所有阶段都比全注意力实现了显著的计算加速。

A3 对稀疏注意力方法的反思

现代稀疏注意力方法在降低Transformer模型的理论计算复杂度方面取得了显著进展。然而,大多数方法主要在推理阶段应用稀疏性,同时保留预训练的全注意力主干,这可能会引入架构偏差,限制其充分利用稀疏注意力的优势。在介绍我们的原生稀疏架构之前,我们从两个关键角度系统地分析了这些局限性。

2.1 高效推理的幻象

稀疏注意力方法在实际推理中面临的延迟挑战 尽管在注意力计算中实现了稀疏性,但许多方法未能相应地减少推理延迟,这主要源于两个挑战:

阶段限制的稀疏性 像H2O【索引43,H2o: Heavy-hitter oracle for efficient generative inference of large language models,NeurIPS 2023】这样的方法在自回归解码期间应用稀疏性,但在预填充阶段需要计算密集型的预处理(例如,注意力图计算、索引构建)。相反,像MInference【索引21,Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention,arXiv 2024】这样的方法仅关注预填充阶段的稀疏性。这些方法未能实现所有推理阶段的加速,因为至少有一个阶段的计算成本与全注意力相当。这种阶段专业化降低了这些方法在以预填充为主的工作负载(如书籍摘要和代码补全)或以解码为主的工作负载(如长链式思维【索引34,Chainof-thought prompting elicits reasoning in large language models,NeurIPS 2022】推理)中的加速能力。

与先进注意力架构的不兼容性 一些稀疏注意力方法未能适应现代高效解码架构,如多查询注意力(MQA)【索引28,Fast transformer decoding: One write-head is all you need,CoRR 2019】和分组查询注意力(GQA)【索引1,Gqa: Training generalized multi-query transformer models from multi-head checkpoints,arXiv 2023】,这些架构通过在多个查询头之间共享KV显著减少了解码期间的内存访问瓶颈。例如,在像Quest【索引30,Quest: Query-aware sparsity for efficient long-context llm inference,arXiv 2024】这样的方法中,每个注意力头独立选择其KV缓存子集。尽管它在多头注意力(MHA)模型中表现出一致的计算稀疏性和内存访问稀疏性,但在基于GQA等架构的模型中情况则不同,其中KV缓存的内存访问量对应于同一GQA组内所有查询头选择的并集。这种架构特性意味着,虽然这些方法可以减少计算操作,但所需的KV缓存内存访问仍然相对较高。这一限制迫使我们做出关键选择:虽然一些稀疏注意力方法减少了计算量,但它们分散的内存访问模式与先进架构的高效内存访问设计相冲突。

现有方法的局限性与设计动机 这些局限性的产生是因为许多现有的稀疏注意力方法专注于KV缓存的缩减或理论计算的减少,但在先进的框架或后端中难以实现显著的延迟降低。这激励我们开发结合先进架构和硬件高效实现的算法,以充分利用稀疏性来提高模型效率。

2.2 可训练稀疏性的迷思

原生可训练稀疏注意力的动机 我们对原生可训练稀疏注意力的追求源于对仅推理方法的两个关键洞察:(1)性能下降:事后应用稀疏性迫使模型偏离其预训练的优化轨迹。正如Chen等人【索引7,Magicpig: Lsh sampling for efficient llm generation,arXiv 2024b】所证明的,前20%的注意力只能覆盖总注意力分数的70%,这使得预训练模型中的检索头等结构在推理过程中容易被剪枝。(2)训练效率需求:高效处理长序列训练对于现代LLM开发至关重要。这包括在更长的文档上进行预训练以增强模型能力,以及后续的适应阶段,如长上下文微调和强化学习。然而,现有的稀疏注意力方法主要针对推理,使得训练中的计算挑战在很大程度上未得到解决。这一限制阻碍了通过高效训练开发更强大的长上下文模型。此外,尝试将现有稀疏注意力用于训练也暴露出一些挑战:

不可训练组件 像ClusterKV【索引25,Clusterkv: Manipulating llm kv cache in semantic space for recallable compression,arXiv 2024】(包括k-means聚类)和MagicPIG【索引7,Magicpig: Lsh sampling for efficient llm generation,arXiv 2024b】(包括基于SimHash的选择)等方法中的离散操作在计算图中造成了不连续性。这些不可训练的组件阻碍了梯度流通过令牌选择过程,限制了模型学习最佳稀疏模式的能力。

低效的反向传播 一些理论上可训练的稀疏注意力方法在实际训练中效率低下。像HashAttention【索引12,Hashattention: Semantic sparsity for faster inference,arXiv 2024】这样的方法中使用的令牌粒度选择策略,导致在注意力计算期间需要从KV缓存中加载大量单个令牌。这种非连续的内存访问妨碍了像FlashAttention这样依赖连续内存访问和块状计算以实现高吞吐量的快速注意力技术的高效应用。结果,实现被迫退回到低硬件利用率的方案,显著降低了训练效率。

2.3 原生稀疏性势在必行

NSA的设计动机 这些在推理效率和训练可行性方面的局限性,促使我们从根本上重新设计稀疏注意力机制。我们提出了NSA,一个原生的稀疏注意力框架,同时解决了计算效率和训练需求。在接下来的章节中,我们将详细介绍NSA的算法设计和算子实现。

A2 方法细节

我们的技术方法涵盖了算法设计和核函数优化。在接下来的子章节中,我们首先介绍我们方法的背景知识,然后介绍NSA的整体框架及其关键算法组件。最后,我们将详细介绍我们为最大化实际效率而进行的硬件优化核函数设计。

3.1 背景

注意力机制 在语言建模中被广泛使用,其中每个查询令牌 $q_i$ 计算与所有先前键 $k_{1:i}$ 的相关性分数,以生成值 $v_{1:i}$ 的加权和。形式上,对于长度为 $i$ 的输入序列,注意力操作定义为:

公式1
公式1

其中Attn表示注意力函数:

公式2
公式2

这里,$A_{i,j}$ 表示 $q_i$ 和 $k_j$ 之间的注意力权重,$d_k$ 是键的特征维度。随着序列长度的增加,注意力计算在总计算成本中变得越来越占主导地位,给长上下文处理带来了重大挑战。

算术强度 是计算操作与内存访问的比率。它内在地影响着硬件上的算法优化。每个GPU都有一个由其峰值计算能力和内存带宽决定的临界算术强度,计算为这两个硬件限制的比率。对于计算任务,算术强度高于此临界阈值则变为计算密集型(受GPU FLOPS限制),而低于此阈值则变为内存密集型(受内存带宽限制)。

因果自注意力机制的算术强度特性 对于因果自注意力机制,在训练和预填充阶段,批处理的矩阵乘法和注意力计算表现出高算术强度,使得这些阶段在现代加速器上是计算密集型的。相比之下,自回归解码则受内存带宽限制,因为它每次前向传播只生成一个令牌,却需要加载整个键值缓存,导致算术强度低。

不同阶段的优化目标 这导致了不同的优化目标——在训练和预填充期间降低计算成本,而在解码期间减少内存访问。

3.2 整体框架

引入紧凑的键值对表示 为了利用具有自然稀疏模式的注意力的潜力,我们建议将公式(1)中原始的键值对 $k_{1:i}, v_{1:i}$ 替换为更紧凑且信息密集的表示键值对 $\tilde{k}_i, \tilde{v}_i$,这是针对每个查询 $q_i$ 给定的。具体来说,我们正式定义优化的注意力输出如下:

公式3
公式3

公式4
公式4

其中 $\tilde{k}_i, \tilde{v}_i$ 是根据当前查询 $q_i$ 和上下文内存 $k_{1:i}, v_{1:i}$ 动态构建的。我们可以设计各种映射策略来获得不同类别的 $\tilde{k}_{i,c}, \tilde{v}_{i,c}$,并按如下方式组合它们:

公式5
公式5

NSA的三种映射策略 如图2所示,NSA有三种映射策略 $C = \{cmp, slc, win\}$,分别代表键和值的压缩、选择和滑动窗口。$\sigma_c \in [0, 1]$ 是相应策略 $c$ 的门控分数,通过一个MLP和sigmoid激活函数从输入特征中导出。让 $N_i$ 表示重映射的键/值的总数:

公式6
公式6

我们通过确保 $N_i \ll i$ 来维持一个高稀疏率。

图2 | NSA架构概览。左图:该框架通过三个并行的注意力分支处理输入序列:对于给定的查询,前面的键和值被处理成用于粗粒度模式的压缩注意力、用于重要令牌块的选择性注意力,以及用于局部上下文的滑动注意力。右图:每个分支产生的不同注意力模式的可视化。绿色区域表示需要计算注意力分数的区域,而白色区域表示可以跳过的区域。
图2 | NSA架构概览。左图:该框架通过三个并行的注意力分支处理输入序列:对于给定的查询,前面的键和值被处理成用于粗粒度模式的压缩注意力、用于重要令牌块的选择性注意力,以及用于局部上下文的滑动注意力。右图:每个分支产生的不同注意力模式的可视化。绿色区域表示需要计算注意力分数的区域,而白色区域表示可以跳过的区域。

3.3 算法设计

本小节中,我们将介绍我们的重映射策略 $\mathcal{M}_k$ 和 $\mathcal{M}_v$ 的设计:令牌压缩、令牌选择和滑动窗口。

3.3.1 令牌压缩

块级表示的生成 通过将连续的键或值块聚合成块级表示,我们获得了能够捕捉整个块信息的压缩键和值。形式上,压缩键表示定义为:

公式7
公式7

其中 $L$ 是块长度,$S$ 是相邻块之间的滑动步长,而 $F$ 是一个可学习的MLP,带有块内位置编码,用于将一个块内的键映射为单个压缩键。$\tilde{k}_{\text{cmp}, i} \in \mathbb{R}^{d_k \times \lfloor \frac{i-L}{S} \rfloor}$ 是由压缩键组成的张量。通常,我们采用 $S < L$ 来减轻信息碎片化。对于压缩值表示 $\tilde{v}_{\text{cmp}, i}$ 也存在类似的公式。压缩表示捕捉了更粗粒度的、更高层次的语义信息,并减少了注意力的计算负担。

3.3.2 令牌选择

细粒度信息保留 仅使用压缩的键和值可能会丢失重要的细粒度信息,这促使我们选择性地保留单个键和值。下面我们描述我们高效的令牌选择机制,该机制以低计算开销识别并保留最相关的令牌。

块状选择 我们的选择策略以空间上连续的块处理键和值序列,这受两个关键因素驱动:硬件效率考虑和注意力分数的内在分布模式。块状选择对于在现代GPU上实现高效计算至关重要。这是因为现代GPU架构对于连续块访问的吞吐量远高于基于随机索引的读取。此外,块状计算能够实现Tensor Cores的最佳利用。这种架构特性已使块状内存访问和计算成为高性能注意力实现的基本原则,如FlashAttention的基于块的设计所例证。块状选择遵循了注意力分数的内在分布模式。先前的工作【索引21,Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention,arXiv 2024】表明,注意力分数通常表现出空间连续性,表明相邻的键倾向于具有相似的重要性水平。我们在第6.2节的可视化也显示了这种空间连续模式。

块状选择的实现 为了实现块状选择,我们首先将键、值序列划分为选择块。为了识别注意力计算中最重要的块,我们需要为每个块分配重要性分数。下面我们介绍计算这些块级重要性分数的方法。

重要性分数计算 计算块重要性分数可能会引入显著的开销。幸运的是,压缩令牌的注意力计算会产生中间注意力分数,我们可以利用这些分数来推导出选择块的重要性分数,公式如下:

公式8
公式8

其中 $p_{\text{cmp},i} \in \mathbb{R}^{\lfloor \frac{i-L}{S} \rfloor+1}$ 是 $q_i$ 和压缩键 $\tilde{k}_{\text{cmp},i}$ 之间的注意力分数。令 $L'$ 表示选择块的大小。当压缩块和选择块共享相同的分块方案时,即 $L' = L = S$,我们可以直接通过 $p_{\text{slc},i} = p_{\text{cmp},i}$ 得到选择块的重要性分数 $p_{\text{slc},i}$。对于分块方案不同的情况,我们根据它们的空间关系推导选择块的重要性分数。给定 $S \le L', L | S$ 和 $S | L'$,我们有:

公式9
公式9

其中$[\cdot]$表示访问向量元素的索引操作符。对于采用GQA或MQA的模型,其中键值缓存在查询头之间共享,必须确保这些头之间的块选择一致,以最小化解码期间的KV缓存加载。组内各头共享的重要性分数正式定义为:

公式10
公式10

其中上标 $(h)$ 表示头的索引,$h_g$ 是每个组中的查询头数量。这种聚合确保了同一组内各头之间的块选择一致。

Top-k块选择 在获得选择块的重要性分数后,我们保留按块重要性分数降序排列的前k个稀疏块内的令牌,公式如下:

公式11
公式11

公式12
公式12

其中 $\text{rank}(\cdot)$ 表示降序排列中的排名位置,$\text{rank}=1$ 对应最高分,$\mathcal{I}_i$ 是所选块的索引集合,Cat表示拼接操作。$\tilde{k}_{\text{slc},i} \in \mathbb{R}^{d_k \times k L'}$ 是由压缩键组成的张量。对于细粒度值 $\tilde{v}_{\text{slc},i}$ 也适用类似的公式。然后,选择的键和值与 $q_i$ 一起参与注意力计算,如公式(5)所定义。

3.3.3 滑动窗口

解决局部模式主导问题 在注意力机制中,局部模式通常适应得更快,并可能主导学习过程,从而可能阻止模型有效地从压缩和选择令牌中学习。为了解决这个问题,我们引入了一个专门的滑动窗口分支,明确处理局部上下文,让其他分支(压缩和选择)专注于学习它们各自的特征,而不会被局部模式所捷径化。具体来说,我们在一个大小为 $W$ 的窗口中保留最近的令牌 $\tilde{k}_{\text{win},i} = k_{i-W:i}, \tilde{v}_{\text{win},i} = v_{i-W:i}$,并将不同信息源(压缩令牌、选择令牌、滑动窗口)的注意力计算隔离到独立的分支中。然后,这些分支的输出通过一个学习到的门控机制进行聚合。为了在边际计算开销下进一步防止跨注意力分支的捷径学习,我们为三个分支提供了独立的键和值。这种架构设计通过防止局部和长距离模式识别之间的梯度干扰,实现了稳定的学习,同时引入了最小的开销。

最终注意力输出 在获得所有三类键和值($\tilde{k}_{\text{cmp},i}, \tilde{v}_{\text{cmp},i}; \tilde{k}_{\text{slc},i}, \tilde{v}_{\text{slc},i};$ 和 $\tilde{k}_{\text{win},i}, \tilde{v}_{\text{win},i}$)之后,我们按照公式(5)计算最终的注意力输出。结合上述的压缩、选择和滑动窗口机制,这就构成了NSA完整的算法框架。

3.4 核函数设计

为训练和预填充实现FlashAttention级别的加速 为了在训练和预填充期间实现FlashAttention级别的加速,我们基于Triton实现了硬件对齐的稀疏注意力核函数。考虑到MHA对于解码是内存密集型且效率低下,我们遵循当前最先进的LLM,专注于具有共享KV缓存的架构,如GQA和MQA。虽然压缩和滑动窗口注意力的计算可以轻松与现有的FlashAttention-2核函数兼容,但我们为稀疏选择注意力引入了专门的核函数设计。如果我们遵循FlashAttention的策略,将时间上连续的查询块加载到SRAM中,将会导致低效的内存访问,因为一个块内的查询可能需要不相交的KV块。为了解决这个问题,我们的关键优化在于采用不同的查询分组策略:对于查询序列上的每个位置,我们将一个GQA组内的所有查询头(它们共享相同的稀疏KV块)加载到SRAM中。图3展示了我们前向传播的实现。

提出的核函数架构特点 提出的核函数架构具有以下关键特点:
1. 以组为中心的数据加载:对于每个内循环,加载位置 $i$ 处组内所有头的查询 $Q \in \mathbb{R}^{[h,d_q]}$ 及其共享的稀疏键/值块索引 $\mathcal{I}_i$。
2. 共享KV获取:在内循环中,按顺序将由 $\mathcal{I}_i$ 索引的连续键/值块加载到SRAM中,作为 $K \in \mathbb{R}^{[B_k,d_k]}, V \in \mathbb{R}^{[B_k,d_v]}$,以最小化内存加载,其中 $B_k$ 是满足 $B_k | L'$ 的核函数块大小。
3. 网格上的外循环:由于内循环长度(与所选块数量 $k$ 成比例)对于不同的查询块几乎相同,我们将查询/输出循环放在Triton的网格调度器中,以简化和优化核函数。

实现近乎最优的算术强度 这种设计通过以下方式实现了近乎最优的算术强度:(1)通过组级别的共享消除冗余的KV传输,以及(2)在GPU流式多处理器之间平衡计算工作负载。

图3 | NSA的核函数设计。该核函数按GQA组加载查询(网格循环),获取相应的稀疏KV块(内循环),并在SRAM上执行注意力计算。绿色块表示SRAM上的数据,蓝色块表示HBM上的数据。
图3 | NSA的核函数设计。该核函数按GQA组加载查询(网格循环),获取相应的稀疏KV块(内循环),并在SRAM上执行注意力计算。绿色块表示SRAM上的数据,蓝色块表示HBM上的数据。

A4 实验环境

  • 模型架构: 实验采用了一个包含分组查询注意力(GQA)和专家混合(MoE)的主干网络,总参数量为27B,其中活跃参数为3B。模型包含30个层,隐藏维度为2560。
  • 模型关键参数:
    • GQA: 设置为4个组,共64个注意力头。每个头的查询、键和值的隐藏维度分别为 $d_q=d_k=192$ 和 $d_v=128$。
    • MoE: 采用DeepSeekMoE结构【索引9, Deepseekmoe: Towards ultimate expert specialization in mixture-of-experts language models,arXiv 2024; 索引10, Deepseek-v2: A strong, economical, and efficient mixture-of-experts language model,arXiv 2024】,包含72个路由专家和2个共享专家,top-k专家设为6。为保证训练稳定,第一层的MoE被替换为SwiGLU形式的MLP。
    • NSA: 压缩块大小 $L=32$,滑动步长 $S=16$,选择块大小 $L'=64$,选择块数量 $k=16$(包括固定激活1个初始块和2个局部块),滑动窗口大小 $W=512$。
  • 数据集与训练过程: 全注意力和稀疏注意力模型都在270B个8k长度文本的token上进行预训练,随后使用YaRN【索引27, Yarn: Efficient context window extension of large language models,ICLR 2024】在32k长度文本上进行持续训练和监督微调,以实现长上下文适应。
  • 硬件配置: 实验在一个8卡A100 GPU系统上进行。
  • 软件配置: 使用Triton【索引31, Triton: an intermediate language and compiler for tiled neural network computations,MLPL 2019】实现优化的注意力核函数。

A5 实验结果

图4 | 27B参数模型上全注意力与NSA的预训练损失对比。两个模型都表现出稳定的收敛,NSA取得了更低的损失值。
图4 | 27B参数模型上全注意力与NSA的预训练损失对比。两个模型都表现出稳定的收敛,NSA取得了更低的损失值。

预训练稳定性:如图4所示,NSA和全注意力基线模型的预训练损失曲线都呈现稳定平滑的下降,且NSA的损失始终低于全注意力模型,表明了其训练的稳定性和有效性。

通用基准评估
- 实验内容: 在知识(MMLU, MMLU-PRO, CMMLU)、推理(BBH, GSM8K, MATH, DROP)和编码(MBPP, HumanEval)等一系列基准上评估预训练的NSA和全注意力模型。
- 实验结果: 如表1所示,尽管NSA是稀疏的,但其总体性能优越,在9个指标中的7个上超过了全注意力基线。特别是在与推理相关的基准测试中(DROP: +0.042, GSM8K: +0.034),NSA表现出显著优势。
- 分析结论: NSA的预训练机制迫使模型关注最重要的信息,可能通过过滤不相关注意力路径的噪声来提高性能。这证实了NSA作为一个通用架构的鲁棒性。

表1 | 全注意力基线与NSA在通用基准上的预训练性能比较,涵盖知识(MMLU, MMLU-PRO, CMMLU)、推理(BBH, GSM8K, MATH, DROP)和编码(MBPP, HumanEval)任务。尽管稀疏度很高,NSA在大多数基准上取得了更优的平均性能。

表1
表1

长上下文评估
- 大海捞针测试 (Needle-in-a-Haystack): 如图5所示,NSA在64k上下文长度下,对所有位置的检索准确率均达到100%。这得益于其分层稀疏设计,即压缩令牌用于全局上下文扫描,选择令牌用于精确局部信息检索。
- LongBench评估: 如表2所示,NSA在LongBench上取得了0.469的最高平均分,超过了所有基线(比全注意力高0.032,比Exact-Top高0.046)。在需要复杂长上下文推理的任务上表现尤为突出,如多跳问答(HPQ, 2Wiki)、代码理解(LCC)和段落检索(PassR-en)。
- 分析结论: NSA的原生稀疏设计使其能够在预训练期间端到端优化稀疏模式,实现了局部和全局信息处理的平衡,从而在各种长上下文挑战中表现出色。

图5 | 在64k上下文长度下,大海捞针测试中不同上下文位置的检索准确率。NSA通过其分层稀疏注意力设计实现了完美的准确率。
图5 | 在64k上下文长度下,大海捞针测试中不同上下文位置的检索准确率。NSA通过其分层稀疏注意力设计实现了完美的准确率。

表2 | NSA与基线模型在LongBench上的性能比较,包括单文档问答、多文档问答、合成任务和代码任务等子集。NSA的表现优于包括全注意力在内的大多数基线。

表2
表2

思维链推理评估

  • 实验内容: 通过对来自DeepSeek-R1的10B token的32k长度数学推理轨迹进行监督微调(SFT),训练了NSA-R和Full Attention-R两个模型,并在AIME 24基准上进行评估。
  • 实验结果: 如表3所示,在8k上下文设置下,NSA-R的准确率显著高于Full Attention-R(+0.075),这一优势在16k上下文中依然存在(+0.054)。
  • 分析结论: 原生稀疏注意力能够有效捕捉对复杂数学推导至关重要的长距离逻辑依赖,并且其硬件对齐设计支持不断增长的推理深度,验证了其在高级推理任务中的可行性。

表3 | 经过监督微调后在AIME指令式评估中的表现。我们的NSA-R在8k和16k序列长度上均表现出比Full Attention-R更好的性能。

表3
表3

效率分析
- 训练速度: 如图6所示,与基于Triton的FlashAttention-2相比,NSA的Triton实现随着上下文长度的增加,速度优势越发明显,在64k上下文长度下,前向传播加速达9.0倍,后向传播加速达6.0倍。
- 解码速度: 如表4所示,NSA显著减少了KV缓存的加载量,在64k上下文长度下实现了高达11.6倍的解码速度提升。
- 分析结论: NSA的硬件对齐算法设计(如块状内存访问和优化的循环调度)以及对内存访问的有效减少,是其在训练和解码阶段实现显著加速的关键。

图6 | 基于Triton的NSA核函数与基于Triton的FlashAttention-2核函数的比较。我们的实现在所有上下文长度上都显著降低了延迟,且随着输入长度的增加,提升效果更为明显。
图6 | 基于Triton的NSA核函数与基于Triton的FlashAttention-2核函数的比较。我们的实现在所有上下文长度上都显著降低了延迟,且随着输入长度的增加,提升效果更为明显。

表4 | 解码期间每次注意力操作的内存访问量(以等效token数量计)。由于解码的算术强度低且受内存限制,预期的加速比与内存访问量近似成线性关系。

表4
表4

A5 结论

本文提出了NSA,一种为高效长上下文建模而设计的硬件对齐稀疏注意力架构。通过将分层令牌压缩与块状令牌选择集成到一个可训练的架构中,NSA在保持全注意力性能的同时,实现了训练和推理的加速。实验结果表明,NSA在通用基准测试中性能与全注意力基线相当,在长上下文评估中展现了更强的建模能力,并提升了推理能力,同时显著降低了计算延迟,实现了显著的加速效果。

A6 附录

讨论

在本节中,我们回顾NSA的开发过程,并讨论从探索不同稀疏注意力策略中获得的关键见解。虽然我们的方法展示了有希望的结果,但了解替代策略遇到的挑战和分析注意力模式为未来的研究方向提供了宝贵的背景。我们首先审视了激发我们设计选择的替代令牌选择策略的挑战,然后通过可视化提供了对注意力分布模式的见解。

6.1 替代令牌选择策略的挑战

NSA设计前的探索 在设计NSA之前,我们探索了将现有的稀疏注意力方法应用于训练阶段。然而,这些尝试遇到了各种挑战,促使我们设计一种不同的稀疏注意力架构:

基于键聚类的策略 我们研究了像ClusterKV【索引25,Clusterkv: Manipulating llm kv cache in semantic space for recallable compression,arXiv 2024】这样的基于聚类的策略。这些方法将来自同一聚类的键和值存储在连续的内存区域中。虽然理论上可用于训练和推理,但它们面临三个重大挑战:(1)动态聚类机制引入了不可忽略的计算开销;(2)算子优化困难,因簇间不平衡而加剧,尤其是在专家混合(MoE)系统中,其中倾斜的专家并行(EP)组执行时间导致持续的负载不平衡;(3)实现上的限制,源于需要强制性的周期性重聚类和块序列训练协议。这些综合因素造成了巨大的瓶颈,严重限制了它们在实际部署中的有效性。

图7 | 在3B参数模型上,比较全注意力与不同令牌选择策略的训练损失。我们的NSA取得了更好的性能。
图7 | 在3B参数模型上,比较全注意力与不同令牌选择策略的训练损失。我们的NSA取得了更好的性能。

其他块状选择策略 我们还考虑了不同于NSA的块状键、值选择策略,例如Quest【索引30,Quest: Query-aware sparsity for efficient long-context llm inference,arXiv 2024】和InfLLM【索引36,Infllm: Training-free long-context extrapolation for llms with an efficient context memory,NeurIPS 2024a】。这些方法依赖于为每个KV块计算一个重要性分数,并根据它们与 $q_i$ 的相似性选择前k个块。然而,现有方法面临两个关键问题:(1)由于选择操作是不可微的,基于神经网络的重要性分数计算依赖于辅助损失,这增加了算子开销并常常降低模型性能;(2)启发式的无参数重要性分数计算策略存在低召回率的问题,导致性能次优。我们在一个具有相似架构的3B参数模型上评估了这两种方法,并将其损失曲线与NSA和全注意力进行比较。对于基于辅助损失的选择方法,我们为每个令牌引入额外的查询,为每个块引入代表性的键来估计块的重要性分数。我们通过对每个键块内的注意力分数进行平均池化来计算块级监督信号,并使用KL散度来监督块重要性预测。我们保持单个查询的粒度而不是块平均查询,以适应高效的解码。这种基于辅助损失的重要性估计在概念上与SeerAttention【索引16,Seerattention: Learning intrinsic sparse attention in your llms,arXiv 2024】有相似之处。对于启发式的无参数选择方法,我们遵循Quest的策略,使用查询与键块的坐标级最小-最大值的乘积进行直接选择,而不引入额外参数。我们还探索了一种冷启动训练方法,即在初始1000步应用全注意力,然后转换到启发式块状选择。如图7所示,两种方法的损失都较差。

6.2 可视化

注意力分布模式的探索 为了探索Transformer注意力分布中的潜在模式并为我们的设计寻求灵感,我们在图8中可视化了我们预训练的27B全注意力模型的注意力图。可视化揭示了一些有趣的模式,其中注意力分数倾向于表现出块状聚类的特征,即邻近的键通常显示出相似的注意力分数。这一观察启发了我们NSA的设计,表明基于空间连续性选择键块可能是一个有前途的方法。块状聚类现象表明,序列中相邻的令牌可能与查询令牌共享某些语义关系,尽管这些关系的确切性质需要进一步研究。这一观察促使我们探索一种在连续令牌块而不是单个令牌上操作的稀疏注意力机制,旨在提高计算效率并保留高注意力模式。

图8 | 全注意力Transformer上的注意力图可视化。浅色区域表示较高的注意力值。如图所示,注意力分数表现出块状聚类的分布。
图8 | 全注意力Transformer上的注意力图可视化。浅色区域表示较高的注意力值。如图所示,注意力分数表现出块状聚类的分布。

A. AIME 结果示例

问题1提示:
“|用户|设 $a, b, c$ 为正实数,满足以下方程组:

问题1方程组
问题1方程组

那么 $| \log_2 (a^4 b^3 c^2) |$ 的值为 $m/n$,其中 $m$ 和 $n$ 是互质的正整数。求 $m + n$。请逐步推理,并将最终答案放在方框内。|助手|”

我们的结果:
⟨思考⟩...⟨/思考⟩ (2275 tokens)
将对数方程转换为指数形式,得到 $b^a = 2, c^b = 2^{1/3}, a^c = 2^{1/4}$。通过代换求解,得到 $c = 2^{-5/12}, b = 2^{-3/8}, a = 2^{-7/24}$。接着计算 $a^4 b^3 c^2 = (2^{-7/24})^4 (2^{-3/8})^3 (2^{-5/12})^2 = 2^{-7/6} \cdot 2^{-9/8} \cdot 2^{-5/6} = 2^{-25/8}$。取以2为底的对数,得到 $\log_2(a^4 b^3 c^2) = -25/8$。其绝对值为 $25/8$。因此,$m=25, n=8$, $m+n=33$。

基线模型结果:
⟨思考⟩...⟨/思考⟩ (9392 tokens)
模型首先将对数方程转为指数形式,但求解过程出现错误,并错误地判断方程组可能无解。随后,模型转换思路,令 $x = \log_2(a), y = \log_2(b), z = \log_2(c)$,将原方程组转换为线性方程组。通过矩阵求逆和克莱默法则求解,得到 $x = -7/24, y = -9/24, z = -10/24$(此处计算有误)。最终计算 $\log_2(a^4 b^3 c^2) = 4x+3y+2z = -115/16$。因此,$m=115, n=16$, $m+n=131$。(注:基线模型推理过程冗长且最终答案错误)。


问题2提示:
“|用户|实数 $x$ 和 $y$ 满足 $x, y > 1$,且 $\log_x(x^2 y^3) = \log_y(x^4 y) = 10$。求 $xy$ 的值。请逐步推理,并将最终答案放在方框内。|助手|”

我们的结果:
⟨思考⟩...⟨/思考⟩ (15147 tokens)
将两个对数方程展开:
1. $\log_x(x^2 y^3) = 2 + 3\log_x(y) = 10 \implies \log_x(y) = 8/3$
2. $\log_y(x^4 y) = 4\log_y(x) + 1 = 10 \implies \log_y(x) = 9/4$
令 $u = \log_x(y)$,则 $\log_y(x) = 1/u$。我们有 $u=8/3$ 和 $1/u=9/4$,这两个方程不一致,说明推理有误。
模型重新尝试,令 $x = u^k, y = u^m$。但最终采用换元法,令 $u = \log_x(y)$。从第一个方程得到 $u = 8/3$。从第二个方程得到 $1/u = 9/4$。模型在此处卡住,但最终通过另一个冗长的推导得出 $x = 10y$ 和 $y = 5/2x$(此步骤有误),最终计算 $xy=25$。

基线模型结果:
⟨思考⟩...⟨/思考⟩ (16223 tokens)
模型使用换底公式进行转换,但过程复杂且引入了不必要的 ln 函数。通过一系列复杂的代数变换,模型最终得到 $y = x^{4/9}$,并代入第一个方程,得到 $\log_x(x^2 \cdot (x^{4/9})^3) = 10$,解得 $x = 2^{9/2}$。随后计算 $y = (2^{9/2})^{4/9} = 2^2=4$。最终 $xy = 2^{9/2} \cdot 4 = 16\sqrt{2} \cdot 4 = 64\sqrt{2}$。(注:基线模型推理过程非常冗长,且最终答案错误)。