Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

作者: Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Sepp Hochreiter

主要贡献

本文旨在解决 Transformer 模型在处理长序列时计算成本呈二次方增长的问题。虽然具有线性计算复杂度的循环神经网络(RNN)变体在理论上更具优势,但要将其转化为实际的性能提升,需要像 Transformer 依赖 Flash Attention 一样的高效定制化核函数(Kernel)。现有的 Flash Linear Attention (FLA) 通过对输入序列分块并行处理,展示了线性 RNN 核函数比 Flash Attention 更快。然而,FLA 的块大小(chunk size)受限于 GPU 的 SRAM 容量,导致在处理长序列时需要将大量中间状态存入高带宽内存(HBM),这不仅增加了内存消耗和 I/O 成本,还降低了算术强度。

为了克服这些限制,本文提出了 Tiled Flash Linear Attention (TFLA),这是一种新颖的线性 RNN 核函数算法。TFLA 的核心创新在于引入了两级序列并行化机制。第一级并行化是在序列块(chunk)的维度上进行,与 FLA 类似;第二级并行化则是在每个块内部,通过对矩阵计算进行分块(tiling)来在序列维度上进一步并行。这种设计使得 TFLA 能够支持任意大的块大小,从而显著提高算术强度,并允许在内存消耗、I/O 成本和计算效率之间进行有效权衡。

本文的主要贡献如下:
1. 提出 Tiled Flash Linear Attention (TFLA):一种新的分块并行核函数算法,通过两级序列并行化支持任意大小的块,提高了算术强度并优化了内存使用。该算法首先被应用于带矩阵内存的 xLSTM,即 mLSTM。
2. 提出 mLSTMsig:一种 mLSTM 的变体,它使用 sigmoid 输入门并减少了计算量。实验表明,在高达 14 亿参数规模的语言模型中,mLSTMsig 在性能上与原版 mLSTM 相当,但核函数运行速度更快。
3. 优化 mLSTM 的训练稳定性与性能:受控制论中传递函数分析的启发,本文进行了一项实证研究,发现两种 mLSTM 变体具有相同的传递行为。该分析表明,应将输入门的偏置初始化为较大的负值。在广泛的语言建模实验中,证实了这种初始化策略能够提高训练稳定性并改善 mLSTM 模型的整体性能。

本文的 TFLA 核函数在速度基准测试中,其性能超越了高度优化的 Flash Attention、Linear Attention 和 Mamba 核函数,为高效的长上下文序列建模基元设定了新的技术水平。

图1. Tiled Flash Linear Attention (TFLA) 包含一个循环核和一个并行核,它们以块 QKV(k) 的形式处理输入序列(第一级序列并行)。循环核为每个块物化内存状态 Ck−1。并行核计算所有块的输出状态 H(k)。TFLA 在并行核中使用分块(tiling)技术处理 3 个矩阵乘法(第二级序列并行),以充分利用硬件并防止物化大量内存状态。
图1. Tiled Flash Linear Attention (TFLA) 包含一个循环核和一个并行核,它们以块 QKV(k) 的形式处理输入序列(第一级序列并行)。循环核为每个块物化内存状态 Ck−1。并行核计算所有块的输出状态 H(k)。TFLA 在并行核中使用分块(tiling)技术处理 3 个矩阵乘法(第二级序列并行),以充分利用硬件并防止物化大量内存状态。

mLSTM 公式

mLSTM 单元是 xLSTM 【索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS】中可完全并行的部分。它拥有一个矩阵内存和指数门控机制。

循环公式

在其循环公式中,mLSTM 单元处理一系列输入向量 $x_t \in \mathbb{R}^d$,其中时间步 $t \in \{1, . . . , T \}$,将一个状态 $(h_{t-1}, C_{t-1}, n_{t-1}, m_{t-1})$ 映射到一个后继状态 $(h_t, C_t, n_t, m_t)$,并给定一个输入 $x_t$。这里,$h_t \in \mathbb{R}^{d_{hv}}$ 是隐藏状态,$C_t \in \mathbb{R}^{d_{qk} \times d_{hv}}$ 是负责长期记忆的单元状态,$n_t \in \mathbb{R}^{d_{qk}}$ 是归一化器状态,$m_t \in \mathbb{R}$ 是最大值状态。归一化器和最大值状态共同控制指数输入门的幅度并确保稳定性(详见附录 D.1)。循环 mLSTM 公式由以下状态更新方程给出:

$$m_t = \max \left\{ \log \sigma(\tilde{\mathrm{f}}_t) + m_{t-1}, \tilde{\mathrm{i}}_t \right\}$$ $$ \boldsymbol{C}_{t} = \mathrm{f}_{t} \boldsymbol{C}_{t-1} + \mathrm{i}_{t} \boldsymbol{k}_{t} \boldsymbol{v}_{t}^{\top} $$ $$ \boldsymbol{n}_{t} = \mathrm{f}_{t} \boldsymbol{n}_{t-1} + \mathrm{i}_{t} \boldsymbol{k}_{t} $$ $$\tilde{\boldsymbol{h}}_{t}=\frac{\boldsymbol{C}_{t}^{\top}\left(\boldsymbol{q}_{t} / \sqrt{d_{q k}}\right)}{\max \left\{\left|\boldsymbol{n}_{t}^{\top}\left(\boldsymbol{q}_{t} / \sqrt{d_{q k}}\right)\right|, \exp \left(-m_{t}\right)\right\}}$$ $$\boldsymbol{h}_t = \mathbf{o}_t \odot \text{NORM}(\widetilde{\boldsymbol{h}}_t)$$

门激活值的计算方式如下:

$$\mathrm{f}_t = \exp \left( \log \sigma (\tilde{\mathrm{f}}_t) + m_{t-1} - m_t \right)$$ $$\begin{aligned} \begin{aligned} \mathbf{i}_t &= \exp(\tilde{\mathbf{i}}_t - m_t) \\ \mathbf{o}_t &= \sigma(\tilde{\mathbf{o}}_t) \end{aligned} \end{aligned}$$

查询(query)、键(key)和值(value)向量 $q_t, k_t \in \mathbb{R}^{d_{qk}}$, $v_t \in \mathbb{R}^{d_{hv}}$ 通过 $\{q_t, k_t, v_t\} = W_{\{q,k,v\}} x_t + b_{\{q,k,v\}}$ 计算得出。标量输入门和遗忘门 $i_t, f_t \in \mathbb{R}$ 由预激活值 $\{\tilde{i}_t, \tilde{f}_t\} = w^T_{\{i,f\}} x_t + b_{\{i,f\}}$ 计算得出,而向量输出门 $o_t \in \mathbb{R}^{d_{hv}}$ 由预激活值 $\tilde{o}_t = W_o x_t + b_o$ 通过 sigmoid 函数 $\sigma$ 计算得出。公式 (5) 中的 NORM 层可以是 RMS norm 【索引[84],Root mean square layer normalization,2019,NeurIPS】或 LayerNorm 【索引[4],Layer Normalization,2016,arXiv】。通常,多个这样的单元会作为并行头(parallel heads)同时运行,类似于 Transformers 【索引[74],Attention is all you need,2017,NeurIPS】。

分块并行公式

分块并行公式是在完全并行和完全循环公式之间的一种权衡。它包含一个循环部分和一个(二次方的)并行部分,整体上在序列长度上具有次二次方的扩展性。与完全并行公式(见附录 B.1)类似,我们假设所有输入都一次性可用。然后,我们将长度为 T 的序列分成 $N_c = \lceil T / L \rceil$ 个长度为 L 的块,并使用 $k \in \{1, . . . , N_c\}$ 作为块索引。我们将输入门、遗忘门以及查询、键和值重排成分块矩阵,其中块索引成为第一个维度。例如,遗忘门预激活值 $\tilde{f} \in \mathbb{R}^T$ 被重排成一个矩阵 $\tilde{f} = (\tilde{f}^{(1)}, \tilde{f}^{(2)}, ..., \tilde{f}^{(N_c)}) \in \mathbb{R}^{N_c \times L}$,其中每一行 $\tilde{f}^{(k)} = (f_{(k-1)N_c+1}, f_{(k-1)N_c+2}, . . . , f_{kN_c}) \in \mathbb{R}^L$ 包含了块 k 的预激活值。输入门预激活值也类似处理。同样地,查询、键和值被重排成分块张量 $Q, K \in \mathbb{R}^{N_c \times L \times d_{qk}}$ 和 $V \in \mathbb{R}^{N_c \times L \times d_{hv}}$。这里,查询矩阵 $Q^{(k)} = (q_{(k-1)N_c+1}, . . . , q_{kN_c}) \in \mathbb{R}^{L \times d_{qk}}$ 包含了块 k 的查询向量。键和值的处理方式类似。为简化符号,我们省略了前导的 $N_c$ 维度,并忽略了归一化层和输出门,即视 $h_t$ 为隐藏状态输出。

分块门控。给定对数遗忘门 $\bar{f}^{(k)} = \log \sigma(\tilde{f}^{(k)}) \in \mathbb{R}^L$ 和输入门 $\bar{i}^{(k)} = \log \exp(\tilde{i}^{(k)}) \in \mathbb{R}^L$,我们可以计算对数分块门控如下:

$$g_k=\operatorname{sum}\left(\overline{\mathbf{f}}^{(k)}\right) \in \mathbb{R},$$ $$\mathbf{b}_{k}=\operatorname{cumsum}\left(\overline{\mathbf{f}}^{(k)}\right) \in \mathbb{R}^{L},$$ $$\mathbf{a}_k = \text{rev\_cumsum}^1 \left( \bar{\mathbf{f}}^{(k)} \right) + \bar{\mathbf{i}}^{(k)} \in \mathbb{R}^L.$$

在图 2 中,我们展示了求和后的遗忘门 $g_k$ 包含了一个块内所有遗忘门的贡献。累积遗忘门向量 $b_k$ 包含了从块开始到当前时间步的遗忘门贡献。累积输入门向量 $a_k$ 包含了每个时间步的输入门以及从当前时间步到块结束的遗忘门贡献。

图2. 块大小 L=4 时分块门 ak、bk 和 gk 的图示。每个箭头表示门向量中的一个元素。更多细节请参见附录 B.2 中的图 8。
图2. 块大小 L=4 时分块门 ak、bk 和 gk 的图示。每个箭头表示门向量中的一个元素。更多细节请参见附录 B.2 中的图 8。

块间循环贡献。块间循环由以下公式给出:

$$\boldsymbol{C}_{k}=\bar{\mathrm{g}}_{k} \boldsymbol{C}_{k-1}+\left(\overline{\mathbf{a}}_{k} \odot \boldsymbol{K}^{(k)}\right)^{\top} \boldsymbol{V}^{(k)}$$ $$\boldsymbol{n}_{k}=\bar{\text{g}}_{k} \boldsymbol{n}_{k-1}+\left(\overline{\mathbf{a}}_{k} \odot \boldsymbol{K}^{(k)}\right)^{\top} \mathbf{1},$$

其中 $\bar{g}_k$ 和 $\bar{a}_k$ 是稳定化的分块门。这个循环部分类似于第 2.1 节中的完全循环公式,但它不是为每个时间步 t 计算中间状态,而是直接为每 L 个时间步计算,而不物化中间的状态。

块内并行贡献。循环部分之后是块内并行贡献:

$$\begin{aligned} \widetilde{\mathbf{D}}^{(k)}=\begin{cases}-\infty & \text { for } i<j \\ \mathbf{b}_{k}-\mathbf{b}_{k}^{\top}+\overline{\mathbf{i}}^{(k)\top} & \text { for } i \geqslant j\end{cases} \end{aligned}$$ $$\mathbf{S}^{(k)}=\frac{1}{\sqrt{d_{qk}}} \boldsymbol{Q}^{(k)} \boldsymbol{K}^{(k)\top}$$ $$\overline{\mathbf{S}}^{(k)}=\mathbf{S}^{(k)} \odot \mathbf{D}^{(k)},$$ <p>其中 $D^{(k)} \in \mathbb{R}^{L \times L}$ 是稳定化的门矩阵。与附录 B.1 中的完全并行部分相比,矩阵 $D^{(k)}, S^{(k)} \in \mathbb{R}^{L \times L}$ 的二次方成本大大降低,因为块大小 L 通常远小于序列长度 T。

输出计算。最后,将块内并行部分的贡献 $H^{(k)}$ 与块间循环部分的贡献 $H^{(k)}_{inter}$ 相结合,得到每个块 k 的隐藏状态 $H^{(k)} \in \mathbb{R}^{L \times d_{hv}}$(见图 1):

$$\mathbf{H}_{\text {inter }}^{(k)}=\left(\overline{\mathbf{b}}_{k} \odot \frac{\boldsymbol{Q}^{(k)}}{\sqrt{d_{q k}}}\right) \boldsymbol{C}_{k-1}=\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{C}_{k-1}$$ $$\mathbf{H}_{\text{intra}}^{(k)}=\overline{\mathbf{S}}^{(k)}\boldsymbol{V}^{(k)}$$ $$\mathbf{H}^{(k)}=\left(\mathbf{H}_{\mathrm{inter}}^{(k)}+\mathbf{H}_{\mathrm{intra}}^{(k)}\right) / \mathbf{h}_{\mathrm{denom}}^{(k)},$$

附录 B.2 和 B.3 提供了分块并行前向和后向传播的详细描述。附录 F 提供了所有公式的 FLOP 和内存操作计数。

Tiled Flash Linear Attention

FlashLinearAttention【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】为线性注意力(Linear Attention)的分块公式(类似于第 2.2 节,但没有门控)引入了一种快速的核函数算法,并证明其实现比优化的 FlashAttention【索引[9],FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2024,ICLR】更快。这种加速是通过单级序列并行化实现的,即首先在 GPU 内存中物化状态 $C_k$,然后并行计算输出 $H^{(k)}$。然而,在 FlashLinearAttention 中,块大小参数决定了 SRAM 中的瓦片(tile)大小,因此最大块大小受到 GPU 物理 SRAM 大小的限制(通常 L = 64)。因此,我们必须在 HBM 中物化许多状态,状态数量为 $N_c = \lceil T / L \rceil$。这导致了较低的算术强度和较高的 GPU 内存消耗,特别是在长上下文预训练中带来了挑战。为了解决这个问题,我们在第 3.2 节中引入了 TiledFlashLinearAttention (TFLA),它支持任意大的块大小。

我们首先在第 3.1 节回顾编写高效核函数的基础知识。由于我们的实验在 NVIDIA GPU 上进行,我们的回顾主要针对 NVIDIA 的术语,但这些原则也适用于其他硬件。有关更全面的概述,我们推荐阅读【索引[63],ThunderKittens: Simple, fast, and adorable ai kernels,2024,arXiv】。最后,在附录 G 中,我们分析了 TFLA 核函数的理论运行时。

GPU 基础知识

GPU 简介。GPU(图形处理单元)是一种专门为高效处理大规模并行计算任务(如神经网络中的矩阵乘法)而设计的处理器。这些任务被分解为称为核函数(kernels)的小程序,在 GPU 上执行。核函数从高带宽内存(HBM)加载数据,对其进行处理,并将结果写回 HBM。为了编写高效的核函数,理解 GPU 的软件层次结构至关重要,它紧密遵循其物理硬件层次结构。

GPU 层次结构。在最底层,GPU 运行多个线程(Threads),这些线程并行操作于小而快的寄存器内存。在软件层面,通常将多个(例如 32 个)线程组合成一个线程束(Warp)。多个线程束又被组合成线程块(Thread blocks),它们共同在一个称为流式多处理器(SM)的物理核心上执行一个核函数。同一线程块内的线程束或线程可以通过特殊的片上共享内存(SRAM)进行数据通信。执行核函数时,会在 GPU 上启动一个并行运行的线程块网格(grid,通常是 3 维的)。所有线程块都可以访问大但慢的片外高带宽内存(HBM),HBM 在所有 GPU 内存中延迟最高、带宽最低。对于高效的核函数来说,最小化从 HBM 读取和写入内存的操作至关重要。

专用计算单元。现代 GPU 拥有称为张量核心(tensor cores)的专用计算单元,可加速 GPU 上的矩阵乘法。张量核心占据了 GPU 大部分的计算能力,并在线程束或块级别进行访问。为了实现高效的核函数,最大化张量核心的利用率非常重要。

Triton 语言。Triton 是一种 GPU 核函数编程语言及其关联的编译器,提供了一个基于 Python 的 GPU 编程环境。用户可以通过 tl.load 指令从 HBM 加载数据,通过 tl.store 将数据存储到 HBM。tl.dot 是一条利用张量核心进行矩阵乘法的指令。虽然这个 Triton 接口提高了编写高速自定义核函数的效率,但有时只有通过 CUDA 核函数才能达到峰值性能。我们用 Triton 编写我们的核函数,并将 CUDA 实现留作未来工作。与 NVIDIA 的编程模型 CUDA(它提供了对 GPU 层次结构所有级别的访问)不同,Triton 程序在线程块级别操作,并对用户隐藏了寄存器和线程管理。因此,在下一节中,我们将在更抽象的线程块或程序级别描述 TFLA。

通过两级序列并行化实现更高效的核函数

TFLA 的核心思想。TiledFlashLinearAttention (TFLA) 通过引入两级序列并行化(见图 3),实现了快速的核函数,并能在内存消耗和计算效率之间进行权衡。第一级是在序列的块(chunk)上进行并行化,这需要计算并物化中间状态 $C_k$ 到 GPU HBM 中。为此,我们使用了一个类似于先前工作【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】的循环核函数。第二级是在每个块内部进行并行化,这是通过沿块维度对块内注意力矩阵进行分块(tiling)来实现的。这第二级并行化使得块大小可以很大,从而减少了中间状态的内存消耗,因为我们每次核函数调用只需在 HBM 中存储和加载 $N_c = \lceil T/L \rceil$ 个中间状态,其中 T 是序列长度,L 是块大小。除了这两级序列并行化和在批次及头维度上的朴素并行化外,TFLA 还在嵌入维度上进行并行化。这使得头维度可以任意大,并导致在五个维度上的大规模并行化,这对于在现代 GPU 上实现高性能至关重要。

图3. TFLA 前向传播分块。我们在 BLkv 和 Bdqk 上循环(由箭头表示),并在 BLhq 和 Bdhv 块上并行化(由虚线表示)。⊕ 表示块级累加。
图3. TFLA 前向传播分块。我们在 BLkv 和 Bdqk 上循环(由箭头表示),并在 BLhq 和 Bdhv 块上并行化(由虚线表示)。⊕ 表示块级累加。

前向传播。我们回顾 mLSTM 块内并行部分的矩阵乘法操作,以展示我们如何高效地并行化这些操作。为简单起见,我们省略了门控计算和归一化,因为它们不影响工作分区。我们还省略了批次、头和块维度,因为它们不与矩阵乘法交互,我们可以对它们进行朴素并行化(见表 1)。简化形式下,mLSTM(以及其他线性 RNN)对于一个块 k 的块内并行前向传播可以写成三个矩阵乘法,我们将它们融合成一个单一的核函数:

$$ \underset{(L_{h q} \times d_{h v})}{\mathbf{H}^{(k)}}=\underbrace{\left(\underset{(L_{h q} \times d_{q k})}{\boldsymbol{Q}^{(k)}} \underset{(d_{q k} \times L_{k v})}{\boldsymbol{K}^{(k)\top}}\right) \underset{(L_{k v} \times d_{h v})}{\boldsymbol{V}^{(k)}}}_{\mathbf{H}_{\text {intra }}^{(k)}} +\underbrace{\underset{(L_{h q} \times d_{q k})}{\boldsymbol{Q}^{(k)}} \underset{(d_{q k} \times d_{h v})}{\boldsymbol{C}_{k-1}}}_{\mathbf{H}_{\text {inter }}^{(k)}} $$

在附录 A.3 中,我们展示了 TFLA 可以应用于任何遵循或可以重构为此形式的线性 RNN。为了并行化公式 (20) 中的计算,我们引入了块大小 $B_{L_{hq}}$, $B_{L_{kv}}$, $B_{d_{qk}}$ 和 $B_{d_{hv}}$,分别对应注意力矩阵、查询、键、值和隐藏状态的维度 $L_{hq}$, $L_{kv}$, $d_{qk}$ 和 $d_{hv}$,我们沿着这些维度进行并行化或通过在核函数内部使用循环进行累加。在图 3 中,我们展示了用于前向传播 $H^{(k)}$ 核函数的 TFLA 分块策略。我们跨外部序列维度 $L_{hq}$ 使用 $N_{L_{hq}} = L_{hq}/B_{L_{hq}}$ 个程序进行并行化,并跨外部嵌入维度 $d_{hv}$ 使用 $N_{d_{hv}} = d_{hv}/B_{d_{hv}}$ 个程序进行并行化。我们循环处理内部维度 $L_{kv}$ 和 $d_{qk}$,它们分别由块大小 $B_{L_{kv}}$ 和 $B_{d_{qk}}$ 进行分块。

分块计算。对于 mLSTM,由于指数输入门需要用最大值状态 $m_t$ 进行稳定化,我们不能简单地沿 $L_{kv}$ 维度累加矩阵乘法的结果 $H^{(k)}_{intra}$ 和 $H^{(k)}$。最大值状态 $m_t$ 随时间跟踪遗忘门和输入门的最大值,并用于稳定指数输入门,类似于安全 softmax 计算【索引[28],Online normalizer calculation for softmax,2018,ArXiv】。由于我们沿块大小(即时间)维度 $L_{kv}$ 分块计算隐藏状态输出 $H^{(k)}$,我们需要在累加 $H^{(k)}$ 的块结果和整合到 $H^{(k)}$ 的总结果时进行重新缩放,这与 FlashAttention【索引[9],FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2024,ICLR】的方式相同。我们在 B.2 节提供了重新缩放的细节。对于反向传播,则不需要重新缩放,因为我们在前向传播中存储了最大值状态,并在反向传播中重用它们。mLSTM 的 TFLA 前向传播伪代码在算法 1 中列出。

反向传播。TFLA 反向传播的并行化策略比前向传播更复杂,因为我们需要计算三个输出张量——查询、键和值的梯度,每个梯度都包含块内和块间两部分。然而,在 C.3 节中,我们展示了各个梯度可以映射到三个类似于前向传播的矩阵乘法。在 TFLA 中,我们为每个梯度实现一个单独的核函数,并使用与前向传播相同的工作分区,但相应地交换循环和并行化的维度。表 1 总结了我们 TFLA 核函数的工作分区。

表1. TFLA 核函数的并行化和循环维度。并行化维度用 P 表示,循环维度用 L 表示。最后一列显示了 3D 核函数启动网格的前两个维度。所有核函数的最后一个维度是 Nchunk · Nhead · Nbatch。
表1. TFLA 核函数的并行化和循环维度。并行化维度用 P 表示,循环维度用 L 表示。最后一列显示了 3D 核函数启动网格的前两个维度。所有核函数的最后一个维度是 Nchunk · Nhead · Nbatch。

更快的 mLSTM:使用 Sigmoid 输入门

现有 mLSTM 的计算开销。Beck 等人【索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS】引入的带指数门控(即指数输入门)的 mLSTM 需要计算并跟踪两个额外的状态:归一化状态 $n_t$ 和最大值状态 $m_t$,如附录 D.1 所示。这两个状态都会增加核函数的运行时间:归一化器必须通过求和来计算,而在 TFLA 的分块计算中跟踪最大值状态(见 3.2 节和 C.1 节)会阻碍核函数内循环的高效融合(见附录 C.2)。

新门控函数的动机。此外,我们在 4.2 节的分析建议将输入门偏置初始化为较大的负值(例如 -10),以便输入门预激活值可以在训练期间缓慢增长。我们观察到,在训练期间,这些值大多保持在 0 以下(见附录 E 中的图 14)。因此,我们寻求一种替代的激活函数,它在负值范围内与指数函数相似,但在正值范围内有界。这促使我们使用 sigmoid 函数:

$$\sigma(x)=\frac{1}{1+\exp (-x)}=\frac{\exp (x)}{\exp (x)+1},$$

该函数在 $x \rightarrow -\infty$ 时收敛于 $\exp(x)$,在 $x \rightarrow \infty$ 时收敛于 1。

带 Sigmoid 输入门的 mLSTM

mLSTMsig 的提出。sigmoid 函数可以按公式 (21) 中的两种方式计算。根据 x 的符号,可以确保 exp 的参数始终小于 0,以避免数值溢出。因此,我们不需要用最大值状态从外部控制 x 的大小,从而也可以去掉归一化器状态(见附录 D.1)。这样就得到了带 sigmoid 输入门的 mLSTM(下文简称 mLSTMsig),其循环公式如下:

$$\boldsymbol{C}_t = \sigma(\tilde{\mathbf{f}}_t) \boldsymbol{C}_{t-1} + \sigma(\tilde{\mathbf{i}}_t) \boldsymbol{k}_t \boldsymbol{v}_t^\top$$ $$\widetilde{\boldsymbol{h}}_{t}=\boldsymbol{C}_{t}^{\top}\left(\boldsymbol{q}_{t} / \sqrt{d_{q k}}\right)$$ $$\boldsymbol{h}_t=\sigma\left(\tilde{\mathbf{o}}_t\right) \odot \operatorname{NORM}\left(\widetilde{\boldsymbol{h}}_t\right)$$

其中,查询、键和值向量 $q_t, k_t, v_t$,以及门控预激活值 $\tilde{i}_t, \tilde{f}_t, \tilde{o}_t$ 与带指数输入门的 mLSTM(下文简称 mLSTMexp)中的定义(见 2.1 节)保持不变。在 5.2 节中,我们证实了我们的 TFLA mLSTMsig 前向核函数比 mLSTMexp 前向核函数快 30% 以上。我们还展示了 mLSTMsig 在语言建模实验中,在高达 14 亿参数的规模下,其性能与 mLSTMexp 相当(见 5.1 节)。

mLSTM 和线性 RNN 的归一化分析

分析动机。受 mLSTMsig 优异性能的启发,我们试图从经验上理解 mLSTMsig 和 mLSTMexp 之间的差异。为此,我们借鉴了控制系统设计中频率响应和传递函数分析的概念,该领域通常关注不同频率下输出和输入信号的振幅比或增益【索引[31],Modern control engineering,2010,Prentice-Hall】。在我们的案例中,我们分析了 mLSTMsig 和 mLSTMexp 在随机输入 $q_t, k_t, v_t$ 以及不同输入门和遗忘门预激活值 $\tilde{i}_t, \tilde{f}_t$ 下的传递行为。

归一化层的作用。我们将看到,归一化层 $y = NORM(x)$ 在我们的分析中扮演着至关重要的角色。语言建模中默认的归一化层是 RMS norm【索引[84],Root mean square layer normalization,2019,NeurIPS】,其输入向量为 $x \in \mathbb{R}^d$,输出向量为 $y \in \mathbb{R}^d$,定义如下:

$$ \boldsymbol{y} = \frac{\boldsymbol{x}}{\mathrm{RMS}(\boldsymbol{x})} \odot \boldsymbol{\gamma}, \text{ where } \mathrm{RMS}(\boldsymbol{x}) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_{i}^{2} + \epsilon} , $$

其中 $\gamma \in \mathbb{R}^d$ 是一个可学习的缩放参数。epsilon 参数 $\epsilon \in \mathbb{R}$ 是一个小的常数,通常设置为 1e-6 以避免除以零。

mLSTM 的传递行为。我们通过计算 mLSTM 单元从随机输入(从 $N(0, 1)$ 采样)到归一化层前后隐藏状态的增益,来分析其传递行为,其中输入门和遗忘门的值是变化的。具体来说,我们计算增益 $G_{before}$ 和 $G_{after}$ 如下:

$$G_{\text {before }}=\frac{\left\|\widetilde{\boldsymbol{h}}_{t}\right\|_{\max }}{\left\|\boldsymbol{v}_{t}\right\|_{\max }} \text { and } G_{\text {after }}=\frac{\left\|\operatorname{NORM}\left(\widetilde{\boldsymbol{h}}_{t}\right)\right\|_{\max }}{\left\|\boldsymbol{v}_{t}\right\|_{\max }},$$

其中 $|x|_{max} := \max(|x_1|, . . . , |x_d|)$,并且我们在时间维度上取平均。更多细节见附录 D.2。在图 4 中,我们观察到不带归一化器的 mLSTMsig 的传递行为与带归一化器和最大值状态的 mLSTMexp 的传递行为是相同的。两者都在输入门预激活值为较大的负值时,表现出从抑制信号(G = 0)到传递信号(G = 1)的过渡,这可能部分解释了它们在我们的语言建模实验中性能相当的原因。

与其他门控线性 RNN 的关系。有趣的是,几乎所有其他门控线性 RNN 变体也都在 RNN 单元之后放置一个归一化层【索引[10],Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,ICML;索引[64],Retentive network: A successor to transformer for large language models,2023,ArXiv;索引[41],HGRN2: Gated linear RNNs with state expansion,2024,CoLM;索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】。这通常被解释为提高了训练稳定性,但缺乏更深入的讨论【索引[26],Jamba: A hybrid transformer-mamba language model,2024,arXiv】。Qin 等人【索引[40],The devil in linear transformer,2022,EMNLP】分析了在非门控、基于核函数的线性注意力层【索引[21],Transformers are RNNs: Fast autoregressive transformers with linear attention,2020,ICML】之后使用归一化层的效果,并表明这能有效防止梯度无界。我们也证实了归一化层对训练稳定性和训练过程中的梯度范数有显著影响。在 5.1 节中,我们展示了按照图 4 中传递行为分析的建议,将输入门偏置初始化为较大的负值,可以防止训练期间梯度范数出现大的方差和尖峰。

图4. mLSTM 在 RMS-norm 层(ε =1e-6)前后,针对不同输入和遗忘门值的传递行为。颜色显示了在 (26) 中定义的 mLSTM 增益。在归一化层之后,mLSTMexp 和 mLSTMsig 表现出相同的传递行为。
图4. mLSTM 在 RMS-norm 层(ε =1e-6)前后,针对不同输入和遗忘门值的传递行为。颜色显示了在 (26) 中定义的 mLSTM 增益。在归一化层之后,mLSTMexp 和 mLSTMsig 表现出相同的传递行为。

归一化对线性 RNN 门控机制的影响。我们假设,在这一点上,归一化层不仅通过重新缩放来控制层激活的大小从而起到稳定作用,而且还主动参与了线性 RNN 的信息路由或门控机制。例如,如果 $C_t^T q$ 的平方范数(该范数由输入门和遗忘门通过 $C_t^T$ 控制)小于 epsilon,那么 NORM(x) 层中的分母将主要由 $\epsilon$ 主导,输出会趋向于零(如图 4 中的紫色区域所示)。因此,通过在图 4 的 x-y 平面中移动,门控可以学会抑制或放大序列中的任何输入。在 D.2 节中,我们展示了关于改变归一化层 epsilon 和 mLSTM 归一化器不同修改效果的额外实验。

实验环境

  • 数据集:实验在 DCLM 数据集【索引[25],DataCompLM: In search of the next generation of training sets for language models,2024,arXiv】上进行。
  • 模型架构

    • 训练了三种不同规模的模型:1.6亿、4亿和14亿参数。
    • 上下文长度设置为 4096 和 8192。
    • 基础架构遵循 Llama2 风格的仅解码器 Transformer 模型【索引[73],Llama 2: Open foundation and fine-tuned chat models,2023,arXiv】,采用预归一化(pre-norm)、RMS-norm 和 SwiGLU 前馈网络。
    • 使用 mLSTM(mLSTMexpmLSTMsig)操作替代了自注意力层。
    • mLSTM 的关键参数设置:查询和键的头维度($d_{qk}$)是值维度($d_{hv}$)的一半,即 $d_{qk} = 0.5 d_{hv}$。单元内部使用 LayerNorm。
    • 使用了 logits soft-capping(c=30)和门控预激活值 soft-capping(c=15)来增强训练稳定性。
  • 硬件配置:所有实验均在 NVIDIA H100 80GB GPU 上运行。

  • 软件配置
    • 语言建模实验使用 JAX 0.4.34【索引[7],JAX: composable transformations of Python+NumPy programs,2018,GitHub】 和 FLAX 0.9.0【索引[16],Flax: A neural network library and ecosystem for JAX,2024,GitHub】。
    • TFLA 核函数使用 Triton 3.1.0【索引[70],Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】实现,并通过 JAX-Triton 0.2.0【索引[75],Jax-triton,2022,GitHub】集成到 JAX 中。
    • 核函数基准测试在 PyTorch 2.5.1【索引[33],Pytorch: An imperative style, high-performance deep learning library,2019,arXiv】中进行。

实验结果

本节评估了两种 mLSTM 变体的性能:mLSTMexp(带指数输入门)和 mLSTMsig(带 sigmoid 输入门),并比较了两种核函数算法:(1) limit_chunk:块大小 L 受限的核函数;(2) xl_chunk:我们提出的 TFLA 核函数,块大小不受限制。

mLSTM 的语言建模

我们在 DCLM 数据集上训练了 160M、400M 和 1.4B 三种参数规模的模型,上下文长度分别为 4096 和 8192。作为参考,我们引入了 Llama2 风格的 Transformer 模型。

语言建模性能。表 2 和表 5 分别展示了在上下文长度为 4096 和 8192 下的验证困惑度(Perplexity)结果。我们发现,对于几乎所有的头维度配置,limit_chunkxl_chunk 核函数都取得了相同的损失(仅有微小的数值偏差)。对于某些头维度,xl_chunk 核函数会出现梯度范数或损失尖峰,影响了最终的损失。一个主要结果是,mLSTMsig 的性能与 mLSTMexp 相当。

表2. 上下文长度 4096 时的验证困惑度。EXP 和 SIG 分别表示 mLSTMexp 和 mLSTMsig。LIMIT 和 XL 对应 limit_chunk 和 xl_chunk 核函数。
表2. 上下文长度 4096 时的验证困惑度。EXP 和 SIG 分别表示 mLSTMexp 和 mLSTMsig。LIMIT 和 XL 对应 limit_chunk 和 xl_chunk 核函数。

输入门偏置初始化的影响。在附录 E.2 中,我们分析了输入门偏置初始化对 mLSTM 模型训练稳定性和性能的影响。如图 11 和 12 所示,将输入门偏置初始化为 -10,能有效缓解 mLSTMexpmLSTMsig 在训练过程中的大梯度范数尖峰和方差。因此,我们得出结论,额外的输入门不仅能提升性能(见表 6),而且如果初始化正确,还能提高训练稳定性。

归一化层 Epsilon 的影响。在附录 E.2 中,我们研究了归一化层 epsilon 对 mLSTMexp 语言建模性能的影响。图 4 的传递行为分析表明,归一化层 epsilon 和输入门偏置初始化之间存在相互作用。我们在图 13 的网格搜索中证实了这一点,发现性能最佳的配置是默认的 epsilon $\epsilon = 1e-6$ 和输入门偏置初始化为 -10。

图5. 在 NVIDIA H100 GPU 上,嵌入维度为 4096、令牌数为 65,536 时的 TFLA 核函数运行时基准测试。在训练中,我们的 TFLA 核函数在较长序列上比 FlashAttention 3 更快,并且在所有序列长度上比 Mamba 2 核函数快 2 倍以上。
图5. 在 NVIDIA H100 GPU 上,嵌入维度为 4096、令牌数为 65,536 时的 TFLA 核函数运行时基准测试。在训练中,我们的 TFLA 核函数在较长序列上比 FlashAttention 3 更快,并且在所有序列长度上比 Mamba 2 核函数快 2 倍以上。

核函数基准测试

我们将 mLSTM limit_chunk 和 TFLA xl_chunk 核函数的运行时与最先进的序列建模基元实现进行了比较,包括 FlashAttention、Mamba 和 Gated Linear Attention (GLA)。

运行时基准测试。我们使用 7B 参数模型的标准嵌入维度 4096,并相应地调整每个核函数的头维度。例如,FlashAttention 使用 32 个头,头维度为 128;mLSTM 使用 16 个头,头维度为 256。我们保持总令牌数恒定为 65,536,并相应地改变序列长度和批次大小。图 5 显示了推理(仅前向传播)和训练(前向+后向传播)的运行时基准测试结果。

  • 我们的 mLSTMexp TFLA xl_chunk 核函数(具有两级序列并行)比我们的 limit_chunk 核函数快约 25%。
  • 通过对 mLSTM 输入门的针对性修改,我们节省了计算量并为 mLSTMsig 的前向传播实现了更高效的核函数实现。这使得 mLSTMsig TFLA 核函数的前向传播速度比 mLSTMexp TFLA 核函数又提升了 30% 以上。
  • 在训练中,我们的 TFLA 核函数在较长序列上比 FlashAttention 3 更快,并且在所有序列长度上都比 Mamba 2 核函数快 2 倍以上。

运行时与内存的权衡。块大小参数 L 平衡了两级序列并行之间的计算(见 3.2 节)。较小的块大小会增加内存消耗,因为有更多的块被物化到内存中,但它减少了并行部分中的二次计算 FLOPs。较大的块大小则效果相反。图 6 展示了我们的 mLSTMsig TFLA xl_chunk 核函数在这种权衡下的表现。

图6. TFLA 前向后向传播的内存与运行时权衡。我们展示了嵌入维度为 4096(8 个头,头维度为 512)、序列长度为 8192、批次大小为 8 的 mLSTMsig。通过改变块大小参数,我们的 TFLA 核函数可以有效地平衡内存与运行时。
图6. TFLA 前向后向传播的内存与运行时权衡。我们展示了嵌入维度为 4096(8 个头,头维度为 512)、序列长度为 8192、批次大小为 8 的 mLSTMsig。通过改变块大小参数,我们的 TFLA 核函数可以有效地平衡内存与运行时。

理论运行时与算术强度

图 6 的经验观察表明,存在一个使运行时最小化的最优块大小(在 128 和 256 之间)。在附录 G 中,我们通过考虑 FLOPs、内存操作和硬件加速器规格,计算了 TFLA 的理论运行时最优块大小和依赖于块大小的算术强度。

我们发现理论上的运行时最优值超过了经验观察到的值(见附录 G.1 中的图 20),这表明我们当前的核函数实现可能尚未完全发挥可用的性能潜力。

我们在附录 F 和 G 中详细分析了块大小参数 L 对运行时、FLOPs 和算术强度的影响,总结如下:
1. 块大小 L 在运行时和 GPU 内存使用之间进行权衡。【图 6】
2. L 决定了总计算量(FLOPs):L = 1 对应循环公式,而 L = T 对应并行公式。【图 18】
3. 存在一个最优块大小 L ∈ [1, T],可以最小化总 FLOP 数量。【公式 (113),图 18,图 19】
4. 增加 L 会提高 TFLA 核函数的算术强度。【公式 (119),图 21】
5. 块大小决定了核函数在给定硬件上是内存密集型还是计算密集型。【图 22,图 21】
6. 仅凭 FLOPs/s 指标可能具有误导性;应根据总运行时来选择最优块大小。【图 23,图 20】
7. 运行时最优的块大小与头维度的平方根以及加速器的计算强度成正比。【图 24,图 25】
8. 更新的硬件代系需要更大的块大小才能接近峰值性能。【图 25,图 20】

结论与未来工作

本文通过 TiledFlashLinearAttention (TFLA) 引入了一种具有两级序列并行机制的线性 RNN 和 mLSTM 核函数算法。我们为带指数输入门(mLSTMexp)的 mLSTM 开发的 TFLA 核函数达到了业界领先的执行速度,同时保留了在 GPU 内存消耗和运行时间之间进行权衡的灵活性。为了进一步提升核函数运行速度,我们提出了 mLSTMsig,这是一种采用 sigmoid 输入门的 mLSTM 变体,它减少了计算量并提高了速度。我们的实验表明,两种 mLSTM 变体在语言建模任务上表现同样出色。

尽管我们通过基于经验性传递行为分析的精心门控初始化增强了训练的稳定性,但未来的工作可以更深入地探索由核函数实现中的数值误差引起的不稳定性问题。最后,正如我们的理论运行时分析所揭示的,用于优化 FlashAttention 的编程技术和硬件特性【索引[58],Flashattention-3: Fast and accurate attention with asynchrony and low-precision,2024,arXiv】也可以应用于我们的 TFLA 算法,以在新一代硬件上接近峰值性能。这使我们相信,TFLA 有潜力成为未来长上下文语言模型的基础构建模块。

附录

相关工作

与 Flash Attention 和 Flash Linear Attention 的关系

TFLA 的核心思想。Tiled Flash Linear Attention (TFLA) 结合了为改善工作分区而对注意力矩阵的一个序列维度进行分块(tiling)的思想【索引[9],FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2024,ICLR】和将序列划分为块(chunks)的思想【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】。这两个思想共同构成了 TFLA 的两级序列并行机制。

FlashAttention。FlashAttention【索引[11],FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022,NeurIPS】是 Vaswani 等人【索引[74],Attention is all you need,2017,NeurIPS】提出的 softmax 注意力的一种 I/O 感知实现。它利用分块(tiling)思想来减少 GPU 高带宽内存(HBM)和 GPU 片上 SRAM 之间的内存读写次数。通过这种方式,二次方的注意力矩阵 $QK^T$ 从不物化到 HBM 中,从而将内存需求从序列长度的二次方降低到线性,并由于减少了内存 I/O 成本而显著加快了核函数的速度。然而,计算复杂度仍然是序列长度的二次方。FlashAttention 2【索引[9],FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2024,ICLR】通过在序列维度上并行化注意力计算,改进了工作分区,这补充了在批次和头维度上的朴素并行化。FlashAttention 3【索引[58],Flashattention-3: Fast and accurate attention with asynchrony and low-precision,2024,arXiv】利用了最新 GPU 代(如 NVIDIA Hopper GPU)的新硬件特性,如 FP8 精度或利用张量核心和张量内存加速器(TMA)的异步性来加速 FlashAttention。TFLA 同样是 I/O 感知的,并作为其第二级序列并行机制,在块内 $QK^T$ 矩阵的一个序列维度上进行并行化。新的硬件特性也将加速未来的 TFLA 实现。

FlashLinearAttention。FlashLinearAttention (FLA)【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML;索引[77],FLA: A Triton-based library for hardware-efficient implementations of linear attention mechanism,2024,GitHub】利用了线性注意力可以被解释为线性 RNN【索引[21],Transformers are RNNs: Fast autoregressive transformers with linear attention,2020,ICML】这一事实。然后,它利用线性 RNN 的分块并行公式【索引[19],Transformer quality in linear time,2022,ICML;索引[64],Retentive network: A successor to transformer for large language models,2023,ArXiv】来实现高效的核函数,以块(chunks)为单位处理序列。具体来说,Yang 等人【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】提出了两种 FLA 变体:一种是在 HBM 中物化中间状态的版本,另一种是非物化版本。物化版本包含两个核函数:第一个是循环核函数,物化每个块的第一个中间状态。第二个核函数则并行处理所有块,并计算块内的输出。非物化版本由 Qin 等人【索引[41],Lightning attention-2: A free lunch for handling unlimited sequence lengths in large language models,2024,arXiv】同时提出,它不采用序列维度上的并行化,而是以块为单位顺序处理输入。TFLA 采用了序列分块的思想作为其第一级序列并行。

其他相关工作

其他硬件感知优化。优化的、硬件感知的实现使得探索新的基元或新的模型架构成为可能。FlashRNN【索引[38],FlashRNN: I/O-aware optimization of traditional RNNs on modern hardware,2025,ICLR】引入了一个 I/O 感知优化的 CUDA 核函数框架,以简化对传统、不可并行化 RNN 的研究。Mamba【索引[15],Mamba: Linear-time sequence modeling with selective state spaces,2024,ICLR】通过其优化的 CUDA 核函数中的高效并行扫描算法,支持了大规模语言建模实验【索引[76],An empirical study of mamba-based language models,2024,arXiv】。FlashFFTConv【索引[12],FlashFFTConv: Efficient convolutions for long sequences with tensor cores,2024,ICLR】通过减少 I/O 和利用专门的矩阵乘法单元,为现代硬件提供了 FFT 卷积的高效实现。DeltaNet【索引[82],Parallelizing linear transformers with the delta rule over sequence length,2024,NeurIPS;索引[80],Gated delta networks: Improving mamba2 with delta rule,2024,arXiv】为使用 delta 规则训练线性 Transformer 引入了一种高效算法,这使得 DeltaNet 能够扩展到标准的语言建模设置中【索引[53],Linear transformers are secretly fast weight programmers,2021,ICML】。我们的 TFLA 核函数算法提供了一种有效的方法来平衡基于分块并行公式的线性 RNN 核函数的运行时间和内存,为更大规模的模型训练铺平了道路。

线性 RNN 的门控机制。许多不同的线性 RNN 门控技术已经被探索【索引[64],Retentive network: A successor to transformer for large language models,2023,ArXiv;索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS;索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML;索引[15],Mamba: Linear-time sequence modeling with selective state spaces,2024,ICLR;索引[10],Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,ICML;索引[65],You Only Cache Once: Decoder-decoder architectures for language models,2024,NeurIPS;索引[42],HGRN2: Gated linear RNNs with state expansion,2024,CoLM;索引[32],Resurrecting recurrent neural networks for long sequences,2023,arXiv;索引[22],GateLoop: Fully data-controlled linear recurrence for sequence modeling,2023,ArXiv;索引[35],Eagle and finch: Rwkv with matrix-valued states and dynamic recurrence,2024,arXiv】。我们提出了 mLSTMsig,一种带有 sigmoid 输入门的 mLSTM 变体,类似于原始的 LSTM【索引[17],Untersuchungen zu dynamischen neuronalen Netzen,1991,Master's thesis;索引[18],Long short-term memory,1997,Neural Computation】,并对其传递行为进行了实证分析。

mLSTM 应用。作为带有矩阵内存的 xLSTM,mLSTM 已经被应用于语言建模之外的多个领域。例如,Alkin 等人【索引[2],Vision-LSTM: xLSTM as generic vision backbone,2025,ICLR】发现 mLSTM 可以作为计算机视觉架构的通用骨干网络。在机器人领域,mLSTM 架构已被探索作为大型循环动作模型,利用了 mLSTM 高效的线性时间推理复杂性【索引[57],A large recurrent action model: xLSTM enables fast inference for robotics tasks,2025,arXiv】,并作为模仿学习策略的骨干网络【索引[20],X-IL: Exploring the design space of imitation learning policies,2025,arXiv】。在生物和化学序列领域,mLSTM 已被用于生成建模、表示学习和上下文学习【索引[55],Bio-xLSTM: Generative modeling, representation and in-context learning of biological and chemical sequences,2025,ICLR】。最后,mLSTM 也在时间序列预测领域得到了检验【索引[1],xLSTMTime: Long-term time series forecasting with xLSTM,2024,AI】,其中 sLSTM【索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS】也得到了应用【索引[24],xlstmmixer: Multivariate time series forecasting by mixing via scalar memories,2024,arXiv;索引[23],Unlocking the power of lstm for long term time series forecasting,2025,arXiv】。这些及其他应用将受益于我们的 TFLA 核函数,它为 mLSTM 提供了高效且可扩展的实现,可以轻松集成到现有模型中。我们相信,我们的 TFLA mLSTM 核函数将促进 mLSTM 在其他应用领域的普及。

TFLA 在其他线性 RNN 上的应用

概述。我们已将 TFLA 应用于具有标量头级门控的线性 RNN,如 mLSTMexp 和 mLSTMsig。本节将展示如何将 TFLA 应用于其他线性 RNN,但实现将留作未来工作。

具有标量头级门控的线性 RNN。只需稍作修改,TFLA 即可应用于其他具有标量头级门控的线性 RNN,例如 Retention【索引[64],Retentive network: A successor to transformer for large language models,2023,ArXiv】、Gated Retention【索引[65],You Only Cache Once: Decoder-decoder architectures for language models,2024,NeurIPS】、Simple GLA【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】或 Mamba 2【索引[10],Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality,2024,ICML】。通常,这些线性 RNN 的 TFLA 实现可以通过修改遗忘门和输入门的参数化(基于 TFLA mLSTMsig)来获得。由于上述模型都没有 mLSTM 意义上的输入门,我们将 TFLA mLSTMsig 的输入门固定为 1(即 $\sigma(\tilde{i}_t) = 1$)。然后,对于 Simple GLA 或 Retention,我们保留头级标量 sigmoid 遗忘门 $\sigma(\tilde{f}_t)$ 或将其设置为一个恒定的衰减参数 $\sigma(\tilde{f}_t) = \gamma$。要实现 Gated Retention——它引入了一个由温度参数 $\tau$ 调制的 sigmoid 遗忘门——我们需要修改遗忘门的激活函数以包含这个温度项。除了通过步长参数 $\Delta$ 将遗忘门参数 A 与键 $k_t$(由参数 B 表示)耦合外,Mamba 2 也采用了公式 (20) 中的线性注意力结构。在此公式中,Mamba 的符号 C、B 和 X 分别对应于我们的查询 Q、键 K 和值 V。

具有 Delta 规则的线性 RNN。最近,具有非对角状态转移矩阵的线性 RNN 因其增强的表达能力而变得流行,使其能够解决综合状态跟踪任务【索引[14],Unlocking state-tracking in linear RNNs through negative eigenvalues,2025,ICLR;索引[61],Deltaproduct: Increasing the expressivity of deltanet through products of householders,2025,ICLR 2025 Workshop;索引[36],Rwkv-7 "goose" with expressive dynamic state evolution,2025,arXiv;索引[30],Fixed-point rnns: From diagonal to dense in a few iterations,2025,arXiv】。实现这类非对角线性 RNN 的一种高效方法是 (Gated) DeltaNet【索引[82],Parallelizing linear transformers with the delta rule over sequence length,2024,NeurIPS;索引[80],Gated delta networks: Improving mamba2 with delta rule,2024,arXiv】,它为使用 delta 规则的线性 Transformer 引入了一种硬件高效的算法。DeltaNet 高效实现的核心是 delta 规则的分块并行公式,该公式使用 Flash Linear Attention 的原语实现。在我们的符号体系中,DeltaNet 的分块并行公式由块间递推给出:

$$\underset{(d_{qk} \times d_{hv})}{\boldsymbol{C}_k} = \underset{(d_{qk} \times d_{hv})}{\boldsymbol{C}_{k-1}} + \underset{(d_{qk} \times L_{kv})}{\boldsymbol{K}^{(k)\top}} \underbrace{\left( \underset{(L_{kv} \times d_{hv})}{\boldsymbol{U}^{(k)}} - \underset{(L_{kv} \times d_{qk})}{\boldsymbol{W}^{(k)}} \underset{(d_{qk} \times d_{hv})}{\boldsymbol{C}_{k-1}} \right)}_{\boldsymbol{V}^{*(k)}},$$

以及块间和块内贡献的组合:

$$\underset{(L_{hq} \times d_{hv})}{\mathbf{H}^{(k)}} = \underset{(L_{hq} \times d_{qk})}{\boldsymbol{Q}^{(k)}} \underset{(d_{qk} \times d_{hv})}{\boldsymbol{C}_{k-1}} + \left( \underset{(L_{hq} \times d_{qk})}{\boldsymbol{Q}^{(k)}} \underset{(d_{qk} \times L_{kv})}{\boldsymbol{K}^{(k)\top}} \right) \underbrace{\left( \underset{(L_{kv} \times d_{hv})}{\boldsymbol{U}^{(k)}} - \underset{(L_{kv} \times d_{qk})}{\boldsymbol{W}^{(k)}} \underset{(d_{qk} \times d_{hv})}{\boldsymbol{C}_{k-1}} \right)}_{\boldsymbol{V}^{*(k)}},$$

为清晰起见,我们省略了 DeltaNet 中应用于 $QK^T$ 的掩码 M。关于矩阵 $U^{(k)}$ 和 $W^{(k)}$ 的定义,请参考 Yang 等人【索引[82],Parallelizing linear transformers with the delta rule over sequence length,2024,NeurIPS】。通过定义新的值 $V^{(k)} = U^{(k)} - W^{(k)}C_{k-1}$,我们可以恢复 TFLA 用于块内并行前向传播的核心公式(见公式 (20)),并对其应用 3.2 节中的 TFLA 分块策略。为了计算矩阵 $U^{(k)}$、$W^{(k)}$ 和新的值 $V^{(k)}$,可以遵循 DeltaNet 中的相同策略,即在循环核函数之前,在一个单独的核函数中计算 $U^{(k)}$ 和 $W^{(k)}$。然后,在循环核函数中计算并存储新的值 $V^{*(k)}$ 和内存状态 $C_k$ 到 HBM。我们得出结论,TFLA 可以应用于任何遵循或可以重新表述为公式 (20) 的简化分块并行形式的线性 RNN。

mLSTM 扩展公式

完全并行公式

公式描述。对于并行公式,假设所有输入一次性可用。然后,可以将查询、键和值 $q_t, k_t, v_t$ 堆叠成矩阵 $Q, K \in \mathbb{R}^{T \times d_{qk}}$, $V \in \mathbb{R}^{T \times d_{hv}}$,以便并行计算所有隐藏状态 $H \in \mathbb{R}^{T \times d_{hv}}$,使用以下方程:

$$\widetilde{\mathbf{D}}=\log \mathbf{F}+\widetilde{\mathbf{I}}$$ $$m = \max_{j} \widetilde{\mathbf{D}}_{ij},$$ $$\mathbf{D} = \exp(\widetilde{\mathbf{D}} - m)$$ $$\mathbf{S} = \frac{1}{\sqrt{d_{qk}}} \boldsymbol{Q} \boldsymbol{K}^{\top}$$ $$\overline{\mathbf{S}}=\mathbf{S} \odot \mathbf{D}$$ $$n=\max(|\overline{\mathbf{S}}\ \mathbf{1}|, \exp(-m))$$ $$\mathbf{H} = \left( \overline{\mathbf{S}} \odot \left( n^{-1} \right) \right) \boldsymbol{V},$$

其中 $1 \in \mathbb{R}^T$ 是一个全为 1 的向量。对数遗忘门激活矩阵 $\log F \in \mathbb{R}^{T \times T}$ 计算如下:

$$\begin{aligned} \log \mathbf{F}_{i j}=\left\{\begin{array}{ll} -\infty & \text { for } i<j \\ 0 & \text { for } i=j \\ \log \left(\prod_{k=j+1}^{i} \sigma\left(\tilde{\mathrm{f}}_{k}\right)\right)=\sum_{k=j+1}^{i} \log \sigma\left(\tilde{\mathrm{f}}_{k}\right) & \text { for } i>j \end{array}\right. \end{aligned}$$

类似地,输入门预激活矩阵 $\tilde{I} \in \mathbb{R}^{T \times T}$ 由以下公式给出:

$$\begin{aligned} \tilde{\mathbf{I}}_{i j}=\left\{\begin{array}{ll}0 & \text { for } i<j \\ \tilde{\mathrm{i}}_{j} & \text { for } i \geqslant j\end{array} .\right. \end{aligned}$$ <p>与循环公式的对比。请注意,与循环公式不同,在并行公式中,状态 $C_t$ 不会被物化,即不会显式计算。这样做的代价是需要计算二次方大小的矩阵 $D, S \in \mathbb{R}^{T \times T}$,导致整体计算复杂度随序列长度 T 呈二次方增长。

详细的分块并行公式

分块门控。本节我们提供第 2.2 节中 mLSTM 分块并行公式的更详细方程。给定对数遗忘门 $\bar{f}^{(k)} = \log \sigma(\tilde{f}^{(k)}) \in \mathbb{R}^L$ 和输入门 $\bar{i}^{(k)} = \log \exp(\tilde{i}^{(k)}) \in \mathbb{R}^L$,我们可以计算对数分块门如下:

$$ \mathrm{g}_{k}=\operatorname{sum}\left(\overline{\mathbf{f}}^{(k)}\right)=\sum_{i=1}^{L} \overline{\mathrm{f}}_{i}^{(k)} \in \mathbb{R}, $$ $$\mathbf{b}_k=\operatorname{cumsum}\left(\overline{\mathbf{f}}^{(k)}\right) \in \mathbb{R}^L, \text { with } \mathrm{b}_{k, j}=\sum_{i=1}^j \overline{\mathrm{f}}_i^{(k)} \text { for } j=1,2, \ldots, L$$ $$\mathbf{a}_k = \text{rev\_cumsum}\left(\bar{\mathbf{f}}^{(k)}[1:]\right) + \bar{\mathbf{i}}^{(k)} \in \mathbb{R}^L, \text{ with } \quad \mathrm{a}_{k,j} = \sum_{i=j+1}^{L} \bar{\mathrm{f}}_i^{(k)} + \bar{\mathrm{i}}_j^{(k)} \text{ for } j = 1, 2, \dots, L,$$

其中 [1:] 表示(在 numpy 表示法中)排除了第一个索引,正如在公式 (40) 的求和符号中所做的那样。此外,在图 8 中我们图示了分块门的计算,并在图 7 中展示了计算分块门的 PyTorch 代码片段。

图7. 用于计算 mLSTMexp 的对数分块门的 PyTorch 函数。
图7. 用于计算 mLSTMexp 的对数分块门的 PyTorch 函数。
图8. 分块门计算图示。
图8. 分块门计算图示。

$$\begin{aligned} \log \mathbf{b}_1 = \begin{pmatrix} \mathrm{f}_1 \\ \mathrm{f}_1 + \mathrm{f}_2 \\ \mathrm{f}_1 + \mathrm{f}_2 + \mathrm{f}_3 \\ \mathrm{f}_1 + \mathrm{f}_2 + \mathrm{f}_3 + \mathrm{f}_4 \end{pmatrix} \quad \log \mathbf{b}_2 = \begin{pmatrix} \mathrm{f}_5 \\ \mathrm{f}_5 + \mathrm{f}_6 \\ \mathrm{f}_5 + \mathrm{f}_6 + \mathrm{f}_7 \\ \mathrm{f}_5 + \mathrm{f}_6 + \mathrm{f}_7 + \mathrm{f}_8 \end{pmatrix} \quad \log \mathbf{b}_3 = \begin{pmatrix} \mathrm{f}_9 \\ \mathrm{f}_9 + \mathrm{f}_{10} \\ \mathrm{f}_9 + \mathrm{f}_{10} + \mathrm{f}_{11} \\ \mathrm{f}_9 + \mathrm{f}_{10} + \mathrm{f}_{11} + \mathrm{f}_{12} \end{pmatrix} \end{aligned}$$ $$\begin{aligned} \log \mathbf{a}_1 = \begin{pmatrix} \mathrm{f}_2 + \mathrm{f}_3 + \mathrm{f}_4 \\ \mathrm{f}_3 + \mathrm{f}_4 \\ \mathrm{f}_4 \\ 0 \end{pmatrix} + \begin{pmatrix} i_1 \\ i_2 \\ i_3 \\ i_4 \end{pmatrix} \quad \log \mathbf{a}_2 = \begin{pmatrix} \mathrm{f}_6 + \mathrm{f}_7 + \mathrm{f}_8 \\ \mathrm{f}_3 + \mathrm{f}_8 \\ \mathrm{f}_8 \\ 0 \end{pmatrix} + \begin{pmatrix} i_5 \\ i_6 \\ i_3 \\ i_8 \end{pmatrix} \quad \log \mathbf{a}_3 = \begin{pmatrix} \mathrm{f}_9 + \mathrm{f}_{10} + \mathrm{f}_{12} \\ \mathrm{f}_{11} + \mathrm{f}_{12} \\ \mathrm{f}_{12} \\ 0 \end{pmatrix} + \begin{pmatrix} i_9 \\ i_{10} \\ i_{11} \\ i_{12} \end{pmatrix} \end{aligned}$$ $$ \log \mathrm{g}_1 = \mathrm{f}_1 + \mathrm{f}_2 + \mathrm{f}_3 + \mathrm{f}_4 \quad \log \mathrm{g}_2 = \mathrm{f}_5 + \mathrm{f}_6 + \mathrm{f}_7 + \mathrm{f}_8 \quad \log \mathrm{g}_3 = \mathrm{f}_9 + \mathrm{f}_{10} + \mathrm{f}_{11} + \mathrm{f}_{12} $$

块间循环贡献。块间循环由以下公式给出:

$$ \boldsymbol{C}_{k} = \exp \left( \mathrm{g}_{k} + m_{k-1}^{(\text{inter})} - m_{k}^{(\text{inter})} \right) \boldsymbol{C}_{k-1} + \left( \exp \left( \mathbf{a}_{k} - m_{k}^{(\text{inter})} \right) \odot \boldsymbol{K}^{(k)} \right)^{\top} \boldsymbol{V}^{(k)} $$ $$\boldsymbol{n}_{k}=\exp \left(\mathrm{g}_{k}+m_{k-1}^{(\text {inter })}-m_{k}^{(\text {inter })}\right) \boldsymbol{n}_{k-1}+\left(\exp \left(\mathbf{a}_{k}-m_{k}^{(\text {inter })}\right) \odot \boldsymbol{K}^{(k)}\right)^{\top} \mathbf{1} .$$

我们可以将块间循环简化为:

$$\begin{aligned} \begin{aligned} \boldsymbol{C}_{k} &=\bar{\text{g}}_{k} \boldsymbol{C}_{k-1}+\left(\overline{\mathbf{a}}_{k} \odot \boldsymbol{K}^{(k)}\right)^{\top} \boldsymbol{V}^{(k)} &=\bar{\text{g}}_{k} \boldsymbol{C}_{k-1}+\overline{\boldsymbol{K}}^{(k) \top} \boldsymbol{V}^{(k)} \\ \boldsymbol{n}_{k} &=\bar{\text{g}}_{k} \boldsymbol{n}_{k-1}+\left(\overline{\mathbf{a}}_{k} \odot \boldsymbol{K}^{(k)}\right)^{\top} \mathbf{1} &=\bar{\text{g}}_{k} \boldsymbol{n}_{k-1}+\overline{\boldsymbol{K}}^{(k) \top} \boldsymbol{V}^{(k)} . \end{aligned} \end{aligned}$$

其中运行的最大值状态已整合到门控中。

块内并行贡献。循环部分之后是块内并行贡献,由以下公式给出:

$$\begin{aligned} \widetilde{\mathbf{D}}^{(k)}=\begin{cases}-\infty & \text { for } i<j \\ \mathbf{b}_k-\mathbf{b}_k^{\top}+\bar{\mathbf{i}}^{(k)^{\top}} & \text { for } i \geqslant j\end{cases} \end{aligned}$$ $$\boldsymbol{m}_{k}^{(\text {intra) }}=\max _{j} \widetilde{\mathbf{D}}_{i j}^{(k)}$$ $$\mathbf{D}^{(k)}=\exp (\widetilde{\mathbf{D}}^{(k)}-\boldsymbol{m}_{k}^{(\text {intra })})$$ $$\mathbf{S}^{(k)}=\frac{1}{\sqrt{d_{qk}}}\boldsymbol{Q}^{(k)}\boldsymbol{K}^{(k)\top}$$ $$\overline{\mathbf{S}}^{(k)}=\mathbf{S}^{(k)} \odot \mathbf{D}^{(k)}$$ <p>其中 exp 是逐元素作用的。

输出计算。块内并行部分的贡献 $H^{(k)}{intra}$ 与块间循环部分的贡献 $H^{(k)}$(见图 1):}$ 相结合,得到每个块 k 的隐藏状态 $H^{(k)

$$m_k^{(\text{combine})} = \max \left\{ \mathbf{b}_k + m_{k-1}^{(\text{inter})}, m_k^{(\text{intra})} \right\}$$ $$\mathbf{H}_{\text {inter }}^{(k)}=\left(\exp \left(\mathbf{b}_{k}+m_{k-1}^{(\text {inter })}-m_{k}^{(\text {combine })}\right) \odot \frac{\boldsymbol{Q}^{(k)}}{\sqrt{d_{q k}}}\right) \boldsymbol{C}_{k-1}$$ $$= \left( \mathbf{\bar{b}}_k \odot \frac{\boldsymbol{Q}^{(k)}}{\sqrt{d_{qk}}} \right) \boldsymbol{C}_{k-1}$$ $$=\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{C}_{k-1}$$ $$\mathbf{H}_{\text{intra}}^{(k)}=\overline{\mathbf{S}}^{(k)} \boldsymbol{V}^{(k)}$$ $$\begin{aligned} \begin{aligned} \mathbf{H}^{(k)} & = \frac{\left(\overline{\mathbf{b}}_{k} \odot\left(\boldsymbol{Q}^{(k)} / \sqrt{d_{q k}}\right)\right) \boldsymbol{C}_{k-1}+\overline{\mathbf{S}}^{(k)} \boldsymbol{V}^{(k)}}{\max \left\{\left|\left(\overline{\mathbf{b}}_{k} \odot\left(\boldsymbol{Q}^{(k)} / \sqrt{d_{q k}}\right)\right) \boldsymbol{n}_{k-1}+\overline{\mathbf{S}}^{(k)} \mathbf{1}\right|, \exp \left(-\boldsymbol{m}_{k}^{(\text {combine })}\right)\right\}} \\ & = \frac{\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{C}_{k-1}+\overline{\mathbf{S}}^{(k)} \boldsymbol{V}^{(k)}}{\max \left\{\left|\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{n}_{k-1}+\overline{\mathbf{S}}^{(k)} \mathbf{1}\right|, \exp \left(-\boldsymbol{m}_{k}^{(\text {combine })}\right)\right\}} \\ & = \left(\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{C}_{k-1}+\overline{\mathbf{S}}^{(k)} \boldsymbol{V}^{(k)}\right) / \mathbf{h}_{\mathrm{denom}}^{(k)} . \end{aligned} \end{aligned}$$

分块并行反向传播

概述。本节我们提供分块并行 mLSTM 反向传播的详细描述。

通过归一化状态的梯度。遵循 Sun 等人【索引[64],Retentive network: A successor to transformer for large language models,2023,ArXiv】的做法,我们不计算通过归一化状态 n 的梯度。由于 mLSTM 单元隐藏状态 H 上的 Layer- 或 RMS-Norm,这些梯度会相互抵消,因为归一化状态在嵌入或特征维度(即归一化维度)上是常数。

块间循环反向传播。给定来自下一个块的输入内存单元状态梯度 $\delta C_k$ 和块 k 的隐藏状态输出梯度 $\delta H^{(k)}$,我们可以计算块间循环反向传播。块间循环部分的查询、键和值梯度 $\delta Q^{(k)}{inter}$, $\delta K^{(k)}$ 计算如下:}$ 和 $\delta V^{(k)}_{inter

$$\delta \widetilde{\mathbf{H}}^{(k)}=\frac{\delta \mathbf{H}^{(k)}}{\mathbf{h}_{\text {denom }}^{(k)}}$$ $$\delta \boldsymbol{V}_{\text {inter }}^{(k)}=\overline{\boldsymbol{K}}^{(k)} \delta \boldsymbol{C}_{k}$$ $$\delta\overline{\boldsymbol{K}}^{(k)}=\boldsymbol{V}^{(k)}\delta\boldsymbol{C}_k^\top$$ $$\delta \boldsymbol{K}_{\mathrm{inter}}^{(k)}=\delta \overline{\boldsymbol{K}}^{(k)} \odot \overline{\boldsymbol{a}}_{k} \mathbf{1}^{\top}$$ $$\delta \overline{\boldsymbol{Q}}^{(k)}=\delta \widetilde{\mathbf{H}}^{(k)} \boldsymbol{C}_{k-1}^{\top}$$ $$\delta \boldsymbol{Q}_{\text {inter }}^{(k)}=\frac{1}{\sqrt{d_{q k}}} \delta \overline{\boldsymbol{Q}}^{(k)} \odot \overline{\boldsymbol{b}}_{k} \mathbf{1}^{\top}$$

内存单元状态梯度 $\delta C_{k-1}$ 接收来自下一个时间步的贡献 $\delta C^{(rec)}{k-1}$ 和输出的贡献 $\delta C^{(out)}$。它们由以下公式给出:

$$\begin{aligned} \begin{aligned} \delta \boldsymbol{C}_{k-1} & =\delta \boldsymbol{C}_{k-1}^{(\mathrm{rec})}+\delta \boldsymbol{C}_{k-1}^{(\mathrm{out})} \\ & =\overline{\mathrm{g}} \odot \delta \boldsymbol{C}_k+\overline{\boldsymbol{Q}}^{(k)^{\top}} \delta \widetilde{\mathbf{H}}^{(k)} \end{aligned} \end{aligned}$$

最后,我们可以计算块 k 的累积门梯度 $\delta g_k, \delta a_k$ 和 $\delta b_k$ 如下:

$$\delta \overline{\text{g}}_{k}=\mathbf{1}^{\top}\left(\boldsymbol{C}_{k-1} \odot \delta \boldsymbol{C}_{k}\right) \mathbf{1}$$ $$\delta \mathrm{g}_{k}=\delta \overline{\mathrm{g}}_{k} \odot \overline{\mathrm{g}}_{k}$$ $$\delta\overline{\boldsymbol{a}}_{k}=(\delta\overline{\boldsymbol{K}}^{(k)}\odot\boldsymbol{K}^{(k)})\boldsymbol{1}$$ $$\delta \boldsymbol{a}_{k}=\delta \overline{\boldsymbol{a}}_{k} \odot \overline{\boldsymbol{a}}_{k}$$ $$\delta \overline{\boldsymbol{b}}_{k} = (\delta \overline{\boldsymbol{Q}}^{(k)} \odot \frac{\boldsymbol{Q}^{(k)}}{\sqrt{d_{qk}}}) \mathbf{1}$$ $$\delta \boldsymbol{b}_{k}=\delta \overline{\boldsymbol{b}}_{k} \odot \overline{\boldsymbol{b}}_{k} .$$

块内并行反向传播。给定 mLSTM 隐藏状态输出梯度 $\delta H^{(k)}$,块内查询、键和值的梯度 $\delta Q^{(k)}{intra}, \delta K^{(k)}$ 计算如下:}$ 和 $\delta V^{(k)}_{intra

$$\delta \widetilde{\mathbf{H}}^{(k)}=\frac{\delta \mathbf{H}^{(k)}}{\mathbf{h}_{\text {denom }}^{(k)}}$$ $$\mathbf{S}^{(k)}=\frac{1}{\sqrt{d_{q k}}} \boldsymbol{Q}^{(k)} \boldsymbol{K}^{(k)\top}$$ $$\overline{\mathbf{S}}^{(k)}=\mathbf{S}^{(k)} \odot \mathbf{D}^{(k)}$$ $$\delta \boldsymbol{V}_{\mathrm{intra}}^{(k)}=\overline{\mathbf{S}}^{(k)\top} \delta \widetilde{\mathbf{H}}^{(k)}$$ $$\delta \overline{\mathbf{S}}^{(k)}=\delta \widetilde{\mathbf{H}}^{(k)} \mathbf{V}^{(k) \top}$$ $$\delta \mathbf{S}^{(k)}=\delta \overline{\mathbf{S}}^{(k)} \odot \mathbf{D}^{(k)}$$ $$\delta \boldsymbol{Q}_{\mathrm{intra}}^{(k)}=\frac{1}{\sqrt{d_{q k}}} \delta \mathbf{S}^{(k)} \boldsymbol{K}^{(k)}$$ $$\delta \boldsymbol{K}_{\text {intra }}^{(k)}=\frac{1}{\sqrt{d_{q k}}} \delta \boldsymbol{S}^{(k) \top} \boldsymbol{Q}^{(k)}$$

为了计算累积的块内门梯度,我们计算通过门矩阵 $D^{(k)}$ 的梯度,该矩阵由累积遗忘门计算得出:

$$\boldsymbol{b}_{k}^{(q)}=\operatorname{cumsum}\left(\overline{\mathbf{f}}_{q}^{(k)}\right) \in \mathbb{R}^{L_{q}}$$ $$\boldsymbol{b}_{k}^{(k v)}=\operatorname{cumsum}\left(\overline{\mathbf{f}}_{k v}^{(k)}\right) \in \mathbb{R}^{L_{k v}},$$

其中我们使用对数遗忘门 $\bar{f} = \log \sigma(\tilde{f})$。我们分别用 $L_q$ 和 $L_{kv}$ 表示查询和键值维度。省略掩码操作,我们计算门矩阵如下:

$$\mathbf{D}^{(k)}=\boldsymbol{b}_k^{(q)} \mathbf{1}_{k v}^{\top}-\mathbf{1}_q \boldsymbol{b}_k^{(k v) \top}+\mathbf{1}_q \overline{\mathbf{i}}_{k v}^{(k) \top},$$

其中 $1_q \in \mathbb{R}^{L_q}$ 和 $1_{kv} \in \mathbb{R}^{L_{kv}}$ 是用于表示广播操作的全一向量,而 $i^{(k)}{kv} \in \mathbb{R}^{L$ 是块 k 的对数输入门。梯度计算如下:}

$$\delta \mathbf{D}^{(k)}=\delta \overline{\mathbf{S}}^{(k)} \odot \mathbf{S}^{(k)}$$ $$\delta \boldsymbol{b}_{k}^{(q)}=\delta \mathbf{D}^{(k)} \mathbf{1}_{k v}$$ $$\delta \boldsymbol{b}_{k}^{(k v)}=-\delta \mathbf{D}^{(k) \top} \mathbf{1}_{q}$$ $$\delta \bar{\mathbf{i}}_{k v}^{(k)}=\delta \mathbf{D}^{(k) \top} \mathbf{1}_{q} .$$

合并的输入和门梯度。块内和块间的梯度通过将它们的贡献相加来合并。对于查询、键和值的梯度,这得到:

$$\delta Q^{(k)} = \delta Q_{\text{inter}}^{(k)} + \delta Q_{\text{intra}}^{(k)}$$ $$\delta \boldsymbol{K}^{(k)}=\delta \boldsymbol{V}_{\text {inter }}^{(k)}+\delta \boldsymbol{K}_{\text {intra }}^{(k)}$$ $$\delta \boldsymbol{V}^{(k)}=\delta \boldsymbol{V}_{\text {inter }}^{(k)}+\delta \boldsymbol{V}_{\text {intra }}^{(k)}$$

输入和遗忘门的梯度 $\bar{i}^{(k)}$ 和 $\bar{f}^{(k)}$ 可以从累积门梯度 $\delta g_k, \delta b_k$ 和 $\delta a_k$ 通过以下等式计算得出:

$$\delta\bar{\mathbf{f}}^{(k)}=\delta g_k$$ $$\delta \overline{\mathbf{f}}^{(k)}=\operatorname{rev\_cumsum}\left(\delta \boldsymbol{b}_{k}\right)$$ $$\delta\bar{\mathbf{f}}^{(k)} = \text{rev\_cumsum}(\delta\boldsymbol{a}_k)$$ $$\delta \bar{\dot{\mathbf{i}}}^{(k)}=\delta \boldsymbol{a}_k$$

Tiled Flash Linear Attention 扩展

分块计算

分块计算过程。对于块内隐藏状态贡献 $H_{intra}$ 在一个块内的分块计算,我们考虑矩阵 $S = [S^{(1)}, S^{(2)}]$ 和门矩阵 $D = [D^{(1)}, D^{(2)}]$ 的块,其中 $S^{(i)}, D^{(i)} \in \mathbb{R}^{B_{L_{hq}} \times B_{L_{kv}}}$。这里的上标 i 表示沿 $L_{kv}$ 维度的块索引(而非块 k 的索引)。类似地,我们考虑值矩阵 $V = [V^{(1)}, V^{(2)}]$ 的块,其中 $V^{(i)} \in \mathbb{R}^{B_{L_{kv}} \times B_{d_{hv}}}$。然后,我们累加未归一化的隐藏状态块 $H^{(i)}{intra,num} \in \mathbb{R}^{B$ 如下:}} \times B_{d_{hv}}}$ 和相应的归一化器 $l^{(i)} \in \mathbb{R}^{B_{L_{kv}}

$$m^{(1)}=\max_j \widetilde{\mathbf{D}}_{ij}^{(1)}$$ $$\boldsymbol{l}^{(1)}=\left(\mathbf{S}^{(1)} \odot \exp \left(\widetilde{\mathbf{D}}^{(1)}-\boldsymbol{m}^{(1)}\right)\right) \mathbf{1}$$ $$\mathbf{H}_{\mathrm{intra,num}}^{(1)} = (\mathbf{S}^{(1)} \odot \exp(\widetilde{\mathbf{D}}^{(1)} - \boldsymbol{m}^{(1)})) \boldsymbol{V}^{(1)}$$ $$ \boldsymbol{m}^{(2)}=\max \left(\boldsymbol{m}^{(1)}, \max _{j} \widetilde{\mathbf{D}}_{i j}^{(2)}\right) $$ $$l^{(2)} = \exp(m^{(1)} - m^{(2)}) \ l^{(1)} + (\mathbf{S}^{(2)} \odot \exp(\widetilde{\mathbf{D}}^{(2)} - m^{(2)})) \ \mathbf{1}$$ $$\mathbf{H}_{\text{intra,num}}^{(2)} = \exp(\boldsymbol{m}^{(1)} - \boldsymbol{m}^{(2)}) \mathbf{H}_{\text{intra,num}}^{(1)} + (\mathbf{S}^{(2)} \odot \exp(\widetilde{\mathbf{D}}^{(2)} - \boldsymbol{m}^{(2)})) \boldsymbol{V}^{(2)}$$

最终输出组合。在计算完这个块内部分后,我们需要再进行一次重新缩放步骤,以组合隐藏状态输出 $H^{(k)}$ 的块内和块间部分,因为 $H^{(k)}{intra}$ 和 $H^{(k)}$ 如下:}$ 是用不同的最大值状态计算的。因此,我们计算最终的隐藏状态输出 $H^{(k)

$$\begin{aligned} \begin{aligned} \boldsymbol{m}_{k}^{(\text {combine})} & =\max \left\{\mathbf{b}_{k}+m_{k-1}^{(\text {inter})}, \boldsymbol{m}_{k}^{(2)}\right\} \\ \mathbf{H}^{(k)} & =\frac{\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{C}_{k-1}+\exp \left(\boldsymbol{m}_{k}^{(2)}-\boldsymbol{m}_{k}^{(\text {combine})}\right) \overline{\mathbf{S}}^{(k)} \boldsymbol{V}^{(k)}}{\max \left\{\left|\overline{\boldsymbol{Q}}^{(k)} \boldsymbol{n}_{k-1}+\exp \left(\boldsymbol{m}_{k}^{(2)}-\boldsymbol{m}_{k}^{(\text {combine})}\right) l_{k}^{(2)}\right|, \exp \left(-\boldsymbol{m}_{k}^{(\text {combine})}\right)\right\}}, \end{aligned} \end{aligned}$$

其中我们假设 $m^{(2)}k$ 是块 k 的块内计算的最后一个 $B_k$ 是相应的归一化器。}}$ 块之后的最大值,而 $l^{(2)

TFLA 前向传播

算法描述。为简化符号,我们省略了查询、键和值矩阵的 k 索引,分别表示为 $Q \in \mathbb{R}^{L_{hq} \times d_{qk}}$, $K \in \mathbb{R}^{L_{kv} \times d_{qk}}$ 和 $V \in \mathbb{R}^{L_{kv} \times d_v}$。我们利用重新加权(如附录 C.1 中讨论)来跟踪门矩阵瓦片上的最大值,这与【索引[11],FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022,NeurIPS】类似。TFLA 对于一个线程块的前向传播算法在算法 1 中描述。

实现细节。请注意,算法 1 中第 27 行的循环与第 6 行的循环相同。在这两个循环中,我们加载了矩阵 Q 的相同块。融合这两个循环可以避免两次加载这些数据。不幸的是,由于在线计算循环中(第 4 行)的最大值 $m_{old}$ 和 $m_{new}$,以及 $m^{(combine)}k$ 和 $b_k$ 对最终 $m$ 的依赖性(见附录 D.1 和 C.1),高效地融合这些循环存在问题。我们在第 4 节通过修改 mLSTM 的输入门来解决这个问题。

算法1. mLSTMexp 的 TFLA 块内前向传播 ($H^{(k)}$ 核函数)

需要:矩阵 Q ∈ R^(Lhq×dqk), K ∈ R^(Lkv×dqk), V ∈ R^(Lkv×dhv)。状态 Ck−1 ∈ R^(dqk×dv), nk−1 ∈ R^dqk。输入门和累积遗忘门向量 ik, bk ∈ R^Lhq。块大小 Bdqk, Bdhv, BLhq 和 BLkv,其中 BLhq ⩾ BLkv。块 Q 索引 iLq 和块 HV 索引 idhv。
1: 在 SRAM 中初始化 mold, mnew ∈ R^Lq 为 −∞。 ▷ 计算块内贡献
2: 在 SRAM 中初始化累加器 Hintra ∈ R^(BLhq×Bdv) 和 n(intra) ∈ R^BLhq。
3: 从 HBM 加载 b(q)k ∈ R^BLhq 到 SRAM。
4: for i = 1 to j (iLq+1)·BLhq k do BLkv
5:   在 SRAM 中初始化累加器 S(i) ∈ R^(BLhq×BLkv)。
6:   for j = 1 to l dqk/Bdqk m do
7:     从 HBM 加载 Q(j) ∈ R^(BLhq×Bdqk) 和 K(j) ∈ R^(BLkv×Bdqk) 到 SRAM。
8:     累加 S(i) += Q(j)K(j)⊤。
9:   end for
10:  从 HBM 加载 b(kv)k ∈ R^BLkv 和 i(kv)k ∈ R^BLkv 到 SRAM。
11:  计算 De(i) = b(q)k - b(kv)k⊤ + i(kv)k⊤ ∈ R^(BLhq×BLkv)。
12:  if i · BLkv ⩾ iLq · BLhq then
13:    对 De(i) 应用因果掩码。
14:  end if
15:  计算 mnew = maximum{mold, rowmax De(i)}。
16:  计算 D(i) = exp(De(i) - mnew)。
17:  计算 S(i) = 1/sqrt(dqk) * S(i) ⊙ D(i)。
18:  为块 idhv 从 HBM 加载 V(i) ∈ R^(BLkv×Bdhv) 到 SRAM。
19:  累加 Hintra = exp(mold - mnew) · Hintra + S(i)V。
20:  累加 n(intra) = exp(mold - mnew) · n(intra) + S(i)1。
21:  更新 mold = mnew。
22: end for
▷ 计算块间贡献
23: 从 HBM 加载 m(inter)k−1 ∈ R 到 SRAM。
24: 计算 m(combine)k = maximum{ b(q)k + m(i-1)(inter)k-1, mnew }。
25: 计算 bk = exp( b(q)k + mk-1(inter) - mk(combine) )
26: 为块 idhv 在 SRAM 中初始化累加器 Hinter ∈ R^(BLhq×Bdhv) 和 n(inter) ∈ R^BLhq。
▷ 注意:这与上面的内循环相同。由于最大值状态的计算,它们不能合并。
27: for j = 1 to dqk/Bdqk do
28:   为块 idhv 从 HBM 加载 Q(j) ∈ R^(BLhq×Bdqk) 和 C(j)k-1 ∈ R^(Bdqk×Bdhv) 到 SRAM。
29:   计算 Q(j) = 1/sqrt(dqk) * Q(j) ⊙ b(q)k。
30:   累加 Hinter += Q(j)C(j)k-1。
31:   加载 n(j)k-1 ∈ R^Bdqk。
32:   累加 n(inter) += Q(j)n(j)k-1。
33: end for
▷ 组合块间和块内贡献
34: 计算 H(comb) = Hintra + exp(mnew - m(combine)) * Hinter。
35: 计算 n(comb) = n(intra) + exp(mnew - m(combine)) * n(inter)。
36: 计算 H(k) = H / max{|n(comb)|, exp(-m(combine)k)}。
37: 将 H(k), n(comb) 和 m(combine)k 存储到 HBM。

TFLA 反向传播

反向传播梯度计算。对于 TFLA 反向传播,我们需要计算查询、键和值的梯度 $\delta Q^{(k)}, \delta K^{(k)}$ 和 $\delta V^{(k)}$。省略门控计算和归一化,我们可以将这些梯度的简化版本写为:

$$\begin{aligned} \begin{aligned} \underset{(L_{hq} \times d_{qk})}{\delta \boldsymbol{Q}^{(k)}} &= \underbrace{\left( \underset{(L_{hq} \times d_{hv})}{\delta \boldsymbol{H}^{(k)}} \underset{(d_{hv} \times L_{kv})}{\boldsymbol{V}^{(k)\top}} \right) \underset{(L_{kv} \times d_{qk})}{\boldsymbol{K}^{(k)}}}_{\delta \boldsymbol{Q}_{\text{intra}}^{(k)}} + \underbrace{\underset{(L_{hq} \times d_{hv})}{\delta \boldsymbol{H}^{(k)}} \underset{(d_{hv} \times d_{qk})}{\boldsymbol{C}_{k-1}^\top}}_{\delta \boldsymbol{Q}_{\text{inter}}^{(k)}} \\ \underset{(L_{kv} \times d_{qk})}{\delta \boldsymbol{K}^{(k)}} &= \underbrace{\left( \underset{(L_{kv} \times d_{hv})}{\boldsymbol{V}^{(k)}} \underset{(d_{hv} \times L_{hq})}{\delta \boldsymbol{H}^{(k)\top}} \right) \underset{(L_{hq} \times d_{qk})}{\boldsymbol{Q}^{(k)}}}_{\delta \boldsymbol{K}_{\text{intra}}^{(k)}} + \underbrace{\underset{(L_{kv} \times d_{hv})}{\boldsymbol{V}^{(k)}} \underset{(d_{hv} \times d_{qk})}{\delta \boldsymbol{C}_k^\top}}_{\delta \boldsymbol{K}_{\text{inter}}^{(k)}} \\ \underset{(L_{kv} \times d_{hv})}{\delta \boldsymbol{V}^{(k)}} &= \underbrace{\left( \underset{(L_{kv} \times d_{qk})}{\boldsymbol{K}^{(k)}} \underset{(d_{qk} \times L_{hq})}{\boldsymbol{Q}^{(k)\top}} \right) \underset{(L_{hq} \times d_{hv})}{\delta \boldsymbol{H}^{(k)}}}_{\delta \boldsymbol{V}_{\text{intra}}^{(k)}} + \underbrace{\underset{(L_{kv} \times d_{qk})}{\boldsymbol{K}^{(k)}} \underset{(d_{qk} \times d_{hv})}{\delta \boldsymbol{C}_k}}_{\delta \boldsymbol{V}_{\text{inter}}^{(k)}} \end{aligned} \end{aligned}$$

工作分区。我们可以看到,查询、键和值的每个梯度都具有与前向传播中公式 (20) 相似的结构。它们可以使用相同的工作分区方案进行计算,即在矩阵乘法的外部块大小和外部嵌入维度上进行并行化,并在内部维度上进行循环。例如,对于键的梯度 $\delta K^{(k)}$,我们在外部块大小 $L_{kv}$ 和外部嵌入维度 $d_{qk}$ 上进行并行化,并在内部维度 $L_{hq}$ 和 $d_{hv}$ 上进行循环。表 1 总结了前向和反向传播核函数的 TFLA 工作分区方案。

带 Sigmoid 输入门的 mLSTM 扩展

指数输入门的稳定化

稳定化机制概述。在本节中,我们展示指数输入门是如何通过最大值状态 $m_t$ 来稳定化的【索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS】。这种稳定化基于安全 Softmax 的思想【索引[28],Online normalizer calculation for softmax,2018,ArXiv】。我们将看到,最大值状态稳定化确保了指数输入门激活函数的参数总是小于 1。我们还将看到,归一化器状态保证了最大值状态的抵消,从而使 mLSTM 的整体输出不受最大值状态的影响。

推导过程。在没有稳定化的情况下,mLSTM 的隐藏状态输出计算如下:

$$\boldsymbol{h}_t = \tilde{\boldsymbol{o}}_t \odot \frac{\boldsymbol{C}_t^{\top} \boldsymbol{q}_t}{\max \left\{\left|\boldsymbol{n}_t^{\top} \boldsymbol{q}_t\right|, 1\right\}},$$

其中我们省略了 q 的缩放因子 $\sqrt{d_{qk}}$。为了简化,我们还省略了分母中点积的下界和绝对值。我们得到:

$$\boldsymbol{h}_t = \sigma(\tilde{\mathbf{o}}_t) \odot \frac{\boldsymbol{C}_t^\top \boldsymbol{q}_t}{\boldsymbol{n}_t^\top \boldsymbol{q}_t}.$$

代入内存单元状态 $C_t$ 和归一化器状态 $n_t$ 的更新公式,得到:

$$\boldsymbol{h}_t=\sigma\left(\tilde{\mathbf{o}}_t\right) \odot \frac{\left(\sigma\left(\tilde{\mathbf{f}}_t\right) \boldsymbol{C}_{t-1}+\exp \left(\tilde{\mathbf{i}}_t\right) \boldsymbol{k}_t \boldsymbol{v}_t^{\top}\right)^{\top} \boldsymbol{q}_t}{\left(\sigma\left(\tilde{\mathbf{f}}_t\right) \boldsymbol{n}_{t-1}+\exp \left(\tilde{\mathbf{i}}_t\right) \boldsymbol{k}_t\right)^{\top} \boldsymbol{q}_t}.$$

我们现在展示如何从这个非稳定化的 mLSTM 版本分三步推导出稳定化的形式。首先,我们使用恒等式 $\sigma(\tilde{i}) = \exp(\log(\sigma(\tilde{f}_t)))$,将公式 (108) 中的分数乘以 $\exp(-m_t)$,并选择 $m_t = \max{\log(\sigma(\tilde{f}_t)), \tilde{i}_t}$ 作为两个指数函数参数的最大值。这得到:

$$ \boldsymbol{h}_{t}=\sigma\left(\tilde{\mathbf{o}}_{t}\right) \odot \frac{\boldsymbol{C}_{t}^{\top} \boldsymbol{q}_{t} \cdot \exp \left(-m_{t}\right)}{\boldsymbol{n}_{t}^{\top} \boldsymbol{q}_{t} \cdot \exp \left(-m_{t}\right)}=\sigma\left(\tilde{\mathbf{o}}_{t}\right) \odot \frac{\left(\exp \left(\log \left(\sigma\left(\tilde{\mathrm{f}}_{t}\right)\right)-m_{t}\right) \boldsymbol{C}_{t-1}+\exp \left(\tilde{\mathrm{i}}_{t}-m_{t}\right) \boldsymbol{k}_{t} \boldsymbol{v}_{t}^{\top}\right)^{\top} \boldsymbol{q}_{t}}{\left(\exp \left(\log \left(\sigma\left(\tilde{\mathrm{f}}_{t}\right)\right)-m_{t}\right) \boldsymbol{n}_{t-1}+\exp \left(\tilde{\mathrm{i}}_{t}-m_{t}\right) \boldsymbol{k}_{t}\right)^{\top} \boldsymbol{q}_{t}} . $$

通过这种方式,我们确保指数函数的参数总是小于 1,从而永远不会因数值过大而发生溢出。下一步,我们将 $C_t$ 和 $n_t$ 重新参数化为 $\tilde{C}_t$ 和 $\tilde{n}_t$。

$$\begin{aligned} \begin{aligned} \widetilde{\boldsymbol{C}}_{t} & =\boldsymbol{C}_{t} \exp \left(-m_{t}\right) & \rightarrow \widetilde{\boldsymbol{C}}_{t-1} & =\boldsymbol{C}_{t-1} \exp \left(-m_{t-1}\right) \Leftrightarrow \boldsymbol{C}_{t-1}=\widetilde{\boldsymbol{C}}_{t-1} \exp \left(m_{t-1}\right) \\ \widetilde{\boldsymbol{n}}_{t} & =\boldsymbol{n}_{t} \exp \left(-m_{t}\right) & \rightarrow \widetilde{\boldsymbol{n}}_{t-1} & =\boldsymbol{n}_{t-1} \exp \left(-m_{t-1}\right) \Leftrightarrow \boldsymbol{n}_{t-1}=\widetilde{\boldsymbol{n}}_{t-1} \exp \left(m_{t-1}\right) \end{aligned} \end{aligned}$$

最后,我们在递推关系中用稳定化的状态 $C_t$ 和 $n_t$ 替换原来的状态。我们得到:

$$\begin{aligned} \begin{aligned} \boldsymbol{h}_{t} & =\sigma\left(\tilde{\mathbf{o}}_{t}\right) \odot \frac{\left(\exp \left(\log \left(\sigma\left(\tilde{\mathrm{f}}_{t}\right)\right)+m_{t-1}-m_{t}\right) \widetilde{\boldsymbol{C}}_{t-1}+\exp \left(\tilde{\mathrm{i}}_{t}-m_{t}\right) \boldsymbol{k}_{t} \boldsymbol{v}_{t}^{\top}\right)^{\top} \boldsymbol{q}_{t}}{\left(\exp \left(\log \left(\sigma\left(\tilde{\mathrm{f}}_{t}\right)\right)+m_{t-1}-m_{t}\right) \tilde{\boldsymbol{n}}_{t-1}+\exp \left(\tilde{\mathrm{i}}_{t}-m_{t}\right) \boldsymbol{k}_{t}\right)^{\top} \boldsymbol{q}_{t}} \\ & =\sigma\left(\tilde{\mathbf{o}}_{t}\right) \odot \frac{\widetilde{\boldsymbol{C}}_{t}^{\top} \boldsymbol{q}_{t}}{\tilde{\boldsymbol{n}}_{t}^{\top} \boldsymbol{q}_{t}} \end{aligned} \end{aligned}$$

现在我们选择最大值状态为 $m_t = \max{\log(\sigma(\tilde{f}t)) + m}, \tilde{it}$,并通过将分母更改为 $\max{|n_t^T q_t|, \exp(m)$,以便它能够抵消。总而言之,我们看到归一化器对于抵消最大值状态是必要的,而指数输入门的参数通过最大值状态得到了限制。})}$ 来得到稳定化的 mLSTM 公式。我们必须在最大值的右侧也加上 $\exp(m_{t-1

mLSTM 的实证传递行为分析

实验设置。我们提供第 4.2 节中 mLSTMexp 和 mLSTMsig 传递行为分析的细节。我们分析了单个头和单个长度为 T=512 的输入序列下 mLSTM 的传递行为。查询、键和值的输入 $q_t, k_t$ 和 $v_t$ 从标准正态分布 $N(0, 1)$ 中采样。我们将头维度设置为 $d_{qk}=128$ 和 $d_{hv}=128$。作为归一化层 NORM(x),我们使用 RMS-norm。将归一化层改为 layernorm 不会改变结果,因为在本实验中我们将输入的均值设置为零。对于每个图,我们测量了在输入和遗忘门预激活值分别在 [-12, 8] 和 [-5, 12] 范围内变化时,如 (26) 定义的增益 $G_{before}$ 和 $G_{after}$。

归一化层 Epsilon 对传递行为的影响。基于我们在 4.2 节中对门控线性 RNN 操作后归一化层的分析,我们假设归一化层,特别是其 epsilon $\epsilon$,是门控机制的一个组成部分。在本实验中,我们探究了 epsilon 值对 mLSTM 传递行为的影响。图 9a 和图 10a 分别展示了 mLSTMexp 和 mLSTMsig 在 $\epsilon = [1e-2, 1e-6, 1e-8]$ 时的传递行为。我们观察到,epsilon 对 mLSTMexp 和 mLSTMsig 的作用方式相同。增加 $\epsilon$ 会导致增益在 y 轴正方向上偏移,减小 $\epsilon$ 则导致在 y 轴负方向上偏移。我们将默认值设为 $\epsilon=1e-6$,这在我们的实验中取得了最佳性能(见 5.1 节)。

mLSTMexp 和 mLSTMsig 的归一化器。在本实验中,我们测试了公式 34 中不同归一化器 n 对 mLSTMexp 和 mLSTMsig 的影响。B.1 节中的并行公式是针对带指数输入门的 mLSTM 提出的,但同样适用于带 sigmoid 输入门的 mLSTM。对于默认的 mLSTMsig,我们将归一化器设为 n=1,并为 sigmoid 输入门修改了门矩阵 D 的计算。在图 9 中,我们展示了不同归一化器对带指数输入门的 mLSTM 的结果。只有默认的 mLSTMexp 带有正确的归一化器和最大值状态(如图 9a)时,才显示出依赖于输入门的传递行为。相比之下,在图 10a 和 10b 中我们观察到,将类似于 mLSTMexp 的归一化器(不包括最大值状态)并入 mLSTMsig 并不会改变其传递行为。图 10c 和 10d 中的另外两种 mLSTMsig 归一化器变体显示出明显不同的传递行为,并且未能成功训练。同样,图 9b 和 9c 中的变体也未能成功训练。总而言之,我们发现如果 mLSTM 表现出特征性的门依赖传递行为,它就能成功训练并在我们的语言建模实验中表现良好。为了使 mLSTMexp 达到这种行为,我们需要按照 D.1 节推导的方式正确归一化。向 mLSTMsig 添加归一化器不会改变性能和传递行为,只要该归一化器对点积 $n_t^T q_t$ 施加了下界。然而,我们的默认 mLSTMsig 省略了归一化器以降低计算成本和运行时间。

图9. 带指数输入门的 mLSTM 在不同归一化层 epsilon (EPS) 和不同归一化器变体下的传递行为。只有默认的归一化方式显示出输入门依赖的传递行为。改变归一化层 epsilon 会导致增益曲线在 y 轴方向上移动。
图9. 带指数输入门的 mLSTM 在不同归一化层 epsilon (EPS) 和不同归一化器变体下的传递行为。只有默认的归一化方式显示出输入门依赖的传递行为。改变归一化层 epsilon 会导致增益曲线在 y 轴方向上移动。
图10. 带 sigmoid 输入门的 mLSTM 在不同归一化层 epsilon (EPS) 和不同归一化器变体下的传递行为。从 mLSTMsig 中移除归一化器(这是我们在 (a) 中的默认设置)对传递行为没有影响。如果添加归一化器,它应该被限制在 1(见 (b))。改变归一化层 epsilon 会导致增益曲线在 y 轴方向上移动。
图10. 带 sigmoid 输入门的 mLSTM 在不同归一化层 epsilon (EPS) 和不同归一化器变体下的传递行为。从 mLSTMsig 中移除归一化器(这是我们在 (a) 中的默认设置)对传递行为没有影响。如果添加归一化器,它应该被限制在 1(见 (b))。改变归一化层 epsilon 会导致增益曲线在 y 轴方向上移动。

扩展实验

概述。在本节中,我们提供第 5 节的额外实验和细节。

TFLA 核函数的数值验证

验证困惑度匹配(表 3)。在开始语言建模实验之前,我们首先验证我们的核函数与一个基于完全并行公式的纯 JAX 参考实现(见附录 B.1)产生相同的结果。我们比较了在 190 亿个 token 上训练的 160M 参数 mLSTMexp 和 mLSTMsig 模型在训练结束时的验证困惑度。我们使用 4096 的上下文长度,因为并行的 JAX 实现在更长的上下文中会出现内存不足。模型架构和训练方案遵循我们在附录 E.2 中描述的通用设置。在表 3 中,我们确认我们的核函数与 JAX 参考实现产生相同的结果。

表3. 160M 参数模型在上下文长度 4096、训练 190 亿 token 后的验证困惑度。
表3. 160M 参数模型在上下文长度 4096、训练 190 亿 token 后的验证困惑度。

mLSTM 的扩展语言建模实验

实验设置与模型架构。本节我们提供我们的实验设置、模型架构和训练方案的详细信息,并增加了在上下文长度 8192 上的额外性能结果,以及分析了归一化层中 epsilon 参数的影响。我们使用 GPT-NeoX 分词器【索引[6],GPT-NeoX-20B: An open-source autoregressive language model,2022,ACL Workshop】(词汇量 50257),并且不共享嵌入层和最后(反嵌入)层的权重。每个块包含两层,每层都有一个跳跃连接和层输入前的归一化层(即我们使用 pre-norm 块架构)。我们使用 RMS-norm【索引[84],Root mean square layer normalization,2019,NeurIPS】作为归一化层,epsilon $\epsilon=1e-6$。我们对 logits 应用 soft-capping【索引[66],Gemma 2: Improving open language models at a practical size,2024,arXiv】,使 logits 的值保持在 -c 和 c 之间,其中 c=30。logits 通过以下函数进行限制:

$$\operatorname{softcap}(x)=c \cdot \tanh (x / c)$$

我们还对输入和遗忘门的预激活值应用了 soft-capping(公式 112),发现这能提高训练稳定性,此时 c=15。其余模型参数见表 4。

训练方案。我们使用 AdamW 优化器【索引[27],Decoupled weight decay regularization,2019,ICLR】,参数为 $\beta_1=0.9$, $\beta_2=0.95$ 和 $\epsilon=1e-8$。学习率和批次大小如表 4 所示。我们对所有线性层(包括最后的线性层或反嵌入层)应用 0.1 的权重衰减,并从权重衰减中排除了偏置和词嵌入。我们将梯度范数裁剪到 0.5。我们使用余弦学习率调度器,前 750 步进行线性预热,然后衰减到峰值学习率的 0.1,最后 1000 步线性冷却到 0。我们在预训练期间,通过在每个新文档开始时重置内存状态来确保信息不会跨文档边界泄露。我们通过在每个新文档开始时手动将遗忘门预激活值设置为一个大的负值来实现这一点。

表4. 我们的 160M、400M 和 1.4B 模型尺寸的训练和模型架构超参数。
表4. 我们的 160M、400M 和 1.4B 模型尺寸的训练和模型架构超参数。

额外性能结果(表 5)。在表 5 中,我们展示了 mLSTMexp 和 mLSTM 在上下文长度 8192 时的验证困惑度(上下文长度 4096 的结果见表 2)。对于某些头维度配置,我们在训练期间观察到不可恢复的梯度范数尖峰(用 - 表示)。

表5. 上下文长度 8192 时的验证困惑度。EXP 和 SIG 表示 mLSTMexp 和 mLSTMsig。LIMIT 和 XL 对应 limit_chunk 和 xl_chunk 核函数。- 表示该运行在训练期间经历了不可恢复的损失尖峰。
表5. 上下文长度 8192 时的验证困惑度。EXP 和 SIG 表示 mLSTMexp 和 mLSTMsig。LIMIT 和 XL 对应 limit_chunk 和 xl_chunk 核函数。- 表示该运行在训练期间经历了不可恢复的损失尖峰。

可训练输入门的影响(表 6)。我们研究了输入门对性能的影响。表 6 显示,对于 mLSTMexp 和 mLSTMsig,使输入门可学习都能持续提高性能。

表6. 160M mLSTM 在上下文长度 4096 下,使用可学习和固定输入门(偏置初始化为 -10)的验证困惑度。
表6. 160M mLSTM 在上下文长度 4096 下,使用可学习和固定输入门(偏置初始化为 -10)的验证困惑度。

输入门偏置初始化的影响(图 11 和 12)。在 4.2 节的传递行为分析中,我们发现在输入门值为约 -8 的负值时,存在从抑制信号到传递信号的转变(见图 4)。由于我们将门控的权重 $w_$ 初始化为 0,输入门和遗忘门的偏置决定了训练开始时在 x-y 平面上的实际位置。最初,当输入门偏置初始化为 0 时,我们观察到梯度范数方差很大,这在 mLSTMsig 中更为明显(见图 11a 和 12a)。因此,我们尝试将输入门偏置初始化为较大的负值。遗忘门偏置在 [3, 6] 范围内均匀初始化。随着权重 $w_$ 在训练中增长,门控预激活值也随之增长,模型可以学会逐渐进入图 4 的动态区域,在该区域输入信号被传递。确实,如图 11 和 12 所示,将输入门偏置初始化为 -10 能有效地缓解训练期间的梯度范数尖峰并减少高梯度范数方差,这对于 mLSTMexp 和 mLSTMsig 都适用。因此,我们得出结论,额外的输入门不仅能提升性能(见表 6),而且如果初始化正确,还能提高训练稳定性。

图11. 可训练和固定的指数输入门,偏置初始化为 [0, -2, -5, -10],归一化 epsilon ε =1e-6。
图11. 可训练和固定的指数输入门,偏置初始化为 [0, -2, -5, -10],归一化 epsilon ε =1e-6。
11-2
11-2
图12. 可训练和固定的 sigmoid 输入门,偏置初始化为 [0, -2, -5, -10],归一化 epsilon ε =1e-6。
图12. 可训练和固定的 sigmoid 输入门,偏置初始化为 [0, -2, -5, -10],归一化 epsilon ε =1e-6。

归一化层 Epsilon 对性能的影响(图 13)。在我们对 mLSTM 的经验性传递行为分析中(4.2 节和 D.2 节),我们发现传递行为取决于输入和遗忘门的预激活值,以及归一化层的 epsilon(见图 9a 和 10a)。因此,我们对带指数输入门的 mLSTM(160M 参数,6 个头,上下文长度 4096)进行了不同归一化层 epsilon 和输入门偏置初始化的网格搜索。结果如图 13 所示。我们观察到,从归一化层 epsilon 和输入门偏置 $(\epsilon, b_i)=(1e-6, -10)$ 到 $(1e-4, -5)$ 的对角线区域性能有所提升。这表明,如果我们增加归一化层的 epsilon,我们也可以或应该增加输入门的偏置初始化,正如增益曲线在图 9a 中较大 epsilon 时向 y 轴正方向移动所暗示的那样。这支持了我们在 4.2 节中的假设,即归一化层对门控机制很重要。我们使用 $(\epsilon, b_i)=(1e-6, -10)$ 作为我们的默认配置。

训练过程中的输入门激活值(图 14)。我们在图 14 中展示了 160M 参数的 mLSTMexp 和 mLSTMsig 在训练过程中最大的输入门预激活值(在批次、序列和头维度上取最大值)。两个模型的输入门偏置都初始化为 -10。

图13. 160M 参数、6 个头的 mLSTMexp 的验证困惑度。在归一化层 epsilon 和输入门偏置初始化上的网格搜索。性能提升的对角线区域表明,归一化层 epsilon 和输入门偏置初始化之间存在相互作用。这支持了归一化层对门控机制很重要的假设。
图13. 160M 参数、6 个头的 mLSTMexp 的验证困惑度。在归一化层 epsilon 和输入门偏置初始化上的网格搜索。性能提升的对角线区域表明,归一化层 epsilon 和输入门偏置初始化之间存在相互作用。这支持了归一化层对门控机制很重要的假设。

扩展核函数基准测试

基准测试细节。在本节中,我们提供基准测试设置的详细信息并添加额外的基准测试结果。在图 6 和图 15 中,我们使用 PyTorch 的 torch.cuda.max_memory_allocated API 来测量核函数在一次迭代期间分配的峰值内存。我们确保在每次迭代后重置内存统计数据,并在每次基准测试开始前清除 PyTorch 缓存。所有核函数均以 bfloat16 精度运行。我们报告了 10 次预热迭代后 30 次迭代的中位数运行时间。

基准测试设置(图 5)。我们使用的基准测试模型包括:
- mLSTMexp (FLA limit_chunk):我们自己基于 FLA 的 mLSTMexp 基线核函数,块大小受限,仅使用单级序列并行。我们报告了性能最佳的块大小 64。
- mLSTMexp (TFLA xl_chunk):用于 mLSTMexp 的 TFLA 核函数,具有两级序列并行。我们设置了性能最佳的块大小 128。
- mLSTMsig (TFLA xl_chunk):用于 mLSTMsig 的 TFLA 核函数。我们设置了块大小 128,但发现块大小 256 在运行时间上表现同样出色。
- cuDNN FlashAttention:PyTorch 2.5.1 中集成的 NVIDIA cuDNN FlashAttention 2 实现。
- FlashAttention 3:针对 NVIDIA H100 GPU 优化的 FlashAttention 3 实现。
- GLA (FLA):基于 FlashLinearAttention 算法的门控线性注意力 Triton 核函数,具有一级序列并行。
- Simple GLA (FLA):GLA 的一个简化版本,每个头有标量遗忘门。它与 mLSTMsig 相似但没有输入门。
- Mamba & Mamba 2:来自官方 Mamba 仓库的 CUDA 和 Triton 核函数实现。

结果分析。图 5 的结果显示,在仅比较前向传播运行时,Mamba2 和 Simple GLA 核函数略快于我们的 mLSTMsig 核函数。然而,在训练中(测量前向和后向传播运行时),我们的 TFLA 核函数在较长序列长度上比 FlashAttention 3 更快,并且在所有序列长度上比 Mamba 2 核函数快两倍以上。只有 Simple GLA (FLA) 在训练速度上能与我们的 TFLA mLSTM 核函数相媲美。然而,我们在图 15 中发现,这需要付出近 2 倍的 GPU 内存使用量。这些内存节省是通过利用更大的块大小实现的,这得益于 3.2 节中概述的两级序列并行。

图14. 160M 参数的 mLSTMexp 和 mLSTMsig 在训练过程中的最大输入门预激活值 ˜it。最大值取自批次、序列和头维度。两个模型的输入门偏置都初始化为 -10。在大多数情况下,输入门预激活值保持在零以下。
图14. 160M 参数的 mLSTMexp 和 mLSTMsig 在训练过程中的最大输入门预激活值 ˜it。最大值取自批次、序列和头维度。两个模型的输入门偏置都初始化为 -10。在大多数情况下,输入门预激活值保持在零以下。

与 FLA 核函数的运行时和内存比较(图 15)。在这个实验中,我们将我们的 TFLA mLSTM 核函数与 Flash Linear Attention 库中的著名核函数进行了运行时和内存消耗的比较。我们使用了与之前基准测试类似的设置,但在更大的头维度(值维度 512,查询和键维度 256)下进行了比较,因为 Beck 等人【索引[5],xLSTM: Extended long short-term memory,2024,NeurIPS】和 Yang 等人【索引[81],Gated linear attention transformers with hardware-efficient training,2024,ICML】都报告了在较大头维度下语言建模性能更好。除了 GLA (chunk) 和 Simple GLA (chunk) 外,我们还与 GLA (fused) 进行了比较,这是 Gated Linear Attention (GLA) 的非物化版本。有趣的是,我们发现在训练中,非物化版本既不比我们的 TFLA mLSTM 核函数快,也不更节省内存。虽然 Simple GLA 略快(在 3 毫秒或 15% 以内),但它使用的 GPU 内存几乎是我们的 TFLA mLSTM 核函数的两倍。

图15. 与 FLA 核函数的运行时和内存比较。左:运行时(前向后向传播)。右:GPU 内存使用。我们使用 8 个头,值维度为 512,查询和键维度为 256。Simple GLA(我们实验中最快的 FLA 核函数)比我们的 TFLA mLSTMsig 核函数略快,但使用的 GPU 内存几乎是其两倍。
图15. 与 FLA 核函数的运行时和内存比较。左:运行时(前向后向传播)。右:GPU 内存使用。我们使用 8 个头,值维度为 512,查询和键维度为 256。Simple GLA(我们实验中最快的 FLA 核函数)比我们的 TFLA mLSTMsig 核函数略快,但使用的 GPU 内存几乎是其两倍。

与 LightningAttention2 核函数的运行时和内存比较(图 16)。与前一个实验类似,我们比较了我们的 TFLA mLSTM 核函数与 LightningAttention2【索引[41],Lightning attention-2: A free lunch for handling unlimited sequence lengths in large language models,2024,arXiv】的运行时和内存消耗。LightningAttention2 也是使用分块并行公式的,但与 Simple GLA 和 TFLA 不同,它不是将计算分为循环和并行部分,而是完全循环地处理所有块。我们发现 LightningAttention2 只支持查询、键和值具有相同的头维度,最大为 128。我们发现 LightningAttention2 的内存使用量最低,为 6.2 GB。然而,这种 GPU 内存效率的代价是运行时间比我们的 TFLA mLSTMsig 核函数长 3 倍以上,后者使用约 7.3 GB 的 GPU 内存。这突显了基于分块并行公式的线性 RNN 核函数在 GPU 内存使用和运行时间之间存在权衡。

图16. 与 LightningAttention2 的运行时和内存比较。左:运行时(前向后向)。右:GPU 内存。我们使用 32 和 64 个头,查询、键和值的头维度分别为 128 和 64。LightningAttention 的内存使用量最低,但在较大的头维度 128 下,其速度比我们的 TFLA mLSTM 慢 3 倍以上。
图16. 与 LightningAttention2 的运行时和内存比较。左:运行时(前向后向)。右:GPU 内存。我们使用 32 和 64 个头,查询、键和值的头维度分别为 128 和 64。LightningAttention 的内存使用量最低,但在较大的头维度 128 下,其速度比我们的 TFLA mLSTM 慢 3 倍以上。

不同头维度的运行时基准测试(图 17)。已有多项工作报告称,对于线性 RNN,较大的头维度(与常见的自注意力头维度相比)可以提高语言建模性能。因此,线性 RNN 核函数在广泛的头维度范围内保持快速和高效是理想的。我们改变头维度从 32 到 1024,并相应调整头的数量,以保持总嵌入维度为 4096,并测量了在序列长度 8192 和批次大小 4 的输入下的运行时间。我们观察到,对于较小的头维度(即 32 和 64),我们的 mLSTM limit chunk 核函数在训练中与我们的 TFLA mLSTM 核函数一样快。总的来说,图 17 中的结果证实了我们的 TFLA 核函数在广泛的头维度范围内都实现了快速的运行时间。

图17. FLA 和 TFLA mLSTM 核函数的头维度基准测试。左:前向传播。右:前向和后向传播。我们测量了序列长度 8192 和批次大小 4 在不同头维度下的运行时间。我们对查询、键和值使用相同的头维度。我们的 TFLA mLSTM 核函数即使在非常大的头维度下也显示出快速的运行时间。
图17. FLA 和 TFLA mLSTM 核函数的头维度基准测试。左:前向传播。右:前向和后向传播。我们测量了序列长度 8192 和批次大小 4 在不同头维度下的运行时间。我们对查询、键和值使用相同的头维度。我们的 TFLA mLSTM 核函数即使在非常大的头维度下也显示出快速的运行时间。

mLSTM 的 FLOP 和内存操作计数

计数方法。我们计算了 mLSTM 在一次前向传播(批次大小为 1)中带指数和 sigmoid 输入门的浮点运算次数(FLOPs)和内存操作(以字节为单位的加载和存储)。我们使用因子 2 来描述乘法累加的 FLOPs 成本。我们不计算核函数内部发生的重计算所产生的 FLOPs。同样,我们不计算重计算所需的额外内存加载操作。我们使用表 7 中总结的符号。

表7. FLOP 和内存操作计数的符号表示。
表7. FLOP 和内存操作计数的符号表示。

mLSTM 的精确 FLOPs 和内存操作计数

分块并行公式(表 8, 9)。我们计算了分块并行 mLSTM 公式(见 2.2 节和附录 B.2)的 FLOPs(表 8)和内存操作(表 9)。所有计数都是针对单个头和单个块的。要获得完整序列的总计数,我们将这些计数乘以头数 Nhead 和块数 Nchunk = T / L。

表8. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头和每个块的 FLOP 计数。
表8. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头和每个块的 FLOP 计数。

表9. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头和每个块的内存操作计数。
表9. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头和每个块的内存操作计数。

完全并行公式(表 10, 11)。我们计算了完全并行 mLSTM 公式(见附录 B.1)的 FLOPs(表 10)和内存操作(表 11)。所有计数都是针对单个头和长度为 T 的完整序列的。要获得总计数,需乘以头数 Nhead。

表10. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头的完整序列的 FLOP 计数。
表10. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头的完整序列的 FLOP 计数。

表11. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头的完整序列的内存操作计数。
表11. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头的完整序列的内存操作计数。

循环公式(表 12, 13)。我们计算了循环 mLSTM 公式(见 2.1 节和 4.1 节)的 FLOPs(表 12)和内存操作(表 13)。对于内存操作,我们假设在每个时间步后物化状态,这是文本生成时的设置。所有计数都是针对单个头和单个时间步的。要获得完整序列的总计数,需乘以序列长度 T 和头数 Nhead。

表12. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头的单个时间步的 FLOP 计数。
表12. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的 FLOP 计数。所有项表示每个头的单个时间步的 FLOP 计数。

表13. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头的单个时间步的内存操作计数。我们假设在每个时间步都物化状态。
表13. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的内存操作计数。所有项表示每个头的单个时间步的内存操作计数。我们假设在每个时间步都物化状态。

mLSTM 的简化 FLOP 计数摘要

简化方法。在本节中,我们通过将所有因子 $F_{OP}$ 设置为 1 来简化 mLSTM 的 FLOP 计数。我们未指定因果因子 $F_{causal}$,但通常将其设置为 0.5 或稍大。由于注意力 logit 矩阵(即二次矩阵 S)总是以块的形式计算,通常会计算上三角矩阵的某些部分然后将其屏蔽掉。为了考虑这一点,可以将因子 $F_{causal}$ 设置为大于 0.5 的值(例如 0.66)。在图 18 中,我们展示了因果因子对 mLSTMsig 整体 FLOP 计数的影响。表 14、15 和 16 总结了分块并行、完全并行和循环 mLSTM 公式的简化 FLOP 计数。

表14. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头和每个块的 FLOP 计数。我们将所有因子 F_OP 设置为 1。
表14. 分块并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头和每个块的 FLOP 计数。我们将所有因子 F_OP 设置为 1。

表15. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头的完整序列的 FLOP 计数。我们将所有因子 F_OP 设置为 1。
表15. 完全并行 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头的完整序列的 FLOP 计数。我们将所有因子 F_OP 设置为 1。

表16. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头的单个时间步的 FLOP 计数。我们将所有因子 F_OP 设置为 1。
表16. 循环 mLSTM 公式的 mLSTMexp 和 mLSTMsig 的简化 FLOP 计数。所有项表示每个头的单个时间步的 FLOP 计数。我们将所有因子 F_OP 设置为 1。

总 FLOP 计数摘要(表 17)。在表 17 中,我们总结了 mLSTM 的分块并行、完全并行和循环公式的总 FLOP 计数,包括指数和 sigmoid 输入门,针对单个头和长度为 T 的完整序列(即批次大小 Nbatch = 1)。

表17. mLSTMexp 和 mLSTMsig 的分块并行、完全并行和循环公式的总简化 FLOP 计数。所有项表示每个头的完整序列 T 的 FLOP 计数。我们将所有因子 F_OP 设置为 1。
表17. mLSTMexp 和 mLSTMsig 的分块并行、完全并行和循环公式的总简化 FLOP 计数。所有项表示每个头的完整序列 T 的 FLOP 计数。我们将所有因子 F_OP 设置为 1。

mLSTMexp 和 mLSTMsig 之间的 FLOP 比较

FLOP 计数差异。带 sigmoid 输入门的 mLSTM 没有归一化器和最大值状态。因此,与带指数输入门的 mLSTM 相比,它的 FLOPs 和内存操作更少。我们计算了 mLSTMexp 和 mLSTMsig 在表 17 最后一列的 FLOP 差异。我们观察到,在主导项中没有差异,并得出结论,mLSTMexp 和 mLSTMsig 之间的 FLOP 差异很小。例如,对于头维度 $d_{qk} = d_{hv} = 64$,我们发现 mLSTMexp 的 FLOPs 不到 2% 的增加,而对于 $d_{qk} = d_{hv} = 512$,mLSTMexp 的 FLOPs 仅增加约 0.2%。

mLSTMsig 更快的原因。我们发现 mLSTMexp 的 FLOPs 仅略多于 mLSTMsig。因此,mLSTMexp 和 mLSTMsig 之间的速度差异不能仅由 FLOP 计数差异来解释。然而,尽管绝对和相对 FLOP 计数差异很小,但不同的 FLOPs 是更“昂贵”的 FLOPs,即逐点操作和向量操作,这些操作比在张量核心上执行的矩阵乘法更昂贵。因此,可以预期实际的相对运行时间差异会大于相对 FLOP 差异。对于前向传播,这表明 mLSTMsig 相对于 mLSTMexp 30% 的速度提升的主要原因是 mLSTMsig 核函数中循环的高效融合(见附录 C.2)。

不同 mLSTM 公式之间的 FLOP 比较

分块并行 FLOPs 在循环和完全并行 FLOPs 之间变化(图 18)。TFLA 相对于 Flash Linear Attention 的主要优势在于可自由配置的块大小,这允许在内存消耗和运行时间之间进行有效权衡(见图 6),并且如图 18 所示,还可以在总 FLOPs 数量之间进行权衡。我们在图 18 中绘制了 mLSTMsig 的循环、完全并行和分块并行公式在不同块大小 L 下的 FLOP 计数。我们观察到,当块大小 L 从 1(循环)变化到 T(完全并行)时,分块并行的 FLOP 计数在循环和完全并行的 FLOP 计数之间过渡。

FLOP 最优的块大小

FLOP 最优块大小的计算。mLSTM 分块并行公式的 FLOPs 数量取决于块大小 L。在本节中,我们计算 FLOP 最优的块大小,即最小化 FLOP 计数的块大小 $L_{opt,FLOP}$。我们使用表 17 中 mLSTMsig 的简化 FLOP 计数进行分析。为了计算 $L_{opt,FLOP}$,我们设置 FLOP 计数对 L 的导数为零并求解 L。这得到:

$$L_{\text{opt,FLOP}} = \sqrt{\frac{2d_{hv}^2 p_{qk} + 5}{2F_{\text{causal}} (d_{hv}(1 + p_{qk}) + 3) + 1}}$$

FLOP 最优的块大小取决于头维度 $d_{hv}$ 和投影因子 $p_{qk}$,并与 $d_{hv}$ 的平方根成正比(即 $L_{opt,FLOP} \propto O(\sqrt{d_{hv}})$)。

FLOP 最优块大小随头维度增长(图 19)。为了最小化计算的 FLOPs,$L_{opt,FLOP}$ 表明我们应该使用相当小的块大小(例如,对于 $d_{hv}=512$,L=16)。然而,块大小 L 不仅影响 FLOPs,还影响内存 I/O。因此,为了找到最小化运行时的最优块大小,我们还需要考虑内存 I/O,这将在下一节中进行(见附录 G)。

图18. mLSTMsig 的循环、完全并行和分块并行公式的 FLOP 计数。左:dqk=256, dhv=512。中:dqk=128, dhv=256 右:dqk=64, dhv=128。我们为一个头和一个长度 T=8192 的序列计算了不同头维度下的 FLOP 数量,并改变了块大小 L。通过改变块大小 L,分块并行公式的 FLOP 计数在循环和完全并行 FLOP 计数之间过渡。较小的头维度减少了总 FLOP 计数。
图18. mLSTMsig 的循环、完全并行和分块并行公式的 FLOP 计数。左:dqk=256, dhv=512。中:dqk=128, dhv=256 右:dqk=64, dhv=128。我们为一个头和一个长度 T=8192 的序列计算了不同头维度下的 FLOP 数量,并改变了块大小 L。通过改变块大小 L,分块并行公式的 FLOP 计数在循环和完全并行 FLOP 计数之间过渡。较小的头维度减少了总 FLOP 计数。
图19. mLSTMsig 的 FLOP 最优块大小 L_opt,FLOP。左:pqk=0.5。右:pqk=1.0。我们绘制了 FLOP 最优块大小随头维度 dhv 的变化,针对不同的投影因子 pqk,该因子决定了查询-键头维度 dqk = pqk*dhv。FLOP 最优块大小与 dhv 的平方根成正比(即 O(sqrt(dhv))),但在典型头维度下仍然很小。
图19. mLSTMsig 的 FLOP 最优块大小 L_opt,FLOP。左:pqk=0.5。右:pqk=1.0。我们绘制了 FLOP 最优块大小随头维度 dhv 的变化,针对不同的投影因子 pqk,该因子决定了查询-键头维度 dqk = pqk*dhv。FLOP 最优块大小与 dhv 的平方根成正比(即 O(sqrt(dhv))),但在典型头维度下仍然很小。

TFLA mLSTM 核函数的理论运行时分析

分析背景。在本节中,我们的目标是理论上理解我们的 TFLA 核函数算法(以 mLSTMsig 为例)属于哪个区域(计算密集型或内存密集型),以及在给定硬件物理约束下的最优块大小。我们将看到,考虑到硬件约束,最优块大小将远大于上一节中找到的 FLOP 最优块大小。

理论运行时

运行时模型。核函数的理论运行时包括计算 FLOPs 的时间 $\tau_{FLOPs}$ 和从 GPU 内存加载/存储输入、输出和中间变量的时间 $\tau_{Bytes}$。给定算法的浮点运算次数 $FLOPs_{algo}$ 和内存操作字节数 $Bytes_{algo}$,以及加速器速度 $\alpha_{FLOPs}$ (FLOPS/s) 和内存带宽 $\beta_{Bytes}$,我们可以计算运行时(秒):

$$\tau_{\text {FLOPs,algo }}=\frac{\text { FLOPs }_{\text {algo }}}{\alpha_{\text {FLOPs }}} \quad \text { and } \quad \tau_{\text {Bytes,algo }}=\frac{\text { Bytes }_{\text {algo }}}{\beta_{\text {Bytes }}}$$

如果计算和内存操作没有重叠,总运行时是两者之和:

$$\tau_{\text{algo,upper}} = \tau_{\text{FLOPs,algo}} + \tau_{\text{Bytes,algo}} \cdot$$

如果可以完全重叠,总运行时是两者的最大值:

$$\tau_{\text {algo,lower }}=\max \left(\tau_{\text {FLOPs,algo }}, \tau_{\text {Bytes,algo }}\right) .$$

表18. 本分析中使用的 NVIDIA GPU 硬件加速器规格。数值为非稀疏值。如果只有稀疏值已知,我们除以 2。
表18. 本分析中使用的 NVIDIA GPU 硬件加速器规格。数值为非稀疏值。如果只有稀疏值已知,我们除以 2。

TFLA mLSTMsig 前向传播的理论运行时。我们假设内存操作不与计算重叠,因为 (1) 在我们当前的 TFLA 实现中,我们首先在循环核函数中物化所有状态,然后启动并行核函数(见图 1),以及 (2) 我们(目前)未使用 NVIDIA GPU 的高级硬件特性,如异步内存加载。因此,总的理论运行时是 FLOP 和内存操作运行时的和,即:

$$\tau_{\text {mLSTMsig }}=\frac{\text { FLOPs }_{\text {mLSTMsig }}}{\alpha_{\text {FLOPs }}}+\frac{\text { Bytes }_{\text {mLSTMsig }}}{\beta_{\text {Bytes }}} .$$

通过代入上述表达式并乘以头数 $N_{head}$ 和批次大小 $N_{batch}$,我们得到 mLSTMsig 前向传播的总运行时:

$$\begin{aligned} \begin{aligned} \tau_{\mathrm{mLSTMsig}} & =N_{\text {batch }} \cdot N_{\text {head }} \cdot \frac{T}{L} \cdot\left(\frac{6 L d_{h v}\left(1+p_{q k}\right)+8 L+2 \text { bytes }_{C m n} d_{h v}^{2} p_{q k}}{\beta_{\text {Bytes }}}\right. \\ & \left.+\frac{L^{2} F_{\text {causal }}\left(2 d_{h v} p_{q k}+2 d_{h v}+6\right)+L^{2}+L\left(4 d_{h v}^{2} p_{q k}+2 d_{h v} p_{q k}+d_{h v}+11\right)+2 d_{h v}^{2} p_{q k}+5}{\alpha_{\text {FLOPs }}}\right), \end{aligned} \end{aligned}$$

理论运行时与块大小的关系(图 20)。我们观察到,更新的 GPU(例如 H100 或 B200)速度更快,内存带宽更高,导致运行时更短。此外,对于更新的 GPU,运行时对块大小 L 的敏感性降低,曲线变得更平坦。此外,存在一个最小化运行时的最优块大小,它由硬件的物理约束决定。这个最优块大小在更新的 GPU(例如 B200)上会增加。

测量运行时与理论运行时的差异(图 20)。在图 20 中,我们还将 NVIDIA H100 上 mLSTMsig 核函数的测量运行时与理论运行时进行了比较。虽然测量运行时高于理论运行时,但运行时随块大小变化的定性趋势是相似的。差异的原因有几点:首先,我们的运行时模型存在近似误差。其次,我们的核函数尚未优化到达到峰值性能。最后,硬件规格指定的是矩阵乘法的峰值性能,而我们的计算包含较慢的逐点和向量操作。

图20. TFLA mLSTMsig 前向传播在不同加速器上随块大小 L 变化的理论运行时。我们绘制了 7B 大小的 mLSTMsig(即 dhv=512, pqk=0.5, Nhead=8)在 Nbatch=8, T=8192 和 bytesCmn=4 时在 NVIDIA A100, H100 和 B200 GPU 上的理论运行时(根据表 18)。我们还绘制了 mLSTMsig 核函数在 NVIDIA H100 上的测量运行时。更新的 GPU(如 H100, B200)速度更快,内存带宽更高,导致理论运行时更低。
图20. TFLA mLSTMsig 前向传播在不同加速器上随块大小 L 变化的理论运行时。我们绘制了 7B 大小的 mLSTMsig(即 dhv=512, pqk=0.5, Nhead=8)在 Nbatch=8, T=8192 和 bytesCmn=4 时在 NVIDIA A100, H100 和 B200 GPU 上的理论运行时(根据表 18)。我们还绘制了 mLSTMsig 核函数在 NVIDIA H100 上的测量运行时。更新的 GPU(如 H100, B200)速度更快,内存带宽更高,导致理论运行时更低。

算术强度

算术强度定义。算术强度(或操作强度)直接关联了 FLOPs 数量与 GPU 内存流量的字节数【索引[77],Roofline: an insightful visual performance model for multicore architectures,2009,Commun. ACM】。我们可以计算加速器的算术强度 $I_{acc}$ 或算法的算术强度 $I_{algo}$:

$$\mathcal{I}_{\mathrm{acc}}=\frac{\alpha_{\text {FLOPs }}}{\beta_{\text {Bytes }}} \text { and } \mathcal{I}_{\text {algo }}=\frac{\mathrm{FLOPs}_{\text {algo }}}{\text { Bytes }_{\text {algo }}} .$$

如果 $I_{algo} > I_{acc}$,核函数可能是计算密集型;如果 $I_{algo} < I_{acc}$,则可能是内存密集型。

TFLA mLSTMsig 前向传播的算术强度。我们计算 TFLA mLSTMsig 前向传播的算术强度 $I_{mLSTMsig}$,通过计算总 FLOP 计数 $FLOPs_{mLSTMsig}$ 和内存操作字节数 $Bytes_{mLSTMsig}$ 的比值得到:

$$\mathcal{I}_{\mathrm{mLSTMsig}}=\frac{L^2 F_{\text {causal }}\left(2 d_{h v} p_{q k}+2 d_{h v}+6\right)+L^2+L\left(4 d_{h v}^2 p_{q k}+2 d_{h v} p_{q k}+d_{h v}+11\right)+2 d_{h v}^2 p_{q k}+5}{6 L d_{h v}\left(1+p_{q k}\right)+8 L+2 \text {bytes}_{C m n} d_{h v}^2 p_{q k}}$$

算术强度与块大小的关系(图 21)。我们观察到,算术强度随块大小的增加而增加,并且核函数的算术强度曲线在较大的块大小处与加速器的算术强度相交。这意味着,我们可以通过增加块大小参数 L 从内存密集型区域转移到计算密集型区域。

图21. TFLA mLSTMsig 前向传播随块大小 L 变化的算术强度。左:状态为 float32,即 bytesCmn=4。右:状态为 bfloat16,即 bytesCmn=2。算术强度随块大小的增加而增加。
图21. TFLA mLSTMsig 前向传播随块大小 L 变化的算术强度。左:状态为 float32,即 bytesCmn=4。右:状态为 bfloat16,即 bytesCmn=2。算术强度随块大小的增加而增加。

Roofline 分析

Roofline 模型。Roofline 模型是一个性能模型,它结合了算术强度和加速器规格来确定核函数峰值性能的上限。对于每个硬件加速器,我们绘制其 roofline,即在给定算术强度下加速器的最大性能:

$$\text{Roofline}(\alpha_{\text{FLOPs}}, \beta_{\text{Bytes}}) = \min (\beta_{\text{Bytes}} \cdot \mathcal{I}_{\text{algo}}, \alpha_{\text{FLOPs}})$$

TFLA mLSTMsig 前向传播的 Roofline 模型(图 22)。我们对 TFLA mLSTMsig 前向核函数进行了 roofline 分析。我们观察到,较小的块大小是内存密集型的,而较大的块大小是计算密集型的(与图 21 类似)。此外,我们发现我们在 NVIDIA H100 上基准测试的 TFLA Triton 核函数(用于 mLSTMsig)仍远未达到可实现的峰值性能(与 H100 roofline 的交点),这凸显了进一步优化的潜力。

图22. TFLA mLSTMsig 前向传播随块大小 L 变化的 Roofline 模型。我们测量了 dhv=512, pqk=0.5, Nhead=8, 批次大小为 8,序列长度为 8192 的 mLSTMsig 的性能,并将其与 NVIDIA V100, A100 和 H100 GPU 的 roofline 进行比较(根据表 18)。较小的块大小是内存密集型的,而较大的块大小是计算密集型的。我们的 TFLA Triton 核函数(用于 mLSTMsig)仍远未达到可实现的峰值性能(与 H100 roofline 的交点)。
图22. TFLA mLSTMsig 前向传播随块大小 L 变化的 Roofline 模型。我们测量了 dhv=512, pqk=0.5, Nhead=8, 批次大小为 8,序列长度为 8192 的 mLSTMsig 的性能,并将其与 NVIDIA V100, A100 和 H100 GPU 的 roofline 进行比较(根据表 18)。较小的块大小是内存密集型的,而较大的块大小是计算密集型的。我们的 TFLA Triton 核函数(用于 mLSTMsig)仍远未达到可实现的峰值性能(与 H100 roofline 的交点)。

FLOPs/s 不是 TFLA 的最优性能指标(图 23)。在图 22 中,我们观察到并且 roofline 模型也表明,我们可以通过增加块大小 L 来提高 TFLA 核函数的性能。然而,虽然这可能会提高 FLOPs/s 的性能,但可能不会减少总运行时间,因为随着块大小 L 的增加,FLOPs 的数量也会因二次项的增加而增加(见图 18)。我们在图 23 中证实了这一点,其中我们绘制了 FLOPs/s 和总 FLOPs 随块大小 L 的变化,并与实际运行时间进行了比较。我们观察到,FLOPs/s 随块大小 L 的增加而持续增加,但总 FLOPs 也在增加。因此,对于我们的 TFLA 核函数,我们应该使用实际运行时间作为最终的性能指标,并以此确定运行时最优的块大小。

图23. TFLA mLSTMsig 前向传播随块大小 L 变化的 FLOPs/s 和 FLOPs。我们比较了不同块大小下的 FLOPs/s、总 FLOPs 与实际运行时间。我们测量了 dhv=512, pqk=0.5, Nhead=8, 批次大小为 8,序列长度为 8192 的 mLSTMsig 的性能。FLOPs/s 随块大小 L 的增加而增加,但总 FLOPs 也在增加。实际运行时间不一定通过增加块大小 L 来最小化。
图23. TFLA mLSTMsig 前向传播随块大小 L 变化的 FLOPs/s 和 FLOPs。我们比较了不同块大小下的 FLOPs/s、总 FLOPs 与实际运行时间。我们测量了 dhv=512, pqk=0.5, Nhead=8, 批次大小为 8,序列长度为 8192 的 mLSTMsig 的性能。FLOPs/s 随块大小 L 的增加而增加,但总 FLOPs 也在增加。实际运行时间不一定通过增加块大小 L 来最小化。

运行时最优的块大小

运行时最优块大小的计算。为了计算 TFLA mLSTMsig 前向传播的运行时最优块大小 $L_{opt,Runtime}$,我们使用公式 (118) 的理论运行时 $\tau_{mLSTMsig}$。我们首先对理论运行时关于块大小 L 求导,并令导数为零以找到最小运行时:

$$\frac{\partial \tau_{\text {mLSTMsig }}}{\partial L}=0.$$

然后我们求解该方程得到运行时最优的块大小 $L_{opt,Runtime}$:

$$L_{\text {opt,Runtime }}=\sqrt{\frac{2 d_{h v}^{2} p_{q k}+5+2 \mathcal{I}_{\text {acc }} d_{h v}^{2} p_{q k} \text { bytes }_{C m n}}{2 F_{\text {causal }}\left(d_{h v}\left(1+p_{q k}\right)+3\right)+1}},$$

其中 $I_{acc} = \alpha_{FLOPs}/\beta_{Bytes}$ 是来自公式 (119) 的加速器算术强度。与 FLOP 最优块大小 $L_{opt,FLOP}$(见公式 (113))相比,运行时最优块大小 $L_{opt,Runtime}$ 还取决于我们硬件的算术强度 $I_{acc}$ 和我们状态的精度 $bytes_{Cmn}$。

运行时最优块大小依赖于模型架构(图 24)。运行时最优块大小与头维度 $d_{hv}$ 的平方根成正比(即 $O(\sqrt{d_{hv}})$),并且远大于 FLOP 最优块大小。我们在 NVIDIA H100 上测得的 mLSTMsig 的运行时最优块大小约为 256,这小于理论上的运行时最优块大小。这种差异是由于我们理论运行时模型中的近似以及我们的核函数尚未优化到在 NVIDIA H100 GPU 上达到峰值性能的事实。

图24. mLSTMsig 的运行时最优块大小 L_opt,Runtime 随头维度 dhv 的变化。左:内存单元状态为 float32。右:内存单元状态为 bfloat16。我们绘制了运行时最优块大小随头维度 dhv 的变化,其中 dqk = 0.5dhv (pqk=0.5)。与 FLOP 最优块大小(图 19)类似,运行时最优块大小与 dhv 的平方根成正比(即 O(sqrt(dhv))),但远大于 FLOP 最优块大小。
图24. mLSTMsig 的运行时最优块大小 L_opt,Runtime 随头维度 dhv 的变化。左:内存单元状态为 float32。右:内存单元状态为 bfloat16。我们绘制了运行时最优块大小随头维度 dhv 的变化,其中 dqk = 0.5dhv (pqk=0.5)。与 FLOP 最优块大小(图 19)类似,运行时最优块大小与 dhv 的平方根成正比(即 O(sqrt(dhv))),但远大于 FLOP 最优块大小。

运行时最优块大小依赖于硬件加速器(图 25)。运行时最优块大小与加速器强度 $I_{acc}$ 的平方根成正比(即 $O(\sqrt{I_{acc}})$)。更新的 GPU(如 H100)具有更高的加速器强度,这导致更大的运行时最优块大小。如果加速器强度增加的趋势持续下去,那么支持任意大块大小的 TFLA 将变得越来越重要。

图25. mLSTMsig 的运行时最优块大小 L_opt,Runtime 随硬件加速器强度的变化。我们绘制了不同头维度 dhv 下运行时最优块大小随硬件加速器强度的变化。我们突出了常见 GPU(如 V100, A100, H100)的 NVIDIA GPU 加速器强度。运行时最优块大小随加速器强度的增加而增加(即 L_opt,Runtime ∝ sqrt(O(I_acc)))。
图25. mLSTMsig 的运行时最优块大小 L_opt,Runtime 随硬件加速器强度的变化。我们绘制了不同头维度 dhv 下运行时最优块大小随硬件加速器强度的变化。我们突出了常见 GPU(如 V100, A100, H100)的 NVIDIA GPU 加速器强度。运行时最优块大小随加速器强度的增加而增加(即 L_opt,Runtime ∝ sqrt(O(I_acc)))。

FLOP 和理论运行时分析总结

核心优势。可配置的块大小 L 是 Tiled Flash Linear Attention 的核心优势。我们总结关于块大小的论述如下:
1. 块大小 L 在运行时和 GPU 内存使用之间进行权衡。【图 6】
2. L 决定了总计算量(FLOPs):L = 1 匹配循环公式,而 L = T 匹配并行公式。【图 18】
3. 存在一个最优块大小 L ∈ [1, T],可以最小化总 FLOP 数量。【公式 (113),图 18,图 19】
4. 增加 L 会提高 TFLA 核函数的算术强度。【公式 (119),图 21】
5. 块大小决定了核函数在给定硬件上是内存密集型还是计算密集型。【图 22,图 21】
6. 仅凭 FLOPs/s 指标可能具有误导性;应根据总运行时来选择最优块大小。【图 23,图 20】
7. 运行时最优的块大小与头维度的平方根以及加速器的计算强度成正比。【图 24,图 25】
8. 更新的硬件代系需要更大的块大小才能接近峰值性能。【图 25,图 20】