Gated Linear Attention Transformers with Hardware-Efficient Training

作者/机构: Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim

A1 主要贡献

本文致力于解决现有线性注意力模型相较于标准Softmax注意力在性能和实际运行速度上的不足。

核心问题:
1. 性能差距: 尽管线性注意力通过其循环(RNN)形式实现了线性时间复杂度的推理,但在语言建模等任务上,其性能通常显著低于标准的Softmax注意力。
2. 效率问题: 现有的线性注意力实现(包括循环、并行和分块并行形式)没有充分考虑硬件的I/O特性,导致在实际运行中,特别是在中等序列长度上,其速度慢于高度优化的Softmax注意力实现(如FLASHATTENTION)。
3. 门控机制:现有的线性注意力变体虽然引入了衰减因子以提升性能,但这些因子通常是全局且数据无关的,而数据依赖的门控机制在传统RNN中已被证明至关重要。

研究目标与创新点:
1. 提出硬件高效的线性注意力算法 (FLASHLINEARATTENTION)
* 本文设计了一种I/O感知的硬件高效算法,用于训练线性注意力模型。该算法通过在内存移动和并行性之间进行权衡,实现了显著的速度提升。
* 该算法的实现(名为FLASHLINEARATTENTION)作为一个独立的层,即使在短序列(如1K)上,其速度也超过了FLASHATTENTION-2。

  1. 提出门控线性注意力 (Gated Linear Attention, GLA)

    • 将FLASHLINEARATTENTION算法推广到一种更具表现力的、带有数据依赖门控机制的线性注意力变体。
    • 这种门控机制允许模型根据输入动态地“遗忘”或保留信息,解决了标准线性注意力因缺乏衰减项而难以处理长上下文任务的潜在不稳定性。
  2. 构建并验证GLA Transformer模型:

    • 将GLA层作为标准注意力层的替代品,构建了GLA Transformer模型。
    • 在中等规模的语言建模实验(在15B/100B tokens上训练340M/1.3B参数模型)中,GLA Transformer的性能与强大的LLaMA架构以及最新的线性时间推理模型(如RetNet和Mamba)相当。
    • GLA Transformer在长度泛化方面表现尤为出色,一个在2K长度上训练的模型能够泛化到超过20K的序列,而困惑度没有显著下降。
    • 在训练速度方面,GLA Transformer的吞吐量高于同等规模的Mamba模型。

A3 背景知识

2.1 并行与循环形式

标准Softmax注意力。自回归Transformer采用Softmax注意力机制,它接收一个输入序列 $X \in R^{L \times d}$($L$为长度,$d$为隐藏维度),并通过以下公式计算输出 $O \in R^{L \times d}$:

公式
公式

其中 $W_Q, W_K, W_V \in R^{d \times d}$ 是可学习的矩阵,$M \in \{-\infty, 1\}^{L \times L}$ 是一个掩码,防止模型关注未来的词元。这种并行形式可以在给定完整输入 $X$ 的情况下并行计算 $O$,从而实现高效训练。然而,在推理过程中,Transformer必须使用循环形式,即维护一个不断增长的键(key)和值(value)集合(KV缓存),并对它们执行注意力操作。
公式
公式

线性注意力。线性注意力机制将 $exp(q_t k_i^T)$ 替换为一个核函数 $k(x,y)$,该函数有关联的特征图 $\phi$(即 $k(x,y) = \langle\phi(x), \phi(y)\rangle$)。这简化了输出 $o_t$ 的计算,因为求和操作可以与矩阵乘法交换顺序:

公式
公式

通过定义 $S_t = \sum_{i=1}^{t} \phi(k_i)^T v_i$ 和 $z_t = \sum_{i=1}^{t} \phi(k_i)^T$,其中 $S_t \in R^{d \times d}, z_t \in R^{d \times 1}$,上述计算可以重写为一个RNN:
公式
公式

尽管已有多种核函数被探索,但最近的研究发现,使用线性核(即设置 $\phi$ 为恒等函数)且不带归一化项在实践中效果很好【索引83,Retentive network: A successor to transformer for large language models,2023a】。这产生了一个(未归一化的)线性注意力层,其更新方程如下:
公式
公式

从公式1可以清楚地看到,线性注意力层本质上是一个线性循环层,其拥有矩阵值的隐藏状态 $S_t$,该状态通过外积 $k_t^T v_t$ 进行更新。因果线性注意力的并行形式的复杂度仍然是 $L$ 的平方,其形式为 $O = (QK^T) \odot M V$,其中 $M$ 是一个下三角掩码矩阵。由于掩码 $M$ 的存在,无法利用矩阵乘法的结合律将并行形式的复杂度从二次降为线性。

2.2 分块并行形式

平衡并行与循环。线性注意力的分块并行形式在并行形式和循环形式之间取得了平衡【索引38,Transformer quality in linear time,2022;索引83,Retentive network: A successor to transformer for large language models,2023a】,并允许进行次二次方、部分并行的训练。

分块计算过程。具体来说,假设输入 $X$ 被分成多个不重叠的块,每个块的长度为 $C$。令 $S[i] \in R^{d \times d}$ 是处理完 $i$ 个块后的块级隐藏状态,即 $S[i] := S_{iC}$。令 $Q[i], K[i], V[i]$ 分别是第 $i$ 个块对应的查询、键和值向量。我们有以下块间的递推关系:

公式
公式

这里 $S[0]$ 可以初始化为零或来自前一个段落的隐藏状态。一个块内所有RNN输入的总和(即 $K[i]^T V[i]$)可以并行计算,复杂度为 $O(C^2d)$。

块内并行计算。输出的块内并行计算由以下公式给出:

公式
公式

其中,“块内”分量 $O_{intra}[i+1]$ 的计算形式与标准并行形式完全相同,耗时 $O(C^2d + Cd^2)$。“块间”分量 $O_{inter}[i+1]$ 计算了来自前一个块的隐藏状态的贡献,耗时 $O(Cd^2)$。因此,总的训练复杂度为 $O(\frac{L}{C}(C^2d + Cd^2)) = O(LCd + Ld^2)$,当 $L > d$ 时,这小于 $O(L^2d)$。值得注意的是,设置 $C = L$ 恢复了并行形式,而 $C = 1$ 则恢复了循环形式。

A2 方法细节

3 硬件高效的线性注意力

本节描述了FLASHLINEARATTENTION,这是一种I/O感知、硬件高效的线性注意力算法,其精神类似于FLASHATTENTION【索引21,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022b;索引18,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】。

3.1 硬件高效算法的原则

高效算法的设计考量。一个高效的算法应该了解现代硬件的计算模型、内存层次结构和专用计算单元。
- 占用率 (Occupancy):GPU拥有大量并行执行的线程,这些线程被分组为线程块,在流式多处理器(SM)上执行。为了保持高GPU占用率,必须使用足够数量的SM。在批次大小较小的大规模训练和长序列建模场景中,沿时间维度进行并行化可以实现高GPU占用率【索引18,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】。
- 专用计算单元 (Specialized compute units):现代神经网络训练硬件通常配备专用计算单元(如NVIDIA GPU上的张量核心),这些单元可以显著加速矩阵乘法。例如,在A100上,使用张量核心的半精度矩阵乘法比使用CUDA核心快大约16倍。
- 内存层次结构 (Memory hierarchy):GPU拥有一个内存层次结构,包括容量较大但速度较慢的全局GPU内存(HBM)和容量较小但速度更快的共享内存(SRAM)。优化SRAM的使用以减少HBM的I/O成本可以带来显著的速度提升。

3.2 线性注意力的硬件考量

不同形式的硬件效率分析。本节讨论与不同形式线性注意力效率相关的硬件因素。
- 循环形式 (Recurrent form):一个基础的实现会将所有时间步的二维隐藏状态存储在HBM中,导致高I/O成本【索引62,Fine-tuning pre-trained transformers into decaying fast weights,2022】。虽然可以通过在反向传播中重新计算隐藏状态来减少I/O,但循环更新中的逐元素操作无法利用张量核心,导致算术强度低。因此,尽管循环形式的总浮点运算量(FLOPs)最低,但这并未转化为实际的运行时间效率。通过并行扫描算法并行化线性递推在理论上是可行的,但需要物化每个时间步的隐藏状态,这会带来巨大的内存I/O负担,抵消了并行化的好处【索引44,Gateloop: Fully data-controlled linear recurrence for sequence modeling,2023】。
- 并行形式 (Parallel form):并行形式可以像FLASHATTENTION一样高效,采用类似的I/O优化技术【索引76,Scaling transnormer to 175 billion parameters,2023b】。然而,由于其二次方的复杂度,大量的FLOPs使得长序列训练成本高昂。
- 分块形式 (Chunkwise form):分块并行形式通过一个额外的参数 $C$ 在并行和循环形式之间进行插值,使得进行细粒度优化变得更容易。与循环形式不同,大多数操作可以通过矩阵乘法完成,从而能够使用张量核心。尽管分块训练算法在文献中已被讨论【索引38,Transformer quality in linear time,2022;索引83,Retentive network: A successor to transformer for large language models,2023a】,但大多数实现并非I/O感知,因此在中等序列长度(例如2K-4K)上比FLASHATTENTION慢。

3.3 FLASHLINEARATTENTION: 采用分块形式的硬件高效线性注意力

I/O感知的实现。我们描述了分块形式的I/O感知、硬件高效实现。我们提供了两个版本,其前向和后向传播根据块级隐藏状态 $S[n]$ 是否物化到HBM中而有所不同。算法1和图1展示了前向传播过程。核心思想是使用分块(tiling)技术逐块加载张量,并在芯片上(SRAM)重复使用张量块,以尽可能避免多次HBM I/O。例如,当 $Q[n]$ 加载到SRAM时,可以在芯片上同时计算 $Q[n]S$ 和 $(Q[n]K^T[n] \odot M)V[n]$,避免了两次加载 $Q[n]$,从而节省了HBM I/O。

算法1:FLASHLINEARATTENTION:前向传播
算法1:FLASHLINEARATTENTION:前向传播

图1:(a) FLASHLINEARATTENTION 非物化版本。此版本更节省内存。(b-c) FLASHLINEARATTENTION 物化版本。此版本支持序列级的分块并行。
图1:(a) FLASHLINEARATTENTION 非物化版本。此版本更节省内存。(b-c) FLASHLINEARATTENTION 物化版本。此版本支持序列级的分块并行。

公式
公式

两个版本的对比
- 非物化版本 (non-materialization):按顺序计算每个块的输出 $O[n]$,使用SRAM临时存储隐藏状态 $S[n]$,因此内存效率高。这个版本在批次大小、头数量和头维度上进行并行化,但缺少序列级别的并行性。当批次大小很大时,这种并行度足以实现高GPU占用率。
- 物化版本 (materialization):首先执行块间递推(公式2)并将所有 $S[n]$ 存储在HBM中。然后,所有块的 $O[n]$ 可以并行计算。这种方法提供了更好的并行性,但内存占用增加了约10-20%。我们通过重计算来缓解这个问题:隐藏状态在前向传播后被丢弃,在反向传播时重新计算。这引入了微小的运行时开销,但显著减少了内存占用,我们默认采用此策略。

性能评估。图2显示了我们实现的运行速度和内存占用。两个版本的FLASHLINEARATTENTION都明显快于FLASHATTENTION-2【索引18,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】和纯PyTorch(即非I/O感知)实现的分块线性注意力,展示了I/O感知的优势。

图2:在单张H100 GPU上的速度比较,批次大小为32,头数量为16,头维度为64,块大小为64。x轴和y轴均为对数尺度。w/ m. 和 w/o m. 分别表示使用和不使用在HBM中物化隐藏状态的FLASHLINEARATTENTION。
图2:在单张H100 GPU上的速度比较,批次大小为32,头数量为16,头维度为64,块大小为64。x轴和y轴均为对数尺度。w/ m. 和 w/o m. 分别表示使用和不使用在HBM中物化隐藏状态的FLASHLINEARATTENTION。

4 门控线性注意力

引入门控机制的动机。标准线性注意力(公式1)缺少衰减项或遗忘门,这在RNN中被证明是至关重要的【索引35,Long short-term memory,1997;索引14,Learning phrase representations using rnn encoder-decoder for statistical machine translation,2014;索引86,The unreasonable effectiveness of the forget gate,2018】。缺少衰减项使得模型难以“遗忘”信息,并被认为是导致线性注意力在长上下文任务中不稳定的部分原因【索引11,Linear Transformers Are Faster After All,2024】。虽然最近的工作通过引入一个全局的、非数据依赖的衰减因子 $\gamma \in (0,1)$ 来提升性能【索引83,Retentive network: A successor to transformer for large language models,2023a;索引76,Scaling transnormer to 175 billion parameters,2023b】,但本文考虑使用一种数据依赖的门控机制。我们证明,尽管门控因子更具表现力,由此产生的门控线性注意力(GLA)层仍然支持硬件高效的分块形式以进行高效训练。

4.1 GLA的循环与并行形式

循环形式。GLA拥有一个随时间变化的二维遗忘门 $G_t \in (0,1)^{d_k \times d_v}$:

公式
公式

这种基于哈达玛积的循环形式非常通用,涵盖了许多近期的具有二维隐藏状态的RNN,如表1所示。
表1:近期模型在门控线性注意力框架下的表述,它们在Gt的参数化上有所不同。偏置项被省略。
表1:近期模型在门控线性注意力框架下的表述,它们在Gt的参数化上有所不同。偏置项被省略。

门控矩阵Gt的参数化。GLA设计的核心在于 $G_t$ 的参数化,需要在参数效率、状态大小和训练效率之间取得平衡。一个简单地将 $x_t$ 映射到 $G_t$ 的方法需要一个大小为 $d \cdot d_k \cdot d_v$ 的矩阵,参数效率低下。Mao (2022)【索引62,Fine-tuning pre-trained transformers into decaying fast weights,2022】提出了一种更高效的基于外积的低秩参数化 ($G_t = \alpha_t^T \beta_t$)。在Mamba中【索引28,Mamba: Linear-time sequence modeling with selective state spaces,2023】, $G_t$ 是通过将一个数据无关的可学习矩阵 $A$ 与一个数据依赖的向量 $\alpha_t$ 结合得到的,这使得矩阵可以是满秩的,但也阻止了张量核心的使用。Mamba-2【索引19,Transformers are ssms: Generalized models and efficient algorithms through structured state space duality,2024】通过一个更受限的门控机制 $G_t = \gamma_t \mathbf{1}^T \mathbf{1}$(其中 $\gamma_t$ 是一个标量)解决了这个问题,使其能够利用张量核心。

本文采用的参数化。本文在标量和完全低秩参数化之间采取了折中方案,使用 $G_t = \alpha_t^T \mathbf{1}$。这得到了以下循环形式:

公式
公式

其中 $\alpha_t$ 是通过一个低秩线性层后接sigmoid函数在 $x_t$ 上参数化的。这个形式是通用的,可以涵盖多个近期的RNN模型【索引44,Gateloop: Fully data-controlled linear recurrence for sequence modeling,2023;索引79,Hgrn2: Gated linear rnns with state expansion,2024b;索引71,Eagle and finch: Rwkv with matrixvalued states and dynamic recurrence,2024】。

并行形式。展开上述循环公式,可以得到一个并行形式:

公式
公式

令 $b_t := \prod_{j=1}^{t} \alpha_j$,我们可以重写为:
公式
公式

这可以表示为一个并行形式:
公式
公式

然而,这种形式在数值上不稳定,因为 $b_t$ 是门控值 $\alpha_j \in (0,1)^{1 \times d}$ 的累积乘积,当 $t$ 很大时可能变得非常小。为了解决这个问题,可以在对数空间中计算:
公式
公式

但与普通线性注意力不同,公式4不能通过标准矩阵乘法表示,因此无法利用张量核心上的半精度矩阵乘法。

4.2 GLA的分块并行形式

推导分块形式。我们推导了GLA的分块形式,类似于基础线性注意力的分块形式(§2.2)。块内操作在块级别上实现上述并行形式以获得 $O_{intra}$。对于块间计算,我们有:

公式
公式

直观上,$\Lambda[i+1]$ 编码了从块开始的累积衰减,用于传播前一个块的隐藏状态 $S[i]$;而 $\Gamma[i+1]$ 编码了到块结束的衰减,用于累积信息以添加到下一个隐藏状态 $S[i+1]$。
图3:用于说明GLA中分块计算的注意力风格图。块间依赖(灰色)在分块形式中不直接计算。块内依赖通过二级分块/平铺建模,其中子块间部分(橙色)由半精度矩阵乘法计算,而子块内部(粉色)在对数空间中以全精度计算。
图3:用于说明GLA中分块计算的注意力风格图。块间依赖(灰色)在分块形式中不直接计算。块内依赖通过二级分块/平铺建模,其中子块间部分(橙色)由半精度矩阵乘法计算,而子块内部(粉色)在对数空间中以全精度计算。

4.3 硬件高效的GLA

适配FLASHLINEARATTENTION。有了分块形式,我们可以将§3中提出的FLASHLINEARATTENTION算法应用于门控情况。这种适配还依赖于以下两个关键技术。
- 二级分块 (Secondary-level chunking)。由于GLA的块内计算因对数空间计算(公式4)而无法利用半精度矩阵乘法,我们采用二级分块方案,将一个块进一步划分为子块(另一层次的分块)。如图3所示,子块之间的交互通过半精度矩阵乘法计算,而子块内部的计算则采用公式4,在全精度下进行以保证稳定性。这种两级分块策略大大减少了非半精度矩阵乘法的FLOPs,从而带来了运行时间的改善。
- 内存高效的dαt计算 (Memory-efficient dαt computation)。过去的工作声称类似GLA的模型必须在HBM中物化大小为 $L \times d \times d$ 的隐藏状态来计算所有梯度 $d\alpha_t$【索引62,Fine-tuning pre-trained transformers into decaying fast weights,2022】。我们为此给出了一个关于 $d\log\alpha_t$ 的闭式解公式:

公式
公式

这个公式可以通过对公式4求导轻易得到,从而避免了大规模的物化操作。

4.4 GLA Transformer

模型架构。我们将GLA层推广到多头情况,并将其与前馈网络(FFN)交错构建成一个Transformer风格的模型。
- 多头GLA层:每个头独立计算门控线性注意力,然后将各头的输出拼接起来。在每个头的输出之后应用LayerNorm。

公式
公式

- Transformer块:模型块的结构如下,其中FFN层采用SwiGLU变体【索引85,Llama: Open and efficient foundation language models,2023】。
公式
公式

公式
公式

- 参数分配:为了保持参数效率,用于预测门控值 $\alpha_t$ 和 $r_t$ 的矩阵 $W_\alpha, W_r$ 采用了低秩参数化。
公式
公式

通过这种设计,一个GLA层的总参数量(大约 $4d^2$)与常规的Softmax注意力层相当。

A4 实验环境

数据集:
- 名称: SlimPajama【索引81,SlimPajama: A 627B token cleaned and deduplicated version of RedPajama,2023】
- 规模: 使用了其中一个100B词元(token)的子集。
- 用途: 用于语言模型的预训练。

模型架构:
- 规模: 训练了两种规模的模型:340M和1.3B参数。
- 基线模型:
- Transformer++【索引85,Llama: Open and efficient foundation language models,2023】: LLaMA架构,使用旋转位置嵌入(RoPE)、SWiGLU激活函数和RMSNorm。
- RetNet【索引83,Retentive network: A successor to transformer for large language models,2023a】: 为了公平比较,将其原始FFN替换为SwiGLU。
- Mamba【索引28,Mamba: Linear-time sequence modeling with selective state spaces,2023】: 使用其开源代码实现。
- 训练细节: 所有模型均在相同数据集上从头开始训练相同数量的词元。使用AdamW优化器,最大学习率为3e-4,余弦学习率衰减。340M模型在15B词元上训练,批次大小为0.5M;1.3B模型在100B词元上训练,批次大小为2M。

硬件配置:
- GPU: H100 GPU(根据图2和图6的描述)。

软件配置:
- 代码实现: 使用PyTorch实现,并为FLASHLINEARATTENTION开发了自定义CUDA核。
- 分词器: Mistral tokenizer【索引40,Mistral 7b,2023】。
- 评估工具: LM evaluation harness【索引26,A framework for few-shot language model evaluation,2021】。

A4 实验结果

主要结果:
- 与基线模型的比较 (表2):
- GLA vs. RetNet: 带有数据依赖门的GLA Transformer在所有任务上都优于使用数据无关衰减率的RetNet。
- GLA vs. Transformer++/Mamba: GLA Transformer和Mamba的性能与强大的Transformer++基线相当。

表2:GLA Transformer与Transformer++、RetNet和Mamba的对比结果。所有模型均在SlimPajama数据集的同一子集上使用Mistral分词器进行训练。340M/1.3B模型分别训练15B/100B词元。单个任务的性能是零样本测试。我们报告了与Gu & Dao (2023)相同的任务集上的主要结果。最后一列显示了所有使用(归一化)准确率作为指标的基准测试的平均值。
表2:GLA Transformer与Transformer++、RetNet和Mamba的对比结果。所有模型均在SlimPajama数据集的同一子集上使用Mistral分词器进行训练。340M/1.3B模型分别训练15B/100B词元。单个任务的性能是零样本测试。我们报告了与Gu & Dao (2023)相同的任务集上的主要结果。最后一列显示了所有使用(归一化)准确率作为指标的基准测试的平均值。

召回密集型任务 (Recall-intensive tasks):
- 合成MQAR任务 (图4): 在这个多查询关联召回任务中,具有矩阵值隐藏状态的模型(Mamba/RetNet/GLA)优于Hyena/RWKV。GLA的性能优于RetNet,证实了数据依赖门的益处。

图4:在合成MQAR任务上的准确率(%)。
图4:在合成MQAR任务上的准确率(%)。

- 真实召回任务 (表3): 在FDA和SWDE等信息提取任务上,虽然所有次二次方模型都显著落后于Transformer,但GLA优于其他次二次方模型。这可能归因于其比Mamba更大的循环状态和比RetNet更优的选择机制。
表3:不同模型在Arora et al. (2024)测试的三个召回密集型任务中的比较。所有任务分数越高越好。
表3:不同模型在Arora et al. (2024)测试的三个召回密集型任务中的比较。所有任务分数越高越好。

长序列训练与长度外推:
- 外推能力 (图5): 在2K长度上下文上训练后,GLA在PG19测试集上的外推能力优于Mamba和RetNet。Mamba在4K后性能下降,而GLA/RetNet能在SlimPajama测试集上泛化到18K。
- 长序列训练: 在更长的上下文(8K直接训练,或通过截断式BPTT在24K上训练)上进行预训练,可以持续提升所有三个模型的性能。对于GLA,使用截断式BPTT是一种更经济的长序列训练方法。

图5:在SlimPajama和PG19测试集上的长度外推。我们从头开始在SlimPajama上使用不同的训练长度预训练1.3B模型100B词元。*表示模型使用截断式BPTT,超过12个段,每个段长度为2K。
图5:在SlimPajama和PG19测试集上的长度外推。我们从头开始在SlimPajama上使用不同的训练长度预训练1.3B模型100B词元。*表示模型使用截断式BPTT,超过12个段,每个段长度为2K。

消融研究:
- 门控机制的重要性 (表4): 实验结果表明,虽然数据依赖的标量门控(Scalar gates)比RetNet有显著改进,但更细粒度的向量门控(GLA)是必要的,能带来更好的性能。
- 头维度的影响 (表4): 默认GLA使用4个头。增加到8个头(更小的头维度)导致困惑度明显下降;减少到1个头(更大的头维度)性能最好,但改进微小且需要更多GPU内存。因此,4个头是一个很好的权衡。

表4:在340M模型上训练7B词元的消融研究结果。我们通过最后200个训练步骤的平均困惑度来评估模型变体。
表4:在340M模型上训练7B词元的消融研究结果。我们通过最后200个训练步骤的平均困惑度来评估模型变体。

训练效率:
- 吞吐量与内存 (图6): 所有模型都具有线性空间复杂度,GPU内存占用差异很小。在训练吞吐量方面,Mamba落后于Transformer++和GLA。GLA在训练长度超过4096时显示出更大的优势。

图6:在H100上的训练吞吐量和内存占用。
图6:在H100上的训练吞吐量和内存占用。

A7 补充细节

5.4 局限性与未来工作

当前研究的局限性。由于计算资源有限,本文的实验未能进行更大规模的验证。因此,GLA如何扩展到更大的模型和数据集上尚不清楚。

未来工作的展望。预计随着模型规模的扩大,GLA的训练效率优势会比Mamba更加明显,因为它能更好地利用张量核心,并且与张量并行兼容。未来的一个有趣方向是将GLA应用于其他模态,特别是那些具有长程依赖性的数据,以利用线性注意力的效率优势。

6 相关工作

线性RNNs/SSMs/Transformers。传统的RNN由于隐藏状态之间的非线性依赖和昂贵的顺序更新而难以扩展。线性RNNs、状态空间模型(SSMs)和线性Transformer消除了非线性依赖,使得沿时间维度的训练可以并行化,成为Transformer架构有竞争力的次二次方替代方案。

数据依赖的衰减率。数据依赖的衰减率一直被认为是RNN的关键。与以往依赖于前一隐藏状态和当前输入的遗忘门不同,Martin & Cundy (2018)【索引63,Parallelizing linear recurrent neural nets over sequence length,2018】提出遗忘门应仅依赖于当前输入以实现并行训练,这一策略在HGRN【索引76,Scaling transnormer to 175 billion parameters,2023b】等模型中被证明有效。RWKV-v6【索引71,Eagle and finch: Rwkv with matrixvalued states and dynamic recurrence,2024】和Mamba【索引28,Mamba: Linear-time sequence modeling with selective state spaces,2023】也使用了类似的数据依赖衰减率。

隐藏状态维度的扩展。RNNs依赖于固定维度的隐藏状态来编码历史信息,其维度大小影响模型的表达能力。线性Transformer通过外积参数化来扩展隐藏维度。而线性SSM则通过单输入单输出(SISO)策略来扩展。当SSM参数是数据无关时,可以通过快速傅里叶变换(FFT)高效训练。但对于数据依赖的SSM参数,如Mamba【索引28,Mamba: Linear-time sequence modeling with selective state spaces,2023】,则无法使用FFT,需要使用并行扫描算法的自定义CUDA核进行训练。为了将所有隐藏状态放入SRAM,Mamba的隐藏状态扩展率最高只能到16。相比之下,本文提出的硬件感知训练算法为扩展隐藏维度提供了一种替代的高效方法,使其可以扩展到更宽的范围,这在召回密集型任务中被证明是有用的。

A5 结论

本文提出了一种用于训练带有数据依赖门控机制的线性注意力Transformer的高效算法。该算法通过在浮点运算量(FLOPs)和并行性之间取得平衡,同时允许使用现代GPU上的张量核心单元进行半精度矩阵乘法,从而实现了高效训练。在语言建模上的实验表明,门控线性注意力(GLA)Transformer的性能与强大的基线模型相当。

A6 附录

A 扩展相关工作

A.1 线性注意力
- 特征图 $\phi$: 线性注意力用核函数 $k(x,y) = \langle\phi(x),\phi(y)\rangle$ 替代指数函数。$\phi$ 通常由两部分组成:一个线性映射 $\phi_1$ 和一个逐元素的激活函数 $\phi_2$。本文遵循Sun等人【索引83,Retentive network: A successor to transformer for large language models,2023a】和Mao【索引62,Fine-tuning pre-trained transformers into decaying fast weights,2022】的做法,使用恒等映射 $\phi=I$。
- 注意力尖锐度 (Spikiness): 线性注意力存在“注意力稀释”问题,即注意力分布过于均匀。解决方法包括增加局部注意力层,或使用缩放的逐元素指数映射 $\phi(x) = \exp(t \cdot x)$ 来集中注意力。
- 内存容量 (Memory capacity): 线性注意力的内存大小是有限的,而Softmax注意力则无界。增加内存大小的关键是高效地增加键的维度 $d_{key}$ 或特征维度 $d_{dot}$。后者可以通过高阶多项式核或确定性无参数投影(DPFP)等参数高效的方式实现。
- 带衰减或门的线性注意力: 先前工作已探索了位置相关的标量门或矩阵值的门来控制内存。标量衰减可以轻松地整合到分块线性注意力中以实现高效训练,但矩阵值的门在训练效率上面临更大挑战。本文提出的硬件高效训练算法解决了这一问题,减少或消除了对隐藏状态的物化,并能够利用张量核心。
- I/O感知的块状线性注意力: 块状线性注意力的形式在文献中是已知的,但大多数实现并非I/O感知。据我们所知,只有LIGHTNINGATTENTION-2【索引78,Lightning attention-2: A free lunch for handling unlimited sequence lengths in large language models,2024a】(与本文同期工作)是I/O感知的,且与我们的FLASHLINEARATTENTION的非物化版本相似。我们额外提出了一个物化版本,利用序列级并行性来提高训练吞吐量。

A.2 序列并行
- 与并行前缀和的相似性: 线性Transformer的分块并行形式类似于两阶段并行前缀和(并行扫描)算法,都结合了块内并行计算和块间通信。
- 与注意力序列并行的区别: 它也类似于用于加速基于注意力的Transformer的序列并行【索引50,Sequence parallelism: Long sequence training from system perspective,2023b】。主要区别在于:(i)线性Transformer的块级并行形式由于其线性复杂性只需要单次传递,而Transformer中的序列并行由于其固有的二次复杂性需要 L/C 次传递;(ii)矩阵乘法的顺序不同。分块线性注意力还可以大大减少分布式训练中设备间的通信成本。

A.3 硬件感知算法
- 理论与实践的差距: 许多算法在理论上很快,但在实践中由于与硬件属性不匹配而变慢。
- I/O感知的重要性: 减少HBM I/O成本(通过分块和重计算等技术)并尽可能利用张量核心至关重要。FLASHLINEARATTENTION在精神上类似于FLASHATTENTION【索引21,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022b;索引18,Flashattention-2: Faster attention with better parallelism and work partitioning,2023】和FLASHCONVFFT【索引25,Flashfftconv: Efficient convolutions for long sequences with tensor cores,2023d】,它们都实现了神经网络层的I/O感知版本,以实现实际的运行时间加速。

B 分块(门控)线性注意力的细节

FLASHLINEARATTENTION的反向传播。算法2列出了线性注意力的反向传播伪代码。

算法2 FLASHLINEARATTENTION: 反向传播
算法2 FLASHLINEARATTENTION: 反向传播

GLA的伪代码。算法3和4展示了物化版本GLA的前向/后向传播;算法5和6展示了非物化版本。列表1以Pytorch风格代码展示了二级分块算法。
- 算法3:门控线性注意力的前向传播(带物化)

算法3:门控线性注意力的前向传播(带物化)
算法3:门控线性注意力的前向传播(带物化)

- 算法4:门控线性注意力的反向传播(带物化)
算法4:门控线性注意力的反向传播(带物化)
算法4:门控线性注意力的反向传播(带物化)

- 算法5:门控线性注意力的前向传播(不带物化)
算法5:门控线性注意力的前向传播(不带物化)
算法5:门控线性注意力的前向传播(不带物化)

- 算法6:门控线性注意力的反向传播(不带物化)
算法6:门控线性注意力的反向传播(不带物化)
算法6:门控线性注意力的反向传播(不带物化)

- 列表1:用于训练GLA的两级分块算法的类Pytorch代码片段

def gated_linear_attention_forward (Q , K , V , a , C , c ) :
    # Q/K/V: query / key / value
    # a: log forget gate
    # C/c: chunk size , subchunk size
    # L: sequence length , d: head dimension
    L , d_k = Q.shape
    d_v = V.shape[-1]
    S = torch.zeros(d_k, d_v)
    O = torch.empty_like(V)
    # cumsum of log decay within a chunk
    B = torch.empty_like(a)
    # local compute of cumulative product of decay within a chunk
    for i in range(0, L // C):
        b = torch.zeros(d_k)
        for j in range(0, C):
            b += a[i]
            B[i] = b
    for i in range(0, L // C):
        r = range(i*C ,(i +1) * C)
        # (C, d) chunking
        bq , bk , bv , bb = Q[r] , K[r] , V[r] , B[r]
        b = bb[-1, None]
        #inter-chunk w/ matmul
        q, k, g = bq * (bb.exp()), bk * ((b - bb).exp()), b.exp()
        o = q @ S
        # hidden state update
        S = g.t() * S + k.t() @ bv
        #intra-chunk (secondary chunking)
        for j in range(0, C // c):
            t = range(j *c , (j +1) * c)
            #(c, head_dim ) subchunking
            q , k , v , b = bq[t] , bk[t] , bv[t] , bb[t]
            p = torch.zeros(c, c)
            #intra-subchunk w/o matmul .
            for m in range(c):
                for n in range(m +1):
                    p[m, n] = torch.sum(q[m] * k[n] * ((b[m] - b[n]).exp()))
            o[t] += p @ v
            # inter-subchunk w/ matmul
            z = b[0, None]
            q = q * (b - z).exp()
            for u in range(0, j):
                y = range(u *c , (u +1) * c)
                p = q @ (bk[y] * (z - bb[y]).exp()).t()
                o[t] += p @ bv[y]
        O[r] = o
    return O

dlogαt的推导。通过展开循环公式,并对 $o_t$ 关于 $\log b_t$(其中 $b_t$ 是累积门控值)求导,可以得到 $d\log b_t$ 的表达式。由于 $\log b_t = \sum_{i=1}^{t} \log\alpha_i$,那么 $d\log\alpha_t$ 就是 $d\log b_i$ 从 $i=t$ 到 $L$ 的累积和。这个过程避免了直接计算与 $S_{t-1}$ 相关的梯度,从而实现了内存高效。

dlogαt的梯度形式
dlogαt的梯度形式

展开循环
展开循环

关于查询/键向量的梯度
关于查询/键向量的梯度

关于累积门控对数的梯度
关于累积门控对数的梯度

dlogbt的最终形式
dlogbt的最终形式

C 通用门控线性注意力

通用形式。在主论文中,我们使用了一个简化的参数化,其中 $G_t = \alpha_t^T \mathbf{1}$。这里探讨了更通用的形式 $G_t = \alpha_t^T \beta_t$。

公式
公式

尽管实验发现学习 $\beta$ 并没有带来性能提升,但该通用形式仍然支持并行和分块形式,这可能对未来的线性注意力模型发展有用。

C.1 并行形式。通过展开循环并利用克罗内克/外积的混合积性质,可以推导出该通用形式的并行形式。最终的并行形式可以表示为 $O = ( (Q \odot B) (K/B)^T \odot M ) (V \odot D)$,其中 $B$ 和 $D$ 分别是 $\alpha$ 和 $\beta$ 的累积乘积矩阵。

公式推导过程
公式推导过程

C.2 分块并行形式。该通用形式同样可以推导出分块并行形式,用于高效训练。块间递推和块内并行计算的公式被相应地推广,以包含 $\alpha$ 和 $\beta$ 两个门控向量。这个分块形式是现有多种线性注意力形式的泛化,例如,通过设置不同的门控值,它可以退化为普通线性注意力或RetNet的分块形式。

内存高效的梯度计算。对于通用形式,关于 $\alpha$ 和 $\beta$ 的梯度也存在闭式解,允许在不实例化隐藏状态 $S$ 的情况下计算 $d\alpha$ 和 $d\beta$,从而实现内存高效。

梯度公式
梯度公式

D 额外的实验结果

扩展性能结果。表5展示了在所有11个任务上的完整结果,包括1.3B模型的5-shot结果。这些结果进一步证实了GLA模型与强基线模型(如Transformer++和Mamba)相比具有竞争力。

表5:扩展的零样本和五样本性能结果。所有模型都在SlimPajama数据集的同一子集上使用Mistral分词器进行训练。340M/1.3B模型分别训练了15B/100B词元。最后一列显示了所有准确率的平均值。
表5:扩展的零样本和五样本性能结果。所有模型都在SlimPajama数据集的同一子集上使用Mistral分词器进行训练。340M/1.3B模型分别训练了15B/100B词元。最后一列显示了所有准确率的平均值。