Efficiently Modeling Long Sequences with Structured State Spaces

作者/机构: Albert Gu, Karan Goel, and Christopher Ré, Department of Computer Science, Stanford University

A1 主要贡献

本文旨在解决序列建模中的一个核心问题:高效处理包含长距离依赖(LRDs)的数据。现有的主流模型,如循环神经网络(RNNs)、卷积神经网络(CNNs)和Transformers,尽管有专门为长序列设计的变体,但在处理长度超过10000的时间步时仍然面临困难。

最近,一种基于状态空间模型(State Space Model, SSM)的方法展现了潜力。SSM是一种基础科学模型,通过方程 $x'(t) = Ax(t) + Bu(t), y(t) = Cx(t) + Du(t)$ 描述系统。研究【18,Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer,NeurIPS 2021】发现,通过为状态矩阵A选择特殊的HiPPO矩阵【16,Hippo: Recurrent memory with optimal polynomial projections,NeurIPS 2020】,SSM在理论和实践上都能很好地捕捉长距离依赖。

然而,这种被称为线性状态空间层(LSSL)的方法存在严重的计算和内存瓶颈。对于状态维度为N、序列长度为L的模型,其计算和存储需求过高(计算复杂度为 $O(N^2L)$,空间复杂度为 $O(NL)$),使其无法作为通用的序列建模解决方案。此外,虽然理论上存在更高效的算法,但本文证明这些算法在数值上是不稳定的。

本文的主要贡献是提出了结构化状态空间序列模型(S4),它解决了LSSL的计算瓶颈,同时保留了其理论优势。
- 技术创新:S4的核心是一种新的SSM参数化方法。它将HiPPO矩阵A分解为一个正规矩阵(Normal Matrix)和一个低秩矩阵(Low-Rank Matrix)之和。通过这种分解,S4可以在频域中高效计算,利用Woodbury恒等式修正低秩项,并将正规项稳定地对角化,最终将核心计算归结为一个经过充分研究且数值稳定的柯西核(Cauchy Kernel)计算问题。这使得模型的计算复杂度降至 $O((N+L)\log(N+L))$,空间复杂度降至 $O(N+L)$,接近理论最优。
- 经验验证:S4在多个基准测试中取得了显著的成果,证明了其作为通用序列模型的潜力。
- 长距离依赖:在为高效序列模型设计的长距离竞技场(LRA)基准测试中,S4在所有任务上都超越了基线模型,并且是第一个成功解决长度为16384的Path-X任务的模型。在原始语音分类任务上,S4的错误率比专门的语音CNN模型降低了一半。
- 通用序列建模能力
- 大规模生成建模:在CIFAR-10密度估计上,S4的效果与顶尖的自回归模型相当;在WikiText-103语言建模上,它显著缩小了与Transformers的性能差距。
- 快速自回归生成:利用其循环表示,S4的生成速度比标准自回归模型快60倍。
- 适应采样率变化:S4能够适应时间序列采样频率的变化而无需重新训练。
- 弱归纳偏见下的学习:在没有领域特定架构修改的情况下,S4在语音分类、时间序列预测和序列图像分类等任务上均取得了与专门模型相当甚至更好的性能,例如在序列CIFAR-10上达到了91%的准确率,与一个更大的2D ResNet相媲美。

下图(Figure 1)展示了SSM的不同计算视角及其参数表示,以及S4如何通过新的参数化方法高效地在这些表示之间切换。

图1:(左)由矩阵A, B, C, D参数化的状态空间模型(SSM)通过一个潜在状态x(t)将输入信号u(t)映射到输出y(t)。(中)最近关于连续时间记忆的理论推导出了特殊的A矩阵,使SSM能够在数学上和经验上捕捉长距离依赖。(右)SSM可以作为循环(左)或卷积(右)来计算。然而,实现这些概念视图需要利用其参数的不同表示(红、蓝、绿),这些表示的计算成本非常高。S4引入了一种新颖的参数化方法,可以有效地在这些表示之间切换,使其能够处理广泛的任务,在训练和推理上都高效,并擅长处理长序列。
图1:(左)由矩阵A, B, C, D参数化的状态空间模型(SSM)通过一个潜在状态x(t)将输入信号u(t)映射到输出y(t)。(中)最近关于连续时间记忆的理论推导出了特殊的A矩阵,使SSM能够在数学上和经验上捕捉长距离依赖。(右)SSM可以作为循环(左)或卷积(右)来计算。然而,实现这些概念视图需要利用其参数的不同表示(红、蓝、绿),这些表示的计算成本非常高。S4引入了一种新颖的参数化方法,可以有效地在这些表示之间切换,使其能够处理广泛的任务,在训练和推理上都高效,并擅长处理长序列。

A3 背景知识

2.1 状态空间模型:一种连续时间的潜在状态模型

状态空间模型(SSM)的定义。SSM由一个简单的线性常微分方程组定义,它将一个一维输入信号 $u(t)$ 映射到一个 $N$ 维的潜在状态 $x(t)$,然后再将该状态投影到一个一维输出信号 $y(t)$。

SSM作为深度学习模块。SSM在许多科学领域被广泛使用,并与隐马尔可夫模型(HMM)等潜在状态模型相关。本文的目标是将其作为一个黑盒表示,用于深度序列模型中,其中矩阵A、B、C、D是通过梯度下降学习的参数。为了简化论述,本文后续将忽略参数D(或等效地假设D=0),因为 $Du(t)$ 项可以看作是一个易于计算的跳跃连接。

2.2 使用HiPPO解决长距离依赖问题

标准SSM的局限性。先前的工作发现,基础的SSM(方程1)在实践中表现非常差。一个直观的解释是,线性一阶常微分方程的解是指数函数,因此可能会遭受梯度随序列长度指数级缩放的问题(即梯度消失/爆炸问题【32,On the difficulty of training recurrent neural networks,ICML 2013】)。

HiPPO理论的引入。为了解决这个问题,LSSL模型利用了关于连续时间记忆的HiPPO理论【16,Hippo: Recurrent memory with optimal polynomial projections,NeurIPS 2020】。HiPPO理论指定了一类特殊的矩阵 $A \in \mathbb{R}^{N \times N}$,当将这些矩阵整合到方程(1)中时,状态 $x(t)$ 能够记忆输入 $u(t)$ 的历史。这个类别中最重要的矩阵由方程(2)定义,我们称之为HiPPO矩阵。例如,LSSL发现,仅仅将SSM中的随机矩阵A修改为方程(2)中的HiPPO矩阵,其在序贯MNIST基准测试上的性能就从60%提高到了98%。

2.3 离散时间SSM:循环表示

连续到离散的转换。为了将模型应用于离散输入序列 $(u_0, u_1, \ldots)$ 而非连续函数 $u(t)$,必须通过一个代表输入分辨率的步长 $\Delta$ 对方程(1)进行离散化。概念上,可以将输入 $u_k$ 视为对一个隐含的潜在连续信号 $u(t)$ 的采样,其中 $u_k = u(k\Delta)$。

离散化方法。为了离散化连续时间SSM,我们遵循先前工作使用双线性方法【43,A method of analysing the behaviour of linear systems in terms of time series,Journal of the Institution of Electrical Engineers-Part IIA: Automatic Regulators and Servo Mechanisms 1947】,该方法将状态矩阵A转换为一个近似的离散矩阵 $\bar{A}$。离散化的SSM如下所示:

循环计算形式。方程(3)现在是一个从序列到序列的映射 $u_k \rightarrow y_k$,而不是函数到函数的映射。更重要的是,状态方程现在是关于 $x_k$ 的一个递推关系,这使得离散SSM可以像RNN一样进行计算。具体来说,可以将 $x_k \in \mathbb{R}^N$ 视为一个隐藏状态,其转移矩阵为 $\bar{A}$。在本文中,我们用 $\bar{A}, \bar{B}, \ldots$ 表示由方程(3)定义的离散化SSM矩阵,这些矩阵是A和步长 $\Delta$ 的函数。

2.4 训练SSM:卷积表示

循环表示的训练瓶颈。循环SSM(方程3)由于其顺序性,在现代硬件上训练并不实用。然而,线性时不变(LTI)SSM(如方程1)与连续卷积之间存在着众所周知的联系。相应地,方程(3)实际上可以写成一个离散卷积形式。

推导卷积形式。为简单起见,假设初始状态为 $x_{-1} = 0$。那么,显式地展开方程(3)的递推关系可以得到:

卷积核的显式公式。这个关系可以被向量化成一个卷积(方程4),其卷积核有显式公式(方程5)。

卷积的高效计算。换句话说,方程(4)是一个单一的(非循环)卷积,只要卷积核 $\bar{K}$ 已知,就可以用快速傅里叶变换(FFT)非常高效地计算。然而,计算方程(5)中的 $\bar{K}$ 并非易事,这也是本文在第3节中技术贡献的焦点。我们将 $\bar{K}$ 称为SSM卷积核或滤波器。

A2 方法细节

我们的技术成果重点在于发展S4参数化方法,并展示如何高效地计算SSM的所有视图(第2节):连续表示(A, B, C)(方程1)、循环表示($\bar{A}, \bar{B}, \bar{C}$)(方程3)和卷积表示 $\bar{K}$(方程4)。

3.1 动机:对角化

核心计算瓶颈。计算离散时间SSM(方程3)的根本瓶颈在于它涉及到对矩阵 $\bar{A}$ 的重复矩阵乘法。例如,像LSSL那样朴素地计算方程(5)需要对 $\bar{A}$ 进行L次连续乘法,这需要 $O(N^2L)$ 的操作和 $O(NL)$ 的空间。

通过共轭变换简化SSM。为了克服这个瓶颈,我们利用一个结构性结果来简化SSM。根据引理3.1,SSM在共轭变换下是等价的,即 $(A, B, C) \sim (V^{-1}AV, V^{-1}B, CV)$。这意味着我们可以通过选择合适的矩阵V将A变换为一个更结构化的规范形式,从而实现更快的计算。理想情况下,如果A是可对角化的,计算将变得非常容易。特别是,期望的卷积核 $\bar{K}$(方程4)将成为一个范德蒙德乘积,理论上只需要 $O((N + L) \log^2(N + L))$ 次算术运算【29,Structured matrices and polynomials: unified superfast algorithms,Springer Science & Business Media 2001】。

  • 引理 3.1:共轭是SSMs上的一个等价关系,即 $(A, B, C) \sim (V^{-1}AV, V^{-1}B, CV)$。
    • 证明:写出两个SSM的状态分别为x和 $\tilde{x}$:


      将右侧SSM乘以V后,两个SSM变得相同,且 $x = V\tilde{x}$。因此,它们计算的是完全相同的算子 $u \rightarrow y$,只是状态x的基底通过V发生了变化。

朴素对角化的数值不稳定性。不幸的是,朴素地应用对角化方法由于数值问题而行不通。我们推导了HiPPO矩阵(方程2)的显式对角化形式,并表明其特征向量矩阵的元素随状态大小N呈指数级增长,使得对角化在数值上不可行(例如,引理3.1中的CV将无法计算)。我们注意到,Gu等人【18,Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer,NeurIPS 2021】曾提出一种不同的(未实现的)算法来比朴素算法更快地计算 $\bar{K}$。在附录B中,我们证明了由于相关原因,该算法也存在数值不稳定性。

  • 引理 3.2:方程(2)中的HiPPO矩阵A由矩阵 $V_{ij} = \binom{i+j}{i-j}$ 对角化。特别地,$V_{3i,i} = \binom{4i}{2i} \approx 2^{4i}$。因此,V的元素大小可达 $2^{4N/3}$。

3.2 S4参数化:正规加低秩

寻找稳定的共轭变换。前面的讨论意味着我们应该只使用条件良好的矩阵V进行共轭变换。理想情况是当矩阵A可以被一个条件完美的(即酉)矩阵对角化时。根据线性代数的谱定理,这正是正规矩阵(normal matrices)的类别。然而,这类矩阵是受限的;特别地,它不包含HiPPO矩阵(方程2)。

核心思想:NPLR分解与组合技术。我们观察到,尽管HiPPO矩阵不是正规的,但它可以被分解为一个正规矩阵和一个低秩矩阵的和。然而,这本身仍然没有用:与对角矩阵不同,对这个和进行幂运算(在方程5中)仍然很慢且不易优化。我们通过同时应用三种新技术来克服这个瓶颈。

  • 从计算核到计算谱:我们不直接计算卷积核 $\bar{K}$,而是通过在单位根 $\zeta$ 上评估其截断生成函数 $\sum_{j=0}^{L-1} \bar{K}_j \zeta^j$ 来计算其谱。然后可以通过应用逆FFT来找到 $\bar{K}$。
  • 利用矩阵求逆和Woodbury恒等式:这个生成函数与矩阵的预解式(resolvent)密切相关,现在它涉及矩阵求逆而不是求幂。低秩项现在可以通过应用Woodbury恒等式进行修正,该恒等式将 $(A+PQ^*)^{-1}$ 的计算归约为 $A^{-1}$ 的计算,从而真正地简化为对角矩阵的情况。
  • 归约为柯西核计算:最后,我们证明了对角矩阵的情况等价于计算一个柯西核 $\frac{1}{\omega - \zeta}$,这是一个经过充分研究的问题,存在稳定的近线性时间算法【30,Fast approximate computations with cauchy matrices and polynomials,Mathematics of Computation 2017】,【31,Transformations of matrix structures work again,Linear Algebra and Its Applications 2015】。

HiPPO矩阵的NPLR表示。我们的技术适用于任何可以分解为正规加低秩(Normal Plus Low-Rank, NPLR)的矩阵。定理1指出,所有来自【16,Hippo: Recurrent memory with optimal polynomial projections,NeurIPS 2020】的HiPPO矩阵都具有NPLR表示,其低秩部分的秩r为1或2。

  • 定理 1:所有来自【16】的HiPPO矩阵都具有NPLR表示:

    对于酉矩阵 $V \in \mathbb{C}^{N \times N}$,对角矩阵 $\Lambda$,以及低秩分解 $P, Q \in \mathbb{R}^{N \times r}$。这些矩阵(HiPPO-LegS, LegT, LagT)都满足 $r=1$ 或 $r=2$。特别地,方程(2)中的HiPPO矩阵是秩为1的NPLR矩阵。

3.3 S4算法与计算复杂度

DPLR形式与计算效率。根据方程(6),NPLR矩阵可以通过共轭变换为对角加低秩(DPLR)形式(现在是在复数域 $\mathbb{C}$ 而非实数域 $\mathbb{R}$ 上)。定理2和定理3描述了当A为DPLR形式时SSM的计算复杂度。S4在循环和卷积两种表示上都达到了最优或近最优的效率。

S4循环计算。定理2指出,对于任意步长 $\Delta$,计算循环表示(方程3)的一步只需要 $O(N)$ 的操作。这是因为DPLR矩阵的逆也是DPLR矩阵(同样可通过Woodbury恒等式证明)。这意味着离散化后的矩阵 $\bar{A}$ 是两个DPLR矩阵的乘积,因此其矩阵-向量乘法可以在 $O(N)$ 时间内完成。附录C.2给出了 $\bar{A}$ 的DPLR闭式形式。
* 定理 2 (S4 Recurrence):给定任意步长 $\Delta$,计算递推式(3)的一步可以在 $O(N)$ 次操作内完成,其中N是状态大小。

S4卷积计算。定理3是本文的核心技术贡献,它指出计算SSM卷积滤波器 $\bar{K}$ 可以归约为4次柯西乘法,只需要 $\tilde{O}(N+L)$ 的操作和 $O(N+L)$ 的空间。计算柯西矩阵是一个在数值分析中被广泛研究的问题,有基于快速多极子方法(FMM)的快速算术和数值算法【29,Structured matrices and polynomials: unified superfast algorithms,Springer Science & Business Media 2001】,【30,Fast approximate computations with cauchy matrices and polynomials,Mathematics of Computation 2017】,【31,Transformations of matrix structures work again,Linear Algebra and Its Applications 2015】。这个算法在算法1中被正式地勾勒出来。
* 定理 3 (S4 Convolution):给定任意步长 $\Delta$,计算SSM卷积滤波器 $\bar{K}$ 可以归约为4次柯西乘法,只需要 $\tilde{O}(N+L)$ 次操作和 $O(N+L)$ 的空间。

主流序列模型复杂度对比。表1比较了S4与其他主流序列模型(RNN, CNN, Transformer)在参数量、训练计算量、训练空间需求、训练并行性和推理计算量等方面的复杂度。SSM结合了卷积(训练高效)和循环(推理高效)的优点。

表1:各种序列模型在序列长度(L)、批量大小(B)和隐藏维度(H)方面的复杂度比较;波浪号表示对数因子。指标包括参数数量、训练计算量、训练空间需求、训练并行性以及推理计算量(针对单个样本和时间步)。为简化起见,S4的状态大小N与H相关联。粗体表示模型在该指标上理论上是最佳的。卷积对于训练是高效的,而循环对于推理是高效的,SSM则结合了两者的优点。
表1:各种序列模型在序列长度(L)、批量大小(B)和隐藏维度(H)方面的复杂度比较;波浪号表示对数因子。指标包括参数数量、训练计算量、训练空间需求、训练并行性以及推理计算量(针对单个样本和时间步)。为简化起见,S4的状态大小N与H相关联。粗体表示模型在该指标上理论上是最佳的。卷积对于训练是高效的,而循环对于推理是高效的,SSM则结合了两者的优点。

3.4 深度S4层的架构细节

S4层的参数化。具体来说,一个S4层按以下方式参数化。首先,用HiPPO矩阵(方程2)初始化一个SSM。根据引理3.1和定理1,这个SSM通过酉变换等价于某个 $(\Lambda - PQ^*, B, C)$,其中 $\Lambda$ 是对角矩阵,而P, Q, B, C是 $\mathbb{C}^{N \times 1}$ 的向量。这些构成了S4的5N个可训练参数。

深度S4网络架构。S4的整体深度神经网络(DNN)架构与先前的工作相似。如上定义,S4定义了一个从 $\mathbb{R}^L \rightarrow \mathbb{R}^L$ 的映射,即一个一维序列映射。通常,DNN操作的是H维的特征图,而不是1维。S4通过定义H个独立的自身副本来处理多个特征,然后用一个位置无关的线性层来混合这H个特征,每层总共有 $O(H^2) + O(HN)$ 个参数。在这些层之间还插入了非线性激活函数。总的来说,S4定义了一个形状为(批量大小,序列长度,隐藏维度)的序列到序列映射,这与相关的序列模型如Transformers、RNNs和CNNs完全相同。

与CNN的类比。请注意,核心的S4模块是一个线性变换,但通过网络深度的非线性变换使得整个深度SSM变为非线性。这类似于一个普通的CNN,因为卷积层也是线性的。本节描述的跨H个隐藏特征的广播也类似于深度可分离卷积。因此,整个深度S4模型与一个带有全局卷积核的深度可分离CNN密切相关。

参数化的改进。最后,我们注意到后续工作【14,It’s raw! audio generation with state-space models,arXiv 2022】发现,当A矩阵的特征值位于右半平面时,这个版本的S4有时会遇到数值不稳定性。该工作对S4的NPLR参数化做了一个微小的改动,从 $\Lambda - PQ^*$ 变为 $\Lambda - PP^*$,从而解决了这个潜在问题。

A4 实验环境

  • 数据集:
    • 长距离竞技场 (LRA): 包含6个任务(列表、文本、检索、图像、Path-X、Path-Finder),序列长度从1K到16K不等,用于评估模型处理长距离依赖的能力。
    • Speech Commands 10 (SC10): Speech Commands数据集的一个10类子集,用于原始音频(长度16000)分类任务。
    • CIFAR-10: 用于序列化图像分类(图像被展平为长度3072的序列)和像素级密度估计。
    • WikiText-103: 一个标准的语言建模基准。
    • 时间序列预测: 包含ETT、Electricity、Exchange-Rate、Traffic、Weather等5个数据集,用于与Informer模型进行比较。
  • 模型架构:
    • S4: 核心是SSM模块。深度S4模型由S4层和位置前馈网络交替堆叠而成,类似于Transformer块。状态大小N根据任务调整,例如LRA任务中N=64,与Transformer对标的实验中N=256。模型参数量从100K(消融实验)到2.5亿(生成模型)不等。
    • 基线模型: LSSL、多种高效Transformer变体(如Linear Transformer, Performer)、RNNs(ExpRNN, LipschitzRNN)、CNNs(WaveGAN-D, ResNet-18)以及专门模型(如Informer)。
  • 硬件配置:
    • 实验主要在单个A100 GPU上进行,特别是在进行基准测试和Path-X等内存消耗大的实验时。
  • 软件配置:
    • S4的实现利用了pykeops库来进行内存高效的核操作,尽管使用的是计算上较朴素的 $O(NL)$ 算法。
    • 代码基于PyTorch实现。

A4 实验结果

4.1 S4效率基准测试

  • 与LSSL的比较(表2): 实验证实,在实际的层尺寸下,S4比LSSL在速度上快30倍,内存效率高400倍,显示出数量级的提升。
  • 与高效Transformer的比较(表3): 在参数匹配的设置下,S4的速度和内存使用与最高效的Transformer变体(如Linear Transformer和Performer)相当,具有竞争力。

表2:深度SSM:S4参数化与算法1在渐进上比LSSL更高效。
表2:深度SSM:S4参数化与算法1在渐进上比LSSL更高效。

表3:与高效Transformers的基准测试对比
表3:与高效Transformers的基准测试对比

4.2 学习长距离依赖

  • 长距离竞技场(LRA) benchmark(表4): S4在LRA的所有6个任务上均超越了所有基线模型,平均准确率达到80.48%,而所有基线模型均低于60%。特别地,S4是第一个解决极具挑战性的Path-X任务(序列长度16384)的模型,准确率达到88%,而之前所有模型都停留在随机猜测水平(50%)。
  • Path-X任务的可视化分析(图2): 通过可视化S4学习到的卷积核,发现模型学到了具有空间一致性的结构。底层网络学习局部特征,而高层网络则能整合跨越整个图像(16384像素)的全局信息,这证实了S4学习长距离依赖的能力。
  • 原始语音分类(表5): 在SC10数据集上处理长度为16000的原始语音信号,S4取得了98.3%的准确率,错误率比专门为原始语音设计的WaveGAN-D模型(参数量是S4的90倍)降低了一半。相比之下,所有RNN和Transformer基线模型都无法在该任务上学习(错误率≥70%)。

表4:(长距离竞技场) (上) LRA中的原始Transformer变体。完整结果见附录D.2。(下) 文献中报道的其他模型。引用此表前请阅读附录D.5。
表4:(长距离竞技场) (上) LRA中的原始Transformer变体。完整结果见附录D.2。(下) 文献中报道的其他模型。引用此表前请阅读附录D.5。

图2:在LRA Path-X上训练的S4模型的可视化。SSM卷积核K∈R^16384被重塑为128×128的图像。(左)Path-X任务的一个例子,涉及推断标记是否由路径连接。(上)第一层的滤波器。(下)最后一层的滤波器。
图2:在LRA Path-X上训练的S4模型的可视化。SSM卷积核K∈R^16384被重塑为128×128的图像。(左)Path-X任务的一个例子,涉及推断标记是否由路径连接。(上)第一层的滤波器。(下)最后一层的滤波器。

表5:(SC10分类) Transformer、CTM、RNN、CNN和SSM模型。(MFCC) 标准预处理的MFCC特征(长度161)。(Raw) 未处理的信号(长度16000)。(0.5×) 测试时频率变化。✗表示不适用或在单个GPU上计算不可行。引用此表前请阅读附录D.5。
表5:(SC10分类) Transformer、CTM、RNN、CNN和SSM模型。(MFCC) 标准预处理的MFCC特征(长度161)。(Raw) 未处理的信号(长度16000)。(0.5×) 测试时频率变化。✗表示不适用或在单个GPU上计算不可行。引用此表前请阅读附录D.5。

4.3 S4作为通用序列模型

  • 大规模生成建模:
    • CIFAR-10密度估计(表7): 在不使用任何2D归纳偏见的情况下,S4达到了2.85 bits/dim,与为该任务设计的最佳自回归模型相当。
    • WikiText-103语言建模(表8): S4显著缩小了与Transformers的性能差距(在0.8 perplexity以内),并为无注意力模型设定了新的SOTA。
  • 快速自回归推理: 由于SSM具有状态化的特性,S4可以切换到其循环表示进行推理。在CIFAR-10和WikiText-103的生成任务中,S4的速度比标准Transformer快约60倍。
  • 适应采样率变化(表5): S4作为连续时间模型,只需调整内部步长参数 $\Delta$,就能在不重新训练的情况下适应不同的采样率。在SC10任务中,当测试数据采样率降为0.5倍时,S4仍能达到96.3%的准确率。
  • 在弱归纳偏见下学习:
    • 时间序列预测(表9): S4在一个简单的序列到序列框架下,在50个设置中的40个上优于专为时间序列预测设计的复杂模型Informer。
    • 序列图像分类(表6): 在序列化的CIFAR-10上,S4在没有数据增强的情况下达到了91.12%的准确率,与一个更大的2D ResNet-18(89.46%)相当,并且远超其他序列模型。

表6:(像素级1-D图像分类) 与先前工作报道的测试准确率对比(Transformer, RNN, CNN, 和 SSM模型)。扩展结果和引文见附录D。
表6:(像素级1-D图像分类) 与先前工作报道的测试准确率对比(Transformer, RNN, CNN, 和 SSM模型)。扩展结果和引文见附录D。

表7:(CIFAR-10密度估计) 作为一个通用序列模型,S4与之前的自回归模型(以bits per dim为单位)相比具有竞争力,同时没有引入2D归纳偏见,并且通过其循环模式实现了快速生成。
表7:(CIFAR-10密度估计) 作为一个通用序列模型,S4与之前的自回归模型(以bits per dim为单位)相比具有竞争力,同时没有引入2D归纳偏见,并且通过其循环模式实现了快速生成。

表8:(WikiText-103语言建模) S4以更快的生成速度接近Transformers的性能。(上) Transformer基线,我们的实现基于该基线,将注意力替换为S4。
表8:(WikiText-103语言建模) S4以更快的生成速度接近Transformers的性能。(上) Transformer基线,我们的实现基于该基线,将注意力替换为S4。

表9:单变量长序列时间序列预测结果。完整结果见附录D.3.5。
表9:单变量长序列时间序列预测结果。完整结果见附录D.3.5。

4.4 SSM消融实验:HiPPO的重要性

在序列CIFAR-10任务上进行的消融实验表明:
- HiPPO初始化的重要性(图3): 使用HiPPO矩阵初始化SSM至关重要。与随机高斯或随机对角初始化相比,HiPPO初始化能显著缩小训练和验证准确率之间的泛化差距(超过15%)。尽管所有可训练的模型都能达到完美的训练准确率,但只有HiPPO初始化能在验证集上取得高分。
- 训练SSM的必要性: 无论何种初始化方法,训练SSM参数总能比固定参数获得更好的性能。
- NPLR参数化的作用(图4a): 实验验证了S4的有效性主要来源于HiPPO初始化所蕴含的结构信息,而非NPLR参数化本身。随机的NPLR矩阵表现不佳。
- S4的综合优势(图4b): 完整的S4方法(HiPPO初始化+NPLR参数化+可训练)在仅有100K参数的情况下,通过简单的正则化就能在序贯CIFAR-10上达到84.27%的测试准确率。

图3:使用不同初始化的非约束实值SSM进行CIFAR-10分类。(左) 训练准确率。(右) 验证准确率。
图3:使用不同初始化的非约束实值SSM进行CIFAR-10分类。(左) 训练准确率。(右) 验证准确率。

图4:具有不同初始化和参数化的SSM在CIFAR-10上的验证准确率。(左) NPLR参数化下随机初始化与HiPPO初始化的对比。(右) 本节考虑的所有方法,包括轻微的Dropout正则化。S4仅用100K参数就在序贯CIFAR-10上实现了SotA准确率。
图4:具有不同初始化和参数化的SSM在CIFAR-10上的验证准确率。(左) NPLR参数化下随机初始化与HiPPO初始化的对比。(右) 本节考虑的所有方法,包括轻微的Dropout正则化。S4仅用100K参数就在序贯CIFAR-10上实现了SotA准确率。

A7 补充细节

相关工作

与本文最相关的工作。本文的工作与一系列最初由生物启发性SSM推动的研究紧密相关,这些研究最终发展为解决长距离依赖的数学模型。Voelker等人【45,Legendre memory units: Continuous-time representation in recurrent neural networks,NeurIPS 2019】,【46,Dynamical systems in spiking neuromorphic hardware,PhD thesis, University of Waterloo 2019】推导了一个受神经形态脉冲模型启发的不可训练SSM。Chilkuri和Eliasmith【7,Parallelizing legendre memory unit training,ICML 2021】展示了该模型可以通过卷积视图在训练时加速。Gu等人【16,Hippo: Recurrent memory with optimal polynomial projections,NeurIPS 2020】将这一特例扩展为一个通用的连续时间函数逼近框架,并提出了更多为长距离依赖设计的特殊A矩阵。然而,所有这些工作都固定了A的选择,并围绕它构建RNN,而不是使用一个真正的SSM。最近,Gu等人【18,Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer,NeurIPS 2021】明确地将完整的SSM(方程1)用作深度SSM模型,探索了SSM的新概念视图,并允许训练A矩阵。如第1节所述,他们的方法使用了SSM的朴素实现,导致内存开销增加了N倍,计算开销增加了 $N^2$ 倍。本文的技术贡献(第3节)关于S4的参数化和算法,适用于更广泛的SSM家族,包括先前研究中探讨的那些,我们处理这些模型的技术可能具有独立的价值。

实现

S4的计算核心。S4训练算法的计算核心是在3.2节、3.3节和附录C.3中讨论的柯西核。如附录C.3命题5所述,存在许多具有不同计算复杂度和复杂性的算法。我们目前S4的实现实际上使用了朴素的 $O(NL)$ 算法,该算法易于在GPU上并行化,并且有更易于获取的库支持;我们利用pykeops库进行内存高效的核操作。然而,这个库是一个更通用的库,可能没有针对这里使用的柯西核进行优化,我们相信专用的CUDA实现可以更高效。此外,如本文所讨论,存在渐近更快且数值稳定的柯西核算法(命题5)。然而,由于之前缺乏需要它们的应用,这些算法目前尚未在GPU上实现。我们相信,这些独立的计算核可以有更高效的实现,S4(以及广义的SSM)在效率上可能还有很大的提升空间。

局限性与未来方向

模型的适用性与未来工作。在本文中,我们展示了S4可以有效地处理各种数据。然而,它不一定对所有类型的数据都是最合适的模型。例如,表8中发现在语言建模方面与Transformers相比仍有差距。一个有趣的未来方向是探索将S4与其他序列模型结合,以互补它们的优势。我们对其他方向也感到兴奋,包括继续探索S4在音频数据上的优势(例如,预训练或生成设置),以及将HiPPO和S4推广到更高维数据,用于图像和视频应用。

A5 结论

我们介绍了S4,一个利用状态空间模型(SSM)的连续时间、循环和卷积视图的新参数化方法,以一种有原则的方式高效地建模长距离依赖(LRDs)的序列模型。在一系列评估不同数据模态和模型能力的既定基准上的结果表明,S4有潜力成为一个有效的通用序列建模解决方案。

A6 附录

B LSSL的数值不稳定性

本节证明了第3.1节中关于先前工作的论断。我们首先推导了HiPPO矩阵的显式对角化,证实了由于存在指数级大的元素,其存在不稳定性。然后我们讨论了【18】中提出的理论上快速的算法(定理2),并表明它也涉及到指数级大的项,因此无法实现。

B.1 HiPPO对角化

引理3.2的证明。HiPPO矩阵(方程2)在符号和对角矩阵共轭的意义上,等于


我们的目标是证明这个A矩阵被以下矩阵对角化:

换言之,这个矩阵的列是A的特征向量。具体地,我们将证明这个矩阵的第j列 $v^{(j)}$,其元素为 $v^{(j)}_k = \binom{k+j}{k-j}$,是特征值为 $j+1$ 的特征向量。也就是说,我们必须证明对于所有索引 $k \in [N]$,$(Av^{(j)})_k = (j+1)v^{(j)}_k$。
当 $k < j$ 时,方程两边都为0。
当 $k \ge j$ 时,通过对k进行归纳证明。展开方程得到

在基准情况 $k=j$ 时,等式成立。归纳步骤通过代数简化和应用归纳假设来证明。

B.2 LSSL的快速但不稳定算法

LSSL快速算法的数值问题。Gu等人【18,Combining recurrent, convolutional, and continuous-time models with the structured learnable linear state space layer,NeurIPS 2021,定理2】提出了一个复杂的快速算法来计算 $\bar{K} = (C^T\bar{B}, \ldots, C^T\bar{A}^{L-1}\bar{B})$。该算法在 $O(N \log^2 N + L \log L)$ 操作和 $O(N+L)$ 空间内运行。然而,我们现在证明该算法在数值上也是不稳定的。该算法的不稳定性有几个原因,但最直接的是,我们可以指出其使用的一个特定中间量。
* 定义 1:快速LSSL算法将A的特征多项式 $p(x)$ 的系数作为中间计算。此外,它还计算其逆 $p(x)^{-1} \pmod{x^L}$ 的系数。

不稳定性证明。我们现在声称这个量在数值上是不可行的。我们聚焦于 $A=I$(单位矩阵)的情况。注意,这种情况在某种意义上是最典型的情况:当通过步长 $\Delta$ 将连续时间SSM离散化为离散时间时,离散化的转移矩阵 $\bar{A}$ 会更接近单位矩阵。例如,使用欧拉离散化 $\bar{A} = I + \Delta A$,当步长 $\Delta \to 0$ 时,我们有 $\bar{A} \to I$。

  • 引理 B.1:当 $A=I$ 时,快速LSSL算法需要计算以N为指数级增长的项。
    • 证明:I的特征多项式是 $p(x) = (x-1)^N = \sum_{k=0}^N \binom{N}{k}(-1)^{N-k} x^k$。这些系数的大小可达 $\binom{N}{N/2} \approx \frac{2^N}{\sqrt{\pi N / 2}}$。
      $p(x)$ 的逆有更大的系数。它可以通过广义二项式公式以闭合形式计算:

      取模 $x^L$ 后,最大的系数是 $\binom{L-1+N}{N}$。当 $L=N-1$ 时,这个系数已经比 $(1-x)^N$ 的系数大,并且随着L的增长而增加。

C S4算法细节

本节证明第3.3节的结果,提供我们高效S4算法的完整细节。附录C.1至C.3分别证明定理1至3。

C.1 HiPPO矩阵的NPLR表示

定理1的证明。我们分别考虑HiPPO-LagT、HiPPO-LegT和HiPPO-LegS三种情况。注意,本文中主要定义的HiPPO矩阵(方程2)是HiPPO-LegT矩阵。

  • HiPPO-LagT:该矩阵是一个上海森堡矩阵。通过加上一个全1的秩1矩阵,可以将其变为一个斜对称矩阵。斜对称矩阵是正规矩阵的一种特例。
  • HiPPO-LegS:该矩阵通过加上一个秩1矩阵可以变为 $\frac{1}{2}I + S$ 的形式,其中S是斜对称矩阵。这个矩阵可以被对角化S的同一个酉矩阵对角化。
  • HiPPO-LegT:该矩阵可以通过加上一个秩2矩阵修正为一个斜对称矩阵。另一种方法是考虑其逆矩阵,其逆矩阵可以通过加上一个秩2项变为斜对称矩阵,再利用Woodbury恒等式,表明原矩阵也是一个正规矩阵加上秩2修正。

C.2 计算S4的循环视图

定理2的证明。本节的目标是显式地写出离散化矩阵 $\bar{A}$ 的闭合形式,证明S4参数化对于计算循环表示一步是高效的。
不失一般性,我们假设状态矩阵 $A=\Lambda-PQ^*$ 是对角加低秩(DPLR)形式。离散化矩阵 $\bar{A} = (I - \frac{\Delta}{2}A)^{-1}(I + \frac{\Delta}{2}A)$。
我们分别简化这两个项。
* 前向离散化: $(I + \frac{\Delta}{2}A) = (I + \frac{\Delta}{2}\Lambda) - \frac{\Delta}{2}PQ^*$,这也是DPLR形式。
* 后向离散化: $(I - \frac{\Delta}{2}A)^{-1}$,使用Woodbury恒等式可以将其简化为另一个DPLR矩阵。
最终,$\bar{A}$ 可以表示为两个DPLR矩阵的乘积,因此其矩阵-向量乘法可以在 $O(N)$ 时间内完成,证明了定理2。

C.3 计算卷积视图

定理3的算法流程。这是使用SSM最复杂的部分,也是S4参数化的主要动机。算法大致分为三个阶段:

  1. 约简0:对角化。根据引理3.1,我们可以通过酉共轭将A转换为对角加低秩(DPLR)形式。DPLR矩阵的幂运算仍然很困难,因此需要进一步的技术。

  2. 约简1:SSM生成函数。我们不直接计算卷积核 $\bar{K}$,而是计算其生成函数 $\hat{\mathcal{K}}(z) = C^T(I - z\bar{A})^{-1}z\bar{B}$ 在单位根上的取值。生成函数将 $\bar{A}$ 的幂运算转换为了求逆运算。之后可以通过逆FFT从生成函数的求值中恢复出 $\bar{K}$,这需要 $O(L \log L)$ 的操作。

  3. 约简2:Woodbury修正。由于 $\bar{A}$ 本身是DPLR形式,其生成函数中的求逆项 $(I - z\bar{A})^{-1}$ 仍然复杂。通过代数变换和应用Woodbury恒等式,可以将对DPLR矩阵的求逆问题转化为对其对角部分的求逆问题,低秩部分通过一个修正项来处理。

  4. 约简3:柯西核。最终,问题被简化为计算在对角矩阵情况下的SSM生成函数。我们证明这等价于一个柯西核(Cauchy Kernel)的计算。

    • 定义3:节点 $\Omega=(\omega_i) \in \mathbb{C}^M$ 和 $\Lambda=(\lambda_j) \in \mathbb{C}^N$ 上的柯西矩阵或核是 $K(\Omega, \Lambda)_{ij} = \frac{1}{\omega_i - \lambda_j}$。
      这是一个在数值分析中被充分研究的问题,存在快速($\tilde{O}(L+N)$)且稳定的数值算法,如基于快速多极子方法(FMM)的算法。
    • 命题 5 (Cauchy):一个柯西核需要 $O(M+N)$ 空间,操作次数为:

最终,计算卷积核 $\bar{K}$ 的工作主要由常数次调用黑盒的柯西核计算主导,从而证明了定理3。

D 实验细节和完整结果

本节包含了所有实验的详细过程、超参数设置和扩展结果。

  • D.1 基准测试: 提供了与LSSL和高效Transformer比较的基准测试设置,均在单个A100 GPU上进行。
  • D.2 长距离依赖: 提供了LRA和Speech Commands实验的复现细节,包括S4及所有基线模型的超参数、优化器设置等。
    • LRA完整结果见表10。
    • Speech Commands基线模型(Transformer, Performer, ExpRNN, LipschitzRNN, WaveGAN-D)的详细配置。
  • D.3 通用序列建模: 提供了CIFAR密度估计、WikiText-103语言建模、自回归生成速度、序列图像分类和时间序列预测等实验的详细设置。
    • 时间序列预测: 图5展示了S4与Informer架构的对比。S4使用简单的序列到序列方法,而Informer使用复杂的编解码器结构。表13和14提供了在所有50个设置上的完整结果,S4在其中40个上取得了最佳效果。
  • D.4 可视化: 附录中包含了在Pathfinder和CIFAR-10任务上S4学习到的卷积滤波器 $\bar{K}$ 的可视化(图6)。
  • D.5 复现说明: 提供了关于论文发表后部分实验结果更新的说明。
    • LRA: 结果已更新为后续论文【19,On the parameterization and initialization of diagonal state space models,arXiv 2022】,【20,How to train your hippo: State space models with generalized basis projections,arXiv 2022】中的数据。
    • Speech Commands: 建议引用后续论文【19】中使用完整SC数据集(35类)的结果,而非本文中的SC10(10类)子集。
    • WikiText-103: 结果已更新为使用更大批量和上下文长度重新训练后的更优数据。

图5:S4与专业时间序列模型在预测任务上的比较。(左上) 预测任务涉及根据过去的情境预测未来的时间序列值。(左下) 我们使用像S4这样的序列模型作为黑盒进行简单预测。(右) Informer使用专为预测问题设计的编码器-解码器架构,其中包含一个定制的注意力模块(图取自Zhou等人 [50])。
图5:S4与专业时间序列模型在预测任务上的比较。(左上) 预测任务涉及根据过去的情境预测未来的时间序列值。(左下) 我们使用像S4这样的序列模型作为黑盒进行简单预测。(右) Informer使用专为预测问题设计的编码器-解码器架构,其中包含一个定制的注意力模块(图取自Zhou等人 [50])。

图6:(Pathfinder上的卷积滤波器) S4在最佳模型的第一层(上2行)和最后一层(下2行)学习到的滤波器的随机选择。
图6:(Pathfinder上的卷积滤波器) S4在最佳模型的第一层(上2行)和最后一层(下2行)学习到的滤波器的随机选择。