MOBA: MIXTURE OF BLOCK ATTENTION FOR LONG-CONTEXT LLMS
文章标题:MOBA:用于长上下文大语言模型的混合块注意力机制
作者/机构:Enzhe Lu, Zhejun Jiang, Jingyuan Liu, Yulun Du, Tao Jiang, Chao Hong, Shaowei Liu, Weiran He, Enming Yuan, Yuzhi Wang, Zhiqi Huang, Huan Yuan, Suting Xu, Xinran Xu, Guokun Lai, Yanru Chen, Huabin Zheng, Junjie Yan, Jianlin Su, Yuxin Wu, Neo Y. Zhang, Zhilin Yang, Xinyu Zhou (Moonshot AI); Mingxing Zhang (Tsinghua University); Jiezhong Qiu (Zhejiang Lab/Zhejiang University)
A1 主要贡献
本文旨在解决大语言模型(LLM)在处理长序列时面临的核心挑战,即传统自注意力机制带来的二次方计算复杂度增长问题。现有方法通常引入强结构偏见(如窗口注意力)或从根本上改变注意力机制(如线性近似),前者可能损害模型的泛化能力,后者在复杂推理任务上的有效性尚未得到充分验证。
核心问题与研究目标:
如何设计一个鲁棒且适应性强的注意力架构,该架构能保留 Transformer 的原始框架,同时遵循“更少结构”原则,让模型自主决定注意力分配,而不是依赖预定义的偏见?理想情况下,该架构应能在全注意力(full attention)和稀疏注意力(sparse attention)模式之间无缝切换,以兼容现有预训练模型,并在不牺牲性能的前提下实现高效推理和加速训练。
主要创新点:
本文提出了混合块注意力(Mixture of Block Attention, MoBA),这是一种新颖的架构,它将混合专家(Mixture of Experts, MoE)的原理创新性地应用于 Transformer 模型的注意力机制。
1. MoE 原理应用于注意力:与主要将 MoE 用于前馈网络(FFN)层的传统做法不同,MoBA 首次将其应用于长上下文注意力,允许每个查询(query)动态地从历史键(key)和值(value)块中选择最相关的部分进行计算。
2. 动态块稀疏注意力:MoBA 将上下文划分为多个块,并使用一个门控网络(gating mechanism)来选择性地将查询路由到最相关的块。这种块稀疏注意力显著降低了计算成本。
3. 无缝模式切换与兼容性:该架构被设计为全注意力的直接替代品,参数量相同,能够在全注意力和稀疏注意力之间平滑过渡。这使其能够高效地进行长上下文模型的持续预训练,并与现有模型框架高度兼容。
MoBA 架构已成功部署于 Kimi 的长上下文服务中,展现了其在高效注意力计算方面的显著优势。
A2 方法细节
本文提出了一种名为混合块注意力(Mixture of Block Attention, MoBA)的新颖架构,通过动态选择历史片段(块)进行注意力计算,扩展了 Transformer 模型的能力。MoBA 的灵感来源于混合专家(MoE)和稀疏注意力技术。前者主要应用于 Transformer 架构中的前馈网络(FFN)层,而后者已被广泛用于扩展 Transformer 以处理长上下文。本文的方法创新性地将 MoE 原理应用于注意力机制本身,从而实现对长序列更高效、更有效的处理。
2.1 预备知识:Transformer 中的标准注意力
标准注意力机制回顾:首先回顾 Transformer 中的标准注意力。为简化起见,我们考虑单个查询词元 $q \in R^{1 \times d}$ 与 $N$ 个键和值词元(分别表示为 $K, V \in R^{N \times d}$)进行交互的情况。标准注意力的计算公式如下:
其中 $d$ 表示单个注意力头的维度。为清晰起见,我们专注于单头场景。多头注意力的扩展涉及将多个此类单头注意力操作的输出进行拼接。
图 1: 混合块注意力(MoBA)图示。(a) MoBA 的一个运行示例;(b) MoBA 与 Flash Attention 的集成。
2.2 MoBA 架构
基于块选择的稀疏注意力:与标准注意力中每个查询词元关注整个上下文不同,MoBA 使每个查询词元只关注键和值的一个子集:
其中 $I \subseteq [N]$ 是所选键和值的集合。
块划分与 Top-k 门控机制:MoBA 的关键创新在于其块划分和选择策略。我们将长度为 $N$ 的完整上下文划分为 $n$ 个块,每个块代表一部分连续的词元。不失一般性,我们假设上下文长度 $N$ 能被块数 $n$ 整除。我们进一步定义块大小为 $B = N/n$,并用以下公式表示第 $i$ 个块的范围:
通过应用来自 MoE 的 top-k 门控机制,我们使每个查询能够选择性地关注来自不同块的词元子集,而不是整个上下文:
门控网络与亲和度分数计算:模型采用一个门控机制(如公式4中的 $g_i$)为每个查询词元选择最相关的块。MoBA 门控首先计算亲和度分数 $s_i$,用于衡量查询 $q$ 与第 $i$ 个块之间的相关性,然后在所有块之间应用 top-k 门控。形式上,第 $i$ 个块的门控值 $g_i$ 计算如下:
其中 $Topk(\cdot, k)$ 表示在为每个块计算的亲和度分数中,值最高的 $k$ 个分数的集合。在本文中,分数 $s_i$ 是通过查询 $q$ 与 $K[I_i]$ 沿序列维度进行平均池化(mean pooling)后的向量之间的内积来计算的:
运行示例:我们在图1a中提供了一个 MoBA 的运行示例,其中有两个查询词元和四个 KV 块。路由器(门控网络)为每个查询动态选择前两个块进行关注。如图1a所示,第一个查询被分配到第一个和第二个块,而第二个查询被分配到第三个和第四个块。
因果性保持机制:在自回归语言模型中保持因果性至关重要,因为它们通过基于先前词元进行下一词元预测来生成文本。这个顺序生成过程确保一个词元不能影响它之前的词元,从而保持因果关系。MoBA 通过两种具体设计来保持因果性:
1. 对未来块不进行注意力计算:MoBA 确保查询词元不能被路由到任何未来的块。通过将注意力范围限制在当前和过去的块,MoBA 遵循了语言建模的自回归特性。形式上,记 $pos(q)$ 为查询 $q$ 的位置索引,对于任何满足 $pos(q) < i \times B$ 的块 $i$,我们设置 $s_i = -\infty$ 和 $g_i = 0$。
2. 当前块注意力与因果掩码:我们将“当前块”定义为包含查询词元本身的块。对当前块的路由也可能违反因果性,因为对整个块进行平均池化可能会无意中包含来自未来词元的信息。为了解决这个问题,我们强制每个词元必须被路由到其各自的当前块,并在当前块的注意力计算期间应用因果掩码。这个策略不仅避免了来自后续词元的任何信息泄露,还鼓励了对局部上下文的关注。形式上,对于查询词元位置 $pos(q)$ 位于区间 $I_i$ 内的块 $i$,我们设置 $g_i = 1$。从混合专家(MoE)的角度看,MoBA 中的当前块注意力类似于现代 MoE 架构(如【Dai et al. 2024, Deepseekmoe: Towards ultimate expert specialization in mixture-of-experts language models, 2024, arXiv】; 【A. Yang et al. 2024, Qwen2. 5 Technical Report, 2024, arXiv】)中共享专家的角色,在专家选择时添加了静态路由规则。
其他关键设计:接下来,我们讨论 MoBA 的一些额外关键设计选择,例如其块分割策略以及 MoBA 与全注意力的混合使用。
细粒度块分割:细粒度专家分割在提升模型性能方面的积极影响已在混合专家(MoE)文献(【Dai et al. 2024, Deepseekmoe: Towards ultimate expert specialization in mixture-of-experts language models, 2024, arXiv】; 【A. Yang et al. 2024, Qwen2. 5 Technical Report, 2024, arXiv】)中得到充分证明。在本文中,我们探索了将类似的细粒度分割技术应用于 MoBA 的潜在优势。受 MoE 启发,MoBA 沿上下文长度维度而非 FFN 中间隐藏维度进行分割。因此,我们的研究旨在确定当我们将上下文划分为更细粒度的块时,MoBA 是否也能受益。更多实验结果见3.1节。
MoBA 与全注意力的混合:MoBA 被设计为全注意力的替代品,保持相同数量的参数,不增不减。这一特性启发我们在全注意力和 MoBA 之间进行平滑过渡。具体来说,在初始化阶段,每个注意力层都可以选择全注意力或 MoBA,并且这个选择可以在训练期间根据需要动态改变。之前的工作【X. Zhang et al. 2024, Simlayerkv: A simple framework for layer-level KV cache reduction, 2024, arXiv】也研究过将全注意力过渡到滑动窗口注意力的类似想法。更多实验结果见3.2节。
与滑动窗口注意力和注意力池的比较:滑动窗口注意力(SWA)和注意力池(attention sink)是两种流行的稀疏注意力架构。我们证明两者都可以被视为 MoBA 的特例。对于滑动窗口注意力【Beltagy et al. 2020, Longformer: The long-document transformer, 2020, arXiv】,每个查询词元只关注其相邻的词元。这可以解释为 MoBA 的一个变体,其门控网络持续选择最近的块。类似地,对于注意力池【G. Xiao et al. 2023, Efficient streaming language models with attention sinks, 2023, arXiv】,每个查询词元关注初始词元和最近词元的组合,这可以看作是 MoBA 的一个变体,其门控网络总是选择初始和最近的块。以上讨论表明 MoBA 比滑动窗口注意力和注意力池具有更强的表达能力。此外,它还表明 MoBA 可以通过整合特定的门控网络灵活地逼近许多静态稀疏注意力架构。
自适应与动态聚焦:总的来说,MoBA 的注意力机制允许模型自适应地、动态地聚焦于上下文中信息最丰富的块。这对于涉及长文档或长序列的任务尤其有益,因为在这些任务中,关注整个上下文可能是不必要的,且计算成本高昂。MoBA 选择性地关注相关块的能力,使得信息的处理更加精细和高效。
2.3 实现
高性能实现方案:我们通过整合来自 FlashAttention 【Dao, D. Fu, et al. 2022, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, 2022, NeurIPS】 和 MoE 【Rajbhandari et al. 2022, Deepspeed-moe: Advancing mixture-of-experts inference and training to power nextgeneration ai scale, 2022, ICML】的优化技术,提供了 MoBA 的高性能实现。图2展示了 MoBA 的高效率,而关于效率和可扩展性的详细实验将在3.4节中讨论。我们的实现包括五个主要步骤:
- 根据门控网络和因果掩码确定查询词元到 KV 块的分配。
- 根据查询词元分配的 KV 块重新排列其顺序。
- 为每个 KV 块及其分配的查询词元计算注意力输出。这一步可以通过使用可变长度的 FlashAttention 进行优化。
图 2: MoBA 与全注意力(使用 Flash Attention 实现)的效率对比。(a) 1M 模型加速评估:在序列长度增加(8K-1M)时,MoBA 与 Flash Attention 在 1M 模型上的计算时间扩展对比。(b) 固定稀疏度比例扩展:在序列长度增加(8K-10M)时,MoBA 与 Flash Attention 的计算时间扩展对比,保持 95.31% 的恒定稀疏度(固定 64 个 MoBA 块,块大小可变,固定的 top-k=3)。
- 将注意力输出重新排列回其原始顺序。
- 使用在线 Softmax(即分块计算)合并相应的注意力输出,因为一个查询词元可能会关注其当前块和多个历史 KV 块。
算法流程:算法工作流在算法1中被形式化,并在图1b中进行了可视化,展示了如何基于 MoE 和 FlashAttention 实现 MoBA。
1. 首先,将 KV 矩阵划分为块(第1-2行)。
2. 接着,根据公式6计算门控分数,该分数衡量查询词元和 KV 块之间的相关性(第3-7行)。
3. 在门控分数(连同因果掩码)上应用 top-k 操作,产生一个稀疏的查询到 KV 块的映射矩阵 G,表示查询到 KV 块的分配(第8行)。
4. 然后,根据查询到 KV 块的映射重新排列查询词元,并计算分块的注意力输出(第9-12行)。值得注意的是,对历史块的注意力(第11和14行)和当前块的注意力(第10和13行)是分开计算的,因为在当前块的注意力中需要保持额外的因果性。
5. 最后,将注意力输出重新排列回原始顺序,并使用在线 softmax(第16行)进行合并【Milakov et al. 2018, Online normalizer calculation for softmax, 2018, arXiv】; 【H. Liu et al. 2023, Blockwise Parallel Transformer for Large Context Models, 2023, arXiv】。
算法 1 MoBA(混合块注意力)实现
需要:查询、键和值矩阵 $Q, K, V \in R^{N \times h \times d}$;MoBA 超参数(块大小 $B$ 和 top-k);$h$ 和 $d$ 分别表示注意力头的数量和头的维度。另记 $n = N/B$ 为块的数量。
1: // 将 KV 分割成块
2: $\{ \tilde{K}_i, \tilde{V}_i \} = \text{split_blocks}(K, V, B)$,其中 $\tilde{K}_i, \tilde{V}_i \in R^{B \times h \times d}, i \in [n]$
3: // 计算门控分数以进行动态块选择
4: $\bar{K} = \text{mean_pool}(K, B) \in R^{n \times h \times d}$
5: $S = Q\bar{K}^\top \in R^{N \times h \times n}$
6: // 使用因果约束选择块(不对未来块进行注意力计算)
7: $M = \text{create_causal_mask}(N, n)$
8: $G = \text{topk}(S + M, k)$
9: // 组织注意力模式以提高计算效率
10: $Q_s, \tilde{K}_s, \tilde{V}_s = \text{get_self_attn_block}(Q, \tilde{K}, \tilde{V})$
11: $Q_m, \tilde{K}_m, \tilde{V}_m = \text{index_select_moba_attn_block}(Q, \tilde{K}, \tilde{V}, G)$
12: // 分别计算注意力
13: $O_s = \text{flash_attention_varlen}(Q_s, \tilde{K}_s, \tilde{V}_s, \text{causal=True})$
14: $O_m = \text{flash_attention_varlen}(Q_m, \tilde{K}_m, \tilde{V}_m, \text{causal=False})$
15: // 使用在线 softmax 合并结果
16: $O = \text{combine_with_online_softmax}(O_s, O_m)$
17: return $O$
A4 实验环境
- 数据集:实验中使用了标准的语言模型预训练语料库,具体名称未详细说明。下游任务评估使用了多个长上下文基准测试,包括 Needle in a Haystack 和 RULER。
- 模型架构:
- 扩展定律实验:训练了五个不同规模的语言模型(具体参数见表1),以验证 MoBA 的可扩展性。
- 消融和混合实验:使用了 1.5B 参数规模的模型。
- 大型语言模型评估:基于 Llama 3.1 8B 基础模型进行持续预训练,最终模型称为 Llama-8B-1M-MoBA。
- 硬件配置:未明确提及具体的硬件型号(如GPU),但实验规模(如训练 10M 上下文长度的模型)表明使用了大规模分布式 GPU 集群。为了支持 10M 上下文长度,将张量并行扩展到了查询头级别。
- 软件配置:
- 代码实现基于 PyTorch。
- 依赖了 FlashAttention 库进行注意力计算优化。
- 代码已在 https://github.com/MoonshotAI/MoBA 开源。
表 1: 扩展定律实验的配置
A4 实验结果
3.1 扩展定律实验与消融研究
- 与模型规模的可扩展性:通过训练五个不同规模的模型,对比了 MoBA 和全注意力的验证集损失。
- 实验内容:在 8K 序列长度上训练 MoBA 和全注意力模型。MoBA 设置块大小为 512,top-k=3,稀疏度高达 81.25%。
- 实验结果:如图 3a 所示,MoBA 和全注意力的验证集损失曲线展现出非常相似的扩展趋势,两者之间的差异始终在 1e-3 范围内。
- 结论:尽管 MoBA 具有高达 75% 的稀疏度,但其扩展性能与全注意力相当。
图 3: MoBA 与全注意力的扩展定律对比。(a) 在验证集上的语言模型损失(序列长度=8K);(b) 在验证集上的尾部语言模型损失(序列长度=32K,最后1K词元);(c) 拟合的扩展定律曲线。
- 长上下文可扩展性:为更准确评估长上下文能力,使用尾部词元 LM 损失(trailing LM loss)作为指标,仅计算长序列最后部分词元的损失。
- 实验内容:将最大序列长度增加到 32K,MoBA 的稀疏度提升至 95.31%。
- 实验结果:如图 3b 所示,尽管 MoBA 的尾部 LM 损失略高于全注意力,但随着模型规模的增加,损失差距逐渐缩小。
- 结论:该实验证明了 MoBA 在长上下文场景下的可扩展性。
- 细粒度块分割的消融研究:探究了 MoBA 块粒度对性能的影响。
- 实验内容:使用 1.5B 参数模型和 32K 上下文长度,将上下文划分为 8, 16, 32, 64, 128 个块,并相应调整 top-k 值以保持 75% 的恒定稀疏度。
- 实验结果:如图 4 所示,块的粒度对 MoBA 性能有显著影响。最粗粒度设置(从8个块中选2个)与更细粒度设置之间存在 1e-2 的性能差异。
- 结论:细粒度分割是提升 MoE 家族模型(包括 MoBA)性能的通用技术。
图 4: 细粒度块分割。验证集上的 LM 损失与不同块粒度的 MoBA 的对比。
3.2 MoBA 与全注意力的混合使用
- MoBA/全注意力混合训练:验证了在训练过程中从 MoBA 平滑切换到全注意力的有效性。
- 实验内容:训练了三个 1.5B 模型:1) 混合模型(90% token 使用 MoBA,10% 使用全注意力);2) 纯全注意力模型;3) 纯 MoBA 模型。使用逐位置 LM 损失进行评估。
- 实验结果:如图 5a 所示,纯 MoBA 模型在尾部词元上有较高的损失,而混合训练策略的损失曲线与纯全注意力模型几乎完全相同。切换过程中未观察到明显的损失尖峰。
- 结论:MoBA/全注意力混合训练是一种在平衡训练效率和模型性能方面非常有效的策略,展示了 MoBA 的灵活性和鲁棒性。
图 5: MoBA 和全注意力的混合。(a) MoBA、全注意力和 MoBA/全注意力混合训练的逐位置 LM 损失;(b) SFT LM 损失与层级混合中全注意力层数的关系;(c) SFT 尾部 LM 损失(序列长度=32K,最后2K)与层级混合中全注意力层数的关系。
- 层级混合策略:为解决 MoBA 在监督微调(SFT)中有时性能不佳的问题,提出了层级混合策略。
- 实验内容:在 SFT 阶段,将 Transformer 的最后几层从 MoBA 切换为全注意力,其余层保持 MoBA。
- 实验结果:如图 5b 和 5c 所示,增加全注意力层的数量可以显著降低 SFT 损失和尾部 SFT 损失。
- 结论:层级混合策略有效解决了 SFT 中因损失掩码(loss masking)可能导致的稀疏梯度问题。
图 6: 持续预训练和 SFT 的流程。
3.3 大型语言模型评估
- 真实世界下游任务评估:在多个长上下文基准上评估了基于 Llama 3.1 8B 训练的 Llama-8B-1M-MoBA 模型。
- 实验设置:模型通过持续预训练将上下文长度逐步扩展到 1M。采用了层级混合策略(最后3层为全注意力)。在所有评估任务中,预填充(prefill)阶段使用 MoBA,生成(generation)阶段切换为全注意力。
- 实验结果:如表 2 所示,Llama-8B-1M-MoBA 的性能与全注意力基线 Llama-8B-1M-Full 高度相当。在最长的 RULER 基准测试中,两者得分几乎持平(0.7818 vs 0.7849)。在“大海捞针”测试中,如图 7 所示,模型在 1M 上下文长度下也表现出令人满意的性能。
- 结论:MoBA 在真实世界的长上下文任务中能够达到与全注意力相当的性能水平。
表 2: MoBA 与全注意力在不同评估基准上的性能对比。
图 7: Llama-8B-1M-MoBA 在大海捞针基准测试上的性能(上下文长度高达 1M)。
3.4 效率与可扩展性
- 计算效率对比:比较了 Llama-8B-1M-MoBA 和 Llama-8B-1M-Full 在注意力层前向传播的时间。
- 实验结果:如图 2a 所示,MoBA 在所有上下文长度上都比全注意力更高效,展现出亚二次方的计算复杂度。在处理 1M 词元时,速度提升高达 6.5 倍。
- 长度可扩展性:将上下文长度逐步增加到 1000 万词元,同时保持恒定的注意力稀疏度。
- 实验结果:如图 2b 所示,MoBA 在扩展到更长序列时表现出卓越的效率。在 10M 词元时,注意力计算时间减少了 16 倍。
- 结论:MoBA 的高效率得益于其块稀疏注意机制和结合 MoE 与 FlashAttention 的优化实现,有效解决了全注意力的二次复杂度瓶颈。
A5 结论
本文介绍了混合块注意力(MoBA),一种受混合专家(MoE)启发的创新注意力架构,旨在提升大语言模型(LLM)在长上下文任务中的效率和可扩展性。MoBA 通过将上下文分块并采用动态门控机制,将查询路由至最相关的 KV 块,从而在降低计算复杂度的同时保持了模型性能,并实现了全注意力与稀疏注意力之间的无缝切换。
大量实验证明,MoBA 在性能上与全注意力相当,同时显著提升了计算效率。结果表明,MoBA 能够有效扩展至长上下文,在各类基准测试中保持低 LM 损失和高性能。此外,MoBA 的灵活性使其能够以较低的训练成本集成到现有模型中,为提升 LLM 的长上下文能力提供了一个实用的持续预训练解决方案。
总而言之,MoBA 代表了高效注意力领域的一项重要进展,在性能和效率之间取得了平衡。未来的工作可能包括进一步优化 MoBA 的块选择策略,探索其在其他模态上的应用,以及研究其在提升复杂推理任务泛化能力方面的潜力。
A6 附录
A.1 长上下文可扩展性
关注尾部词元损失的动机:为了解决自然数据分布中偏向短上下文的偏差,我们策略性地根据实际位置将整个序列分割成离散的片段。例如,跨越位置 30K-32K 的片段仅反映与上下文长度超过 30K 的文档相关的损失,并且也对从 30K 到 32K 的位置进行了掩码。这种方法确保了在不同上下文长度上进行更均衡和有代表性的评估。
关键发现与简化评估:在探索长上下文可扩展性的过程中,我们做出了一个关键发现:尾部词元(trailing tokens)是导致全上下文基线与新提出的稀疏注意力架构之间性能差异的主要原因。因此,我们通过专注于尾部词元扩展来简化长上下文的扩展过程。这不仅简化了计算需求,还显著提高了研究长上下文场景的效率和有效性。这一发现对未来开发更高效、更可扩展的注意力机制具有重要意义。
图 8: 位置 0-16k 的扩展定律
图 8: 位置 16-32k 的扩展定律
表 3: 不同位置的损失扩展
💬 评论讨论
欢迎在这里分享您的想法和见解!