作者/机构: Albert Gu (卡内基梅隆大学机器学习系), Tri Dao (普林斯顿大学计算机科学系)

A1 主要贡献

本文旨在解决现有主流基础模型(Foundation Models)骨干架构 Transformer 在处理长序列时计算效率低下的问题。Transformer 的核心注意力机制虽然效果强大,但其在上下文窗口内的密集信息路由能力导致了两个根本性缺陷:无法对窗口外的内容进行建模,以及计算复杂度随序列长度呈二次方增长。尽管已有许多亚二次方时间复杂度的架构(如线性注意力、门控卷积、循环模型和结构化状态空间模型SSM)被提出,但它们在语言等重要模态上的性能未能超越注意力机制。

核心问题: 现有高效的序列模型(如SSM)缺乏一种关键能力——基于内容进行推理(content-based reasoning),这限制了它们在信息密集型数据(如文本)上的表现。

研究目标: 提出一种新的序列模型架构,既能保持 Transformer 的强大建模能力,又能实现随序列长度线性扩展的计算效率,从而成为适用于各类序列数据的通用基础模型骨干。

创新点/主要贡献:

  1. 选择机制 (Selection Mechanism):

    • 本文发现先前 SSM 模型的一个关键弱点是它们无法根据输入数据动态地选择信息(即关注或忽略特定输入)。
    • 为了解决这个问题,本文提出了一种简单的选择机制:让SSM的参数成为输入的函数。这一改进使得模型能够根据当前词元(token)选择性地在序列维度上传播或遗忘信息,从而有效处理离散和信息密集的模态。
  2. 硬件感知的高效算法 (Hardware-aware Algorithm):

    • 将SSM参数变为输入依赖(时变)后,无法再使用高效的卷积运算。
    • 为克服这一技术挑战,本文设计了一种硬件感知的并行算法。该算法在循环模式下通过扫描(scan)操作来计算模型,但通过巧妙利用GPU内存层级结构(在高速SRAM中计算,避免在慢速HBM中物化扩展状态),避免了IO瓶颈。
    • 该实现不仅在理论上实现了线性时间复杂度(优于所有基于卷积的SSM的伪线性复杂度),在现代硬件(A100 GPU)上也比先前方法快3倍。
  3. Mamba 架构 (Architecture):

    • 本文将先前SSM架构的设计与Transformer中的MLP模块相结合,提出了一个简化且同质化的端到端神经网络架构——Mamba。该架构不包含注意力机制,甚至没有独立的MLP块,而是由集成了选择性SSM的Mamba模块堆叠而成。

Mamba的核心特性:

  • 高质量: 选择机制在语言和基因组学等密集模态上带来了强大的性能。
  • 快速训练和推理: 训练时的计算和内存随序列长度线性扩展。推理时,由于无需缓存先前元素(没有KV-cache),每一步自回归生成仅需常数时间,吞吐量比Transformer高5倍。
  • 长上下文: 高质量和高效率的结合使其在真实数据上的性能能够随序列长度增长而提升,最长可达一百万。

实证成果: Mamba 作为一种通用的序列模型骨干,在多种模T态上均取得了最先进的性能。
- 合成任务: 轻松解决了如选择性复制(Selective Copying)和归纳头(Induction Heads)等对大语言模型至关重要的任务,并能无限外推(超过1M词元)。
- 音频和基因组学: 在音频波形和DNA序列建模方面,Mamba在预训练质量和下游任务指标上均优于SaShiMi、Hyena和Transformers等先前SOTA模型。
- 语言建模: Mamba是首个真正达到Transformer级别性能的线性时间序列模型。Mamba-3B模型在预训练和下游评估中均优于同等规模的Transformers,并与两倍于其规模的Transformers性能相当。

图1:(概述。)结构化SSM将输入的每个通道(例如 D=5)通过一个更高维的潜在状态h(例如 N=4)独立地映射到输出y。先前的SSM通过巧妙的替代计算路径(需要时间不变性)避免了物化这个巨大的有效状态(大小为L×D×N,再乘以批次大小B):(Δ, A, B, C)参数在时间上是恒定的。我们的选择机制重新引入了输入依赖的动态性,这也需要一个精心的硬件感知算法,只在更高效的GPU内存层级中物化扩展状态。
图1:(概述。)结构化SSM将输入的每个通道(例如 D=5)通过一个更高维的潜在状态h(例如 N=4)独立地映射到输出y。先前的SSM通过巧妙的替代计算路径(需要时间不变性)避免了物化这个巨大的有效状态(大小为L×D×N,再乘以批次大小B):(Δ, A, B, C)参数在时间上是恒定的。我们的选择机制重新引入了输入依赖的动态性,这也需要一个精心的硬件感知算法,只在更高效的GPU内存层级中物化扩展状态。

A3 背景知识与设计动机

状态空间模型(State Space Models)

结构化状态空间序列模型(S4) 是一类新兴的深度学习序列模型,与RNN、CNN和经典状态空间模型广泛相关。它们受到一个特定连续系统的启发,该系统通过一个隐式潜在状态 $h(t) \in \mathbb{R}^N$ 将一维函数或序列 $u(t) \in \mathbb{R}$ 映射到 $y(t) \in \mathbb{R}$。

公式 (1)
公式 (1)

公式 (2)
公式 (2)

公式 (3)
公式 (3)

S4模型由四个参数(Δ, A, B, C)定义,通过两个阶段完成序列到序列的转换。

  • 离散化 (Discretization):第一阶段通过固定的公式 $\bar{\mathbf{A}} = f_A(\Delta, \mathbf{A})$ 和 $\bar{\mathbf{B}} = f_B(\Delta, \mathbf{A}, \mathbf{B})$ 将“连续参数”(Δ, A, B)转换为“离散参数”($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)。这对 $(f_A, f_B)$ 称为离散化规则,例如零阶保持(ZOH)规则如公式(4)所示。从机械角度看,离散化可视为SSM前向传播计算图的第一步。

    公式 (4)
    公式 (4)

  • 计算 (Computation):参数从 (Δ, A, B, C) 转换为 ($\bar{\mathbf{A}}, \bar{\mathbf{B}}, \mathbf{C}$) 后,模型可以通过两种方式计算:线性循环(公式2)或全局卷积(公式3)。通常,模型在训练时使用高效并行的卷积模式,在自回归推理时切换到高效的循环模式

线性时间不变性(LTI) 是公式(1)到(3)的一个重要特性,即模型的动态性在时间上是恒定的。换句话说,(Δ, A, B, C) 以及因此得到的 ($\bar{\mathbf{A}}, \bar{\mathbf{B}}$) 对所有时间步都是固定的。LTI特性与循环和卷积紧密相关。迄今为止,所有结构化SSM都是LTI模型,这是因为存在根本的效率约束。然而,本文的一个核心洞见是LTI模型在建模某些类型的数据时存在根本局限,本文的技术贡献在于移除LTI约束的同时克服效率瓶颈。

结构和维度 结构化SSM为了高效计算,需要对 A 矩阵施加结构,最流行的结构是对角结构,本文也采用此结构。在这种情况下,$\mathbf{A} \in \mathbb{R}^{N \times N}, \mathbf{B} \in \mathbb{R}^{N \times 1}, \mathbf{C} \in \mathbb{R}^{1 \times N}$ 矩阵都可以用 $N$ 个数字表示。对于一个批次大小为 $B$、长度为 $L$、通道数为 $D$ 的输入序列 $x$,SSM会独立应用于每个通道。此时,总的隐藏状态维度为 $D \times N$,在整个序列上计算它需要 $O(BDLN)$ 的时间和内存,这是本文在第3.3节中解决的根本效率瓶颈的根源。

通用状态空间模型 “状态空间模型”一词含义广泛,泛指任何具有潜在状态的循环过程。本文中,“SSM”特指结构化SSM或S4模型及其衍生物。

SSM架构 SSM是可集成到端到端神经网络架构中的独立序列变换模块。一些知名的SSM架构包括:
- 线性注意力 (Linear attention):可视为一种退化的线性SSM。
- H3:可以看作是在一个SSM两侧夹了两个门控连接的架构。
- Hyena: 与H3架构相同,但用一个MLP参数化的全局卷积替代了S4层。
- RetNet: 在架构中增加了一个额外的门,并使用一个更简单的SSM,允许一种替代的并行计算路径。
- RWKV: 一种基于另一种线性注意力近似的现代RNN,其核心“WKV”机制涉及LTI循环,可看作是两个SSM的比值。

设计动机:作为压缩手段的选择机制

序列建模的根本问题是将上下文压缩成一个更小的状态。 不同的序列模型在此问题上做出了不同的权衡。注意力机制之所以有效但低效,是因为它完全不压缩上下文(自回归推理需要存储完整的KV缓存)。相反,循环模型之所以高效,是因为它们具有有限的状态,但其效果受限于这个状态对上下文的压缩程度。

为了理解这一原则,本文关注两个合成任务示例。
- 选择性复制 (Selective Copying) 任务通过改变需要记忆的词元位置,要求模型具备内容感知能力,以记忆相关词元(彩色)并过滤掉不相关词元(白色)。
- 归纳头 (Induction Heads) 任务是一个被假设为解释大语言模型大部分上下文学习能力的关键机制。它要求模型具备上下文感知能力,以便在适当的上下文中产生正确的输出。

图2:(左)标准版的复制任务,输入和输出元素之间有固定的间距,很容易被时间不变模型(如线性循环和全局卷积)解决。(右上)选择性复制任务,输入之间有随机的间距,需要能够根据内容选择性记忆或忽略输入的时变模型。(右下)归纳头任务是联想回忆的一个例子,需要根据上下文检索答案,这是LLM的一项关键能力。
图2:(左)标准版的复制任务,输入和输出元素之间有固定的间距,很容易被时间不变模型(如线性循环和全局卷积)解决。(右上)选择性复制任务,输入之间有随机的间距,需要能够根据内容选择性记忆或忽略输入的时变模型。(右下)归纳头任务是联想回忆的一个例子,需要根据上下文检索答案,这是LLM的一项关键能力。

这些任务揭示了LTI模型的失败模式。 从循环角度看,它们恒定的动态性(如公式2中的($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)转移)无法让它们从上下文中选择正确的信息,或以输入依赖的方式影响沿序列传递的隐藏状态。从卷积角度看,全局卷积虽然能解决常规复制任务(因为它只需时间感知),但在选择性复制任务上遇到困难,因为缺乏内容感知能力,无法处理输入输出之间变化的间距。

总结来说,序列模型的效率与效果权衡取决于它们压缩状态的好坏。 高效模型必须有小状态,而有效模型必须有包含所有必要上下文信息的状态。因此,本文提出构建序列模型的一个基本原则是选择性(selectivity):即具备上下文感知能力,以关注或过滤输入,并将其整合到一个序列状态中。选择机制控制信息如何沿序列维度传播或交互。

A2 方法细节

3.2 通过选择机制改进SSM

将参数变为输入依赖。在模型中引入选择机制的一种方法是,让那些影响序列交互的参数(例如RNN的循环动态或CNN的卷积核)依赖于输入。

算法1和2
算法1和2

算法1和2展示了本文使用的主要选择机制。核心区别在于,将几个参数 Δ, B, C 变为输入的函数,并对张量形状进行相应调整。特别地,这些参数现在有了一个长度维度 $L$,意味着模型从时间不变变成了时间变化。这使得模型不再等价于卷积(公式3),对其效率产生了影响,将在下一节讨论。

参数化的具体选择。本文具体选择 $\mathbf{B}(x_t) = \text{Linear}_N(x_t)$, $\mathbf{C}(x_t) = \text{Linear}_N(x_t)$, $\tau_\Delta(x_t) = \text{Broadcast}_D(\text{Linear}_1(x_t))$, 并且 $\Delta_t = s_\Delta(\tau_\Delta(x_t))$,其中 $s_\Delta = \text{softplus}$,$\text{Linear}_d$ 是一个参数化的投影,将输入投影到维度 $d$。选择 $s_\Delta$ 和 $\tau_\Delta$ 是基于与RNN门控机制的联系,将在3.5节中解释。

3.3 选择性SSM的高效实现

先前模型的动机回顾。我们首先回顾先前方法的动机,并概述我们如何克服它们的局限性。
- 从高层次看,像SSM这样的循环模型总是在表达能力速度之间进行权衡:如3.1节所述,具有更大隐藏状态维度的模型应该更有效但更慢。因此,我们的目标是在不牺牲速度和内存成本的情况下最大化隐藏状态维度。
- 循环模式比卷积模式更灵活,因为后者(公式3)是通过展开前者(公式2)得到的。然而,这需要计算并物化形状为 (B, L, D, N) 的潜在状态 $h$,这个状态比输入 $x$ 和输出 $y$ 的形状 (B, L, D) 大得多(大了 $N$ 倍,即SSM状态维度)。因此,引入了更高效的卷积模式,它可以绕过状态计算,只物化一个大小仅为 (B, L, D) 的卷积核(公式3a)。
- 先前的LTI状态空间模型利用循环-卷积的双重形式,将有效状态维度增加了 $N$ 倍(约10-100),远大于传统RNN,且没有效率损失。

选择性扫描概述:硬件感知的状态扩展。选择机制旨在克服LTI模型的局限性;同时,我们因此需要重新审视SSM的计算问题。我们通过三种经典技术来解决这个问题:核函数融合(kernel fusion)、并行扫描(parallel scan)和重计算(recomputation)。我们有两个主要观察:
- 朴素的循环计算使用 $O(BDLN)$ FLOPs,而卷积计算使用 $O(BDL \log(L))$ FLOPs,前者的常数因子更低。因此,对于长序列和不太大的状态维度 $N$,循环模式实际上可能使用更少的FLOPs。
- 两个挑战是循环的顺序性大的内存使用。为了解决后者,就像卷积模式一样,我们可以尝试不实际物化完整的状态 $h$。

核心思想是利用现代加速器(GPU)的特性,只在内存层级结构中更高效的层次上物化状态 $h$。具体来说,大多数操作(矩阵乘法除外)都受内存带宽的限制。这包括我们的扫描操作,我们使用核函数融合来减少内存IO的数量,从而与标准实现相比,速度得到显著提升。具体地,我们不是在GPU HBM(高带宽内存)中准备大小为 (B, L, D, N) 的扫描输入($\bar{\mathbf{A}}, \bar{\mathbf{B}}$),而是将SSM参数(Δ, A, B, C)直接从慢速HBM加载到快速SRAM中,在SRAM中执行离散化和循环,然后将大小为 (B, L, D) 的最终输出写回HBM。

为避免顺序循环,我们观察到尽管它不是线性的,但仍可以用一种工作高效的并行扫描算法来并行化【10, Blelloch 1990; 71, Martin and Cundy 2018; 98, Smith, Warrington, and Linderman 2023】。

最后,我们必须避免保存中间状态,这些状态是反向传播所必需的。我们谨慎地应用经典的重计算技术来减少内存需求:中间状态不被存储,而是在反向传播中当输入从HBM加载到SRAM时重新计算。因此,融合的选择性扫描层具有与使用FlashAttention的优化版Transformer实现相同的内存需求。融合核和重计算的细节在附录D中。完整的选择性SSM层和算法如图1所示。

3.4 简化的SSM架构

与结构化SSM一样,选择性SSM是独立的序列变换,可以灵活地集成到神经网络中。H3架构是大多数知名SSM架构的基础,通常由一个受线性注意力启发的模块和一个MLP(多层感知器)模块交错组成。我们通过将这两个组件合并为一个,并进行同质化堆叠,简化了该架构(图3)。这一设计的灵感来自于门控注意力单元(GAU)【53, Hua et al. 2022】,它对注意力机制做了类似的事情。

图3:(架构。)我们简化的块设计结合了H3块(大多数SSM架构的基础)和现代神经网络中无处不在的MLP块。我们不是交错这两个块,而是简单地同质化地重复Mamba块。与H3块相比,Mamba用激活函数替换了第一个乘法门。与MLP块相比,Mamba在主分支上增加了一个SSM。对于σ,我们使用SiLU / Swish激活函数【48, Hendrycks and Gimpel 2016; 89, Ramachandran, Zoph, and Quoc V Le 2017】。
图3:(架构。)我们简化的块设计结合了H3块(大多数SSM架构的基础)和现代神经网络中无处不在的MLP块。我们不是交错这两个块,而是简单地同质化地重复Mamba块。与H3块相比,Mamba用激活函数替换了第一个乘法门。与MLP块相比,Mamba在主分支上增加了一个SSM。对于σ,我们使用SiLU / Swish激活函数【48, Hendrycks and Gimpel 2016; 89, Ramachandran, Zoph, and Quoc V Le 2017】。

该架构通过一个可控的扩展因子 E 来扩展模型维度 D。对于每个块,大部分参数($3ED^2$)位于线性投影中($2ED^2$用于输入投影,$ED^2$用于输出投影),而内部SSM贡献的参数较少。SSM的参数数量(用于Δ, B, C的投影,以及矩阵A)相对要小得多。我们重复这个块,并与标准的归一化和残差连接交错,以构成Mamba架构。我们总是在实验中固定 $E=2$,并使用两个堆叠的块来匹配Transformer中交错的MHA(多头注意力)和MLP块的$12D^2$参数。我们使用SiLU / Swish激活函数,这样门控MLP就变成了流行的“SwiGLU”变体【16, Chowdhery et al. 2023; 22, Dauphin et al. 2017; 95, Shazeer 2020; 105, Touvron et al. 2023】。最后,我们还额外使用了一个可选的归一化层(我们选择LayerNorm【4, J. L. Ba, Kiros, and Hinton 2016】),这是受到RetNet在类似位置使用归一化层的启发【100, Y. Sun et al. 2023】。

3.5 选择机制的特性

选择机制是一个更广泛的概念,可以以不同方式应用,例如应用于更传统的RNN或CNN,应用于不同的参数(例如算法2中的 A),或使用不同的变换 $s(x)$。

3.5.1 与门控机制的联系

我们强调最重要的联系:RNN的经典门控机制是SSM选择机制的一个实例。RNN门控与连续时间系统离散化之间的联系是公认的【32, Funahashi and Nakamura 1993; 102, Tallec and Ollivier 2018】。实际上,定理1是对【40, Gu, Johnson, Goel, et al. 2021, Lemma 3.1】的改进,推广到了ZOH离散化和输入依赖的门(证明在附录C)。更广泛地看,SSM中的Δ可以被视为扮演了RNN门控机制的广义角色。与先前工作一致,我们认为SSM的离散化是启发式门控机制的原则性基础。

定理1。当 $N=1, A=-1, C=1, \tau_\Delta = \text{Linear}(x)$, 且 $s_\Delta = \text{softplus}$ 时,选择性SSM循环(算法2)的形式为:

公式 (5)
公式 (5)

如3.2节所述,我们对 $\tau_\Delta, s_\Delta$ 的具体选择正是源于这种联系。特别地,请注意,如果某个给定的输入 $x_t$ 应该被完全忽略(如在合成任务中必需的那样),所有 $D$ 个通道都应该忽略它,因此我们在通过广播操作重复Δ之前,将输入投影到1维。

3.5.2 选择机制的诠释

我们阐述了选择机制的三个具体机理效果
- 可变间距 (Variable Spacing)。选择性允许过滤掉可能出现在感兴趣输入之间的不相关噪声词元。这在选择性复制任务中得到了体现,但在常见的数据模态中普遍存在,尤其是在离散数据中——例如语言中“嗯”之类的填充词。这个特性之所以出现,是因为模型可以机械地过滤掉任何特定的输入 $x_t$,例如在门控RNN的情况下(定理1),当 $\bar{\mathbf{B}}_t \to 0$ 时。
- 过滤上下文 (Filtering Context)。经验观察表明,许多序列模型并不会随着上下文的增长而改善【96, F. Shi et al. 2023】,尽管原则上更长的上下文应该带来更好的性能。一种解释是,许多序列模型在必要时无法有效忽略不相关的上下文;一个直观的例子是全局卷积(以及一般的LTI模型)。另一方面,选择性模型可以随时重置其状态以移除多余的历史,因此其性能原则上随上下文长度单调提高(例如4.3.2节)。
- 边界重置 (Boundary Resetting)。在将多个独立序列拼接在一起的场景中,Transformer可以通过实例化特定的注意力掩码来保持它们的分离,而LTI模型则会在序列之间泄露信息。选择性SSM也可以在边界处重置其状态(例如 $\Delta_t \to \infty$,或者定理1中当 $\bar{\mathbf{A}}_t \to 1$ 时)。这些场景可能人为出现(例如,为了提高硬件利用率而将文档打包在一起)或自然发生(例如,强化学习中的情节边界【68, Lu et al. 2023】)。

此外,我们详细阐述了每个选择性参数的效果
- Δ的诠释。总的来说,Δ 控制着对当前输入 $x_t$ 的关注或忽略程度。它推广了RNN的门控(例如定理1中的 $\bar{\mathbf{B}}_t$):从机理上讲,大的Δ会重置状态 $h$ 并专注于当前输入 $x$,而小的Δ会保持状态并忽略当前输入。
- A的诠释。我们注意到,虽然 A 参数也可以是选择性的,但它最终只通过与Δ的相互作用影响模型,即 $\bar{\mathbf{A}} = \exp(\Delta \mathbf{A})$(离散化公式4)。因此,Δ的选择性足以确保($\bar{\mathbf{A}}, \bar{\mathbf{B}}$)的选择性,并且是性能提升的主要来源。我们假设,在Δ之外(或替代Δ)使A具有选择性会产生类似的性能,为简单起见我们将其省略。
- B和C的诠释。如3.1节所述,选择性最重要的特性是过滤掉不相关的信息,以便将序列模型的上下文压缩到一个高效的状态中。在SSM中,将BC修改为选择性的,允许更精细地控制是否让输入 $x_t$ 进入状态 $h_t$,或让状态进入输出 $y_t$。这可以被解释为允许模型根据内容(输入)和上下文(隐藏状态)分别调节循环动态。

3.6 额外的模型细节

实数与复数。大多数先前的SSM在其状态 $h$ 中使用复数,这对于在感知模态中许多任务上取得强劲性能是必要的【37, Gu, Goel, and Ré 2022】。然而,经验观察表明,在某些情况下,完全实值的SSM似乎效果不错,甚至可能更好【70, Ma et al. 2023】。我们默认使用实数值,这在我们除了一项任务外的所有任务中都表现良好;我们假设复数-实数的权衡与数据模态中的连续-离散谱有关,其中复数对连续模态(如音频、视频)有帮助,但对离散模态(如文本、DNA)则不然。

初始化。大多数先前的SSM也建议特殊的初始化,特别是在复数值情况下,这在几种情况下(如低数据量场景)有所帮助。我们对于复数情况的默认初始化是S4D-Lin,对于实数情况是S4D-Real【39, Gu, Gupta, et al. 2022】,这基于HIPPO理论【36, Gu, Dao, et al. 2020】。然而,我们预计许多初始化方法都能工作得很好,特别是在大数据和实值SSM的场景中;一些消融研究在4.6节中进行了考虑。

Δ的参数化。我们将对Δ的选择性调整定义为 $\tau_\Delta(x) = \text{Broadcast}_N(\text{Linear}_1(x))$,这是由Δ的机理所驱动的(3.5节)。我们观察到,这可以从维度1推广到更大的维度R。我们将其设置为D的一小部分,与块中的主要线性投影相比,使用的参数数量可以忽略不计。我们还注意到,广播操作可以被看作是另一个线性投影,初始化为特定的1和0模式;如果这个投影是可训练的,这将导致另一种选择 $\tau_\Delta(x) = \text{Linear}_R(\text{Linear}_R(x))$,可以看作是一个低秩投影。在我们的实验中,Δ参数(可看作一个偏置项)被初始化为 $s_\Delta^{-1}(\text{Uniform}([0.001, 0.1]))$,遵循了先前关于SSM的工作【41, Gu, Johnson, Timalsina, et al. 2023】。

备注3.1。为了在我们的实验结果中保持简洁,我们有时将选择性SSM缩写为S6模型,因为它们是具有选择机制并通过扫描计算的S4模型。

A4 实验

实验环境

  • 数据集:
    • 语言模型: The Pile数据集【33, L. Gao, Biderman, et al. 2020】,包含800GB的多样化文本。
    • DNA序列: HG38数据集(人类基因组),训练集包含约45亿个DNA碱基对。下游任务使用5种大型猿类的基因组进行分类。
    • 音频波形: YouTubeMix【23, DeepSound 2017】,一个包含4小时独奏钢琴音乐的标准数据集,采样率为16000 Hz。SC09数据集【109, Warden 2018】用于语音生成,包含0-9数字的1秒音频片段。
    • 合成任务: Selective Copying任务和Induction Heads任务。
  • 模型架构:
    • Mamba: 核心模型,参数量从约1.25亿到69亿不等。
    • 基线模型: 标准Transformer (GPT-3架构)、Transformer++ (基于PaLM和LLaMa的优化配方)、Hyena、H3、RetNet、RWKV、SaShiMi等。
    • 模型尺寸(深度和宽度)通常参照GPT-3规范设置。
  • 硬件配置:
    • GPU: 实验在A100 80GB PCIe GPU上进行基准测试和训练。
  • 软件配置:
    • 实现: 模型代码和预训练检查点在https://github.com/state-spaces/mamba开源。
    • 依赖库: PyTorch。基准测试中与FlashAttention-2【19, Dao 2024】进行比较。下游评估使用EleutherAI的LM evaluation harness【34, L. Gao, Tow, et al. 2021】。

实验结果

合成任务

  • 选择性复制(Selective Copying): 实验结果(表1)证实,包含门控的架构(如H3和Mamba)性能仅部分提升,而引入选择机制(将S4升级为S6)能轻松解决此任务,尤其是在与更强大的架构结合时。这表明选择机制是实现内容感知推理的关键。
  • 归纳头(Induction Heads): Mamba(更确切地说是其选择性SSM层)能够完美解决该任务,因为它能选择性地记住相关词元而忽略中间的所有其他内容。如表2所示,Mamba能完美地泛化到百万长度的序列,比其训练长度长4000倍,而其他方法最多只能泛化到2倍。
架构 内部层 准确率(%)
GatedMLP S4 (LTI) 2.0
H3 S4 (LTI) 3.1
Mamba S4 (LTI) 3.1
GatedMLP S6 (选择性) 99.9
H3 S6 (选择性) 100.0
Mamba S6 (选择性) 100.0
表1: (选择性复制。) 架构与内部序列层组合的准确率。
模型 L=64 L=256 (训练) ... L=1,048,576
MHA-Abs 99.8 ...
MHA-RoPE 99.9 ...
MHA-xPos 99.9 ...
H3 (S4) 99.8 ... 50.3
Hyena 99.8 ... 50.2
Mamba (S6) 100.0 ... 100.0
表2: (归纳头。) 模型在序列长度256上训练,并在64到1,048,576的序列长度上测试。

语言建模

  • 扩展定律(Scaling Laws): 如图4所示,在1.25亿到13亿参数的模型规模上,Mamba是首个性能与一个非常强大的Transformer配方(Transformer++)相匹配的无注意力模型,尤其是在序列长度增加时,其优势更加明显。
  • 下游评估: 如表3所示,在一系列流行的零样本下游评估任务上,Mamba在其各自的模型规模上均是同类最佳,并且通常能与两倍于其大小的基线模型(如Pythia)性能相匹配。例如,Mamba-3B在常识推理任务上的平均得分比Pythia-3B高4个点,甚至超过了Pythia-7B。

图4:(扩展定律。)模型大小从约1.25亿到约13亿参数,在Pile数据集上训练。Mamba的扩展性优于所有其他无注意力模型,并且是第一个在性能上与一个非常强大的“Transformer++”配方相匹配的模型,尤其是在序列长度增长时。
图4:(扩展定律。)模型大小从约1.25亿到约13亿参数,在Pile数据集上训练。Mamba的扩展性优于所有其他无注意力模型,并且是第一个在性能上与一个非常强大的“Transformer++”配方相匹配的模型,尤其是在序列长度增长时。

表3:(零样本评估。)Mamba在每个模型规模上均是同类最佳,通常与两倍大小的基线模型性能相当。
表3:(零样本评估。)Mamba在每个模型规模上均是同类最佳,通常与两倍大小的基线模型性能相当。

DNA建模

  • 模型规模扩展: 如图5(左)所示,在固定上下文长度为1024时,Mamba的预训练困惑度随模型规模平滑提升,并且扩展性优于HyenaDNA和Transformer++。
  • 上下文长度扩展: 如图5(右)所示,在固定模型大小时,Mamba能够利用长达100万的上下文,其预训练困惑度随上下文增加而改善。相反,HyenaDNA模型的性能随序列长度增加而变差,这直观地验证了选择机制能够过滤无关信息的特性。
  • 合成物种分类: 在一项极具挑战性的下游任务——对DNA相似度高达99%的五种大型猿类进行分类中,Mamba在所有序列长度上都取得了更高的准确率,尤其是在长序列上优势更为明显(图6)。

图5:(DNA扩展定律。)在HG38(人类基因组)数据集上预训练。(左)固定短上下文长度1024,增加模型大小,Mamba的扩展性优于基线。(右)固定模型大小,增加序列长度,Mamba的选择机制促进了在更长上下文下的性能提升,而基线模型则不然。
图5:(DNA扩展定律。)在HG38(人类基因组)数据集上预训练。(左)固定短上下文长度1024,增加模型大小,Mamba的扩展性优于基线。(右)固定模型大小,增加序列长度,Mamba的选择机制促进了在更长上下文下的性能提升,而基线模型则不然。

图6:(大型猿类DNA分类。)在使用相同上下文长度的预训练模型上,对长度从1024到1,048,576的序列进行微调后的准确率。
图6:(大型猿类DNA分类。)在使用相同上下文长度的预训练模型上,对长度从1024到1,048,576的序列进行微调后的准确率。

音频建模与生成

  • 长上下文自回归预训练: 如图7所示,在YouTubeMix钢琴音乐数据集上,Mamba在自回归音频建模方面优于先前的SOTA模型(SaShiMi),并且在长达一分钟(约100万个采样点)的上下文中性能持续提升。值得注意的是,这是本文唯一一个使用复数参数化的实验。
  • 自回归语音生成: 在SC09语音生成基准测试中,如表4所示,一个小型Mamba模型在自动化指标上就超越了最先进的(且大得多的)基于GAN和扩散的模型。与基线模型参数匹配的更大Mamba模型在保真度指标上取得了更显著的提升。

图7:(音频预训练。)Mamba在自回归音频建模方面优于先前的SOTA模型(Sashimi),并且在长达一分钟或百万长度序列的上下文中性能持续提升(在控制计算量的情况下)。
图7:(音频预训练。)Mamba在自回归音频建模方面优于先前的SOTA模型(Sashimi),并且在长达一分钟或百万长度序列的上下文中性能持续提升(在控制计算量的情况下)。

表4:(SC09)在具有挑战性的定长语音片段数据集上进行无条件生成的自动化指标。
表4:(SC09)在具有挑战性的定长语音片段数据集上进行无条件生成的自动化指标。

速度和内存基准

  • 训练速度: 如图8(左)所示,本文的高效扫描实现比标准PyTorch实现快40倍,在序列长度超过2K时比FlashAttention-2更快。
  • 推理吞吐量: 如图8(右)所示,作为循环模型,Mamba由于没有KV缓存,可以使用更大的批次大小,从而实现了比同等规模Transformer高4-5倍的推理吞吐量。
  • 内存: Mamba的内存占用与最优化的Transformer实现相当(表15)。

图8:(效率基准。)(左)训练:我们高效的扫描实现比标准实现快40倍。(右)推理:作为循环模型,Mamba可以实现比Transformers高5倍的吞吐量。
图8:(效率基准。)(左)训练:我们高效的扫描实现比标准实现快40倍。(右)推理:作为循环模型,Mamba可以实现比Transformers高5倍的吞吐量。

模型消融实验

  • 架构和SSM层(表6): Mamba块的性能与H3相似但更简单。性能的主要提升来自于选择性SSM(S6),而非LTI模型(S4)的具体参数化方式。
  • 选择性参数(表7): Δ是实现选择性最重要的参数,但同时让B和C也具备选择性能产生协同效应,带来进一步提升。
  • SSM状态维度(表10): 增加SSM状态维度N可以显著提升性能,且参数/FLOPs成本几乎不变,但这只有在B和C也具有选择性时才有效,验证了本文的核心动机。

A7 补充细节

本文讨论了相关工作、局限性和一些未来方向。

无免费午餐:连续-离散谱。结构化SSM最初被定义为连续系统的离散化,对连续时间数据模态(如音频、视频)有很强的归纳偏置。如第3.1和3.5节所讨论,选择机制克服了它们在离散模态(如文本和DNA)上的弱点;但这反过来可能会影响它们在LTI SSM擅长的数据上的性能。本文对音频波形的消融实验更详细地研究了这种权衡。

下游可供性(Downstream Affordances)。基于Transformer的基础模型(特别是LLM)拥有丰富的生态系统,包括微调、适配、提示、上下文学习、指令调优、RLHF、量化等多种交互模式。一个特别令人感兴趣的问题是,像SSM这样的Transformer替代品是否具有类似的属性和可供性。

扩展性(Scaling)。本文的实证评估仅限于较小的模型尺寸,低于大多数强大的开源LLM(如Llama)以及其他循环模型(如RWKV和RetNet)的规模,后者已在7B参数规模及以上进行了评估。Mamba在这些更大规模下是否仍具竞争力尚待评估。我们还注意到,扩展SSM可能涉及本文未讨论的进一步工程挑战和模型调整。

A5 结论

本文为结构化状态空间模型引入了一种选择机制,使其能够在序列长度上线性扩展的同时执行上下文相关的推理。当集成到一个简单的无注意力架构中时,Mamba在多种领域取得了最先进的结果,其性能与强大的Transformer模型相当或更优。我们对选择性状态空间模型在为不同领域构建基础模型,特别是在需要长上下文的新兴模态(如基因组学、音频和视频)中的广泛应用感到兴奋。我们的结果表明,Mamba是成为通用序列模型骨干的有力竞争者。

A6 附录

附录 A: 选择机制的讨论

本文的选择机制受到并与门控、超网络和数据依赖等概念相关。 但我们认为它是一个值得澄清的独特概念。
- 门控(Gating): 最初指LSTM和GRU等RNN的门控机制,控制信息如何随时间传播并沿序列长度维度交互。现在,“门控”一词的用法已泛化为任何乘法交互,即使不沿序列维度交互,这与原始RNN的意义非常不同。
- 超网络(Hypernetworks)和数据依赖(Data-dependence): 这些概念非常宽泛,几乎可以涵盖任何参数依赖于数据的模型,包括注意力机制和简单的GLU激活函数,因此信息量不大。
- 选择(Selection): 我们认为选择机制与传统RNN的门控机制关系最密切(定理1证明了后者是前者的一个特例)。我们使用“选择”一词来特指模型选择或忽略输入并促进数据沿序列长度维度交互的机理作用

附录 B: 相关工作

  • S4变体和衍生物:
    • S4: 引入了第一个结构化SSM,采用对角加低秩(DPLR)结构。
    • DSS/S4D: 发现了对角结构SSM的有效性。
    • S5: 独立发现了对角SSM近似,并且是第一个使用并行扫描进行循环计算的S4模型。但它采用了MIMO(多输入多输出)公式,降低了有效状态维度。S6保留了SISO维度,通过硬件感知算法实现了更大的有效循环状态,并增加了选择机制。
    • Mega: 将S4简化为实数值,并将其解释为指数移动平均(EMA)。
    • Liquid S4: 增加了输入依赖的状态转移,但形式有限,仍以卷积方式计算,接近LTI。
    • SGConv, Hyena, LongConv等: 专注于S4的卷积表示,但无法直接进行快速自回归推理。
  • SSM架构:
    • GSS: 第一个结合SSM的门控神经网络架构。
    • H3: 结合S4与线性注意力的架构,是后续架构的基础。
    • RetNet: 基于线性注意力,可视为线性SSM的一个特例,使用类似MHA的变体进行并行化。
    • RWKV: 基于AFT(无注意力Transformer)的现代RNN,其核心机制可视为两个SSM的比值。
  • 与RNN的关系:
    • 一些较早的RNN如QRNN和SRU,由于其门控机制,可被视为选择性SSM的实例。但它们不使用状态扩展(N>1)或选择性的B, C参数,且其门控是启发式的。
    • 现代结构化SSM通过更仔细的参数化(如离散化)和理论分析,克服了传统RNN的效率和梯度消失问题。

附录 C: 选择性SSM的机理

定理1的证明。考虑一个选择性SSM(算法2),其中 $N=1, A=-1, C=1, \tau_\Delta = \text{Linear}(x), s_\Delta = \text{softplus}$。对应的连续时间SSM(公式1)是一个漏积分器(leaky integrator)。离散化步长为 $\Delta_t = \text{softplus}(\mathbf{v} \cdot \mathbf{x}_t + b)$。应用零阶保持(ZOH)离散化公式:
$\bar{\mathbf{A}}_t = e^{\Delta_t A} = e^{-\Delta_t}$
$\bar{\mathbf{B}}_t = (\Delta_t A)^{-1}(e^{\Delta_t A} - I) B = (-\Delta_t)^{-1}(e^{-\Delta_t} - 1) \cdot 1 = \frac{1 - e^{-\Delta_t}}{\Delta_t}$
最终的离散循环(公式2a)为 $h_t = e^{-\Delta_t} h_{t-1} + \frac{1 - e^{-\Delta_t}}{\Delta_t} x_t$,这与门控RNN的形式相符。

附录 D: 选择性SSM的硬件感知算法

在没有输入依赖的选择性时,SSM可以高效地实现为卷积。有了选择性,SSM不再等价于卷积,但我们可以利用并行关联扫描。我们通过核函数融合重计算使SSM扫描在现代硬件(GPU)上既快速又内存高效。
- 速度:
- 标准实现方式需要在慢速的GPU HBM(高带宽内存)中物化大小为(B, L, D, N)的扫描输入,导致大量内存IO。
- 我们的融合核将离散化、扫描和与C的乘法合并为一个操作:
1. 从HBM读取SSM参数(Δ, A, B, C)到快速SRAM。
2. 在SRAM中离散化得到$\bar{\mathbf{A}}, \bar{\mathbf{B}}$。
3. 执行并行扫描,中间状态仍在SRAM中。
4. 与C相乘求和,得到最终输出并写回HBM。
- 这种方式将IO减少了 $O(N)$ 倍,实际操作中速度提升20-40倍。对于过长的序列,可分块处理。
- 内存:
- 为了避免内存爆炸,我们不在前向传播中保存大小为(B, L, D, N)的中间状态。
- 在反向传播中,我们使用重计算技术,当需要时重新计算这些中间状态。这不仅节省了存储空间,还因为避免了从HBM读取大量中间状态而加快了反向传播的速度。
- 通过对整个选择性SSM块(包括输入投影、激活、扫描等)进行重计算优化,其总内存需求与使用FlashAttention的优化版Transformer实现相当。

附录 E: 实验细节和附加结果

  • 合成任务:
    • 选择性复制: 序列长度4096,词汇表大小16,记忆16个“数据”词元。2层模型,维度D=64。
    • 归纳头: 2层模型,D=64(Mamba)或128(其他),在序列长度256上训练。
  • 语言建模:
    • 扩展定律: 遵循GPT-3配方,在Pile数据集上训练,使用Chinchilla扩展定律调整训练词元数。基线模型包括一个采用旋转位置编码(RoPE)和SwiGLU MLP的Transformer++。Mamba和Transformer++等模型使用了改进的训练配方(更高的学习率、RMSNorm、无偏置等)。
    • 附加消融(图9):
      • 将Mamba块与MLP块交错(Mamba-MLP),性能略有下降但仍强于多数基线。
      • 将Mamba块与MHA块交错(Mamba-MHA),性能仅略微提升,这表明在选择性SSM存在时,注意力的增益可能不如在LTI SSM中那么大。
      • 训练配方的改进(例如从Hyena到Hyena+)能带来巨大性能提升。
  • DNA建模:
    • 预训练: 在HG38数据集上进行。对于不同上下文长度,数据处理方式不同以确保利用全部训练数据。
    • 物种分类: 微调10个epoch,批次大小64。长序列使用序列长度预热(SLW)。Mamba使用1e-4的学习率,Hyena使用4e-5。
  • 音频细节:
    • YouTubeMix预训练: 模型约3.5M参数。数据为mu-law编码的8位离散词元。
    • 附加消融(图10): 在音频这种均匀采样的“连续”信号上,LTI模型(S4)实际上比选择性模型(S6)表现更好,这符合其归纳偏置。然而,如果只在U-Net的内部层使用选择性模型(外层仍为LTI),性能差距会显著缩小,表明靠近原始信号的层更适合LTI,而经过压缩的深层表示则不然。
  • 效率基准:
    • 扫描操作: 在A100 GPU上使用BF16数据类型进行测量,模型维度D=1024,状态维度N=16。
    • 端到端推理: 提示长度2048,生成长度128。在A100 80GB GPU上测量不同批次大小下的吞吐量(tokens/s)。

表11:(归纳头。)模型在序列长度256上训练,并在各种序列长度(从64到1,048,576)上测试。✓表示完美的泛化准确率,而✗表示内存不足。
表11:(归纳头。)模型在序列长度256上训练,并在各种序列长度(从64到1,048,576)上测试。✓表示完美的泛化准确率,而✗表示内存不足。