Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality
标题:Transformer即SSM:通过结构化状态空间对偶性实现的广义模型和高效算法
作者:Tri Dao, Albert Gu
机构:普林斯顿大学计算机科学系,卡内基梅隆大学机器学习系
A1 主要贡献
本文的核心目标是建立结构化状态空间模型(SSMs)与多种注意力变体之间的丰富理论联系,从而将为Transformer开发的算法和系统优化迁移到SSM上,构建性能优于Transformer且序列长度扩展效率更高的基础模型。
核心问题:
1. Transformer模型在训练时存在序列长度的二次方缩放问题,在自回归生成时需要线性大小的缓存。
2. 结构化状态空间模型(SSMs)虽然在训练时实现了线性缩放,在生成时具有恒定的状态大小,但其发展与Transformer社区的努力脱节,导致理论理解、实验和硬件优化上存在困难。
研究目标:
本文旨在通过结构化矩阵的抽象,建立一个连接结构化SSMs与注意力变体的理论框架,称为结构化状态空间对偶性(Structured State Space Duality, SSD)。该框架旨在结合SSMs和注意力的优势,开发新的高效算法和模型架构。
主要贡献与创新点:
1. 理论框架(状态空间对偶性 SSD):通过结构化半可分矩阵(structured semiseparable matrices)的分解,建立了SSMs和注意力变体之间的理论联系。该框架从矩阵变换和张量收缩两个角度揭示了这种对偶性。
* SSM与半可分矩阵的等价性:证明了状态空间模型与一类成熟的结构化矩阵——半可分矩阵是等价的(第3节)。这揭示了SSM的新属性和算法,表明计算SSM的不同方法可以被重新构建为对结构化矩阵的各种矩阵乘法算法。
* 线性注意力的理论改进:通过张量收缩的语言,为线性注意力的循环形式提供了简洁的证明,并将其推广到一个新的结构化掩码注意力(Structured Masked Attention, SMA)家族(第4节)。
* 连接SSM与SMA:证明了SSM和SMA存在一个很大的交集,它们互为对偶,同时拥有类似SSM的线性形式和类似注意力的二次形式(第5节)。同时证明了任何具有快速循环形式的核注意力方法都必须是一个SSM。
2. 高效算法(SSD算法):基于半可分矩阵的块分解,提出了一种新的SSD算法(第6节)。该算法结合了线性SSM循环和二次对偶形式的优点,在训练和推理计算、内存使用以及利用现代硬件的矩阵乘法单元方面实现了最优权衡。
* 性能:专门实现的SSD算法比Mamba的优化选择性扫描实现快2-8倍,同时允许更大的循环状态大小(Mamba的8倍或更高)。与FlashAttention-2相比,SSD在序列长度2K时性能超越对方,在序列长度16K时快6倍。
3. 架构设计(Mamba-2):利用SSD框架,将注意力机制的设计思想(如多头机制)引入SSM,并对Mamba架构进行改进,设计了Mamba-2架构(第7节)。
* 引入多头机制:为SSM引入了类似多头注意力(MHA)的“头”结构。分析表明Mamba架构是一种多输入SSM(MIS),类似于多值注意力(MVA)。
* 支持张量并行:通过引入分组值注意力(GVA)头结构,并将所有数据依赖的投影移至块的开头并行进行,使Mamba-2能够实现张量并行。
* 性能优势:在Chinchilla缩放定律下,Mamba-2在困惑度和训练时间上均优于Mamba和Transformer++。在Pile数据集上训练的Mamba-2模型在标准下游评测中匹配或超越了Mamba和开源Transformer。例如,27亿参数的Mamba-2(训练300B tokens)优于28亿参数的Mamba、Pythia-2.8B乃至Pythia-6.9B。
4. 系统优化:将为Transformer开发的系统优化技术(如张量并行、序列并行、变长序列处理)应用于SSM和Mamba-2(第8节),提升了大规模训练和微调的效率。
* 张量并行(TP):Mamba-2设计对TP友好,将每个块的同步点减少了一半。
* 序列并行:通过在设备间传递循环状态,实现了对SSM的序列并行训练。
* 变长序列处理:Mamba-2能高效处理变长序列,无需填充令牌。
A3 背景知识与概述
2.1 结构化状态空间模型
S4模型简介。结构化状态空间序列模型(S4)是一类新兴的序列模型,与RNN、CNN和经典状态空间模型相关。其灵感来源于一个特定的连续系统(1),该系统通过一个隐式潜状态$h \in R^{(T,N)}$将一维序列$u \in R^T$映射到$y \in R^T$。
结构化SSM的通用离散形式如公式(1)所示。
其中,$A_t \in R^{(N,N)}$, $B_t \in R^{(N,1)}$, $C_t \in R^{(N,1)}$。结构化SSM之所以得名,是因为控制时间动态的$A_t$矩阵必须是结构化的,才能高效地计算这种序列到序列的转换。最初引入的结构是对角加低秩(DPLR)【42,Efficiently Modeling Long Sequences with Structured State Spaces,2022,ICLR】和对角结构【43,On the Parameterization and Initialization of Diagonal State Space Models,2022,NeurIPS;46,Diagonal State Spaces are as Effective as Structured State Spaces,2022,NeurIPS;93,Simplified State Space Layers for Sequence Modeling,2023,ICLR】,后者至今仍是最流行的结构。
SSM的不同视角。本文中,状态空间模型(SSM)指代结构化SSM。这类模型有多种形式,与连续时间模型、循环模型和卷积模型等主流神经序列模型范式有深厚联系【44,Combining Recurrent, Convolutional, and Continuous-time Models with the Linear State Space Layer,2021,NeurIPS】。
* 连续时间模型:最初的结构化SSM源于对函数$u(t) \in R \mapsto y(t) \in R$的连续时间映射。在连续时间视角下,公式(1a)中的矩阵$(A_t, B_t)$不是直接学习的,而是由底层参数$(\mathring{A}, \mathring{B})$和一个参数化的步长$\Delta$生成。这些“连续参数”$(\Delta, \mathring{A}, \mathring{B})$通过固定的离散化规则$A = f_A(\Delta, \mathring{A})$和$B = f_B(\Delta, A)$转换为“离散参数”$(A_t, B_t)$。为简化表述,本文后续将省略这一参数化和离散化步骤。
* 循环模型:公式(1)和(2)呈现为一种在其输入$u$上是线性的循环形式。因此,结构化SSM可被视为一种循环神经网络(RNN),其线性特性赋予了它们额外的属性,并允许它们避免传统RNN的顺序计算。尽管有这种简化,SSM在序列转换方面仍然具有完全的表达能力(在通用近似的意义上)【54,Linear Dynamical Systems as a Core Computational Primitive,2020,NeurIPS;68,Resurrecting Recurrent Neural Networks for Long Sequences,2023,ICML;104,State-space Models with Layer-wise Nonlinearity are Universal Approximators with Exponential Decaying Memory,2023,arXiv】。
* 卷积模型:当SSM的动态随时间保持不变时(如公式(1)),该模型被称为线性时不变(LTI)模型。在这种情况下,它们等价于卷积。因此,SSM也可被视为一种CNN,但其卷积核是通过SSM参数$(A, B, C)$隐式参数化的,并且通常是全局而非局部的。反之,根据经典的信号处理理论,所有行为良好的卷积都可以表示为SSM。之前的LTI SSM通常在训练时使用卷积模式(高效并行),在自回归推理时切换到循环模式(1)。
* 选择性状态空间模型:Mamba引入了公式(2)的形式,其中参数$(A_t, B_t, C_t)$可以随时间变化,称为选择性SSM。与标准的LTI公式(1)相比,该模型可以在每个时间步选择性地关注或忽略输入。它在信息密集型数据(如语言)上表现远优于LTI SSM,尤其是在其状态大小N增加时。然而,它只能以循环模式计算,需要精心设计的硬件感知实现才能高效。即便如此,由于它不利用现代加速器(如GPU和TPU)专门优化的矩阵乘法单元,其效率仍低于CNN和Transformer等硬件友好模型。本文将展示选择性SSM与注意力之间的深层关系,并利用它显著提升SSM的训练速度,同时允许更大的状态大小N。
结构化SSM作为序列转换。
* 定义2.1 (序列转换):我们将序列转换定义为一个参数化的映射$y = f_\theta(x)$,其中$x, y \in R^{(T,P)}$,$\theta$是任意参数集合。T代表序列或时间轴;下标索引第一个维度,例如$x_t, y_t \in R^P$。序列转换是深度序列模型(如Transformer)的基石。公式(1)或(2)中的SSM是一个P=1的序列转换;通过在该维度上广播,可以推广到P>1(即视输入为P个独立序列并分别应用SSM)。可将P看作一个头维度。
* 定义2.2 (SSM算子):我们将SSM算子$SSM(A, B, C) = SSM(A_{0:T}, B_{0:T}, C_{0:T})$定义为由公式(2)定义的序列转换$x \in R^{(T,P)} \mapsto y \in R^{(T,P)}$。在SSM中,N维是一个称为状态大小或状态维度的自由参数,也称为状态扩展因子。
* 定义2.3 (矩阵转换):如果一个序列转换$y = f_\theta(x)$可以写成$y = Mx$的形式,其中M是一个依赖于参数$\theta$的矩阵,我们称之为一个矩阵转换。我们将序列转换与矩阵M等同起来。
2.2 注意力
注意力机制概述。注意力是一种为序列中每对位置分配分数的计算,允许每个元素“关注”其余部分。最常见的变体是softmax自注意力,定义为:
其中$Q, K, V \in R^{(T,P)}$。成对比较(通过物化$QK^T$)导致了注意力特有的二次方训练成本。
线性注意力。线性注意力【52,Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention,2020,ICML】是本文关注的重要变体。它通过将softmax折叠到一个核特征图中来去除softmax,并利用矩阵乘法的结合律重写$(QK^T) \cdot V = Q \cdot (K^T V)$。在重要的因果(自回归)注意力情况下,当因果掩码被整合为$(M \circ QK^T) \cdot V$(其中M是下三角全1矩阵),右侧可以展开为一个循环。最近的如RetNet【95,Retentive network: A successor to transformer for large language models,2023,arXiv】和GateLoop【53,GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling,2023,arXiv】的工作将此推广到更一般的M形式。本文的结构化掩码注意力将极大地推广这些思想。
2.3 结构化矩阵
结构化矩阵定义。一般矩阵$M \in R^{(T,T)}$需要$T^2$个参数表示,并且矩阵向量乘法等基本操作需要$O(T^2)$时间。结构化矩阵是指那些 (i) 可以通过压缩表示以亚二次方(理想情况下是线性)参数来表示,并且 (ii) 通过直接对这种压缩表示进行操作而拥有快速算法(最重要的是矩阵乘法)的矩阵。除了经典的稀疏和低秩矩阵,还存在许多其他类型,如Toeplitz、Cauchy、Vandermonde和蝶形矩阵,它们都已在机器学习中用于高效模型。本文将证明SSM等价于另一类之前未在深度学习中使用过的结构化矩阵,并利用这种联系推导出高效的方法和算法。
2.4 概述:结构化状态空间对偶性
SSD层概述。状态空间对偶(SSD)层可以被定义为选择性SSM (2) 的一个特例。SSM的标准计算方式是作为循环(或并行扫描),其序列长度复杂度是线性的。与Mamba中使用的版本相比,SSD有两个微小差异:
* $A_t$的结构从对角简化为标量乘以单位矩阵。此时每个$A_t$可以等同于一个标量。
* 使用更大的头维度P(通常P={64, 128}),类似于现代Transformer的惯例。
与原始选择性SSM相比,这些变化可以看作是略微降低了表达能力,以换取显著的训练效率提升,特别是允许使用现代加速器上的矩阵乘法单元。
SSD的对偶(二次)形式。SSD的对偶形式是一种与注意力密切相关的二次计算,定义为:
其中$a_t$是输入依赖的标量,范围在[0, 1]内。与标准softmax注意力相比,主要有两个不同:
* 去除了softmax。
* 注意力矩阵与一个额外的掩码矩阵L进行了逐元素相乘。
这两个变化可以看作是解决了vanilla注意力中的问题,例如softmax导致的“注意力沉洞”现象。更重要的是,掩码矩阵L可以看作是用一种不同的数据依赖的位置掩码取代了Transformer中启发式的位置嵌入,以控制信息在时间上的传递。此形式是本文在第4节中定义的结构化掩码注意力(SMA)的一个实例。
矩阵形式与SSD算法。SSD的各种形式通过统一的矩阵表示连接起来,即SSM具有矩阵转换形式$y = M_\theta x$,其中矩阵$M_\theta \in R^{(T,T)}$依赖于$\theta = (A, B, C)$。SSD的对偶形式等价于与矩阵M的朴素(二次时间)乘法,而循环形式是利用M中结构的一种特定高效(线性时间)算法。本文提出的硬件高效SSD算法(第6节)是一种新的结构化矩阵乘法方法,涉及M的块分解,从而在纯线性和纯二次形式之间取得了更好的效率权衡。与通用选择性SSM相比,它相对简单且易于实现。
2.5 符号表示
本文偏好使用可映射到代码的精确符号。
* 矩阵与向量:小写字母表示向量,大写字母表示矩阵。 · 表示标量或矩阵乘法,◦ 表示Hadamard(逐元素)乘法。
* 索引:使用Python风格索引,例如$i:j$表示范围。$[T]$等价于$0:T$。速记法$\times_{i:j} A_i$表示$A_i \times \cdots \times A_{j-1}$的乘积。
* 维度:使用打字机字体的大写字母(如D, N, T)表示维度和张量形状。
* 张量收缩:将大量使用张量收缩或einsum表示法。例如,矩阵-矩阵乘法表示为 contract(MN, NK → MK)(X, Y)。
附录A包含了符号术语表。
A2 方法细节
3 状态空间模型是结构化矩阵
本节探讨了将状态空间模型作为序列转换的不同视角,并概述了这类映射的属性和算法。主要成果是证明状态空间模型与一类称为半可分矩阵(semiseparable matrices)的结构化矩阵之间的等价性,这带来了新的效率结论(定理3.5和3.7)。
3.1 状态空间模型的矩阵转换形式
从循环到矩阵。回顾SSM的定义是通过(2)式定义的参数化映射。我们的理论框架始于将此转换写成一个矩阵乘法,将向量$x \in R^T$映射到$y \in R^T$。通过对$h_0 = B_0 x_0$进行归纳,可以推导出:
将上式乘以$C_t$得到$y_t$,并对$t \in [T]$进行向量化,我们得到SSM的矩阵转换形式:
3.2 半可分矩阵
半可分矩阵的定义。公式(3)中的M是一种称为半可分矩阵的特定表示。
* 定义3.1:一个(下三角)矩阵M若其下三角部分(即对角线或对角线以下)的任何子矩阵的秩最多为N,则称其为N-半可分矩阵。我们称N为半可分矩阵的阶或秩。
半可分矩阵有多种结构化表示,包括层次半可分(HSS)、序列半可分(SSS)和Bruhat形式。我们将主要使用SSS形式。
3.2.1 序列半可分(SSS)表示
SSS表示的定义。
* 定义3.2:一个下三角矩阵$M \in R^{(T,T)}$如果可以写成以下形式,则称其具有N-序列半可分(SSS)表示:
其中向量$B_0, \dots, B_{T-1}, C_0, \dots, C_{T-1} \in R^N$,矩阵$A_0, \dots, A_{T-1} \in R^{(N,N)}$。我们定义算子SSS使得$M = SSS(B{0:T}, C_{0:T}, A_{0:T})$。半可分矩阵的一个基本结果是它们与具有SSS表示的矩阵完全等价。
SSS矩阵即N-半可分矩阵。一个方向的证明是建设性的。
* 引理3.3:一个具有表示(4)的N-SSS矩阵M是N-半可分的。
* 证明:考虑任何非对角块$M_{i':i,j:j'}$,其中$i' > i \ge j > j'$。它有一个明确的秩-N分解:
这个公式(5)将在推导我们序列模型的快速算法中被广泛使用。另一个方向在半可分矩阵的文献中已经确立。
* 命题3.4:每个N-半可分矩阵都有一个N-SSS表示。
尽管定义3.2涉及$O(N^2 T)$个参数,但实际上可以压缩到$O(NT)$个参数。因此,本文将结构化矩阵类(定义3.1)与其特定表示(定义3.2)等同看待,并使用N-SS来指代SSS形式的N-半可分矩阵。
3.2.2 1-半可分矩阵:标量SSM循环
1-SS矩阵的特例。我们特别关注1-SS矩阵的情况。在这种情况下,$C_j$和$B_i$是标量,可以从SSS表示(4)中分解出来:
忽略对角矩阵项后,1-SS矩阵的基本表示为$M_{ji} = a_{j:i}$ 或:
1-SS矩阵的重要性在于其与标量循环的最小形式等价——即状态维度N=1且没有$(B, C)$投影的退化SSM。矩阵乘法$y=Mx$可以通过以下循环计算:
因此,我们也称1-SS矩阵乘法为标量SSM循环或cumprodsum(累积乘积和)算子。它是我们主要算法的基础构建块。许多序列模型算法都可以归结为结构化矩阵乘法算法,1-SS矩阵是这一联系的例证。
3.3 状态空间模型是半可分矩阵
SSM与半可分矩阵的等价性。公式(3)直接建立了状态空间模型和序列半可分表示之间的联系,后者又等价于一般的半可分矩阵(引理3.3和命题3.4)。
* 定理3.5:状态空间模型转换$y = SSM(A, B, C)(x)$(状态大小为N)与通过一个序列半可分表示的N-SS矩阵的矩阵乘法$y = SSS(C, B, A) \cdot x$是相同的。
换句话说,序列转换算子SSM(定义2.2)与矩阵构建算子SSS(定义3.2)是重合的。由于缩写上的巧合,SSM(状态空间模型/半可分矩阵)、SSS(结构化状态空间/序列半可分)或SS(状态空间/半可分)可以互换使用。
3.4 通过结构化矩阵算法计算状态空间模型
利用结构化矩阵算法进行高效计算。定理3.5的重要性在于,它允许我们将高效计算SSM的问题归结为高效的结构化矩阵乘法算法。半可分矩阵是经典的结构化矩阵:(i) 它们有压缩表示,如SSS形式,只有$O(T)$而不是$O(T^2)$个参数;(ii) 它们有直接操作于压缩表示的快速算法。
* 命题3.6:一个大小为T的N-SS矩阵可以用$O(NT)$个参数表示,并且矩阵-向量乘法的时间和空间复杂度为$O(NT)$。
3.4.1 线性(循环)模式
对角结构SSM的计算。命题3.6对于对角结构化SSM(S4D【43,On the Parameterization and Initialization of Diagonal State Space Models,2022,NeurIPS】)很容易看出,只需展开状态空间模型公式(2)的循环。正式的张量收缩算法如下:
这里,$L \in R^{(T,T)}$被定义为$1SS(A_n)$。该算法对应于(2)的三个步骤:(i) 通过输入矩阵$B$扩展输入$x$,(ii) 展开独立的标量SSM循环,(iii) 通过输出矩阵$C$收缩隐藏状态$h$。我们注意到(8)是Mamba(S6)模型的一个特例,但朴素实现很慢,因为需要物化大小为(T, P, N)的张量$x'$和$h$;Gu和Dao【40,Mamba: Linear-Time Sequence Modeling with Selective State Spaces,2023,arXiv】引入了硬件感知的实现来避免物化这些张量。
一般SSM的计算。令人惊讶的是,定理3.5和命题3.6立即意味着所有SSM都具有与算法(8)相同的渐近效率。
* 定理3.7:任何状态大小为N、序列长度为T的状态空间模型(定义2.2)都可以在$O(TN)$时间内计算(不考虑可能的预处理)。
这个结果对于结构化SSM文献来说是新的。特别是,对于密集的非结构化$A_t$矩阵,总表示似乎是$O(TN^2)$。因此,定理3.7陈述了一个非平凡的结果:即使是一个非结构化SSM,在经过预处理步骤后,也可以以最优效率计算,其上界与$y$和$x$的大小给出的下界$O(TN)$相匹配。在实践中,高效计算的SSM仍然需要在$A_t$上施加额外的结构,特别是为了避免昂贵的预处理步骤。这些结构是过去关于结构化SSM工作(如S4(D)和Mamba)以及我们新算法的焦点。
3.4.2 二次(朴素)模式
矩阵视角的朴素计算。我们的新矩阵视角揭示了另一种计算SSM的方法。对SSM矩阵表示(3)的朴素计算是直接物化序列转换矩阵$M = SSS(C, B, A)$。这是一个(T, T)矩阵,因此这个朴素算法的复杂度将是序列长度的二次方。然而,当序列长度T较短时,由于常数因子和计算模式的硬件友好性(例如,利用矩阵-矩阵乘法),这实际上可能比线性算法更高效。对于特定情况的结构化SSM,这看起来非常类似于二次注意力的计算(第5节)。
3.4.3 总结
SSM的矩阵变换视角。许多序列模型,特别是Transformer,被明确定义为矩阵序列转换。而RNN和SSM以前没有这样描述。通过提供状态空间模型的显式矩阵转换形式,我们揭示了理解和使用它们的新方法。从计算角度看,计算状态空间模型前向传播的任何方法都可以被视为对半可分矩阵的矩阵乘法算法。半可分矩阵的视角提供了状态空间对偶性(SSD)的一个透镜,其中对偶模式分别指线性时间的半可分矩阵乘法算法和二次时间的朴素矩阵乘法。
4 结构化掩码注意力:用结构化矩阵推广线性注意力
本节从第一性原理重新审视线性注意力框架。主要成果是通过张量收缩的简单证明(命题4.1),以及我们对结构化掩码注意力的推广抽象(定义4.2)。本节可以独立于第3节阅读。
4.1 注意力框架
4.1.1 注意力
通用注意力形式。单头注意力的基本形式是三个向量序列$(Q, K, V)$到$O$的映射。
$Q = \text{input (T, N)}$
$K = \text{input (S, N)}$
$V = \text{input (S, P)}$
$S = QK^T \text{ (T, S)}$
$A = f(S) \text{ (T, S)}$
$O = AV \text{ (T, P)}$
其中S和T代表源和目标序列长度,N代表特征维度,P代表头维度。最常见的softmax注意力使用$f=softmax$来归一化S矩阵的行。
4.1.2 自注意力
自注意力的特殊情况。我们的处理主要针对最重要的自注意力情况,其中(i) 源和目标序列相同(S = T),(ii) 特征和头维度相同(N = P),(iii) Q, K, V由对同一输入向量的线性投影生成。我们的表述抽象了这些选择,从Q, K, V矩阵开始。
4.1.3 核注意力
核注意力的分解。softmax应用于格拉姆矩阵S的步骤可以分解为两部分:1. 对S矩阵进行指数化。2. 在S轴上对S矩阵进行归一化。我们可以暂时忽略归一化项。指数化项可以看作是一个核转换:存在一个(无限维)特征映射$\phi$,使得$\exp(QK^T) = \phi(Q)\phi(K)^T$。通过将特征映射抽象到Q和K本身的定义中,我们可以忽略softmax转换,并假设Q, K是由核特征映射任意生成的,且可能N≠P。许多核注意力的实例化已被提出,包括原始的线性注意力【52,Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention,2020,ICML】、随机特征注意力(RFA)【73,Random Feature Attention,2021,ICLR】、Performer【19,Rethinking Attention with Performers,2021,ICLR】和cosFormer【83,CosFormer: Rethinking Softmax in Attention,2022,ICLR】等。
4.1.4 掩码(核)注意力
掩码注意力的定义。设M是一个形状为(T, S)的掩码。在自回归自注意力中,M通常是一个下三角全1矩阵,代表因果掩码。掩码注意力通常用矩阵表示为:
更精确地,用形状注释并分解为计算序列:
我们注意到这个公式可以写成一个单一的收缩:
而(11)中的算法可以被重构为按特定顺序计算(12)的成对收缩:
4.2 线性注意力
线性注意力的动机。线性注意力的动机通常是通过改变核心注意力计算$(QK^T)V = Q(K^T V)$中矩阵结合的顺序。当加入掩码时,推导不那么直接。线性注意力方法声称以下公式等价于(10):
* 命题4.1:自回归核注意力,即带有因果掩码的掩码核注意力,可以通过一个每步耗时恒定的循环在$O(T)$时间内计算。
4.2.1 线性注意力的张量收缩证明
简洁证明。我们提出了一个简单而严谨的线性注意力推导,它也立即揭示了如何推广它。主要思想是按另一种顺序执行收缩(12):
直观地,我们将这个收缩顺序解释如下:第一步(15a)执行特征扩展;第三步(15c)收缩扩展的特征维度。如果将V视为主输入,那么K和Q分别执行扩展和收缩。第二步最关键,解释了线性注意力的“线性”部分。当掩码M是标准的因果注意力掩码(下三角全1)时,与M的矩阵-向量乘法等同于一个特征维度的累积和:
4.3 结构化掩码注意力
从线性注意力到SMA。从掩码注意力的张量收缩视角(15)可以看出,原始线性注意力的关键在于与因果掩码的矩阵-向量乘法等价于累积和算子。然而,注意力掩码不必是全1。只要M是一个结构化矩阵(即具有快速矩阵乘法的矩阵),线性注意力就可以快速计算。我们可以使用任何具有亚二次方(理想情况下是线性)矩阵-向量乘法的掩码矩阵M,这将通过加速瓶颈方程(15b)而具有与标准线性注意力相同的复杂度。
* 定义4.2(结构化掩码注意力 SMA):SMA被定义为在查询/键/值Q, K, V以及任何结构化矩阵M(即具有亚二次方矩阵乘法)上的函数,通过四路张量收缩实现:
* SMA二次模式算法是由(13)定义的成对收缩序列,对应于标准的(掩码)注意力计算。
* SMA线性模式算法是由(15)定义的成对收缩序列,其中步骤(15b)通过亚二次方的结构化矩阵乘法进行优化。
4.3.1 总结:掩码注意力的对偶形式
函数与算法的分离。我们将掩码注意力视为一个特定的函数(12),而标准的二次注意力计算(13)和线性注意力(15)是计算该函数的两种不同算法。这两种算法对应于两种不同的张量收缩顺序,导致了二次与线性的计算复杂度差异。就像状态空间模型可以通过多种方式计算,具有对偶的二次与线性形式一样,线性注意力也存在由两种收缩顺序产生的类似对偶性。这些实际上是同一底层对偶性的不同视角,我们将在第5节中明确这一点。
5 状态空间对偶性
本节将前两节定义的结构化状态空间模型和结构化注意力联系起来。我们的主要结果是证明了结构化状态空间模型的一个特例与结构化注意力模型的一个特例重合,并且线性时间SSM算法和二次时间核注意力算法是彼此的对偶形式。
5.1 标量-单位结构的状态空间模型
特殊SSM的二次形式。在第3节中,我们展示了SSM等价于半可分矩阵转换。考虑$A_t$是标量的特殊情况,即$A_t = a_t I$。那么SSS表示$M_{ij} = C_i^T (\prod_{k=j+1:i} A_k) B_j$可以重排为:
这可以向量化为:
其中$C, B \in R^{(T,N)}$。使用这个公式,完整的输出$y=Mx$被精确地计算为:
其中S=T。这与原始的掩码核注意力定义(13)完全相同!因此,通过物化半可分矩阵M并执行二次矩阵-向量乘法来朴素地计算标量结构化SSM,与二次掩码核注意力完全相同。
5.2 1-半可分结构化掩码注意力
特殊SMA的线性形式。当SMA中的结构化掩码M是因果掩码时,它就是标准的线性注意力。因果掩码是$M=SS(1_T)$,即由$a_t=1$生成的1-SS掩码。这启发我们将M推广到1-半可分掩码类,即1-半可分结构化掩码注意力(1-SS SMA)。在这种情况下,线性注意力循环中的cumsum被更通用的循环——标量SSM扫描(即1-半可分矩阵乘法)所取代。
计算1-SS SMA的线性形式是对角状态空间模型的一个特例。SMA的线性形式是算法(15),其瓶颈步骤(15b)可以看作是与1-SS掩码的矩阵乘法。在第3节中,我们为对角SSM写出了计算公式(8),其瓶颈步骤(8b)是一个标量SSM循环,等价于1-SS乘法。唯一的区别是(8b)在$x'$中有一个额外的N维,因为矩阵A是大小为N的对角矩阵。如果A的所有对角项都相同,这个N维就会消失,从而得到推论5.1。
- 推论5.1:1-SS SMA(使用1-半可分结构化矩阵M的掩码注意力)(15)是对角SSM (8)的一个特例,其中对角矩阵是单位矩阵的标量倍。
高效自回归注意力的特征。我们还可以证明一个逆向结果,刻画了哪些SMA实例具有高效的自回归能力。
* 定理5.2:对于任何具有有界阶自回归过程的结构化掩码注意力实例化(定义4.2),其结构化掩码M必须是一个半可分矩阵。
换句话说,高效的自回归注意力就是通用的半可分SMA。定理5.2在附录C.2中证明。总之,1-半可分结构化注意力是SMA最重要的特例,因为它自然推广了线性注意力,是高效自回归注意力的最简单形式,并且是对角状态空间模型的一个特例。
5.3 结构化状态空间对偶性 (SSD)
总结。
* 结构化状态空间模型(第3节):通常通过线性时间循环定义,但通过展开其线性序列到序列转换的矩阵公式,可以导出一个二次形式。
* 注意力变体(第4节):通过二次时间成对交互定义,但通过将其视为四路张量收缩并以不同顺序约简,可以导出一个线性形式。
* 对偶性:两者各有一个自然的特例——具有标量-单位结构$A_t$矩阵的SSM,和具有1-半可分结构$M$掩码的SMA——它们互为对偶,具有完全相同的线性和二次形式。
6 SSD模型的高效硬件算法
本节展示了如何从结构化矩阵乘法的各种算法中推导出计算SSD模型的各种高效算法。我们提出的算法结合了线性(循环)模式和二次(注意力)模式,在计算上与SSM一样高效(序列长度线性扩展),在硬件上与注意力一样友好(主要使用矩阵乘法)。
- 定理6.1:考虑一个状态扩展因子为N、头维度P=N的SSD模型。存在一种算法,在任何输入$x \in R^{(T,P)}$上计算该模型,其训练FLOPs仅为$O(TN^2)$,推理FLOPs为$O(TN)$,推理内存为$O(N^2)$,并且其工作主要由矩阵乘法主导。
公式
这些界限都是紧的。定理6.1背后的主要思想是再次将计算状态空间模型的问题视为半可分矩阵乘法,但以一种新的方式利用其结构。我们对矩阵进行块分解,而不是以循环或注意力模式计算整个矩阵。对角块可以使用对偶注意力模式计算,这可以通过矩阵乘法高效完成;而非对角块可以利用半可分矩阵的秩结构进行分解,并简化为更小的循环。代码清单1提供了一个自包含的SSD算法实现。
图5:(SSD算法。)通过使用状态空间模型的矩阵转换视角将其写成半可分矩阵(第3节),我们通过块分解矩阵乘法算法开发了一种更硬件高效的SSD模型计算方法。该矩阵乘法也可以解释为状态空间模型,其中块代表对输入和输出序列的分块。对角块表示块内计算,非对角块表示块间计算,通过SSM的隐藏状态进行分解。
块分解。我们将矩阵M划分为一个$T \times T$大小的子矩阵网格,每个子矩阵大小为$Q \times Q$。非对角块根据半可分矩阵的定义是低秩的。
一个“块”的输出$y_{kQ:(k+1)Q}$可以分为两部分:块内输入$x_{kQ:(k+1)Q}$的影响,和块前输入$x_{0:kQ}$的影响。
6.1 对角块
块内计算。对角块是自相似的、规模更小的问题。第k个块代表计算范围$t=kQ:(k+1)Q$的$SSM(A_t, B_t, C_t)(x_t)$。关键是这个块可以用任何期望的方法计算。特别是,对于小的块长度Q,使用对偶二次SMA形式计算更高效。此外,这些块可以并行计算。这可以解释为:假设块的初始状态为0,计算每个块的输出。
6.2 低秩块
块间计算。低秩分解由3个项组成,计算也相应地分为三部分。我们将低秩因子分别称为右因子(B-块因子)、中心因子(A-块因子)和左因子(C-块因子)。
* 右因子:此步骤计算与右B-块因子的乘法。对于每个块,这是一个(N, Q)乘以(Q, P)的矩阵乘法,结果是每个块的一个(N, P)张量,与扩展的隐藏状态h具有相同的维度。这可以解释为:假设块的初始状态为0,计算每个块的最终状态。
* 中心因子:此步骤计算中心A-块因子的影响。上一步中每个块的最终状态总形状为(T/Q, N, P),现在与由$\times_{k=2Q-1:Q-1}A_k, \dots$生成的1-SS矩阵相乘。此步骤可以通过任何计算1-SS乘法(即标量SSM扫描或cumprodsum算子)的算法来完成。这可以解释为:考虑所有先前输入,计算每个块的实际最终状态,即计算真正的隐藏状态$h_{kQ}$。
* 左因子:此步骤计算与左C-块因子的乘法。对于每个块,这可以表示为矩阵乘法contract(QN, NP → QP)。这可以解释为:考虑正确的初始状态$h_{kQ-1}$,并假设块内输入为0,计算每个块的输出。
6.3 计算成本
成本分析。我们用BMM(B, M, N, K)表示批处理矩阵乘法。若设N=P=Q,则所有BMM项都变为BMM(T/N, N, N, N)。计算特性为:
* 总FLOPs为$O(TN^2)$。
* 总内存为$O(TN)$。
* 工作主要由(N, N)形状矩阵的矩阵乘法构成。
内存消耗是紧的。FLOPs计数反映了自回归状态大小带来的额外N倍成本。除了矩阵乘法,还有一个在$NP=N^2$个特征上,序列长度为T/Q的标量SSM扫描,其成本$O(T/QN^2)$可以忽略不计。与纯SSM相比,SSD利用矩阵乘法单元,效率更高。与二次注意力相比,SSD的状态大小可控(N),而注意力状态大小与序列长度T成比例,因此SSD在长序列上更快。
代码清单1 状态空间对偶(SSD)模型的完整PyTorch示例。
def segsum(x):
"""朴素的段求和计算。
exp(segsum(A)) 产生一个1-SS矩阵,等价于一个标量SSM。"""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd(X, A, B, C, block_len=64, initial_states=None):
"""
参数:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
返回:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# 重排为块/分块
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. 计算每个块内(对角块)的输出
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. 计算每个块内的状态
# (非对角块低秩分解的右项;B项)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. 计算块间SSM循环;在块边界产生正确的SSM状态
# (非对角块分解的中间项;A项)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. 计算每个块的状态 -> 输出转换
# (非对角块低秩分解的左项;C项)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
# 将块内和块间项的输出相加(对角块和非对角块)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
7 Mamba-2 架构
本节我们讨论一些使用Transformer思想来理解和修改SSD层的例子,最终形成了Mamba-2架构。
7.1 块设计
Mamba-2块的修改。
* 并行参数投影:在Mamba-1中,SSM参数A, B, C是SSM输入x的函数。在Mamba-2中,我们将SSD层视为从x, A, B, C到y的映射。因此,在块的开始处通过单个投影并行生成x, A, B, C更有意义,这类似于标准注意力架构中并行创建Q, K, V投影。这种并行投影减少了参数,并且通过使用标准的Megatron分片模式,更有利于大型模型的张量并行。
* 额外归一化:为了解决大型模型中可能出现的不稳定性,我们在块的最终输出投影之前添加了一个额外的归一化层(如LayerNorm)。这种用法与NormFormer架构【91,NormFormer: Improved Transformer Pretraining with Extra Normalization,2021,arXiv】最直接相关。这也类似于TransNormerLLM【82,TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer,2023,arXiv】和RetNet【95,Retentive network: A successor to transformer for large language models,2023,arXiv】在不稳定的线性注意力层之后添加LayerNorm或GroupNorm。
7.2 序列转换的多头模式
多头模式定义。多头序列转换由H个独立的头组成,总模型维度为D = d_model。
* 多头SSM (MHS) / 多头注意力 (MHA) 模式:经典的MHA模式假设头维度P能整除模型维度D,头的数量为H=D/P。通过创建每个参数的H个独立副本,创建H个核心序列转换的副本。
* 多收缩SSM (MCS) / 多查询注意力 (MQA) 模式:MQA【90,Fast Transformer Decoding: One Write-head is All You Need,2019,arXiv】是一种优化注意力的技术,通过在所有V头之间共享一个(K, Q)头来提高自回归推理速度。利用状态空间对偶性,我们可以定义等价的SSM版本,其中B和A(SSM中对应于注意力的K和Q)在H个头之间共享。
* 多输入SSM (MIS) / 多值注意力 (MVA) 模式:对于SSM,将x视为主输入更自然,因此A和B参数在输入通道间共享。我们定义了一种新的MVA或MIS模式,这可以应用于任何序列转换,如SSD。
Mamba架构的特征。
* 命题7.2:Mamba架构的S6层可以被视为具有:(1) 头维度P=1:每个通道都有独立的SSM动态$A_t$。(2) MIS或MVA头结构:B, C矩阵(对应注意力的K, Q)在输入x(对应注意力的V)的所有通道间共享。
实验发现,MVA模式如Mamba原始使用的一样表现最好。
分组头模式。MQA可以扩展到分组查询注意力(GQA)【1,GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,2023,arXiv】。类似地,Mamba-2中使用的MIS头模式可以轻松扩展到分组输入SSM(GIS),或同义的分组值注意力(GVA)。
7.3 从线性注意力到SSD的其他扩展
核注意力对softmax注意力的近似。许多线性/核注意力变体将softmax看作由指数核和归一化组成。
* 指数核特征映射:在Mamba-2中,我们为B和C分支(对应注意力的K和Q)加入了一个灵活的核特征映射$f$。默认我们选择$f$为逐元素的Swish/SiLU函数。
* 引入归一化(分母)项:为了模拟softmax的归一化,我们需要计算$A1$。这可以通过将V增广一个全1列来实现。
实验表明,这些从线性注意力文献中借鉴的想法并没有显著提高性能。
8 SSM的系统优化
本节描述了针对SSM(特别是Mamba-2)的几种系统优化,用于大规模高效训练和推理。
8.1 张量并行
TP在Mamba中的挑战。张量并行(TP)【92,Megatron LM: Training Multi-Billion Parameter Language Models Using Model Parallelism,2019,arXiv】是分割模型层到多个加速器上的技术。在Mamba架构中,由于$\Delta, B, C$是$x_z$的函数,直接应用TP会在计算这些参数前需要一次额外的all-reduce,使得每个块需要两次all-reduce,通信开销翻倍。
TP在Mamba-2中的优化。Mamba-2的设计目标是每个块只有一个all-reduce,类似于Transformer中的注意力或MLP块。我们让$\Delta, B, C$的投影直接来自$x$而不是$x_z$,允许我们分割这些投影矩阵。这意味着不同GPU上有不同的$\Delta, B, C$组。此外,我们在每个块内使用GroupNorm,其组数可被TP度整除,这样TP组内的GPU就不需要在块内通信。这样,我们只需在块的末尾进行一次all-reduce,设计与TP对注意力和MLP层的设计类似。
8.2 序列并行
处理长序列。对于非常长的序列,我们需要沿序列长度维度分割输入和激活。
1. 残差和归一化操作的序列并行(SP):Korthikanti等人【55,Reducing activation recomputation in large transformer models,2023,Proceedings of Machine Learning and Systems 5】提出的技术,将TP中的all-reduce分解为reduce-scatter和all-gather。Mamba-2架构使用相同的残差和归一化结构,因此SP可以直接应用。
2. 令牌混合操作的序列并行(上下文并行 CP):对于注意力层,这需要复杂的负载均衡。对于SSM,我们可以简单地分割序列:每个工作节点接收一个初始状态,计算其输入的SSM,返回最终状态,并将其传递给下一个工作节点。通信带宽与工作节点数量呈线性关系。这种分解与SSD算法中的块分解完全相同。
8.3 变长序列
高效处理变长输入。在微调或推理中,模型可能需要处理不同长度的输入序列。朴素的方法是右填充到最大长度,但这效率低下。对于SSM和Mamba,我们可以通过将整个批次视为一个长序列来处理变长序列,并避免在单个序列之间传递状态。这等价于在一个序列末尾的令牌t处简单地设置$A_t=0$,以防止其将信息传递给属于不同序列的下一个令牌t+1。
A4 实验
实验环境
- 数据集:
- 语言模型预训练:The Pile [35],使用GPT-2或GPT-NeoX [12]分词器,训练数据量高达300B tokens。
- 合成任务:Multi-Query Associative Recall (MQAR) [6],词汇量8192,序列长度256到1024。
- 模型架构:
- 遵循GPT-3 [18]的尺寸规格,从125M到2.7B参数。
- Mamba-2默认配置:状态扩展因子N=64,头维度P=64,并行参数投影,额外归一化层,MVA头结构。
- 硬件配置:
- A100 80GB PCIe GPU用于速度基准测试。
- 软件配置:
- PyTorch实现。
- 速度基准测试中与Mamba的融合扫描实现和FlashAttention-2 [23]进行比较。
- 下游评估使用EleutherAI的LM evaluation harness [36]。
实验结果
9.1 合成任务:关联召回
MQAR任务。在挑战性的多查询关联召回(MQAR)任务中,Mamba-2表现出色,而Mamba-1则表现不佳。即使在状态大小受控(N=16)的情况下,Mamba-2也显著优于Mamba-1。该任务还验证了状态大小的重要性:将N从16增加到64和256,性能持续提升,因为更大的状态允许记忆更多的键值对。
9.2 语言模型
缩放定律。在Chinchilla缩放协议下,从约125M到1.3B参数的模型训练中,Mamba-2在困惑度、理论FLOPs和实际训练时间上都优于Mamba和强大的"Transformer++"基线。
下游评估。在零样本下游评估任务中,Mamba-2在每个模型尺寸上都优于Mamba,并且通常能与两倍模型尺寸的Pythia模型相媲美。例如,2.7B参数的Mamba-2在多个任务上优于2.8B的Pythia,甚至接近6.9B的Pythia。
混合模型。实验发现SSD层和注意力层是互补的。在350M模型(48层)中,混合约10%的注意力层可以获得最佳性能(表2)。在2.7B规模下,包含6个注意力层的Mamba-2-Attention模型在困惑度和下游任务上均优于纯Mamba-2和Transformer++模型(表3)。加入MLP层(Mamba-2-MLP)会降低模型质量,但可以加速训练和推理。
9.3 速度基准测试
SSD算法效率。由于利用了GPU上的张量核心,SSD算法比Mamba的融合扫描快2-8倍。随着状态扩展因子N的增加,Mamba扫描实现的速度呈线性下降,而SSD的减速幅度很小,可以处理更大的状态扩展因子。与FlashAttention-2相比,SSD在序列长度2k以上时更快。
9.4 架构消融实验
块设计。消融实验(表4)验证了Mamba-2块的设计选择。并行投影创建(x, A, B, C)不仅节省了参数,性能也略好于Mamba的顺序投影,并且更利于张量并行。额外的归一化层也略微提升了性能并有助于训练稳定性。
头结构。对多头结构的消融实验(表5)显示,多值(MVA)模式显著优于多查询(MQA)或多键(MKA)模式,尽管它们的总状态大小相同。标准多头模式的性能介于MVA和MQA/MKA之间。
注意力核近似。对多种线性注意力核近似方法(如cosFormer, RFA, Performer, (Re)Based)的消融实验(表6、7)表明,这些方法并未优于简单的逐点非线性激活函数(如Swish)。因此,Mamba-2的默认设置沿用了Mamba-1的Swish激活。
A7 补充细节:相关工作与讨论
10.1 状态空间模型
SSM的分类。结构化SSM可沿时间不变性、维度和$A_t$结构等轴进行分类。SSD可被描述为一个具有SISO维度和标量-单位结构的选择性SSM。与Mamba相比,SSD仅在$A_t$的对角形式上做了更严格的限制,用这种表达能力的牺牲换取了硬件效率和实现的简易性。一个重要的特点是状态扩展,即使用大的状态维度N,这对于信息密集领域(如语言)至关重要。SSD的主要优势之一是允许更大的状态扩展因子而不减慢模型速度。
10.2 结构化矩阵
矩阵混合器视角。SSD框架的核心成果之一是将SSM视为具有特定结构——半可分矩阵——的矩阵混合器。线性和二次对偶性则表现为结构化矩阵乘法与朴素矩阵乘法。这种视角不仅催生了高效的SSD算法,也为序列模型的研究开辟了新方向,例如设计有原则的非因果Mamba变体,或通过分析矩阵转换结构来弥合softmax注意力与亚二次模型之间的差距。
10.3 (线性) 注意力
与标准注意力的比较。与标准因果注意力相比,SSD有两个主要区别:
1. 无Softmax:SSD不使用softmax,而是通过线性注意力框架实现线性扩展。
2. 1-半可分掩码:SSD将logits矩阵乘以一个输入依赖的1-半可分掩码。这个掩码可以看作是取代了softmax,并提供了一种更具原则性的相对位置编码形式,其累积乘积控制着不同位置间的信息交互量。
结构化掩码注意力(SMA)框架。SSD是更通用的SMA框架的一个特例,该框架揭示了对偶性源于简单的四路张量收缩的不同收缩顺序。SMA比SSD更通用,其他形式的结构化掩码可能导致具有不同属性的新型高效注意力变体。
10.4 相关模型
近期涌现了大量与Mamba和Mamba-2相似的序列模型:
* RetNet [95] 和 TransNormerLLM [82] 使用时间不变的衰减项,可视为SSD中$A_t$为常数的特例。
* GateLoop [53]、Gated Linear Attention (GLA) [108] 和 HGRN(2) [85, 84] 都引入了数据依赖的门控或衰减因子,与SSD的思想相似。
* Griffin [29] 和 RecurrentGemma [14] 结合了输入依赖门控的RNN和局部注意力,表现出与强Transformer相当的竞争力。
* xLSTM [9]、RWKV-5/6 [72] 等模型也采纳了状态扩展和选择性(数据依赖门控)的思想。
A5 结论
本文提出了一个基于成熟的结构化矩阵理论的框架,弥合了SSM和注意力变体之间的概念鸿沟。该框架不仅解释了为何像Mamba这样的近期SSM在语言建模上能与Transformer媲美,还提供了新的思路,通过连接双方的算法和系统进展来改进SSM(甚至Transformer)。作为示范,该框架指导我们设计了一个位于SSM和结构化注意力交叉点的新架构——Mamba-2。
A6 附录
A 术语表
B 标量SSM扫描(1-SS乘法)的高效算法
问题定义。标量SSM扫描,或称cumprodsum(累积乘积和),定义为计算$h_t = a_t h_{t-1} + b_t x_t$的循环。这等价于与一个1-SS矩阵的矩阵-向量乘积。本节通过结构化矩阵分解的视角,阐述了计算此操作的多种算法,支持了“高效序列模型算法可被视为结构化矩阵乘法算法”的核心主题。
算法。
1. 顺序循环:按时间步顺序计算。
2. 并行关联扫描:将循环转化为一个2x2矩阵的累积乘积,由于矩阵乘法是关联的,因此可以使用并行扫描算法计算。
3. 通过结构化矩阵分解的高效算法:
* 扩张模式(Dilated Mode):将1-SS矩阵分解为一系列步长递增的矩阵乘积,类似于扩张卷积。这表明1-SS矩阵是蝶形矩阵的特例。
* 状态传递(分块)模式(State-Passing Mode):将矩阵分解为块,计算一个块的答案并将其最终状态传递给下一个块。这是标准循环模式的推广。
* (并行)块分解模式(Block Decomposition Mode):使用相同的矩阵分解,但并行地递归处理子问题,用计算换取并行性。总工作量为$O(T \log(T))$,深度为$O(\log(T))$。
* 关联扫描模式(Associative Scan Mode):通过一种不同的三阶段矩阵分解,可以推导出与关联扫描算法计算图相同的算法,该算法同时具有$O(T)$的工作量和$O(\log(T))$的深度。
C 理论细节
C.1 SSM的闭包性质
半可分矩阵的性质。半可分矩阵在加法、乘法和求逆等基本操作下是封闭的。例如,一个N-SS矩阵和一个M-SS矩阵的和是(N+M)-SS的,乘积是(N+M)-SS的,逆是(N+1)-SS的。这些性质意味着状态空间模型也具有相应的闭包性质,例如,两个并行SSM的和仍然是一个SSM,两个串行SSM的组合也是一个SSM,其总状态大小是相加的。
C.2 自回归掩码注意力是半可分结构化注意力
证明定理5.2。本节旨在刻画哪些SMA实例具有高效的自回归生成能力(即增量计算下一步输出的时间是常数)。
* 定义C.2 (自回归转换):一个阶为k的自回归转换$y=f(x)$,其每个输出$y_t$只依赖于当前输入和过去的k个输出:$y_t = f(x_t, y_{t-1}, ..., y_{t-k})$。
* 定理C.3 (定理5.2的正式表述):设M是一个阶为k的高效自回归转换矩阵。那么M是一个阶为k+1的状态空间模型。
* 证明思路:将自回归定义(23)重排并向量化,可以表示为$Ly=Dx$,其中L是一个(k+1)-带状下三角矩阵。因此$M=L^{-1}D$。由于(k+1)-带状矩阵是(k+1)-半可分的,根据闭包性质,其逆M至多是(k+2)-半可分的(更紧的界是k+1)。最后,根据定理3.5,这表明M是一个阶为k+1的状态空间模型。
结论:高效的自回归注意力就是半可分SMA。
💬 评论讨论
欢迎在这里分享您的想法和见解!