Hungry Hungry Hippos: Towards Language Modeling with State Space Models

  • 文章标题:饥饿的河马:迈向基于状态空间模型的语言建模
  • 作者/机构:Daniel Y. Fu∗†, Tri Dao∗†, Khaled K. Saab‡, Armin W. Thomas††, Atri Rudra‡‡, and Christopher Ré†
    • † 斯坦福大学计算机科学系
    • ‡ 斯坦福大学电子工程系
    • †† 斯坦福大学心理学系
    • ‡‡ 纽约州立大学布法罗分校计算机科学与工程系

A1 主要贡献

本文旨在解决状态空间模型(SSM)在语言建模领域相较于Transformer存在的两个核心问题:模型表达能力不足和硬件利用率低下导致的训练速度慢。

核心问题与研究目标:
1. 表达能力差距:尽管SSM在某些模态(如时间序列、音频)上表现出色,但在语言建模上,其性能(以困惑度PPL衡量)显著落后于Transformer。本文旨在探究此差距是否源于注意力机制固有的归纳偏置和能力。
2. 硬件效率障碍:SSM的计算复杂度随序列长度呈近线性增长($O(N \log N)$),优于Transformer的二次方增长($O(N^2)$)。然而,由于未能有效利用现代硬件(如GPU的张量核心),SSM的实际运行速度反而更慢。本文致力于提升SSM在现代加速器上的训练效率。

主要创新与贡献:

  1. 通过合成任务揭示SSM的表达能力缺陷

    • 本文利用被认为是Transformer上下文学习能力基础的合成语言建模任务,来评估SSM与注意力机制的差距。
    • 研究发现,现有SSM在两项关键能力上存在不足:回忆序列中早期的词元(token)跨序列比较词元
  2. 提出新型SSM层——H3 (Hungry Hungry Hippo)

    • 为解决上述表达能力缺陷,本文设计了H3层。该层通过堆叠两个具有特定结构的SSM(一个使用移位矩阵,一个使用对角矩阵),并引入输入投影与其输出之间的乘法交互,来显式地增强模型的回忆和比较能力。
    • 效果:H3在合成语言任务上表现与注意力机制相当。在OpenWebText数据集上,纯H3模型与Transformer的困惑度差距从3.4缩小到0.4。更重要的是,一个包含两个注意力层的混合H3-注意力模型在OpenWebText上的表现甚至比纯Transformer模型还要好1.0 PPL。

    图1:左:H3堆叠了两个带有移位和对角矩阵的离散SSM,并使用输入投影及其输出之间的乘法交互来模拟序列中点与点之间的比较。中:H3可以执行关联回忆——这对注意力来说很容易,但对现有的SSM来说则不然。右:FlashConv使用一种新的状态传递算法,通过融合的分块FFTConv来提高SSM的硬件效率,使H3能够扩展到十亿参数级别的模型。
    图1:左:H3堆叠了两个带有移位和对角矩阵的离散SSM,并使用输入投影及其输出之间的乘法交互来模拟序列中点与点之间的比较。中:H3可以执行关联回忆——这对注意力来说很容易,但对现有的SSM来说则不然。右:FlashConv使用一种新的状态传递算法,通过融合的分块FFTConv来提高SSM的硬件效率,使H3能够扩展到十亿参数级别的模型。

  3. 提出高效训练算法——FlashConv

    • 为提升SSM的硬件效率,本文提出了FlashConv算法,其灵感来源于IO感知的注意力机制【15, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, Advances in Neural Information Processing Systems】。
    • 核心技术
      • 融合的分块FFT(Fused Block FFT):对于中等长度序列(在A100上可达8K),该算法将FFT计算分解为一系列矩阵乘法,从而利用GPU的专用矩阵乘法单元(如Tensor Cores),并通过核函数融合(kernel fusion)减少内存读写开销。
      • 状态传递算法(State Passing Algorithm):对于超长序列(>8K),该算法利用SSM的循环特性,将输入序列分块处理。每处理完一个数据块,就计算并传递一个状态向量给下一个数据块,从而在保持近线性计算复杂度的同时,将模型扩展到任意序列长度。
    • 效果:FlashConv在Long Range Arena基准测试上取得了2倍的速度提升,并使混合语言模型的文本生成速度比Transformer快2.4倍。
  4. 大规模语言模型验证

    • 利用FlashConv,本文成功将混合H3-注意力语言模型扩展至2.7B参数,并在The Pile数据集上进行了训练。
    • 结果:这些模型在困惑度上优于同等规模的Transformer,并在SuperGLUE基准测试的大多数任务中,在零样本和少样本学习方面与Transformer相当或更优。

A3 背景知识

本节介绍状态空间模型(SSM)和线性注意力的背景,它们是H3层的设计灵感来源。

2.1 状态空间模型

连续时间状态空间表示
* 定义:一个连续时间的状态空间表示【6, Modern control theory, 1974】通过一个状态变量$x(t) \in R^m$定义了从输入信号$u(t) \in R$到输出信号$y(t) \in R$的线性映射。该过程由以下微分方程描述,其中$A \in R^{m \times m}$, $B \in R^{m \times 1}$, $C \in R^{1 \times m}$, $D \in R^{1 \times 1}$是矩阵参数:$\dot{x}(t) = Ax(t) + Bu(t)$, $y(t) = Cx(t) + Du(t)$。

离散时间状态空间表示
* 定义:类似地,离散时间的状态空间表示定义了从离散输入信号$u_i$(其中$i=1, 2, ...$)到离散输出信号$y_i$的线性映射,通过状态变量$x_i \in R^m$实现:


* SSM作为深度学习层:状态空间模型(SSM)将这些表示用作深度学习流程中的一个层,其中矩阵$A, B, C, D$是通过数据学习得到的(例如,使用基于梯度的优化)。通常会并行设置$d$个这样的SSM,每个对应一个隐藏维度。为了保留序列历史,HiPPO【24, Hippo: Recurrent memory with optimal polynomial projections, 2020, Advances in Neural Information Processing Systems】将历史投影到正交多项式基上,这相当于将SSM的$A, B$矩阵初始化为一些特殊的矩阵。
* SSM的循环形式与推理:SSM的这种循环形式支持高效的推理(即生成):要生成下一个时间步的输出,仅需要当前时间步的状态,而无需整个输入历史。此外,SSM可以自由地外推到比训练时更长的序列。

SSM作为卷积
* 卷积表示:对于高效训练,给定完整的输入序列$u_1, ..., u_N$,输出序列$y_1, ..., y_N$也可以写成输入与滤波器【27, Combining recurrent, convolutional, and continuous-time models with linear state-space layers, 2021, Advances in neural information processing systems】的卷积形式:


* 具体形式:从一个初始条件$x_0$出发,我们有$y_i = CA^iBx_0 + (f * u)_i + Du_i$,其中$(f*u)$表示$f$和$u$之间的线性卷积。如果我们将初始条件$x_0$设为零,那么$y$就完全是$u$的线性卷积,并带有一个残差连接$Du$。更一般地,任何线性时不变系统(SSM是其特例)都可以写成卷积形式。
* 符号表示:给定一个长度为$N$的一维输入序列$u \in R^N$,由矩阵$A, B, C, D$参数化的SSM的一维输出序列$y \in R^N$表示为:

* 简化符号与多维输入:为简化符号,如果上下文清晰,我们省略对$A, B, C, D$的引用,写作$y = SSM(u)$。当$u$是$d$维的多维数据时,我们将$d$个这样的SSM堆叠在一起,定义一个从$u \in R^{N \times d}$到$y \in R^{N \times d}$的映射,并使用相同的符号$y = SSM(u)$。
* 滤波器构建:为了高效地从$A, B, C$构建滤波器$f$,$A$通常被约束为对角矩阵【26, On the parameterization and initialization of diagonal state space models, 2022, Advances in Neural Information Processing Systems; 29, Diagonal state spaces are as effective as structured state spaces, 2022, Advances in Neural Information Processing Systems】,或对角加低秩矩阵【25, Efficiently modeling long sequences with structured state spaces, 2022, ICLR】。

通过FFT实现SSM
* 计算效率:通过传统的矩阵运算天真地计算卷积对于长核来说是昂贵的,其复杂度为$O(N^2)$。作为替代,我们可以使用快速傅里叶变换(FFT):对$f$和$u$进行FFT,将它们逐点相乘,然后进行逆FFT。这产生了一个$O(N \log N)$的算法。

2.2 线性注意力

线性注意力与RNN的联系
* 背景:我们描述线性注意力【35, Transformers are RNNs: Fast autoregressive transformers with linear attention, 2020, PMLR】及其与RNN的联系,这启发了我们的模型设计(第3节)。在标准注意力【62, Attention is all you need, 2017, Advances in neural information processing systems】中,我们有$N$个查询/键/值词元$Q_i, K_i, V_i \in R^d$,其中$i=1, ..., N$,$N$是序列长度,$d$是头维度。对于某个相似度度量$Sim: R^d \times R^d \rightarrow R$,我们希望计算输出:


* 线性注意力的假设:对于标准的softmax注意力,$Sim(q, k) = e^{q^Tk}$(通常点积会用$1/\sqrt{d}$进行缩放)。线性注意力的假设是$Sim$具有$Sim(q, k) = \phi(q)^T\phi(k)$的形式,其中$\phi$是某个(非线性)函数。输出则为$O_i = \frac{\phi(Q_i)^T \sum_{j=1}^i \phi(K_j)V_j^T}{\phi(Q_i)^T \sum_{j=1}^i \phi(K_j)}$。令$S_i = \sum_{j=1}^i \phi(K_j)V_j^T \in R^{d \times d}$,$z_i = \sum_{j=1}^i \phi(K_j) \in R^d$,$d_i = \phi(Q_i)^T z_i \in R$。那么$O_i = \frac{\phi(Q_i)^T S_i}{d_i}$。
* 与RNN的联系:这揭示了线性注意力与RNN的联系:输出$O_i$是$S_i$和$z_i$的函数,而$S_i$和$z_i$都是以累加和的形式增量更新的。

A2 方法细节

3 Hungry Hungry Hippos层:用于建模离散序列

为了理解SSM和注意力在语言建模上的差距,我们研究了两个合成语言建模任务。这些任务启发我们设计了H3层,通过增加一个基于移位矩阵的离散SSM和乘法交互来有效地建模离散序列。我们接着展示了H3层具有足够的表达能力来解决这些合成任务,并且这种理解也带来了在真实语言建模基准上更好的性能。

3.1 动机:合成语言建模任务

合成任务描述
* 任务概述:我们描述了两个密切相关的合成任务,总结在表1中。Olsson等人【49, In-context learning and induction heads, 2022, Transformer Circuits Thread】认为,解决这些任务(的变体)的能力是Transformers上下文学习能力的主要来源。更多直觉见附录E。
表1:合成语言建模任务。


* 归纳头(Induction Head)任务:该任务测试模型在一个特殊标记(例如表1中的`)之后回忆内容的能力。在序列末尾,模型必须回忆起紧跟在特殊标记之后出现的那个词元。
* 关联回忆(Associative Recall)任务:该任务【1, Using fast weights to attend to the recent past, 2016, Advances in neural information processing systems】与归纳头任务类似,但要求模型记住多个键值对。在序列末尾,模型必须回忆起属于特定键的特定值。

现有SSM的局限性
* 实验发现:表2(针对两层模型)显示S4D【26, On the parameterization and initialization of diagonal state space models, 2022, Advances in Neural Information Processing Systems】和Gated State Spaces【42, Long range language modeling via gated state spaces, 2022, arXiv preprint】都无法对这些合成语言进行建模,这表明它们可能不具备通用语言所需的表达能力。
表2:两层模型在合成语言任务上的评估。


* 能力缺失分析:我们认为这些失败表明了两种能力的缺失:(i) 记住在特定事件(例如归纳头任务中的特殊标记)之后出现的词元的能力,以及(ii) 跨序列比较词元的能力(例如,比较键来决定回忆哪个值)。
* 注意力的优势:注意力机制具备这两种能力:它可以通过构建二次注意力矩阵$QK^T$来比较词元,并且可以通过将$\text{softmax}(QK^T)$与$V$相乘来直接复制(回忆)词元。在3.2节中,我们设计了新的H3层,旨在为SSM赋予这些能力,从而缩小SSM与注意力之间的表达能力差距。

3.2 H3层

H3层设计
* 核心组件:H3使用带有移位(shift)和对角(diagonal)矩阵的SSM,并结合与输入投影的乘法操作,来捕获通过合成任务识别出的缺失能力。

高层直觉
* (i) 记忆过去词元:我们希望状态$x_i$能够从输入$u_i$中复制信息,然后将该信息传递给下一个状态$x_{i+1}$。由于$x_{i+1}$通过$Ax_i$与$x_i$相关联,我们使用一个带有移位矩阵$A$的离散SSM(下文将正式描述),该矩阵能够移动状态向量的元素(例如,将$[a, b, c]$映射为$[0, a, b]$)。
* (ii) 跨序列比较词元:我们使用乘法交互:一个SSM的输出(包含来自先前时间步的信息)与当前时间步的输入相乘,从而衡量词元之间的相似性。

H3层与线性注意力的联系
* 设计灵感:H3的设计松散地受到线性注意力(第2节)的启发:我们将输入$u$投影得到三个信号$Q, K, V$。然后,我们用一个$A$为移位矩阵的SSM($SSM_{shift}$)替换非线性项$\phi(K)$,并用一个$A$为对角矩阵的SSM($SSM_{diag}$)替换求和项$S_i$。对于头维度$d_h=1$的情况,输出为:


* 结构解读:其中$\odot$表示逐点乘法。我们可以将这种形式看作是堆叠了两个带有乘法交互的SSM(每个都是一个“饥饿的河马”,因此得名)。附录B中可以找到线性注意力、时变系统和H3之间更正式的联系。

移位和对角SSM:记忆关键Token
* 设计目的:移位SSM和对角SSM旨在解决在特定事件后记录词元的能力。
* 移位SSM (Shift SSM):我们约束$A \in R^{m \times m}$为一个移位矩阵,即当$i-1=j$时$A_{i,j}=1$,否则为0。该矩阵作用于隐藏状态$x_i$时,会将每个坐标向下移动一位,从而创建一个先前状态的“记忆”。例如,如果$B=e_1$(第一个基向量),那么$x_i = [u_i, u_{i-1}, ..., u_{i-m+1}]$就包含了前$m$个时间步的输入。我们学习$B$和$C$(为简单起见,$B$也可以固定为$e_1$,此时输出是一个核大小为$m$的一维卷积)。
* 对角SSM (Diagonal SSM):我们将$A$约束为对角矩阵,并从HiPPO的对角版本(S4D【26, On the parameterization and initialization of diagonal state space models, 2022, Advances in Neural Information Processing Systems】)进行初始化。这种参数化允许模型在整个序列上记住状态。移位SSM可以检测特定事件何时发生,而对角SSM可以在事件发生后记住一个词元并贯穿整个序列。

乘法交互:实现比较功能
* 功能来源:我们从线性注意力中借鉴了乘法交互,但当它们与移位矩阵结合时,提供了另一种缺失的能力:跨序列比较词元。
* 具体作用:移位SSM的输出与$V$投影之间的乘法交互模仿了线性注意力中的局部乘法交互(取决于隐藏状态的大小)。类似地,$Q$投影与对角SSM输出之间的乘法交互允许在整个序列上对词元进行比较。

H3层算法
* 整体结构:H3层的整体流程在算法1中给出,并以示意图形式展示在图1(左)中。我们使用H3层以与Transformer相同的方式构建模型,即将其与MLP交错排列,并通过残差连接和层归一化连接(即pre-norm架构【2, Adaptive input representations for neural language modeling, 2018, ICLR】)。我们还将考虑一个混合H3-注意力模型(保留两个注意力层,其余为H3层,见3.3节和第5节)。

算法1 H3层
* 输入: 来自上一层的输入序列$u \in R^{N \times d}$,权重矩阵$W_Q, W_K, W_V, W_O \in R^{d \times d}$,一个移位SSM $SSM_{shift}$,一个对角SSM $SSM_{diag}$,头维度$d_h$。
* 流程:
1. 计算 $Q = uW_Q$, $K = uW_K$, $V = uW_V \in R^{N \times d}$。
2. 将$K$通过移位SSM:$\hat{K} = SSM_{shift}(K) \in R^{N \times d}$。
3. 将$Q, \hat{K}, V$分割成$H$个“头” ($Q^{(h)}, \hat{K}^{(h)}, V^{(h)}$ for $h=1, ..., H$),每个头是$N$个大小为$d_h=d/H$的向量序列。
4. 对于$1 \le h \le H$:
5. 计算批次化的外积$\hat{K}^{(h)}(V^{(h)})^T \in R^{N \times d_h \times d_h}$(在$N$维度上批处理),并将其通过一个对角SSM:$KV^{(h)} = SSM_{diag}(\hat{K}^{(h)}(V^{(h)})^T) \in R^{N \times d_h \times d_h}$。
6. 通过$Q$进行批次乘法:$O^{(h)} = [Q_1^{(h)}KV_1^{(h)}, ..., Q_N^{(h)}KV_N^{(h)}] \in R^{N \times d_h}$(在$N$维度上批处理)。
7. 结束循环。
8. 拼接每个头的输出$O^{(h)}$,并乘以输出投影矩阵$W_O \in R^{d \times d}$。

效率分析
* 计算复杂度:我们证明H3的计算复杂度随序列长度$N$呈$O(N \log N)$增长,这在渐进意义上比注意力更高效,后者通常需要$O(N^2 d)$的时间和$O(N^2)$的空间(证明见附录D.3)。
* 命题1:设$N$为序列长度,$d$为隐藏维度,并假设头维度$d_h$为$O(1)$阶。则H3层的计算需要$O(d^2N + dN \log N)$的时间和$O(dN)$的空间。

3.3 表达能力

H3的表达能力验证
* 综合任务与自然语言:我们证明了H3可以对我们的合成语言进行建模,并且在OpenWebText【23, Openwebtext corpus, 2019】上也能很好地处理自然语言。我们还提出了一种混合H3-注意力扩展模型,其在OpenWebText上的表现优于Transformers。

H3解决关联回忆任务的机制
* 合成任务上的成功:如表2所示,H3在表达能力上足以解决我们的合成语言建模任务。图1(中)展示了一个单层H3解决特定键值对(a, 3)关联回忆任务的机制。
* 机制详解:移位SSM及其后的乘法交互充当一个门控,根据前一个词元是否为键a来决定是否让值3通过到对角SSM。对角SSM将值3存储在内存中,并持续输出它。最后的乘法交互则根据当前输入词元是否为键a,来决定是否让对角SSM的输出通过。我们在附录D.1中正式构建了一个H3层的权重来解决此任务。

从合成语言到自然语言的转化
* 性能验证:我们验证了当H3能够解决这些合成任务时,它在自然语言(例如,在OpenWebText数据集上)上的建模能力也得到了提升。如表3所示,在OpenWebText上训练50B个词元后,H3的困惑度与Transformers相差不到0.4点,并且比现有的SSM变体(S4D, GSS)性能好得多,差距达到3-3.9点。
表3:SSM变体与Transformers在OpenWebText上的困惑度对比。所有模型均为12层,大小约为125M,并使用相同的超参数训练50B个词元。

扩展:H3-注意力混合模型
* 模型设计与性能:一个简单的混合H3-注意力语言模型出人意料地在OpenWebText上比Transformers性能好1.0个点。我们的混合模型仅保留了两个自注意力层:一个在第二层,一个在中间层(对于一个$N$层模型,$N$为偶数,则在第2层和第$N/2$层)。H3-注意力混合模型也优于GSS-注意力混合模型【42, Long range language modeling via gated state spaces, 2022, arXiv preprint】。

4 FlashConv:高效训练SSM

FlashConv简介
* 目标:为了提高SSM在现代硬件上的效率,我们提出了FlashConv。
* 核心思想:FlashConv融合了FFT、逐点乘法和逆FFT,以减少内存读写。它还使用了一种分块FFT算法,以便在序列长度达到8K时能够利用专门的矩阵乘法单元(例如A100上的张量核心)。
* 长序列处理:对于长度超过8K的序列,计算不再能容纳于GPU SRAM中,因此我们提出了一种新颖的状态传递算法,该算法将序列分成块,每次计算一个块的FFT卷积。FlashConv可以加速任何SSM(不仅仅是H3)。

4.1 融合的分块FFTConv

短序列加速技术
* 两种技术:我们部署了两种技术来加速序列长度小于8K的基于FFT的卷积:核函数融合(kernel fusion)和分块FFT(block FFT)。
* 功能:核函数融合解决了由于读写中间结果导致的IO瓶颈,而分块FFT则允许基于FFT的卷积利用专门的矩阵乘法单元。这些技术使我们能够将FFTConv的速度提高2倍(第6节),适用于长度小于8k的序列。

核函数融合
* 问题:使用标准库(如cuFFT)的FFTConv朴素实现是IO密集型的,因为需要重复读写中间结果。SSM中输入为$u$、滤波器为$f$的FFT卷积形式为$\text{iFFT}(\text{FFT}(u) \odot \text{FFT}(f))$(其中$\odot$表示逐点乘法)。这需要将中间结果读写到GPU内存,这可能主导运行时间。
* 解决方案:借鉴FlashAttention【15, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, Advances in Neural Information Processing Systems】,我们首先将整个FFTConv融合成一个单一的内核,并在SRAM中计算,以避免这种开销。

分块FFT
* 目标:为了进一步加速基于FFT的卷积计算,我们利用了现代GPU上的专用矩阵乘法硬件(例如,Nvidia GPU上的Tensor Cores可以执行快速的16×16矩阵乘法)。
* 方法:我们借鉴了经典结果,即FFT可以写成一系列块对角矩阵乘法与置换(permutation)的交错。我们注意到,这类算法并非新生事物,但我们的应用场景(GPU上融合的FFTConv)引入了新的瓶颈——通过消除IO瓶颈,计算成为了瓶颈(注意,GPU上的单个FFT通常是IO密集型的)。
* Cooley-Tukey分解:假设我们要执行一个$N$点FFT,这等同于乘以DFT矩阵$F_N$。假设$N=N_1N_2$,其中$N_1, N_2$为整数。根据DFT的Cooley-Tukey分解【3, FFTs in external or hierarchical memory, 1990, The journal of Supercomputing; 11, An algorithm for the machine calculation of complex fourier series, 1965, Mathematics of Computation】(也称为四步FFT算法),我们可以将$F_N$写为$F_N = P(I_{N_2} \otimes F_{N_1})P^TD(I_{N_1} \otimes F_{N_2})P$,其中$P$表示一个固定的置换,它将输入重塑为$N_1 \times N_2$数组然后转置,$\otimes$表示克罗内克积,$D$是一个$N \times N$的对角矩阵(称为旋转因子)【14, Monarch: Expressive structured matrices for efficient and accurate training, 2022, ICML】,而$I_{N_i}$和$F_{N_i}$分别是大小为$N_i \times N_i$的单位矩阵和DFT矩阵。由于$I_{N_2} \otimes F_{N_1}$和$I_{N_1} \otimes F_{N_2}$只是块对角矩阵,我们可以利用专门的矩阵乘法单元来执行这些乘法。类似地,如果$N=N_1N_2N_3$,我们可以将$N$点FFT分解为一系列大小为$N_1, N_2, N_3$的(分块)FFT,并用置换交错。
* FLOPs分析:如果$N$可以写成$r^p$($r,p$为整数),分块FFT算法对于序列长度$N$会产生$O(N r \log N / \log r)$的FLOPs。这比标准FFT($O(N \log N)$)的FLOPs更多,但当我们使用专用矩阵乘法硬件时,运行速度可以更快。

4.2 状态传递

长序列处理的挑战与解决方案
* 问题:当序列过长以至于无法装入GPU SRAM时(在A100上超过8K),融合的内核无法运行。
* 解决方案:我们展示了如何利用SSM中FFT的特定形式来加速长序列的处理。SSM的循环特性允许我们将一个长度为$N$的序列的FFTConv分解成大小为$N_0$的块($N_0$是我们可以装入SRAM的最长FFT长度,假设$N$是$N_0$的倍数)。我们使用FFTConv计算每个块,并使用一个递推关系来连接这些块。

状态传递算法细节
* 分块处理:具体来说,我们将输入$u$分成$C=N/N_0$个块$u^{(c)} \in R^{N_0}$,其中$c=1, ..., C$。类似地,将状态$x$分成$x^{(c)} \in R^{N_0 \times m}$,输出$y$分成$y^{(c)} \in R^{N_0}$。我们只需要每个块$c$的最终状态$x_{N_0}^{(c)}$。
* 递推关系:令$f = [CB, CAB, CA^2B, ..., CA^{N_0-1}B]$为SSM滤波器。回顾第2节,对于每个块$c$,有$y_i^{(c)} = CA^iBx_{N_0}^{(c-1)} + (f*u^{(c)})_i + Du_i^{(c)}$,因为前一个块$(c-1)$的最终状态$x_{N_0}^{(c-1)}$是当前块$c$的初始条件。用向量表示法,$y^{(c)} = M_{xy}x_{N_0}^{(c-1)} + f*u^{(c)} + Du^{(c)}$,其中$M_{xy} \in R^{N_0 \times m}$是某个矩阵。此外,我们需要用$x_{N_0}^{(c)} = A^{N_0}x_{N_0}^{(c-1)} + M_{ux}u^{(c)}$来更新每个块的最终状态,其中$M_{ux}$是某个$m \times N_0$的矩阵(推导见附录C.2)。
* 算法本质:本质上,只要我们记住前一个块的最终状态,就可以用基于FFT的卷积计算每个块的输出,并且每个块的最终状态可以通过递推方式更新。这产生了一个用于长序列的状态传递算法,其中我们只计算长度为$N_0$的FFT,并在每次迭代中更新一些隐藏状态。
* 算法描述:令BlockFFTConv指代我们的融合分块FFTConv内核。那么,对于一维输入的状态传递算法由算法2给出。对于维度为$d$的输入,我们堆叠$d$个SSM,只需在$d$维度上批处理算法2即可。

算法2 状态传递算法
* 输入: 输入$u \in R^N$,由矩阵$A \in R^{m \times m}, B \in R^{m \times 1}, C \in R^{1 \times m}, D \in R^{1 \times 1}$参数化的SSM,块大小$N_0$($N$是$N_0$的倍数)。
* 流程:
1. 预计算$A^{N_0} \in R^{m \times m}$,$M_{ux} = [A^{N_0-1}B, ..., B] \in R^{m \times N_0}$,$M_{xy} = [C, CA, ..., CA^{N_0-1}]^T \in R^{N_0 \times m}$。
2. 将输入$u_{1:N}$分割成$C = N/N_0$个块$u_{1:N_0}^{(c)}$,其中$c=1, ..., C$。
3. 令初始状态为$x_{N_0}^{(0)} = 0 \in R^m$。
4. 对于$1 \le c \le C$:
5. 计算$y^{(c)} = M_{xy}x_{N_0}^{(c-1)} + \text{BlockFFTConv}(f, u_j) + Du^{(c)} \in R^{N_0}$。
6. 更新状态:$x_{N_0}^{(c)} = A^{N_0}x_{N_0}^{(c-1)} + M_{ux}u^{(c)}$。
7. 结束循环。
8. 返回$y = [y^{(1)} ... y^{(C)}]$。

算法正确性
* 证明:我们证明算法2产生的输出与使用大小为$N$的大FFT计算SSM的输出相同(证明见附录D.4):
* 命题2:对于输入$u \in R^N$和矩阵$A, B, C, D$,算法2返回的输出$y \in R^N$与由$A, B, C, D$参数化的SSM定义的输出相同。

A4 实验

实验环境

  1. 数据集

    • 语言建模
      • The Pile【21, The pile: An 800gb dataset of diverse text for language modeling, 2021, arXiv preprint】: 用于训练125M至2.7B参数模型的核心数据集,训练量达400B tokens。
      • OpenWebText【23, Openwebtext corpus, 2019】: 用于125M模型与Transformer的详细对比。
      • WikiText-103【43, Pointer sentinel mixture models, 2016】: 用于评估模型的零样本迁移能力。
      • PG-19【54, Compressive transformers for long-range sequence modelling, 2019, ICLR】: 用于评估长文本建模能力。
    • 合成任务
      • Induction Head & Associative Recall: 用于诊断模型(特别是H3)的表达能力。
    • 长序列基准
      • Long Range Arena (LRA)【59, Long range arena: A benchmark for efficient transformers, 2020, ICLR】: 用于评估FlashConv的加速效果。
    • 非文本序列建模
      • TUSZ v1.5.2 EEG Corpus【56, The temple university hospital seizure detection corpus, 2018, Frontiers in neuroinformatics】: 用于癫痫分类任务。
      • Speech Commands (SC10)【64, Speech commands: A dataset for limited-vocabulary speech recognition, 2018, arXiv preprint】: 用于原始音频分类任务。
      • fMRI数据集【60, Self-supervised learning of brain dynamics from broad neuroimaging data, 2022, arXiv preprint; 61, The wu-minn human connectome project: an overview, 2013, Neuroimage; 38, Functional boundaries in the human cerebellum revealed by a multi-domain task battery, 2019, Nature neuroscience】: 用于脑功能状态解码。
  2. 模型架构

    • H3-Attention Hybrid Models: 在125M, 355M, 1.3B, 2.7B四个尺寸上进行训练。
      • 125M: 12层,隐层维度1024,MLP维度4096,12个头。注意力层在第1和第7层。
      • 355M: 24层,隐层维度1024,MLP维度4096,16个头。注意力层在第1和第13层。
      • 1.3B: 24层,隐层维度2048,MLP维度8192,16个头。注意力层在第1和第13层。
      • 2.7B: 32层,隐层维度2560,MLP维度10240,20个头。注意力层在第10和第21层。
    • H3特定参数:SSM状态大小为64。混合模型中H3的头维度为1,纯H3模型为8。
    • 基线模型: GPT-2, GPT-Neo, OPT等同等规模的Transformer模型。
  3. 硬件配置

    • GPU: 训练在单个包含16块A100-40GB GPU的节点或一个由8块A100-80GB GPU组成的集群上进行。
    • 基准测试: 速度测试在A100-SMX4-40GB GPU上进行。
  4. 软件配置

    • 代码实现: 基于PyTorch,使用混合精度训练(AMP),其中MLP和注意力部分使用bf16,FFTConv部分使用fp32。
    • Tokenizer: 使用GPT-2 BPE tokenizer。
    • 优化器: AdamW。
    • 依赖库: HuggingFace Transformers【65, Transformers: State-of-the-art natural language processing, 2020, EMNLP】。

实验结果

H3 语言建模评估 (Section 5)

本节评估混合H3-注意力模型在困惑度、零样本/少样本学习以及推理速度方面与Transformer的对比。

  1. 困惑度 (Perplexity)

    • 实验内容:在The Pile、OpenWebText和WikiText-103上评估了125M到2.7B参数的混合H3模型,并与GPT-Neo和GPT-2进行比较。
    • 实验结果
      • 在The Pile上,125M的混合H3模型优于同样在该数据集上训练的GPT-Neo(表4)。
      • 在向OpenWebText和WikiText-103的零样本迁移任务中,混合H3模型同样优于GPT-Neo和GPT-2(表4)。
    • 分析结论:混合H3模型在困惑度指标上全面超越或持平于同等规模的Transformer模型,证明了H3层在提升语言建模能力上的有效性。

    表4:模型在The Pile, OpenWebText和WikiText-103上的困惑度(越低越好)。GPT-Neo和混合H3模型在The Pile上训练,而GPT2在WebText上训练。所有模型使用相同的GPT2分词器。我们报告了GPT-2模型在The Pile上的困惑度(*)作为参考,但由于训练数据不同,性能不直接可比。

  2. 零样本和少样本学习性能

    • 实验内容:在SuperGLUE基准上,比较了混合H3模型与OPT、GPT-Neo、GPT-2的零样本和3样本性能。
    • 实验结果
      • 零样本:混合H3模型在超过一半的任务上表现优于或持平于最好的Transformer基线(表5)。
      • 3样本:趋势与零样本一致,混合模型在多数任务上取得领先或有竞争力的结果(表6)。
    • 分析结论:在下游任务的零/少样本学习上,混合H3模型展现出与Transformer相当甚至更强的能力,表明其学到的语言表示是有效且通用的。

    表5:在SuperGLUE上的零样本准确率(使用logit评分)。最好结果加粗,次好结果下划线。

    表6:在SuperGLUE上的3样本准确率(使用logit评分)。最好结果加粗,次好结果下划线。

  3. 推理速度

    • 实验内容:比较了1.3B参数的混合H3模型和Transformer在文本生成任务上的推理吞吐量。
    • 实验结果:由于SSM的循环特性,混合H3模型的推理吞吐量最高可达Transformer的2.4倍,且序列越长,优势越明显(表7)。
    • 分析结论:H3-Attention混合模型在保持高质量的同时,显著提升了生成任务的效率。

    表7:在A100 80GB上,1.3B模型的推理吞吐量。批量大小为64,提示长度为512、1024或1536,每个序列生成128个token。混合H3的推理速度比同等规模的Transformer快高达2.4倍。序列越长,差异越大。

FlashConv 效率评估 (Section 6)

本节评估FlashConv对SSM的加速效果。

  1. Long Range Arena (LRA) 基准测试

    • 实验内容:使用FlashConv加速S4模型,并在LRA基准上测试其性能。
    • 实验结果:FlashConv将S4的速度提升了2倍,总体性能比Transformer快5.8倍,创造了新的SOTA速度记录(表8)。
    • 分析结论:FlashConv能有效加速现有的SSM模型,在长序列任务上展现出巨大优势。

    表8:在LRA基准测试上的加速效果。

  2. H3模块与Attention的速度对比

    • 实验内容:在不同序列长度(256到32K)下,测试了使用FlashConv的H3模块前向和后向传播的时间,并与FlashAttention进行对比。
    • 实验结果
      • FlashConv(结合分块FFT和状态传递)比基于cuFFT的朴素FFTConv实现快2-3倍(图2)。
      • 短序列 (<=512): 内核融合提供了高达3.4倍的加速。
      • 中序列 (1k-8k): 分块FFT提供了高达2倍的加速。
      • 长序列 (>=16k): 状态传递算法使得FFTConv快了2.3倍。
      • 随着序列长度增加,H3的运行时间呈近线性增长,而Attention则呈二次方增长,使得H3在长序列上比最快的Attention实现快数十倍(图2)。
    • 分析结论:FlashConv通过多层次优化(内核融合、分块FFT、状态传递)显著提升了SSM的硬件效率,使其在各种序列长度上都比Attention更快,尤其是在长序列场景下。

    图2:我们比较了不同算法执行基于FFT的卷积的速度,以及FlashAttention【15, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, Advances in Neural Information Processing Systems】(我们所知的最快的注意力实现)。我们使用批量大小8,隐藏维度1024,在A100-SMX4-40GB GPU上测量了从256到32k的不同序列长度。我们看到,对于短序列(最大512),内核融合比朴素的FFTConv快达3.4倍;对于中等长度序列(1k-8k),分块FFT快达2倍;对于长序列(16k及以上),状态传递使得FFTConv快了2.3倍。
    图2:我们比较了不同算法执行基于FFT的卷积的速度,以及FlashAttention【15, Flashattention: Fast and memory-efficient exact attention with io-awareness, 2022, Advances in Neural Information Processing Systems】(我们所知的最快的注意力实现)。我们使用批量大小8,隐藏维度1024,在A100-SMX4-40GB GPU上测量了从256到32k的不同序列长度。我们看到,对于短序列(最大512),内核融合比朴素的FFTConv快达3.4倍;对于中等长度序列(1k-8k),分块FFT快达2倍;对于长序列(16k及以上),状态传递使得FFTConv快了2.3倍。

附加实验 (Appendix F)

  • LRA准确率: H3在LRA基准测试上表现出色,在部分任务上优于S4D,证明其不仅适用于自回归语言建模(表9)。
  • WikiText103 & PG-19: 在这两个数据集上,混合H3模型与同等规模的Transformer相比具有竞争力,并显著优于Performer、Reformer和线性注意力等高效注意力变体(表10,表11)。
  • 长度外推: 在关联回忆合成任务上,H3模型在训练序列长度为20的情况下,能在长度为40的测试序列上保持高准确率,展示了其良好的长度外推能力(表12)。
  • 非文本序列建模: H3在原始音频分类(SC10)和原始EEG信号的癫痫分类任务上表现优于Transformer,并在癫痫分类上达到SOTA性能,显示其作为多模态基础模型的潜力(表18,表19)。

A5 结论

本文旨在理解并缩小注意力机制与状态空间模型(SSM)在语言建模领域的模型能力和硬件效率差距。

主要研究成果:
1. 模型设计 (H3):通过合成语言任务的探索,我们发现现有SSM在回忆比较能力上存在不足。为此,我们设计了H3层,该层通过堆叠两个SSM并引入乘法交互,显著提升了SSM的表达能力,使其在语言建模任务上表现出与注意力机制惊人地具有竞争力。
2. 算法优化 (FlashConv):我们提出了BlockFFTConv算法,该算法利用GPU的矩阵乘法单元以及SSM的循环-卷积双重视图,大幅提升了SSM的计算速度,从而降低了注意力与SSM之间的硬件壁垒。

未来工作展望:
- 更优的模型设计:H3只是两个SSM的简单组合,未来可以探索更复杂的结构以获得更强的表达能力。
- 更大规模的扩展:鉴于在高达2.7B参数模型上的积极结果,将SSM扩展到更大规模是一个充满希望的研究方向。
- 混合模型的潜力:仅在H3模型中加入两个注意力层就能同时超越纯H3模型和纯Transformer模型,这表明SSŠM和注意力具有互补优势。未来结合两者的优点将是一个富有成效的研究方向。

A6 附录

B 线性注意力与时变系统

将线性注意力与LTI系统和SSM建立联系
* 线性注意力作为线性时变系统:我们首先将线性注意力表示为一个线性时变系统,然后展示如何将其转换为一个线性时不变系统,从而与H3建立联系。
* 线性时变系统:通常,序列模型中的层接收一个序列并输出一个序列。许多层可以表示为线性时变系统(根据Picard-Lindelof定理,非线性系统可以由一系列线性系统逼近):


这与SSM(第2节)的形式相同,只是矩阵可以依赖于时间步。
* 线性注意力的递推形式:回顾第2节中线性注意力的形式。为了近似,我们忽略分母(即假设$d_i=1$)。我们看到$S_i$只是一个累加和,满足递推关系$S_{i+1} = S_i + \phi(K_{i+1})V_{i+1}^T$。类似地,$O_i$满足递推关系$O_{i+1} = \phi(Q_{i+1})^T S_{i+1}$。这是一个线性时变系统,其形式为$x_{i+1} = Ax_i + Bu_{i+1}$和$y_{i+1} = C_{i+1}x_{i+1}$(其中$A=I, B=I, u_i = \phi(K_i)V_i^T, C_i = \phi(Q_i)^T$)。也就是说,$A$和$B$是常数,但$C$是时变的。
* 转换为线性时不变系统:为了将其转换为线性时不变版本,我们将时变的$C_i$视为一个后处理步骤。我们为LTI使用一个固定的$C$。这产生了一个LTI:

其中$A, B, C$是学习的矩阵。然后我们通过将$y_{i+1}$与$\phi(Q_i)^T$相乘来进行后处理。
* 与H3的类比:将$\phi(K_i)$替换为一个移位SSM,就得到了一个与H3类似的模型。

C 方法细节

C.1 反向传播
* 融合核中的反向传播:我们展示如何在融合核中计算反向传播。
* 设置:令$y = f * u + Du$。在我们的例子中,$f$和$u$长度相同,因此在卷积方面它们是对称的。
* 目标:给定$dy = \frac{\partial l}{\partial y}$(其中$l$是某个损失函数),我们希望计算$du, df$和$dD$(它们分别是$\frac{\partial l}{\partial u}, \frac{\partial l}{\partial f}, \frac{\partial l}{\partial D}$)。
* 直接梯度:最直接的部分是$dD = dy u^T$。最挑战的部分是通过卷积算子的梯度计算,但我们将看到可以复用我们的FFT基础设施。
* 卷积的梯度:这里我们讨论如何通过对卷积算子$*$积分来计算$df$。一个直接的结果是,我们也能计算$du$。由于$f$和$u$长度相同(设为$L$),$f*u$和$u*f$结果相同。因此我们从$u*f$开始。
* 符号:令$O = u * f$。那么$dO = dy$。回顾$O[i] = \sum_{j=0}^{i-1} u[i-j]f[j]$。
* 零填充与Toeplitz矩阵:我们首先用零将$u$和$f$扩展到长度$2L$,得到$u'$和$f'$。令$O' = u' * f'$,则$O=O'[:N]$。我们构造一个Toeplitz矩阵$H_{u'}$使得$u' * f' = H_{u'}f'$:


由于当$i \ge L$时$u'[i] = f'[i] = 0$,我们可以填入矩阵中的零:

* 链式法则应用:然后,我们可以使用矩阵乘法的链式法则来发现:

其中我们用$u'[-i]$表示$u'[2L-i]$。注意到这个矩阵与$H_{u'}$具有相同的格式!令$u'^* = [u'[0], u'[-1], ..., u'[-(2N-1)]]$。那么:

* 利用DFT性质高效计算:直接计算$u'^*$可能会导致棘手的内存访问问题。但DFT的一个很好的性质拯救了我们。令$U[i]$是信号$u$的DFT的第$i$个元素。注意$U[i]$是复数。我们有:

其中$^*$代表复共轭。我们可以利用这个性质高效地计算$df'$:

其中$FFT^*$表示取FFT的复共轭,$dy'$表示用零填充的$dy$。
* 计算du:我们可以使用同样的技巧来计算$du$,只需要加上来自原始$Du$项的贡献。我们最终得到:

C.2 状态传递矩阵
* M_ux的推导:我们展示如何为我们的状态传递算法中的状态更新推导$M_{ux}$。
* 构造:我们希望构造一个矩阵$M_{ux} \in R^{m \times N_0}$,使得$M_{ux}u = \sum_{i=1}^{N_0} A^{N_0-i}Bu_i$。注意到$A^{N_0-i}B \in R^{d \times 1}$是一个列向量。我们可以简单地将这些列向量堆叠起来形成一个矩阵:$M_{ux} = [A^{N_0-1}B, A^{N_0-2}B, ..., B]$。

D 证明

D.1 H3 表达能力
* 证明H3解决关联回忆任务的参数化:本节正式描述了一个可以解决关联回忆任务的H3参数化。
* 示例语言Λ:考虑一个有4个键$\{k_1, k_2, k_3, k_4\}$和4个值$\{v_1, v_2, v_3, v_4\}$的简单语言。给定一个键值对序列,并在末尾有一个键,模型需要生成与该键关联的值。
* H3模型设置:我们描述一个可以解决Λ的玩具H3模型。模型由嵌入层、H3模型和带softmax的输出投影组成。设模型维度$d=8$,隐藏状态维度$m=2$,头数$H=4$。嵌入层将键$k_i$映射到基向量$e_i$,值$v_i$映射到$e_{4+i}$。
* SSM参数设置:对于移位SSM,设$B_{shift}=e_1$, $C_{shift}=[0, 1]^T$。对于对角SSM,设$B_{diag}=C_{diag}=e_1$, $A_{diag}$为对角线上为1的对角矩阵。这样,移位SSM的作用是将其输入延迟一个时间步,而对角SSM的作用是对其所有输入进行累加。
* 投影矩阵设置
* $W_Q$和$W_K$被设置为:


* $W_V$被设置为:

* 作用:$W_Q$和$W_K$的作用是为每个键分配一个不同的H3头。$W_V$的作用是将输入值编码为“二进制”形式,并发送到所有H3头。
* 命题3:上述描述的模型解决了关联回忆问题。
* 证明概要:不失一般性,假设$x_{N+1}=k_i$。那么$Q^{(i)}=[1,1]$,而对于$j \ne i$,$Q^{(j)}=[0,0]$。因此,由于乘法交互,对于$j \ne i$,$O^{(j)}=[0,0]$。对于头$i$,其输出$O^{(i)}$是对应对角SSM的输出。该对角SSM的输入仅在$x_{t-1}=k_i$时非零,此时$x_t=f_x(k_i)$。因此,对角SSM的输入恰好是$f_x(k_i)$的二进制编码。输出$O^{(i)}$是该编码的倍数,输出投影$W_O$将此输出解码为$f_x(k_i)$的嵌入。

D.2 注意力表达能力
* 两层注意力模型解决关联回忆的草图:我们提供一个非正式的草图,说明一个两层注意力模型如何解决关联回忆任务,灵感来自【49, In-context learning and induction heads, 2022, Transformer Circuits Thread】的构造。
* 第一层:信息移位:第一层注意力输出序列中前一个词元的嵌入,并将其与当前词元的嵌入拼接。通过将$Q_i$映射到$x_{i-1}$的位置嵌入,$K_i$映射到$x_i$的位置嵌入,$\text{softmax}(QK^T)$近似于一个移位矩阵,从而将$V_j$($x_j$的编码)的信息“移位”到$O_{j+1}$。
* 第二层:键值查找:第二层的输入同时编码了$x_{i-1}$和$x_i$。令$Q_i$提取$x_i$的编码,$K_i$提取$x_{i-1}$的编码。那么$\text{softmax}(QK^T)_{i,j}$在$x_i=x_{j-1}$时值较大。令$V_i$提取$x_i$的编码。最终输出$O_i$是所有满足$x_{j-1}=x_i$的$x_j$的值的总和,这正好实现了对先前序列中$x_i$后面出现的词元的查找,从而解决了关联回忆。

D.3 H3 复杂度
* 命题1证明:我们证明了命题1,即H3层对于序列长度$N$和隐藏维度$d$需要$O(d^2N + dN \log N)$的时间和$O(dN)$的空间。
* 时间复杂度分析
1. 输入$u \in R^{N \times d}$与三个$d \times d$权重矩阵相乘需要$O(d^2N)$时间。输出$O$与一个$d \times d$投影矩阵相乘也需要$O(d^2N)$。
2. 第一个SSM涉及对$K \in R^{N \times d}$的卷积,需要$O(dN \log N)$时间。
3. 第二个SSM涉及$H$个卷积,输入大小为$N \times d_h \times d_h$。总时间为$H \times O(d_h^2 N \log N) = H(d/H)^2 N \log N = (d^2/H) N \log N$。由于$d_h=O(1)$,这意味着$H$与$d$成正比,所以这部分是$O(dN \log N)$。
4. 总时间为$O(d^2N + dN \log N)$。
* 空间复杂度分析
1. 矩阵乘法需要$O(dN)$空间。
2. 两个SSM中的FFT和逐点乘法需要$O(dN)$和$O(H d_h^2 N) = O(d d_h N) = O(dN)$的空间(因为$d_h$是常数)。
3. 总空间复杂度为$O(dN)$。

D.4 状态传递正确性
* 命题2证明:我们证明命题2。我们假设BlockFFTConv算法是正确的。
* 证明方法:对块数$C$进行数学归纳。
* 基础情况 (C=1):当只有一个块时,$N=N_0$。算法的输出$y^{(1)}$直接等于$\text{BlockFFTConv}(f, u_1)$,这与整个序列的SSM计算结果相同。状态更新也与SSM的递归定义一致。
* 归纳步骤 (C>1):假设对于前$C-1$个块,算法的输出和状态是正确的。我们需要证明对于第$C$个块,其输出$y^{(C)}$和最终状态$x_{N_0}^{(C)}$也是正确的。根据SSM的递归定义,第$i$个时间步的输出依赖于初始状态$x_t$($t=N_0(C-1)$)和从$t$到$i$的输入。算法中的$y^{(C)} = M_{xy}x_{N_0}^{(C-1)} + \text{BlockFFTConv}(f, u^{(C)})$正是这个过程的向量化表示。状态更新同理可证。

F 附加实验

F.1 LRA 准确率
* 实验:评估H3在LRA上的准确率,并与S4D【26, On the parameterization and initialization of diagonal state space models, 2022, Advances in Neural Information Processing Systems】进行比较。
* 结果:如表9所示,H3在LRA基准测试上表现良好,在两个任务上优于S4D,在其他任务上差距在1个点以内。这表明H3虽为自回归语言建模设计,但在其他序列任务上同样强大。
表9:H3与S4D在LRA上的性能对比。

F.2 WikiText103
* 实验:在WikiText103上训练125M大小的模型,并比较其测试PPL与transformer及其他高效注意力变体。
* 结果:如表10所示,混合H3模型与同等规模的Transformer具有竞争力,并显著优于Performer、Reformer和线性注意力。
表10:WikiText103上的测试PPL。

F.3 PG-19
* 实验:在由书籍文本组成的PG-19数据集上评估模型。
* 结果:如表11所示,混合H3模型优于transformers和线性注意力。
表11:PG-19上的测试PPL。

F.4 长度外推
* 实验:使用关联回忆合成任务来证明H3的长度外推能力。在长度为20的序列上训练一个两层H3模型,然后在长度为20和40的序列上评估其末位词元预测的准确率。
* 结果:如表12所示,H3在长度为40的序列上保持了高准确率,这是训练序列长度的两倍。
表12:在长度为20的序列上训练的H3模型,在长度为20和40的序列上评估的关联回忆准确率。

F.5 按Token数量扩展
* 实验:评估混合H3模型与Transformer在训练token数量增加时的扩展性。
* 结果:如表13所示,随着训练token数量的增加,混合H3模型和Transformer模型的性能都得到提升。
表13:在The Pile上用较少token训练的模型的测试PPL。

F.6 H3语言模型
* 实验:报告纯H3语言模型在NLP评估上的结果。
* 结果:表14和15分别显示了在SuperGLUE上的零样本和少样本性能。
表14:SuperGLUE上使用排名分类的零样本性能。每个模型大小的最佳结果加粗。


表15:SuperGLUE上使用排名分类的3样本性能。每个大小的最佳结果加粗,次佳结果下划线。

F.7 生成性能
* 实验:报告SuperGLUE上的生成结果,通过在模型生成的响应中搜索黄金标签来评估。
* 结果:表16和17报告了结果。少样本学习的趋势与logit结果一致,但混合模型和H3模型在某些任务的零样本性能上表现很差,倾向于生成不相关的长文本。少样本示例有助于模型生成可解析格式的答案。
表16:SuperGLUE上的零样本性能。每个大小的最佳结果加粗,次佳结果下划线。


表17:SuperGLUE上使用生成的3样本性能。每个大小的最佳结果加粗,次佳结果下划线。

F.8 非文本序列建模
* 实验:展示H3在两个非文本序列建模任务上的性能优于Transformer:原始语音分类和基于原始EEG信号的癫痫分类。
* 癫痫分类:在TUSZ v1.5.2 EEG语料库上,对60秒长(12000个时间步)的EEG片段进行二元癫痫分类。Transformer因序列过长而无法处理。H3不仅能处理,还取得了SOTA性能(表18)。
表18:从原始EEG(序列长度12000)进行60秒癫痫分类的性能(AUROC)。


* 原始语音分类:在SC10语音命令任务上,输入为1秒长、16kHz采样的原始音频信号(16000个时间步)。H3的性能与SOTA方法S4相差不到半个点(表19)。
表19:在原始音频(序列长度16000)上的SC 10类分类。

* 功能性磁共振成像(fMRI)数据:在fMRI数据上,将H3作为GPT的替代品进行了预训练和下游任务(心理状态解码)的评估。
- 预训练:在包含11,980次fMRI扫描的大规模数据集上进行预训练。结果显示,H3(dropout=0.2)在平均绝对误差上与GPT模型相当(图3)。
- 下游任务:在HCP和MDTB两个心理状态解码数据集上进行微调。结果表明,H3和GPT模型在解码性能上总体上不相上下(表20)。
- 结论:H3在处理高维、长序列的fMRI数据时,表现出与基于Transformer的模型同等的强大能力。

图3:模型训练过程中,训练和评估数据集的上游平均绝对误差(Lrec)。
图3:模型训练过程中,训练和评估数据集的上游平均绝对误差(Lrec)。

图4:最终预训练模型在每个脑区体素的平均绝对误差(Lrec),投射到FsAverage模板的膨胀皮质表面。
图4:最终预训练模型在每个脑区体素的平均绝对误差(Lrec),投射到FsAverage模板的膨胀皮质表面。
**表20:在fMRI数据上预训练的模型的下游适应性能,对20次不同随机种子的训练运行取平均。F1分数是宏平均。**