SeerAttention: Self-distilled Attention Gating for Efficient Long-context Prefilling

  • 作者/机构: Yizhao Gao (University of Hong Kong), Zhichen Zeng (University of Washington), Dayou Du (University of Edinburgh), Shijie Cao (Microsoft Research), Peiyuan Zhou (NVIDIA), Jiaxing Qi (NVIDIA), Junjie Lai (NVIDIA), Hayden Kwok-Hay So (University of Hong Kong), Ting Cao (Tsinghua University), Fan Yang (Microsoft Research), Mao Yang (Microsoft Research)

A1 主要贡献

本文旨在解决现代大语言模型(LLM)中注意力机制的二次复杂度问题,该问题尤其在处理长上下文时严重制约了模型的效率和可扩展性。现有利用注意力稀疏性的方法大多依赖于预定义的模式或启发式规则,难以高效地动态适应不同上下文。

核心问题与研究目标
当前主流LLM依赖的全注意力机制计算成本高昂。虽然注意力图中普遍存在稀疏性,尤其是在长上下文中,但这种稀疏性是动态变化的,随模型、输入上下文和注意力头的不同而异。因此,依赖固定模式或启发式规则的稀疏化方法【索引28, MInference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024】、【索引18, Moa: Mixture of sparse attention for automatic large language model compression, 2024】、【索引54, Duoattention: Efficient long-context llm inference with retrieval and streaming heads, 2024】通用性不足。而从头预训练稀疏注意力模型【索引58, Native sparse attention: Hardware-aligned and natively trainable sparse attention, 2025】、【索引38, Moba: Mixture of block attention for long-context llms, 2025】成本过高,且无法直接应用于已有的密集预训练模型。

创新点与解决方案 (SeerAttention)
本文提出了一种名为 SeerAttention 的简单而有效的后训练(post-training)蒸馏方法,通过自蒸馏(self-distillation)为任何全注意力模型引入可学习的注意力稀疏性,而无需依赖预定义模式。
1. 可学习的门控机制(AttnGate):受专家混合(MoE)模型门控机制的启发【索引45, Outrageously large neural networks: The sparsely-gated mixture-of-experts layer, 2017】,SeerAttention 在传统注意力模块上增加了一个名为 AttnGate 的可学习门控。该门控能选择性地激活注意力图中的一小部分重要块。
2. 轻量级自蒸馏训练:AttnGate 通过学习模仿原始LLM生成的二维块级稀疏性真值(ground truth)进行训练。此蒸馏过程仅需学习门控参数,而模型其他参数保持不变,从而实现了快速收敛。用户可以为不同模型定制专属的 AttnGate,无需依赖人工设计的启发式规则。
3. 高效推理:SeerAttention 结合了专门开发的块稀疏 FlashAttention 内核,在推理时能显著提升在 GPU 上的运行速度。

主要贡献总结
- 提出 SeerAttention,一个创新的可学习注意力门控机制,用以提升长上下文 LLM 的效率。
- 开发了一种自蒸馏训练方案,能高效地训练 AttnGate,使其学习到预训练模型的内在稀疏性。
- 实验证明,SeerAttention 在长上下文模型准确率和预填充(pre-filling)延迟方面均优于先前方法,并能适应不同的上下文长度和稀疏度。在128k序列长度和90%稀疏度下,实现了7.3倍的加速。此外,SeerAttention能够学习到包括A-shape和Vertical-Slash在内的多种稀疏模式,展示了其多功能性和高性能。

A3 背景知识与相关工作

Transformer中强大但复杂的注意力机制。注意力机制,尤其是在Transformer架构【索引51, Attention is all you need, 2017】中的应用,标志着自然语言处理领域的重大进步。通过让每个词元关注序列中的所有其他词元,注意力机制能够更好地处理长距离依赖关系和理解上下文,但这导致了$O(n^2)$的时间和内存复杂度,其中n是序列长度。随着社区向能够处理更长上下文的LLM发展,这一挑战日益显著。许多研究探索了替代性的注意力机制以降低这种复杂性。Reformer架构【索引31, Reformer: The efficient transformer, 2020】将复杂度降至$O(n \log n)$,而线性注意力机制【索引30, Transformers are rnns: Fast autoregressive transformers with linear attention, 2020】、【索引57, Gated linear attention transformers with hardware-efficient training, 2023】则进一步将其降至$O(n)$。近期,重新审视循环神经网络的趋势也催生了新的架构框架,如RWKV【索引40, Rwkv: Reinventing rnns for the transformer era, 2023】、RetNet【索引47, Retentive network: A successor to transformer for large language models, 2023】和Mamba【索引20, Mamba: Linear-time sequence modeling with selective state spaces, 2023】。尽管这些方法在效率上很有前景,但它们在性能上仍难以与全注意力机制相媲美,尤其是在更大的模型和更长的上下文中。

注意力中内在但动态的稀疏性。注意力机制本身就表现出稀疏性,这源于由Q和K生成的注意力图$A = \text{softmax}(QK^T / \sqrt{d})$。softmax函数通常会产生大量可忽略不计的分数,这些分数可以被视为零而不影响模型准确性【索引59, Big bird: Transformers for longer sequences, 2020】、【索引35, Transformer acceleration with dynamic sparse attention, 2021】、【索引52, Spatten: Efficient sparse attention architecture with cascade token and head pruning, 2021】、【索引7, Generating long sequences with sparse transformers, 2019】、【索引36, Deja vu: Contextual sparsity for efficient llms at inference time, 2023】。随着上下文变长,注意力稀疏性变得更加显著,为优化推理速度提供了机会。然而,这种稀疏性是动态的,会因不同的上下文输入和注意力头而异,每个头都显示出不同的稀疏位置和比例。以往的研究尝试使用预定义的模式和启发式方法来近似不同注意力头的稀疏性【索引18, Moa: Mixture of sparse attention for automatic large language model compression, 2024】、【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024】。然而,这些方法缺乏通用性,并且常常依赖于手工设计的特征,难以完全捕捉注意力机制的稀疏行为。注意力稀疏性的动态和输入依赖性与专家混合(MoE)模型【索引45, Outrageously large neural networks: The sparsely-gated mixture-of-experts layer, 2017】、【索引17, Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity, 2022】的原则相呼应,表明稀疏性理想情况下应由模型直接从数据中学习。这种方法将使模型能够自适应地利用稀疏性,从而在保持准确性的同时提高效率。

长上下文LLM优化。处理长上下文的能力对大型语言模型(LLM)至关重要,因为它使其能够保留和利用更广泛的信息。然而,这也带来了巨大的计算和内存成本。各种研究工作探索了不同的策略来优化长上下文处理。一个主要方向是提高预填充(prefill)效率,其中采用了提示压缩【索引27, Longllmlingua: Accelerating and enhancing llms in long context scenarios via prompt compression, 2023】、【索引39, Learning to compress prompts with gist tokens, 2024】、【索引8, Learning to compress prompt in natural language formats, 2024】和稀疏注意力【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024】、【索引18, Moa: Mixture of sparse attention for automatic large language model compression, 2024】、【索引1, Star attention: Efficient llm inference over long sequences, 2024】、【索引62, Selective attention: Enhancing transformer through principled context control, 2024】等技术。另一种方法侧重于通过引入稀疏加载机制【索引56, Tidaldecode: Fast and accurate llm decoding with position persistent sparse attention, 2024】、【索引6, Magicpig: Lsh sampling for efficient llm generation, 2024】来优化解码阶段。此外,还有一些方法旨在压缩KV缓存,包括KV缓存共享【索引3, Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023】、【索引5, Reducing transformer key-value cache size with cross-layer attention, 2024】、KV驱逐策略【索引63, H2o: Heavy-hitter oracle for efficient generative inference of large language models, 2023】、【索引34, Snapkv: Llm knows what you are looking for before generation, 2024】、【索引19, Model tells you what to discard: Adaptive kv cache compression for llms, 2023】和KV量化【索引37, Kivi: A tuning-free asymmetric 2bit quantization for kv cache, 2024】、【索引25, Kvquant: Towards 10 million context length llm inference with kv cache quantization, 2024】、【索引14, Qaq: Quality adaptive quantization for llm kv cache, 2024】、【索引61, Kv cache is 1 bit per channel: Efficient large language model inference with coupled quantization, 2024】。

A2 方法细节

SeerAttention 采用一种完全基于学习的方法来在 LLM 中自适应地识别注意力稀疏性,并利用学习到的稀疏性进行高效推理。为了确保在现代硬件(如GPU)上的效率,我们专注于学习块稀疏性,这种稀疏性可以与 FlashAttention【索引13, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022】、【索引12, Flashattention-2: Faster attention with better parallelism and work partitioning, 2023】的分块计算方案无缝集成。图1展示了 SeerAttention 的整体示意图,它通过一个称为“注意力门”(AttnGate)的可学习门控模块来增强传统注意力。AttnGate 模块包含可学习的参数(线性层),并通过蒸馏来模仿注意力图的二维最大池化(2D-Maxpooled)结果。在推理时,AttnGate 可以为后续的注意力计算预测块级稀疏性,并由一个块稀疏的 FlashAttention 内核执行,通过减少I/O和计算开销显著提升性能。

图1: SeerAttention 概览。SeerAttention 中的 AttnGate 首先对 Q 和 K 张量在序列维度上进行池化,并通过可学习的线性层进行处理。训练 AttnGate 的输出以模仿预训练模型的二维最大池化结果。在推理过程中,它对 AttnGate 的输出应用 TopK 或阈值化来定位激活块。
图1: SeerAttention 概览。SeerAttention 中的 AttnGate 首先对 Q 和 K 张量在序列维度上进行池化,并通过可学习的线性层进行处理。训练 AttnGate 的输出以模仿预训练模型的二维最大池化结果。在推理过程中,它对 AttnGate 的输出应用 TopK 或阈值化来定位激活块。

3.1 注意力门设计

AttnGate 模块的设计与公式。AttnGate 模块的设计目标是以最小的开销学习块级信息。它将原始的 Q 和 K 矩阵作为输入,并使用池化操作沿序列维度对它们进行下采样。如图1所示,对于给定的注意力头,池化后的 Q 和 K 的大小变为 [seq/B, d],其中 B 是池化操作的核大小和步长(非重叠块)。下采样后的 Q 和 K 随后通过一个线性层处理并相乘,类似于标准的注意力操作。这产生一个大小为 [seq/B, seq/B] 的矩阵,其中每个元素对应于原始完整注意力图中的一个块。当块大小为64时,AttnGate 模块的输出大小仅为原始注意力图的1/4096,使其计算效率极高。以最简单的形式,AttnGate 输出的软分数可以表示为:

$$\mathbf{Q}_{c}=\operatorname{RoPE}\left(W_{q} \operatorname{concat}_{i=1}^{m_{q}} P_{i}^{(q)}\left(\mathbf{Q}_{n o p e}\right)\right),$$ $$\mathbf{K}_{c}=\operatorname{RoPE}\left(W_{k} \operatorname{concat}_{j=1}^{m_{k}} P_{j}^{(k)}\left(\mathbf{K}_{n o p e}\right)\right),$$ $$ \mathbf{O} = \text{softmax}(\mathbf{Q}_c \mathbf{K}_c^\top / \sqrt{d}). $$

其中,$P(q_i)$ 和 $P(k)_j$ 代表对 Q 和 K 的不同池化操作,d 是张量的隐藏维度,类似于注意力计算。详细算法将在下文解释。

池化方法的选择与实验。池化操作会对张量进行下采样,可能导致信息损失。为了更好地保留注意力张量的特性,SeerAttention 允许为 Q 和 K 组合不同的池化方法。具体来说,我们考虑了平均池化、最大池化和最小池化。当对 Q 或 K 应用多种池化方法时,得到的池化张量会在隐藏维度上进行拼接,然后输入到后续的线性层。图2展示了使用 LLaMA-3.1-8B 模型在 PG19【索引42, Compressive transformers for long-range sequence modelling, 2019】数据集上排名前15的池化组合的测试困惑度。我们观察到,在 Q 上应用平均池化(AvgPooling),并在 K 上应用最大池化、最小池化和平均池化的组合,在不同的稀疏度下均能产生最佳的困惑度。这一趋势可能与 LLM 量化中的先前发现有关,即 K 张量表现出更多的异常值。因此,结合最大和最小池化有助于捕捉这些极端激活值,从而在池化后产生更丰富的特征表示。

图2: PG19 上不同池化组合的测试困惑度。最佳配置在 Q 上应用 AvgPooling,在 K 上应用 Max、Min 和 AvgPooling 的组合。
图2: PG19 上不同池化组合的测试困惑度。最佳配置在 Q 上应用 AvgPooling,在 K 上应用 Max、Min 和 AvgPooling 的组合。

使用位置编码实现长度外推。近期的顶尖 LLM 通常采用如 RoPE【索引46, Roformer: Enhanced transformer with rotary position embedding, 2024】等位置编码(PE)来编码位置信息。如果 AttnGate 仅依赖模型中原始的 RoPE,即向 AttnGate 输入 $Q_{rope}$ 和 $K_{rope}$,位置信息可能会因池化操作而受损。这会损害 AttnGate 在蒸馏过程中的长度外推能力。为解决此问题,我们在 AttnGate 中对没有PE的输入 $Q_{nope}$ 和 $K_{nope}$ 重新应用块级 PE(如公式9所示)。为了表示块级信息,AttnGate 中的 RoPE 使用一个缩减的 $\theta' = \theta / B$,其中 $\theta$ 是 LLM 原始 RoPE 的 theta。

图3: PG19 数据集上 AttnGate 中两种 PE 设置的困惑度比较。AttnGate 中的块级 RoPE 使其能够有效学习块级位置信息,从而在不同上下文长度下获得更好的测试性能。没有 AttnGate PE,它在长于训练长度的数据上无法提供合理的结果。
图3: PG19 数据集上 AttnGate 中两种 PE 设置的困惑度比较。AttnGate 中的块级 RoPE 使其能够有效学习块级位置信息,从而在不同上下文长度下获得更好的测试性能。没有 AttnGate PE,它在长于训练长度的数据上无法提供合理的结果。

图3展示了在 AttnGate 中使用和不使用块级 RoPE 设计的测试困惑度结果。结果表明,在没有这种块级 RoPE 设计的情况下,使用8k长度数据训练的 AttnGate 在超过8k的评估数据上表现不佳。同样,当用32k长度数据训练时,它在128k长度数据上表现不佳。然而,通过增加块级 RoPE,AttnGate 可以外推到不同的上下文长度,显著提升了模型性能和训练效率。

3.2 AttnGate 训练

训练的挑战与解决方案。尽管我们提出的 SeerAttention 架构很直接,但训练过程面临挑战。像在 MoE 中那样从头开始联合训练门控和模型,成本高昂且困难重重。幸运的是,与 MoE 不同(其中门控网络必须从头学习专家选择),SeerAttention 中的 AttnGate 可以利用标准注意力的结果作为真值进行蒸馏。

真值的获取方法。我们使用来自全注意力的二维最大池化(2D-MaxPooled)注意力图作为真值来蒸馏 AttnGate,如图1所示。从语义上讲,这意味着只有当一个块内所有的注意力分数都很小时,二维最大池化的结果才会很小,这与块稀疏的定义一致。然而,获取用于训练的最大池化注意力图并非易事,尤其是在长上下文场景中,因为中间结果 $QK^T$ 会导致二次方的内存消耗。为了应对这一挑战,我们定制了一个高效的内核,通过修改 FlashAttention 内核直接输出最大池化的注意力图真值,同时大量复用其原始计算流程。详细设计在附录A.1中解释。

损失函数的选择与公式。我们使用 Kullback-Leibler 散度损失【索引29, Kullback-leibler divergence, 2011】来蒸馏 AttnGate。由于 AttnGate 的输出也使用 softmax,类似于全注意力计算,其门控分数的行和将始终为1。KL散度损失使得训练过程能够专注于模仿注意力分布,而不是像均方误差损失那样关注绝对数值。整个蒸馏过程可以表示为:

$$ \mathbf{gt} = \text{MaxPool2D} \left( \text{softmax} \left( \frac{\mathbf{Q}_{rope} \mathbf{K}_{rope}^{T}}{\sqrt{d}} \right) \right), $$ $$\mathbf{o}=\operatorname{AttnGate}\left(\mathbf{Q}_{n o p e}, \mathbf{K}_{n o p e}\right),$$ $$\text{loss} = D_{KL}(\text{gt} \parallel \mathbf{o}).$$

3.3 使用 SeerAttention 进行推理

推理过程概述。经过自蒸馏训练过程后,SeerAttention 可以利用训练好的 AttnGate 为全注意力机制内的每个块生成一个门控分数。这些分数随后被用来选择最终被激活的稀疏块。再结合我们的后端块稀疏 FlashAttention 内核,SeerAttention 可以在保持高准确率的同时,为长上下文预填充实现显著的加速。

生成二元掩码的方法。SeerAttention 提供了将浮点门控分数 o 转换为最终二元块掩码的灵活性,可以使用 TopK 或阈值法(Thresholding)。如果使用 TopK 方法,则会选择每行中得分最高的 k 个块。

$$\begin{aligned} b_{i j}= \begin{cases}1 & \text { if } j \in \operatorname{TopK}\left(\mathbf{o}_i, k\right) \text {.index, } \\ 0 & \text { otherwise. }\end{cases} \end{aligned}$$

或者,用户可以激活分数超过一个阈值的块。这可以进一步省去对 AttnGate 输出分数进行排序的需要。

$$b = \mathbf{o} > threshold$$

值得注意的是,一旦 AttnGate 训练完成,在推理阶段,我们可以根据测试时的需求调整 TopK 的比例或基于阈值的方法,以实现不同的权衡。

后端稀疏核的设计。在设计块稀疏 Flash-Attention 内核时,AttnGate 的块大小与 Flash-Attention 中使用的分块(tiling)大小保持一致,通常为64或128。通过这样做,我们可以创建一个定制的块稀疏 Flash-Attention 内核,该内核利用 AttnGate 生成的二元块掩码来选择性地跳过未激活块的 I/O 和计算。这种方法在现代 GPU 上非常高效,因为它在块级别优化了稀疏数据的处理,而不是处理细粒度的逐元素级别,从而带来了显著的性能提升。

A4 实验

实验环境

  • 通用设置:AttnGate 和稀疏内核的块大小 B 固定为64。AttnGate 仅在预填充(prefill)阶段应用。
  • 模型

    • Llama-3.1-8B-Instruct, Llama-3.1-70B-Instruct
    • Qwen2.5-7B-Instruct, Qwen2.5-14B-Instruct, Qwen2.5-32B-Instruct
  • 基线方法

    • MoA【索引18, Moa: Mixture of sparse attention for automatic large language model compression, 2024】:采用离线搜索方案为不同注意力头应用静态稀疏模式。实验中采用其官方实现的0.5 "KV Sparsity"(对应0.35 "Attention Sparsity")。
    • MInference【索引28, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024】:基于预定义稀疏模式,使用启发式方法为每个头动态生成稀疏索引。实验中采用其对Llama-3.1-8B-Instruct的官方配置,所有注意力头选择"Vertical-Slash"模式。
    • DuoAttention【索引54, Duoattention: Efficient long-context llm inference with retrieval and streaming heads, 2024】:将部分注意力头设为流式头,其余为密集头。实验中采用其对Llama-3.1-8B-Instruct的官方配置,50%的头为流式头。
    • :仅 Llama-3.1-8B-Instruct 模型有所有基线方法的官方支持/配置,因此对比实验主要基于此模型。
  • 数据集与任务

    • 长上下文:LongBench【索引4, Longbench: A bilingual, multitask benchmark for long context understanding, 2023】和 RULER【索引26, Ruler: What’s the real context size of your long-context language models?, 2024】。稀疏性仅应用于上下文部分。
    • 短上下文:Open LLM Leaderboard【索引50, Zephyr: Direct distillation of lm alignment, 2023】中的4个任务:HellaSwag, MMLU, ARC-challenge, GSM8K。
  • 硬件配置:所有评估均在 A100 GPU 上运行。

  • 蒸馏训练配置
    • 数据集:RedPajama【索引11, Redpajama: an open dataset for training large language models, 2023】,数据被切分为64k长度的块。
    • 软件:使用 DeepSpeed【索引43, Deepspeed: System optimizations enable training deep learning models with over 100 billion parameters, 2020】 stage 2 优化。
    • 超参数:学习率1e-3,使用余弦衰减;全局批量大小为16;训练步数为500。
    • 训练成本:对于7B或8B模型,在A100 GPU上约需40小时。为避免生成真值时内存爆炸,使用了定制的FlashAttention内核直接输出2D最大池化结果。

实验结果

准确率评估

  • LongBench 评估 (见表1)

    • 实验内容:在 LongBench 上比较 SeerAttention 与基线方法在不同上下文长度下的准确率。SeerAttention 使用统一的 2e-3 阈值,允许稀疏度自适应变化。
    • 实验结果
      • 对于 Llama-3.1-8B-Instruct,SeerAttention 在所有测试长度上均优于其他方法,并且在0-8k长度上甚至超过了密集基线。其平均得分和平均稀疏度均为最高。
      • 对于其他模型,SeerAttention 在平均稀疏度超过50%的情况下,准确率与密集基线相当。
    表1: Llama 和 Qwen 模型在 LongBench 上的结果。
    表1: Llama 和 Qwen 模型在 LongBench 上的结果。
  • RULER 评估 (见表2)

    • 实验内容:在 RULER 基准上评估长上下文性能,测试数据长度从4k到128k。SeerAttention 使用 5e-4 的阈值,稀疏度从4k的约10%自适应到128k的约85%。
    • 实验结果:SeerAttention 在大多数测试中(8k-64k)取得了最佳准确率。虽然在128k测试中准确率略低于DuoAttention,但其稀疏度远高于后者(>80% vs <50%)。最终,SeerAttention 实现了最高的平均准确率(仅比密集基线低0.41%),同时获得了最高的平均端到端预填充加速比(1.41倍)。
    表2: Llama-3.1-8B-Instruct 模型在 RULER 基准测试上的结果。
    表2: Llama-3.1-8B-Instruct 模型在 RULER 基准测试上的结果。
  • 短上下文测试 (见表3)

    • 实验内容:在短上下文任务上,稀疏注意力的延迟优势不明显,本实验主要评估其在高稀疏度下的准确率保持能力。SeerAttention 使用了非常高的阈值 3e-2。
    • 实验结果:SeerAttention 表现出可忽略的准确率损失。例如,在GSM-8K任务中,平均序列长度为872,平均稀疏度为52%时,准确率仅下降0.1%。
    表3: Llama-3.1-8B-Instruct 模型上的短上下文测试。
    表3: Llama-3.1-8B-Instruct 模型上的短上下文测试。

效率评估

  • 内核评估 (见图4, 图5)

    • AttnGate 开销:图4显示,AttnGate 在推理时引入的开销极小。在32K上下文和0.5稀疏度下,AttnGate 仅占注意力层总延迟的1%;在128K上下文中,相对开销几乎可以忽略不计。
    • 块稀疏内核加速:图4显示,我们的内核在不同稀疏度下表现出线性加速。在128K序列长度和90%稀疏度下,SeerAttention 相较于 FlashAttention-2(全注意力)实现了7.3倍的加速。
    • 内核级比较:图5显示,SeerAttention 使用的块稀疏注意力内核在大多数情况下都优于 MoA 和 MInference,能更有效地将稀疏度转化为实际的加速。
    图4: SeerAttention 在内核级别相对于 FlashAttention-2 的加速比。AttnGates 带来的延迟开销极小。我们的块稀疏注意力内核在密集配置上实现了高度线性的加速,在 128k 序列上以 90% 的稀疏度实现了 7.3 倍的加速。AttnGate 的开销在 128k 上下文长度中几乎消失。
    图4: SeerAttention 在内核级别相对于 FlashAttention-2 的加速比。AttnGates 带来的延迟开销极小。我们的块稀疏注意力内核在密集配置上实现了高度线性的加速,在 128k 序列上以 90% 的稀疏度实现了 7.3 倍的加速。AttnGate 的开销在 128k 上下文长度中几乎消失。
    图5: 不同工作之间的内核级加速比较。SeerAttention 更有效地将稀疏性转化为加速。
    图5: 不同工作之间的内核级加速比较。SeerAttention 更有效地将稀疏性转化为加速。
  • 端到端加速比较 (见图6)

    • 实验内容:在 RULER 测试设置下,测量 Llama-3.1-8B-Instruct 的平均预填充时间(TTFT)。
    • 实验结果:所有方法在更长的上下文中都获得了更好的加速。SeerAttention 在长上下文场景(≥16k)中优于其他方法,在128k长度上实现了高达2.43倍的端到端预填充加速。相比之下,MInference 由于运行时搜索稀疏索引的开销较大,在小于64k的数据上出现减速。总体而言,在 RULER 基准测试中,SeerAttention 在保持最高平均准确率的同时,实现了最高的平均加速比(1.41倍)。
    图6: 在 RULER 测试设置上比较预填充时间加速。SeerAttention 在大多数长上下文数据场景(≥16k)中优于相关工作。对于更长的上下文数据,注意力机制占总运行时间的比例更大,使得稀疏方法能够实现更好的加速。总体而言,SeerAttention 在此 RULER 基准设置下实现了最高的平均加速(1.41倍),同时保持了最佳的平均准确率。
    图6: 在 RULER 测试设置上比较预填充时间加速。SeerAttention 在大多数长上下文数据场景(≥16k)中优于相关工作。对于更长的上下文数据,注意力机制占总运行时间的比例更大,使得稀疏方法能够实现更好的加速。总体而言,SeerAttention 在此 RULER 基准设置下实现了最高的平均加速(1.41倍),同时保持了最佳的平均准确率。

训练成本与参数开销

  • 训练成本:SeerAttention 的额外训练成本适中。在 LLaMA-3-8B-Instruct 上约需40个A100 GPU小时,与DuoAttention相当或更低。
  • 参数开销:引入约101M可训练参数,约占模型总参数的1.3%。
  • 内存与延迟开销:与FlashAttention相比,内存和延迟开销增加不到5%。
  • 可扩展性:门控模块对更大模型也能很好地扩展,例如对LLaMA-3.1-70B为503MB,对DeepSeek-R1-Distill-Qwen-32B为252MB,表明该方法是轻量级且可扩展的。

学习到的注意力图可视化 (见图7)

  • 实验内容:可视化 AttnGate 的输出,以展示其学习到的稀疏模式。
  • 实验结果:AttnGate 模块无需任何先验知识或启发式规则,就能自动学习到多样的稀疏模式,包括 (a) "A-shape"(流式头)、(b) "Vertical"(垂直)、(c) 带有垂直空白的 "Slash"(斜线)、(d) 沿对角线的块稀疏以及 (e) 随机模式。这些模式不仅包含了,甚至超越了先前工作(如MoA和MInference)中观察到的模式,展示了我们基于学习的方法的多功能性。
图7: AttnGate 输出的可视化。
图7: AttnGate 输出的可视化。

A5 结论

本文提出了 SeerAttention,一种新的注意力机制,它通过学习和利用注意力中固有的稀疏性来提升长上下文 LLM 的性能。SeerAttention 通过一种轻量级的自蒸馏方法,从 LLM 自身学习注意力稀疏性。我们的实验表明,在长上下文模型准确率和预填充延迟方面,SeerAttention 均优于先前的方法。

对于未来的工作,有几个有希望的方向可以探索,以改进和扩展 SeerAttention 的能力。一个关键领域是增强 SeerAttention 的训练方法,例如在长上下文持续预训练中应用 SeerAttention,并使用更多的训练数据,以在不牺牲准确率的情况下实现更高的稀疏度。另一个重要途径是在解码阶段应用 SeerAttention,特别是对于长思维链(long-CoT)任务。

A6 附录

A.1 使用定制化 GPU 内核训练 SeerAttention

定制化训练核的设计与性能。为了在训练 AttnGate 时高效地获取作为真值的二维最大池化(2D-MaxPooled)注意力图,我们开发了一个基于 Triton【索引49, Triton: an intermediate language and compiler for tiled neural network computations, 2019】的定制内核。该内核通过修改 FlashAttention 内核实现,在很大程度上保留了其原始计算流程,从而克服了直接计算注意力图所带来的二次内存复杂度问题。图8展示了此定制内核的伪代码和示意图。

图8: 带有注意力图池化功能的高效 FlashAttention 内核。
图8: 带有注意力图池化功能的高效 FlashAttention 内核。

具体实现。FlashAttention 通常通过减去最大值来保证 softmax 的数值稳定性,它会计算每个块的局部行最大值,并通过迭代逐步更新全局最大值。

$$\begin{aligned} \begin{aligned} S_{ij} &= Q_i K_j^T; \\ r_{ij} &= \text{rowmax}(S_{ij}); \\ m_{ij} &= \max(m_{i(j-1)}, r_{ij}). \end{aligned} \end{aligned}$$

其中 $r_{ij}$ 通常被视为临时结果。我们的方法是将其存储在 HBM 中,并在迭代结束后,使用最终的全局最大值 $m_i$ 和指数和 $l_i$ 对其进行重新缩放:

$$a_{i j}=\exp (r_{i j}-m_{i})/l_{i}$$

这里的 $a_{ij}$ 代表了原始注意力块正确的行最大值。在此基础上,通过对 $a_{ij}$ 进行列最大值操作,即可实现二维最大池化。这个过程只引入了微小的开销(存储和重新缩放 $r_{ij}$),但显著提高了获取真值的效率。

训练内核性能。如图9所示,我们评估了定制内核的性能。与 PyTorch 的朴素实现相比,我们的内核在内存和延迟上都有巨大优势。PyTorch 内核在序列长度超过4k时便会内存溢出(OOM),而我们的定制内核的峰值内存使用量与 FlashAttention-2 相当。在延迟方面,额外的池化操作引入的开销与 FlashAttention-2 相比微乎其微,远胜于 PyTorch 实现。

图9: 带有最大池化功能的定制 FlashAttention 训练内核的内存和延迟。
图9: 带有最大池化功能的定制 FlashAttention 训练内核的内存和延迟。

A.2 SeerAttention 微调的初步实验

SeerAttention在微调中的初步实验。我们证明了 SeerAttention 可以无缝集成到长上下文扩展微调阶段。我们遵循 YaRN【索引41, YaRN: Efficient context window extension of large language models, 2024】的方法,将 Llama-3-8B 模型的上下文大小从8k扩展到32k。损失函数是原始语言模型损失和 AttnGate 损失之和。为确保训练稳定,AttnGate 在微调整个模型之前,会先通过后训练自蒸馏进行初始化。我们将 SeerAttention 集成到 YaRN 中,并与 YaRN 密集基线以及在 YaRN 之后应用后训练自蒸馏的 SeerAttention 进行比较。图10显示了 YaRN 密集基线以及 SeerAttention 在50%和90%稀疏度下的损失曲线。50%稀疏度的曲线几乎与基线重合,而90%稀疏度的曲线显示出略高的损失。表4展示了在32k上下文长度下评估的 PG19 和 ProofPile 数据集上的测试困惑度。在 YaRN 扩展微调期间应用 SeerAttention,在50%稀疏度下保持了接近无损的性能,即使在90%稀疏度下,损失仍然很小。

图10: 微调损失。
图10: 微调损失。
图11: 通过将 SeerAttention 与 YaRN [41] 结合,将 Llama-3-8B 模型的上下文长度从 8k 扩展到 32k,50% 至 90% 稀疏度的损失曲线与密集的 YaRN 基线几乎相同。
图11: 通过将 SeerAttention 与 YaRN [41] 结合,将 Llama-3-8B 模型的上下文长度从 8k 扩展到 32k,50% 至 90% 稀疏度的损失曲线与密集的 YaRN 基线几乎相同。
表4: YaRN 基线、YaRN 后的 SeerAttention 以及与 SeerAttention 一起进行 YaRN 微调的困惑度。
表4: YaRN 基线、YaRN 后的 SeerAttention 以及与 SeerAttention 一起进行 YaRN 微调的困惑度。

A.3 在解码阶段引入稀疏注意力的初步结果

SeerAttention在解码阶段的初步结果。当前的 AttnGate 设计主要用于加速长上下文预填充,但将注意力门蒸馏的思想应用于解码阶段也是一个可行的方向,这对于提高推理模型生成长序列(即测试时扩展)的效率至关重要。为使当前设计兼容解码情况,需要移除 Query 的序列级池化,以适应逐词元生成的方式。以下是对 AttnGate 设计的修改:

$$\mathbf{Q}_{c}=\operatorname{RoPE}\left(\mathbf{W}_{\text {gate }}^{\mathbf{q}} \operatorname{reshape}\left(\mathbf{Q}_{n o p e},[..., g \cdot d]\right)\right),$$ $$\mathbf{K}_{c}=\operatorname{RoPE}\left(\mathbf{W}_{\text {gate }}^{\mathbf{k}} \text { concat }\left[\mathrm{P}_{\max }\left(\mathbf{K}_{n o p e}\right), \mathrm{P}_{\min }\left(\mathbf{K}_{n o p e}\right), \mathrm{P}_{\mathrm{avg}}\left(\mathbf{K}_{n o p e}\right)\right]\right)$$ $$\mathbf{S} = \text{softmax}(\mathbf{Q}_c \mathbf{K}_c^\top / \sqrt{d_{gate}}).$$

其中,$P_{max}$、$P_{min}$ 和 $P_{avg}$ 代表序列维度上的最大、最小和平均池化。我们聚合每个组内的 Query 头以共享稀疏性决策,从而提高解码效率。具体来说,Q 分支中的一个线性层将每个查询子组(例如,g=4时,32个头减少到8个头)减少为单个 $Q_c$ 头,同时保持 K 头不变,从而实现分组查询间的稀疏性共享。为了沿序列维度压缩 K,我们应用非重叠的块级池化,将最大、最小和平均池化的输出连接起来,然后进行线性投影。此外,AttnGate 在预-RoPE 的 Q 和 K 上重新应用 RoPE,为每个块分配其第一个词元的位置,我们发现这能提高准确性。

在推理任务上的评估。我们使用 DeepSeek-R1-Distill-Qwen-14B【索引21, Deepseek-r1: Incentivizing reasoning capability in llms via reinforcement learning, 2025】在三个推理基准测试上评估此设计:AIME24、MATH-500【索引23, Measuring massive multitask language understanding, 2020】和 GPQA-Diamond【索引44, Gpqa: A graduate-level google-proof q&a benchmark, 2024】。我们与全注意力和 Quest【索引48, Quest: Query-aware sparsity for efficient long-context llm inference, 2024】(一种使用查询感知 KV 缓存选择的免训练稀疏解码方法)进行比较。表5显示了我们的设计在多个推理基准和词元预算下相比 Quest 的优越性能。在 AIME24 上,我们始终优于 Quest。同样,在 MATH500 和 GPQA-Diamond 上,SeerAttention-decoding 展示了更好的预算-准确率权衡,这反映了其在稀疏解码下更好的信息保留能力,得益于其学习到的、共享稀疏性决策的门控,从而实现了更连贯的块选择。

表5: 使用 DeepSeek-R1-Distill-Qwen-14B 模型,SeerAttention-decoding 和 Quest 在不同词元预算下的性能比较。
表5: 使用 DeepSeek-R1-Distill-Qwen-14B 模型,SeerAttention-decoding 和 Quest 在不同词元预算下的性能比较。