SIMPLIFIED STATE SPACE LAYERS FOR SEQUENCE MODELING

简化用于序列建模的状态空间层

作者/机构: Jimmy T.H. Smith*, 1, 2, Andrew Warrington*, 2, 3, Scott W. Linderman2, 3
*同等贡献。
1斯坦福大学计算与数学工程研究所。
2斯坦福大学吴蔡神经科学研究所。
3斯坦福大学统计系。
{jsmith14,awarring,scott.linderman}@stanford.edu.


A1 主要贡献

本文旨在解决机器学习中高效建模长序列的挑战性问题,即关键信息可能编码在相隔数千个时间步的观测数据之间。现有方法如高效 Transformer 在极长序列任务上表现不佳。Gu等人(2021a)提出的结构化状态空间序列(S4)层通过结合线性状态空间模型(SSMs)、HiPPO框架和深度学习,在长程序列建模任务上取得了显著进展。

本文在 S4 层的基础上,引入了一个新的状态空间层——S5层。其核心创新点和贡献如下:

  1. 结构简化:从SISO到MIMO

    • S4 层使用一个包含许多独立的单输入单输出(SISO)SSM 的“银行”(bank)。
    • S5 层将其简化为一个多输入多输出(MIMO)SSM。这简化了模型结构,将原本分离的特征通道在SSM内部进行了混合。
  2. 计算范式转变:从卷积到并行扫描

    • S4 层依赖于一种卷积和频域方法来高效处理序列,这需要一个非平凡的卷积核计算过程。
    • S5 层采用了一种高效且被广泛实现的并行扫描(parallel scan),使其能够纯粹在时域中以循环方式运行,同时保持与 S4 相当的计算效率。这种方法也自然地支持变长时间步和不规则采样数据,而这是 S4 的卷积实现难以处理的。

    图1:一个S5层对序列进行离线应用的计算组件。S5层在一个对角化的线性SSM上使用并行扫描来计算SSM输出y1:L ∈ RL×H。一个非线性激活函数被应用于SSM输出以产生层输出。附录B中包含了S4的类似图示。
    图1:一个S5层对序列进行离线应用的计算组件。S5层在一个对角化的线性SSM上使用并行扫描来计算SSM输出y1:L ∈ RL×H。一个非线性激活函数被应用于SSM输出以产生层输出。附录B中包含了S4的类似图示。

  3. 理论联系与初始化继承

    • 本文建立了 S4 和 S5 之间的数学关系,证明了在特定条件下,S5 的输出可以被解释为某个 S4 系统隐状态的投影。
    • 这一联系使得 S5 能够继承对 S4 成功至关重要的 HiPPO 初始化方案。虽然 S4 使用的 HiPPO-LegS 矩阵无法稳定对角化以用于 S5,但本文证明了对 HiPPO 矩阵的对角近似(即 HiPPO-N 矩阵)在 MIMO 设置下是合理的,这与最近的 DSS 和 S4D 层的发现一致。
  4. 卓越的性能

    • S5 层在多个长程序列建模任务上取得了最先进的性能
    • 在长程竞技场(LRA)基准测试中,S5 平均得分达到 87.4%
    • 在最具挑战性的 Path-X 任务上,准确率高达 98.5%

综上所述,S5 层是一个易于实现、在序列长度上具有线性复杂度、能高效处理时变 SSM 和不规则采样观测,并在多个基准测试中取得最先进性能的新型状态空间层。


A3 背景知识

2.1 线性状态空间模型

连续时间线性SSM。作为 S4 和 S5 层的核心组件,一个连续时间线性 SSM 由以下微分方程定义,给定输入信号 $u(t) \in R^U$,潜状态 $x(t) \in R^P$ 和输出信号 $y(t) \in R^M$:

公式1
公式1

该模型由状态矩阵 $A \in R^{P \times P}$、输入矩阵 $B \in R^{P \times U}$、输出矩阵 $C \in R^{M \times P}$ 和直通矩阵 $D \in R^{M \times U}$ 参数化。

离散化。对于一个恒定的步长 $\Delta$,可以使用欧拉、双线性和零阶保持(ZOH)等方法对 SSM 进行离散化,从而定义线性递推关系:

公式2
公式2

其中,离散时间参数都是由离散化方法指定的连续时间参数的函数。更多关于离散化方法的信息可参见【28,A first course in the numerical analysis of differential equations,2009】。

2.2 使用扫描并行化线性状态空间模型

并行扫描。我们使用并行扫描来高效计算离散化线性 SSM 的状态。给定一个二元结合算子 •(即 $(a • b) • c = a • (b • c)$)和一个包含 L 个元素的序列 $[a_1, a_2, ..., a_L]$,扫描操作(有时称为全前缀和)返回序列:

公式3
公式3

线性递推的并行化。计算一个长度为 L 的离散 SSM 的线性递推 $x_k = \bar{A}x_{k-1} + \bar{B}u_k$ 是扫描操作的一个具体例子。根据【6,Prefix sums and their applications,1990】中第1.4节的讨论,假设有 L 个处理器,并行化离散 SSM 中潜状态转移的线性递推可以在 $O(T_{\text{MM}} \log L)$ 的并行时间内完成,其中 $T_{\text{MM}}$ 代表矩阵-矩阵乘法的成本。对于一个通用矩阵 $A \in R^{P \times P}$,$T_{\text{MM}}$ 为 $O(P^3)$,这在深度学习设置中可能成本过高。然而,如果 $A$ 是一个对角矩阵,并行时间将变为 $O(P \log L)$,且只需要 $O(PL)$ 的空间。最后,我们注意到高效的并行扫描是以工作高效的方式实现的,因此使用对角矩阵的并行扫描的总计算成本是 $O(PL)$ 次操作。关于并行扫描的更多信息见附录H。

2.3 S4: 结构化状态空间序列层

S4层结构。S4层(【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】)定义了一个非线性的序列到序列转换,将输入序列 $u_{1:L} \in R^{L \times H}$ 映射到输出序列 $u'_{1:L} \in R^{L \times H}$。一个S4层包含一个由H个独立的单输入单输出(SISO)SSM组成的“银行”,每个SSM的状态维度为N。每个S4 SSM被应用于输入序列的一个维度,这导致从每个输入通道到每个预激活通道都有一个独立的线性变换。然后,一个非线性激活函数被应用于这些预激活值。最后,一个逐位置的线性混合层被用来组合这些独立的特征,并产生输出序列 $u'_{1:L}$。附录中的图4a展示了S4层作为一个独立SSM银行的视图。图2a则展示了S4的另一种视图,即一个具有大小为HN的状态和块对角状态、输入和输出矩阵的大型SSM。

HiPPO初始化与DPLR参数化。每个S4 SSM都利用了HiPPO框架(【18,HiPPO: Recurrent memory with optimal polynomial projections,2020a,NeurIPS】)进行在线函数逼近,通过使用一个HiPPO矩阵(最常用的是HiPPO-LegS矩阵)来初始化状态矩阵。经验证明这能带来强大的性能(【21,Combining recurrent, convolutional, and continuous-time models with linear state space layers,2021b,NeurIPS】;【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】),并且可以被证明是关于一个无限长的指数衰减度量的长程依赖的近似(【23,How to train your HIPPO: State space models with generalized orthogonal basis projections,2023,ICLR】)。虽然HiPPO-LegS矩阵不能被稳定地对角化(【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】),但它可以表示为一个正规加低秩(NPLR)矩阵。其正规部分,被称为HiPPO-N并表示为 $A_{\text{NormalLegS}}$,可以被对角化。因此,HiPPO-LegS可以通过共轭变换成对角加低秩(DPLR)形式,S4利用这种形式来推导出一个高效的卷积核形式。这推动了S4的DPLR参数化。

双重实现模式:循环与卷积。高效地应用S4层需要两种独立的实现方式,具体取决于上下文:一种是循环模式,一种是卷积模式。对于在线生成,SSM像其他RNN一样循环迭代。然而,当整个序列都可用且观测是均匀间隔时,会使用更高效的卷积模式。该模式利用了将线性递推表示为输入和每个SSM的卷积核之间的一维卷积的能力。然后可以使用快速傅里叶变换(FFT)来高效地并行化这一应用。附录中的图4a展示了S4层用于离线处理的卷积方法。我们注意到,虽然原则上并行扫描可以允许在离线场景中使用循环方法,但对所有H个N维SSM应用并行扫描通常会比S4实际使用的卷积方法昂贵得多。

可训练参数。每个S4层的可训练参数是H个独立的、可学习的SSM参数副本和混合层的 $O(H^2)$ 个参数。对于 $h \in \{1, ..., H\}$ 中的每个S4 SSM,给定一个标量输入信号 $u^{(h)}(t) \in R$,一个S4 SSM使用一个输入矩阵 $B^{(h)} \in C^{N \times 1}$,一个DPLR参数化的转移矩阵 $A^{(h)} \in C^{N \times N}$,一个输出矩阵 $C^{(h)} \in C^{1 \times N}$,以及一个直通矩阵 $D^{(h)} \in R^{1 \times 1}$,来产生一个信号 $y^{(h)}(t) \in R$。为了将S4 SSM应用于离散序列,每个连续时间SSM都使用一个恒定的时间尺度参数 $\Delta^{(h)} \in R^+$ 进行离散化。每个SSM的可学习参数是时间尺度参数 $\Delta^{(h)} \in R^+$,连续时间参数 $B^{(h)}$、$C^{(h)}$、$D^{(h)}$,以及由向量 $\Lambda^{(h)} \in C^N$ 和 $p^{(h)}, q^{(h)} \in C^N$ 参数化的DPLR矩阵,分别代表对角矩阵和低秩项。为了符号简洁,我们将离散时间索引k处的S4 SSM状态的拼接表示为 $x_k^{(1:H)} = [(x_k^{(1)})^T, ..., (x_k^{(H)})^T]^T$,并将H个SSM的输出表示为 $y_k = [y_k^{(1)}, ..., y_k^{(H)}]^T$。

图2:离散化的S4层(Gu et al., 2021a)(上)和S5层(下)的内部结构示意图。为简化起见,省略了D。我们将S4层视为一个单一的块对角SSM,其潜状态大小为HN,后跟一个非线性和混合层以混合独立特征。图注:层输入,u1:L 1 k L 可学习参数 操作。(b) 相比之下,S5层使用一个密集的MIMO线性SSM,其潜状态大小P << HN。
图2:离散化的S4层(Gu et al., 2021a)(上)和S5层(下)的内部结构示意图。为简化起见,省略了D。我们将S4层视为一个单一的块对角SSM,其潜状态大小为HN,后跟一个非线性和混合层以混合独立特征。图注:层输入,u1:L 1 k L 可学习参数 操作。(b) 相比之下,S5层使用一个密集的MIMO线性SSM,其潜状态大小P << HN。


A2 方法细节

3.1 S5 结构:从 SISO 到 MIMO

用MIMO SSM替代SISO SSM组。S5层用一个多输入多输出(MIMO)SSM替换了S4中的SISO SSM“银行”(或大型块对角系统),如公式(1)所示,其潜状态大小为P,输入和输出维度为H。这个MIMO SSM的离散化版本可以应用于一个向量值的输入序列 $u_{1:L} \in R^{L \times H}$,以产生一个向量值的SSM输出(或预激活值)序列 $y_{1:L} \in R^{L \times H}$,使用了潜状态 $x_k \in R^P$。然后,一个非线性激活函数被应用于产生层输出序列 $u'_{1:L} \in R^{L \times H}$。图2b展示了这一结构。与S4不同,我们不需要一个额外的逐位置线性层,因为特征已经在MIMO SSM中混合了。我们在此指出,与S4层中块对角SSM的HN潜状态大小相比,S5的潜状态大小P可以显著更小,这使得可以使用高效的并行扫描,我们将在3.3节讨论。

3.2 S5 参数化:对角化的动态

为并行扫描进行对角化。S5层MIMO SSM的参数化是出于使用高效并行扫描的动机。如2.2节所讨论,需要一个对角的状态矩阵才能使用并行扫描高效地计算线性递推。因此,我们对系统进行对角化,将连续时间状态矩阵写为 $A = V\Lambda V^{-1}$,其中 $\Lambda \in C^{P \times P}$ 表示包含特征值的对角矩阵,而 $V \in C^{P \times P}$ 对应于特征向量。因此,我们可以将公式(1)中的连续时间潜状态动态对角化为:

公式4
公式4

定义 $\tilde{x}(t) = V^{-1}x(t)$,$\tilde{B} = V^{-1}B$,和 $\tilde{C} = CV$ 得到一个重参数化的系统:
公式5
公式5

这是一个具有对角状态矩阵的线性SSM。这个对角化系统可以使用一个时间尺度参数 $\Delta \in R^+$ 通过ZOH方法进行离散化,得到另一个对角化系统,其参数为:
公式6
公式6

在实践中,我们使用一个可学习的时间尺度参数向量 $\Delta \in R^P$(见4.3节)并限制直通矩阵D为对角矩阵。因此,S5层的可学习参数为:$\tilde{B} \in C^{P \times H}$,$\tilde{C} \in C^{H \times P}$,$\text{diag}(D) \in R^H$,$\text{diag}(\Lambda) \in C^P$ 和 $\Delta \in R^P$。

初始化。先前的工作表明,深度状态空间模型的性能对状态矩阵的初始化很敏感(【21,Combining recurrent, convolutional, and continuous-time models with linear state space layers,2021b,NeurIPS】;【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】)。我们在2.2节讨论过,状态矩阵必须是对角的才能高效应用并行扫描。我们还在2.3节讨论过,HiPPO-LegS矩阵不能被稳定地对角化,但HiPPO-N矩阵可以。在第4节中,我们将S5的动态与S4联系起来,以说明为什么在MIMO设置中使用类似HiPPO的矩阵进行初始化也可能效果很好。我们通过经验支持这一点,发现对角化HiPPO-N矩阵能带来良好的性能,并在附录E中进行消融实验以与其他初始化方法进行比较。我们注意到DSS(【25,Diagonal state spaces are as effective as structured state spaces,2022,NeurIPS】)和S4D(【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】)层在SISO设置中也通过使用HiPPO-N矩阵的对角化发现了强大的性能。

共轭对称性。一个具有实数项的可对角化矩阵的复特征值总是成对出现。我们通过使用一半数量的特征值和潜状态来强制实现这种共轭对称性。这确保了输出是实数,并将并行扫描的运行时间和内存使用量减少了一半。这个想法在【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】中也有讨论。

3.3 S5 计算:完全循环

使用并行扫描。与块对角S4层的大HN有效潜状态大小相比,S5层的较小潜状态维度(P)允许在整个序列可用时使用高效的并行扫描。因此,S5层可以作为时域中的一个递推,高效地用于在线生成和离线处理。并行扫描和连续时间参数化还允许高效处理不规则采样的时间序列和其他时变SSM,只需在每一步提供一个不同的 $\bar{A}_k$ 矩阵即可。我们利用这一特性,在6.3节中将S5应用于不规则采样的数据。相比之下,S4层的卷积需要一个时不变系统和规则间隔的观测。

3.4 匹配S4和S5的计算效率

计算复杂度匹配。S5的一个关键设计期望是匹配S4在在线生成和离线递推两方面的计算复杂度。以下命题保证了如果S5的潜状态大小 $P = O(H)$,它们的复杂度在同一数量级。

命题1。给定一个具有H个输入/输出特征的S4层,一个具有H个输入/输出特征且潜状态大小 $P=O(H)$ 的S5层,在运行时间和内存使用方面与S4层具有相同的数量级复杂度。

证明。见附录C.1。

我们还在附录C.2中通过经验比较支持了这一命题。

4.1 等价动态的不同输出投影

S4与S5动态关系的简化假设。我们在一些简化假设下比较S4和S5的动态:
1. 假设1:我们只考虑H维到H维的序列映射。
2. 假设2:我们假设每个S4 SSM的状态矩阵是相同的,$A^{(h)} = A \in C^{N \times N}$。
3. 假设3:我们假设每个S4 SSM的时间尺度是相同的,$\Delta^{(h)} = \Delta \in R^+$。
4. 假设4:我们假设S5中使用的状态矩阵A与S4中相同(也参见假设2)。注意,这也指定了S5的潜状态大小P=N。我们还假设S5的输入矩阵是S4使用的列输入向量的水平拼接:$B = [B^{(1)} | ... | B^{(H)}]$。

S5输出为S4潜状态的线性组合。我们稍后将讨论放宽这些假设,但在这些条件下,推导S4和S5动态之间的关系是直接的:

命题2。考虑一个S5层,其状态矩阵为A,输入矩阵为B,以及某个输出矩阵C(参见假设1);以及一个S4层,其中H个S4 SSM中的每一个都具有状态矩阵A(参见假设2、4)和输入向量 $B^{(h)}$(参见假设4)。如果S4和S5层使用相同的时间尺度进行离散化(参见假设3),那么S5 SSM产生的输出 $y_k$ 等价于H个S4 SSM潜状态的线性组合,$y_k = C_{\text{equiv}}x_k^{(1:H)}$,其中 $C_{\text{equiv}} = [C \cdots C]$。

证明。见附录D.2。

关系解读。重要的是,S5 SSM的输出不等于块对角S4 SSM的输出。相反,它们等价于修改了输出矩阵为 $C_{\text{equiv}}$ 的块对角S4 SSM的输出。然而,在这些假设下,底层的状态动态是等价的。回想一下,用HiPPO初始化S4的动态是其性能的关键(【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】),命题2中建立的关系激励了为S5使用HiPPO初始化,我们现在将讨论这一点。

4.2 可对角化的初始化

使用HiPPO-N进行初始化。理想情况下,根据上述解释,我们应该用精确的HiPPO-LegS矩阵来初始化S5。不幸的是,如2.3节所讨论,这个矩阵不能被稳定地对角化,而这对S5使用的高效并行扫描是必需的。然而,【25,Diagonal state spaces are as effective as structured state spaces,2022,NeurIPS】和【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】凭经验表明,移除低秩项并用对角化的HiPPO-N矩阵进行初始化仍然表现良好。【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】为单输入系统使用这种正规近似提供了一个理论依据:在状态维度无限的极限下,具有HiPPO-N状态矩阵的线性ODE产生与具有HiPPO-LegS矩阵的ODE相同的动态。利用线性性,将这个结果扩展到S5使用的多输入系统是直接的:

推论1(【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】中定理3的扩展)。考虑附录B.1.1中定义的 $A_{\text{LegS}} \in R^{N \times N}$,$A_{\text{NormalLegS}} \in R^{N \times N}$,$B_{\text{LegS}} \in R^{N \times H}$,$P_{\text{LegS}} \in R^N$。给定向量值输入 $u(t) \in R^H$,当 $N \rightarrow \infty$ 时,常微分方程 $dx'(t)/dt = A_{\text{NormalLegS}}x'(t) + \frac{1}{2}B_{\text{LegS}}u(t)$ 收敛于 $dx(t)/dt = A_{\text{LegS}}x(t) + B_{\text{LegS}}u(t)$。

证明。我们在附录D.3中包含了这个扩展的一个简单证明。这个扩展激励了使用HiPPO-N来初始化S5的MIMO SSM。注意,S4D(S4的对角扩展)使用相同的HiPPO-N矩阵。因此,在命题2的假设下,一个S5 SSM实际上产生的输出等价于S4D的SSM们产生的潜状态的线性组合。我们在第6节的经验结果表明,用HiPPO-N矩阵初始化的S5与用HiPPO-LegS矩阵初始化的S4表现得一样好。

4.3 放宽假设

放宽假设的策略。我们现在重新审视命题2所需的假设,因为它们只关联了S5的一个受限版本和一个S4的受限版本。关于假设2,【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】报告说,具有绑定状态矩阵的S4模型仍然可以表现良好,尽管允许不同的状态矩阵通常会产生更高的性能。同样,根据假设3,要求所有S4 SSM使用单一的标量时间尺度是限制性的。S4通常为每个SSM学习不同的时间尺度参数(【23,How to train your HIPPO: State space models with generalized orthogonal basis projections,2023,ICLR】)以捕捉数据中的不同时间尺度。为了放宽这些假设,注意假设4将S5的维度限制为 $P=N$,而N通常远小于输入的维度H。命题1确定了当 $P=O(H)$ 时,S5可以匹配S4的复杂度。通过允许更大的潜状态大小,可以放宽假设2和3,如附录D.4中所讨论。我们还讨论了这种放宽如何激励在对角线上使用HiPPO-N矩阵进行块对角初始化。最后,为了进一步放宽绑定的时间尺度假设,我们注意到在实践中,我们发现通过学习P个不同的时间尺度(每个状态一个)可以提高性能。关于这一经验发现的进一步讨论和消融实验见附录D.5和E.1。


A4 实验环境

  • 数据集:

    • 长程竞技场 (LRA): 包括6个任务,序列长度从1024到超过16000,用于测试长程建模能力。具体任务包括:ListOps、Text、Retrieval、Image、Pathfinder、Path-X。
    • Speech Commands: 35类关键词识别任务,数据为16kHz的1秒音频波形(序列长度16000)。
    • Pendulum Regression: 序列长度为50的24x24像素图像,从连续轨迹中不规则采样,任务是回归摆的角度。
    • 像素级1-D图像分类: Sequential MNIST (sMNIST), Permuted Sequential MNIST (psMNIST), Sequential CIFAR (sCIFAR)。
  • 模型架构:

    • 基础架构遵循 (Gu et al., 2021a),由线性编码器、多层S5层、平均池化层和线性解码器组成。
    • 激活函数使用加权S型门控单元 (一种无附加线性变换的GLU)。
    • 层间使用dropout、层归一化或批归一化,以及前置或后置归一化。
    • 具体超参数(层数、特征维度H、潜状态大小P、dropout率等)详见附录表11。模型参数量与S4基线相当。
  • 硬件配置:

    • 运行时基准测试使用 16GB NVIDIA V100 GPU
    • 钟摆回归实验使用 NVIDIA GeForce RTX 2080 Ti
  • 软件配置:

    • 代码实现:JAX。附录A提供了S5层的JAX实现代码。
    • 优化器:AdamW,并为SSM参数($\Lambda, \tilde{B}, \Delta$)设置了较小的学习率和不同的权重衰减策略。
    • 学习率调度:所有实验均使用余弦退火
    • 模型方向性:LRA和语音任务使用双向模型,其他任务为公平比较使用单向(因果)模型

A4 实验结果

6.1 长程竞技场 (LRA)

  • 实验内容: 在LRA基准测试的六个任务上评估S5的性能,序列长度从1,024到超过16,000。
  • 实验结果:
    • 如表1所示,S5在具有线性序列长度复杂度的模型中取得了最高的平均分(87.4%),优于S4、S4D以及同期的Liquid-S4和Mega-chunk。
    • 在所有模型中(包括具有 $O(L^2)$ 复杂度的Mega),S5在Path-X任务上取得了最高分(98.5%),该任务的序列长度在基准测试中最长。
  • 分析结论: S5在长程序列建模任务上达到了最先进的性能,尤其擅长处理极长的序列。

表1:LRA基准任务的测试准确率 (Tay et al., 2021)。✗ 表示模型未超过随机猜测。我们在附录中提供了包含完整引用和误差条的扩展表,表7。我们遵循Gu et al. (2021a; 2022)报告的程序,并报告S4、S4D(由Gu et al. (2021a; 2022)报告)和S5在三个种子上的平均值。粗体分数表示最高性能,下划线分数表示第二名。我们还包括了同期方法Liquid-S4 (Hasani et al., 2023)和Mega (Ma et al., 2023)的结果。与S4方法和S5不同,最好的Mega模型保留了Transformer的O(L2)复杂度。
表1:LRA基准任务的测试准确率 (Tay et al., 2021)。✗ 表示模型未超过随机猜测。我们在附录中提供了包含完整引用和误差条的扩展表,表7。我们遵循Gu et al. (2021a; 2022)报告的程序,并报告S4、S4D(由Gu et al. (2021a; 2022)报告)和S5在三个种子上的平均值。粗体分数表示最高性能,下划线分数表示第二名。我们还包括了同期方法Liquid-S4 (Hasani et al., 2023)和Mega (Ma et al., 2023)的结果。与S4方法和S5不同,最好的Mega模型保留了Transformer的O(L2)复杂度。

6.2 原始语音分类

  • 实验内容: 在Speech Commands数据集上进行35类关键词分类。此外,测试了模型在不同采样率下的零样本泛化能力(从16kHz训练数据到8kHz测试数据)。
  • 实验结果:
    • 如表2所示,S5的性能优于基线模型和之前的S4方法,在16kHz测试集上达到98.56%的准确率。
    • 在零样本迁移到8kHz测试集上,S5同样取得了最佳性能(91.56%),这得益于其连续时间参数化特性。
  • 分析结论: S5在原始音频分类任务上表现出色,并且其连续时间参数化使其能够灵活适应不同的数据采样率,无需重新训练。

表2:在35类Speech Commands分类任务上的测试准确率 (Warden, 2018)。我们在附录中提供了一个带误差条的扩展表,表8。训练样本是1秒的16kHz音频波形。最后一列表示在8kHz下的0-shot测试(通过朴素抽取构建)。与Gu et al. (2022)一样,报告了三个随机种子的平均值。从InceptionNet到S4D-Lin的基线性能是从Gu et al. (2022)报告的。
表2:在35类Speech Commands分类任务上的测试准确率 (Warden, 2018)。我们在附录中提供了一个带误差条的扩展表,表8。训练样本是1秒的16kHz音频波形。最后一列表示在8kHz下的0-shot测试(通过朴素抽取构建)。与Gu et al. (2022)一样,报告了三个随机种子的平均值。从InceptionNet到S4D-Lin的基线性能是从Gu et al. (2022)报告的。

6.3 变动观测间隔

  • 实验内容: 在一个摆锤回归任务中测试S5处理不规则采样时间序列的能力。输入是50张在100个时间单位内不规则采样的图像,任务是回归摆锤的角度。S5通过在每一步向离散化过程提供不同的$\Delta_t$值来处理不规则间隔。
  • 实验结果:
    • 如图3所示,任务涉及从噪声图像序列回归动态系统的状态。
    • 如表3所示,S5在回归任务中取得了比基线模型CRU更低的均方误差(MSE),并且在相同硬件上的运行速度显著快于CRU。
  • 分析结论: S5的并行扫描和连续时间参数化使其能够自然且高效地处理不规则采样的数据,这是S4的卷积实现无法做到的。

图3:摆锤回归示例的图示。顶行显示了在指定时间点用作输入的图像。底行显示了sin(θt)和cos(θt)的值,其中θt是摆锤在时间t的角度,这些值被用作回归目标。
图3:摆锤回归示例的图示。顶行显示了在指定时间点用作输入的图像。底行显示了sin(θt)和cos(θt)的值,其中θt是摆锤在时间t的角度,这些值被用作回归目标。

表3:在摆锤回归任务的保留测试集上的回归MSE ×10−3(均值±标准差)和相对应用速度。从mTAND到CRU的基线性能是从Schirmer et al. (2022)报告的。我们在附录中包含了一个扩展表,表9,以及更多细节。CRU(我们的运行)和S5的结果是基于20个种子。
表3:在摆锤回归任务的保留测试集上的回归MSE ×10−3(均值±标准差)和相对应用速度。从mTAND到CRU的基线性能是从Schirmer et al. (2022)报告的。我们在附录中包含了一个扩展表,表9,以及更多细节。CRU(我们的运行)和S5的结果是基于20个种子。

6.4 像素级1-D图像分类

  • 实验内容: 在sMNIST、psMNIST和sCIFAR等常见的像素级1-D图像分类基准上进行评估。
  • 实验结果: 附录F.4中的表10显示,S5的性能与S4大致相当,并优于一系列先进的基于RNN的方法。
  • 分析结论: S5在这些经典的序列分类任务上也表现出强大的竞争力。

A5 结论

本文介绍了用于长程序列建模的S5层。S5层通过修改S4层的内部结构,将S4使用的频域方法替换为一种利用并行扫描的纯循环、时域方法。S5在保持S4计算效率的同时,取得了优异的性能。

此外,S5还提供了新的机遇。例如,与卷积式的S4方法不同,并行扫描使得高效、简便地处理参数随时变化的时变SSM成为可能。第6.3节展示了其在处理可变采样率序列上的应用。同期发展的方法Liquid-S4【26,Liquid structural state-space models,2023,ICLR】使用输入相关的双线性动力系统,进一步凸显了时变SSM的潜力。

更通用的MIMO SSM设计也将有助于建立与经典概率状态空间建模以及近期关于并行化滤波和平滑操作(【57,Temporal parallelization of Bayesian smoothers,2020】)工作的联系。总的来说,我们希望S5层的简洁性和通用性能够扩大状态空间层在深度序列建模中的应用,并催生新的公式化和扩展。


A6 附录

A JAX S5层实现

import jax
import jax.numpy as np
parallel_scan = jax.lax.associative_scan

def discretize(Lambda, B_tilde, Delta):
    &quot;&quot;&quot; Discretize a diagonalized, continuous-time linear SSM
    Args:
        Lambda (complex64): diagonal state matrix (P,)
        B_tilde (complex64): input matrix (P, H)
        Delta (float32): discretization step sizes (P,)
    Returns:
        discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H)&quot;&quot;&quot;
    Identity = np.ones(Lambda.shape[0])
    Lambda_bar = np.exp(Lambda * Delta)
    B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
    return Lambda_bar, B_bar

def binary_operator(element_i, element_j):
    &quot;&quot;&quot; Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
    Args:
        element_i: tuple containing A_i and Bu_i at position i (P,), (P,)
        element_j: tuple containing A_j and Bu_j at position j (P,), (P,)
    Returns:
        new element ( A_out, Bu_out ) &quot;&quot;&quot;
    A_i, Bu_i = element_i
    A_j, Bu_j = element_j
    return A_j * A_i, A_j * Bu_i + Bu_j

def apply_ssm(Lambda_bar, B_bar, C_tilde, D, input_sequence):
    &quot;&quot;&quot; Compute the LxH output of discretized SSM given an LxH input.
    Args:
        Lambda_bar (complex64): discretized diagonal state matrix (P,)
        B_bar (complex64): discretized input matrix (P, H)
        C_tilde (complex64): output matrix (H, P)
        D (float32): feedthrough matrix (H,)
        input_sequence (float32): input sequence of features (L, H)
    Returns:
        ys (float32): the SSM outputs (S5 layer preactivations) (L, H) &quot;&quot;&quot;
    # Prepare elements required to initialize parallel scan
    Lambda_elements = np.repeat(Lambda_bar[None, ...], input_sequence.shape[0], axis=0)
    Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements) # (L, P), (L, P)

    # Compute latent state sequence given input sequence using parallel scan
    _, xs = parallel_scan(binary_operator, elements) # (L, P)

    # Compute SSM output sequence
    ys = jax.vmap(lambda x, u: (C_tilde @ x + D * u).real)(xs, input_sequence)
    return ys

def apply_S5_layer(params, input_sequence):
    &quot;&quot;&quot; Computes LxH output sequence of an S5 layer given LxH input sequence.
    Args:
        params: tuple of the continuous time SSM parameters
        input_sequence: input sequence of features (L, H)
    Returns:
        The S5 layer output sequence (L, H) &quot;&quot;&quot;
    Lambda, B_tilde, C_tilde, D, log_Delta = params
    Lambda_bar, B_bar = discretize(Lambda, B_tilde, np.exp(log_Delta))
    preactivations = apply_ssm(Lambda_bar, B_bar, C_tilde, D, input_sequence)
    return jax.nn.gelu(preactivations)

def batch_apply_S5_layer(params, input_sequences):
    &quot;&quot;&quot; Computes BxLxH output sequence of an S5 layer given BxLxH input sequence.
    Args:
        params: tuple of the continuous time SSM parameters
        input_sequences: batch of input feature sequences (B, L ,H)
    Returns:
        Batch of S5 layer output sequences (B, L, H)
    return jax.vmap(apply_S5_layer, in_axes=(None, 0))(params, input_sequences)

B S5 层细节

B.1 初始化细节
B.1.1 状态矩阵的初始化

HiPPO-LegS。补充第3.2节的讨论,【23,How to train your HIPPO: State space models with generalized orthogonal basis projections,2023,ICLR】解释了S4在使用HiPPO-LegS矩阵时能够捕捉长程依赖的能力,这是通过相对于一个无限长的指数衰减度量来分解输入实现的。HiPPO-LegS矩阵和相应的SISO输入向量定义如下:

公式7
公式7

公式8
公式8

注意,在第4.2节中,推论1中使用的输入矩阵 $B_{\text{LegS}} \in R^{N \times H}$ 是通过拼接H个 $b_{\text{LegS}} \in R^N$ 的副本形成的。

NPLR/DPLR形式。【20,Efficiently modeling long sequences with structured state spaces,2021a,ICLR】的定理1表明,【18,HiPPO: Recurrent memory with optimal polynomial projections,2020a,NeurIPS】中的HiPPO矩阵 $A_{\text{HiPPO}} \in R^{N \times N}$ 可以表示为正规加低秩(NPLR)形式,包括一个正规矩阵 $A_{\text{NormalHiPPO}} = V\Lambda V^* \in R^{N \times N}$ 和一个低秩项:

公式9
公式9

对于酉矩阵 $V \in C^{N \times N}$、对角矩阵 $\Lambda \in C^{N \times N}$ 和低秩分解 $P, Q \in R^{N \times r}$。方程右侧表明HiPPO矩阵可以共轭变换为对角加低秩(DPLR)形式。因此,HiPPO-LegS矩阵可以表示为正规HiPPO-N矩阵和低秩项 $P_{\text{LegS}} \in R^N$ 的形式(【17,It’s raw! Audio generation with state-space models,2022,ICML】):
公式10
公式10

其中
公式11
公式11

公式12
公式12

S5的初始化策略。我们的默认做法是设置S5层状态矩阵 $A = A_{\text{NormalLegS}} \in R^{P \times P}$,并进行特征分解以初始化 $\Lambda = \text{diag}(\lambda_1, \dots, \lambda_P)$ 和 $V$。然后我们使用 $V$ 来初始化 $\tilde{B}$ 和 $\tilde{C}$,如下所述。

块对角初始化。如4.3节所述,我们还发现在许多任务上,将S5状态矩阵初始化为块对角形式可以提升性能,其中对角线上的每个块等于 $A_{\text{NormalLegS}} \in R^{R \times R}$,这里的R小于状态维度P,例如当使用4个块时 $R = P/4$。然后我们对这个矩阵进行特征分解以初始化 $\Lambda$ 以及 $\tilde{B}$ 和 $\tilde{C}$。我们注意到,即使在这种情况下,$\tilde{B}$ 和 $\tilde{C}$ 仍然被初始化为密集形式,并且没有约束要求A在学习过程中保持块对角。在附录G的超参数表中,J超参数表示用于初始化的对角线上HiPPO-N块的数量,其中J=1表示我们使用默认的单个HiPPO-N矩阵进行初始化。我们在附录D.4中进一步讨论了这种块对角初始化的动机。

B.1.2 输入、输出和直通矩阵的初始化

$\tilde{B}$ 和 $\tilde{C}$ 的初始化。通常,我们使用初始状态矩阵对角化得到的特征向量来显式初始化输入矩阵 $\tilde{B}$ 和输出矩阵 $\tilde{C}$。具体来说,我们采样B和C,然后将(复数)可学习参数 $\tilde{B}$ 初始化为 $\tilde{B} = V^{-1}B$,$\tilde{C}$ 初始化为 $\tilde{C} = CV$。

D的初始化。我们通过从标准正态分布中独立采样每个元素来初始化 $D \in R^H$。

B.1.3 时间尺度的初始化

时间尺度初始化方法。先前的工作(【25,Diagonal state spaces are as effective as structured state spaces,2022,NeurIPS】;【23,How to train your HIPPO: State space models with generalized orthogonal basis projections,2023,ICLR】)发现时间尺度参数的初始化很重要。这在【23】中有详细研究。我们遵循S4的方式对这些参数进行采样,从区间 $[\log \delta_{\min}, \log \delta_{\max})$ 上的均匀分布中采样 $\log \Delta \in R^P$ 的每个元素,默认范围是 $\delta_{\min} = 0.001$ 和 $\delta_{\max} = 0.1$。唯一的例外是Path-X实验,我们从 $\delta_{\min} = 0.0001$ 和 $\delta_{\max} = 0.1$ 进行初始化,以考虑更长的时间尺度,如【23】中所讨论。

B.2 S4和S5计算元素的比较

计算流程对比。在图4中,我们展示了S4和S5层在高效、并行的离线处理中的计算细节比较。

图4:S4层 (Gu et al., 2021a)(上)和S5层(下)对序列进行离线应用的计算组件。(a) S4层对输入序列u1:L ∈ RL×H的每个维度应用一个独立的SSM。这需要一个柯西核计算来计算频域中的卷积核系数。使用FFT计算卷积以产生独立的SSM输出y1:L ∈ RL×H。一个包含混合层的非线性激活函数被应用于SSM输出以产生层输出。(b) (从图1复制) S5层在一个对角化的线性SSM上使用并行扫描来计算SSM输出y1:L ∈ RL×H。一个非线性激活函数被应用于SSM输出以产生层输出。
图4:S4层 (Gu et al., 2021a)(上)和S5层(下)对序列进行离线应用的计算组件。(a) S4层对输入序列u1:L ∈ RL×H的每个维度应用一个独立的SSM。这需要一个柯西核计算来计算频域中的卷积核系数。使用FFT计算卷积以产生独立的SSM输出y1:L ∈ RL×H。一个包含混合层的非线性激活函数被应用于SSM输出以产生层输出。(b) (从图1复制) S5层在一个对角化的线性SSM上使用并行扫描来计算SSM输出y1:L ∈ RL×H。一个非线性激活函数被应用于SSM输出以产生层输出。

C S5的计算效率

C.1 理论计算效率

命题1的证明。我们首先考虑整个序列可用的情况,比较S4层的卷积模式和S5层使用的并行扫描。然后我们考虑在线生成的情况,此时每个方法都以循环方式运行。

并行化离线处理。我们考虑将S4和S5层应用于向量值序列 $u_{1:L} \in R^{L \times H}$。S4的总操作数为 $O(H^2L + HL \log L)$。S5层需要 $O(PHL + PL)$ 次操作。因此,当S5状态维度 $P=O(H)$ 时,S5层需要 $O(H^2L + HL)$ 次操作,而S4层为 $O(H^2L + HL \log L)$。关键是,当 $P=O(H)$ 时,S4和S5的并行复杂度均为 $O(\log H + \log L)$(当有 $H^2L$ 个处理器时)。此外,当 $P=O(H)$ 时,并行扫描的空间复杂度为 $O(HL)$,与S4的FFT空间复杂度相匹配。

在线生成。对于在线生成,S4和S5层都以循环方式运行。S4层每步需要 $O(H^2 + HN)$ 次操作。S5层每步需要 $O(PH + P)$ 次操作。因此,当 $P=O(H)$ 且单个S4 SSM状态大小 $N=O(H)$ 时,两种方法的每步复杂度均为 $O(H^2)$。

结论。因此,在两种情况下,S4和S5的计算复杂度和内存需求都处于同一数量级。

C.2 经验性运行时比较

运行时基准测试。表4提供了S4、S4D和S5在一系列LRA任务不同序列长度下的速度和内存的经验性评估。我们使用了JAX实现进行公平比较。
- S5 (P=H)配置:将S5的潜状态大小P设为与输入/输出特征数H相等,经验性地支持了C.1中的复杂度分析。
- S5 (P=N)配置:将S5的潜状态大小P设为与S4中每个SISO SSM的潜状态大小N相等,对应于消融研究中与S4/S4D性能相似的受限版本。
- S5 (实际使用)配置:使用表11中报告的最佳架构尺寸,用于表1中的LRA结果。

结论。这个经验研究的总体结论是,S5和S4/S4D的运行时间和内存使用大体相似,正如复杂度分析所表明的。

表4:在三个不同序列长度的LRA任务上,使用C.2节中描述的参数化对S4、S4D和S5的运行时性能进行基准测试。对于速度,> 1×表示比S4D基线快。对于内存,< 1×表示比S4D基线使用的内存少。每个度量的第五行显示了用于表1中LRA结果的实际S5模型在每个任务上的性能,使用了表11中报告的架构。
表4:在三个不同序列长度的LRA任务上,使用C.2节中描述的参数化对S4、S4D和S5的运行时性能进行基准测试。对于速度,> 1×表示比S4D基线快。对于内存,< 1×表示比S4D基线使用的内存少。每个度量的第五行显示了用于表1中LRA结果的实际S5模型在每个任务上的性能,使用了表11中报告的架构。

D S4和S5之间的关系

D.2 等价动态的不同输出投影

命题2的证明。对于单个S4 SSM,离散化后的潜状态可以表示为:

公式13
公式13

对于S5层,潜状态可以表示为:
公式14
公式14

根据线性性质,我们观察到:
公式15
公式15

这表明(在假设下)MIMO S5 SSM的状态等价于H个不同SISO S4 SSM状态的总和。S5的输出为 $y_k^{S5} = C x_k^{S5}$。将(15)代入,可得:
公式17
公式17

将H个S4 SSM的状态向量垂直拼接为 $x_k^{(1:H)}$,S5的输出可以表示为:
公式18
公式18

因此,S5的输出等价于H个S4 SSM计算的HN个状态的线性组合。S5可以被解释为定义了与S4不同的对H个独立SISO SSMs的投影。S4的等效输出矩阵是块对角的,而S5的等效输出矩阵是将一个密集的输出矩阵C在H个S4 SSMs上进行了绑定。

D.3 可对角化的初始化

推论1的证明。【22,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】中的定理3显示了在 $N \rightarrow \infty$ 时标量输入信号的如下关系:

公式23
公式23

我们希望将其扩展到向量值输入信号的情况。我们利用(15)中的关系,即MIMO S5 SSM的潜状态是H个SISO S4 SSM潜状态的总和,这在连续时间下同样适用:
公式24
公式24

因此,S5状态的导数可以定义为:
公式25
公式25

将(23)代入可得:
公式26
公式26

公式27
公式27

公式28
公式28

这一等价性激励了用可对角化的HiPPO-N矩阵来初始化S5状态矩阵,并表明我们可以期待看到类似的性能提升。

通过块对角结构放宽假设。通过让S5的状态矩阵为块对角,可以将系统看作是J个独立的N维S5 SSM子系统的和。每个子系统都可以有自己的状态矩阵和时间尺度参数。根据命题2,每个子系统的动态又可以关联到一个不同的S4系统。因此,一个JN维的S5 SSM的输出可以等价于J个不同的S4系统潜状态的线性组合。这为使用多个HiPPO-N矩阵初始化块对角S5状态矩阵提供了理论依据,从而放宽了S4 SSM状态矩阵和时间尺度绑定的假设。

D.5 时间尺度参数化

向量化时间尺度。S4可以为每个SSM学习一个不同的时间尺度参数。在S5中,我们发现使用P个不同的时间尺度参数(每个状态一个)比使用单个或J个参数性能更好。这可以看作是为对角化系统中的每个特征值学习一个不同的缩放因子,或者增加初始化时采样的时间尺度参数数量,以避免糟糕的初始化。

E 消融实验

E.1 S5潜状态大小、块对角初始化和时间尺度参数化

实验设置和结果。表5展示了在LRA任务上进行的消融研究,比较了三种S5变体:
1. 受限S5 (P=N, 标量$\Delta$): 潜状态大小P与S4的N相等,使用单个时间尺度参数。在多数任务上表现良好,但在Image和ListOps上不如S4。
2. 受限S5 (P=N, 向量$\Delta$): 与上一版本相同,但使用向量时间尺度参数$\Delta \in R^N$。性能普遍优于标量版本。
3. 无约束S5 (本文方法): 使用更大的潜状态大小P>N和块对角初始化。此版本在所有任务上都取得了最佳或接近最佳的性能。

结论。向量化的时间尺度参数和块对角初始化对S5的性能至关重要,允许其超越受限版本和S4基线。

表5:LRA基准任务的消融实验(Tay et al., 2021)。S4的结果取自Gu et al. (2022; 2021a)。请注意,此表中所有模型的总参数数量是相当的,因此性能变化不能归因于模型参数数量的巨大差异。
表5:LRA基准任务的消融实验(Tay et al., 2021)。S4的结果取自Gu et al. (2022; 2021a)。请注意,此表中所有模型的总参数数量是相当的,因此性能变化不能归因于模型参数数量的巨大差异。

E.2 HiPPO-N和连续时间参数化的重要性

实验设置和结果。表6比较了不同状态矩阵初始化(随机高斯、随机反对称、HiPPO-N)和参数化(连续时间 vs. 离散时间)的S5变体。
- 随机初始化方法在简单任务上可行,但在Path-X等长程任务上失败。
- 仅使用离散时间参数化的模型(即使使用HiPPO-N初始化)性能也较差,并且训练不稳定。
- 只有同时使用连续时间参数化HiPPO-N初始化的S5模型才能在所有任务上持续表现良好,并解决Path-X问题。

结论。连续时间参数化和HiPPO初始化是S5相比于先前并行化线性RNN取得成功的关键因素。

表6:S5初始化和参数化消融研究。✗ 表示模型没有比随机猜测有所改进。
表6:S5初始化和参数化消融研究。✗ 表示模型没有比随机猜测有所改进。

H 并行扫描用于线性递推的背景知识

核心思想。为了并行计算线性递推 $x_k = A x_{k-1} + B u_k$,我们将问题转化为一个扫描操作。
1. 初始元素:为序列中的每个时间步k创建一个元组 $c_k = (A, B u_k)$。
2. 二元结合算子:定义一个算子 ,用于合并两个元素 $q_i = (A_i, v_i)$ 和 $q_j = (A_j, v_j)$:

公式34
公式34

这个算子是结合的,即 $(q_i • q_j) • q_k = q_i • (q_j • q_k)$。

并行计算示例。对于一个长度为4的序列,传统的循环计算需要4个串行步骤。而使用并行扫描,可以首先并行计算 $(c_1 • c_2)$ 和 $(c_3 • c_4)$,然后在第二步中计算 $((c_1 • c_2) • (c_3 • c_4))$ 来得到最终状态。通过这种树状的归约方式,可以将串行步骤从L减少到 $\log L$,从而实现并行加速。当状态矩阵A是对角矩阵时,矩阵乘法变为元素级乘法,使得该操作在计算上非常高效。