Various Lengths, Constant Speed: Efficient Language Modeling with Lightning Attention

作者/机构: Zhen Qin, Weigao Sun, Dong Li, Xuyang Shen, Weixuan Sun, Yiran Zhong

A1 主要贡献

本文旨在解决现有线性注意力机制在大型语言模型中未被广泛采用的两个核心问题:1) 性能不佳:与顶尖的基于Softmax注意力的模型相比,存在明显的性能差距;2) 训练速度慢:在因果语言建模中,由于需要使用累积求和操作(cumsum),导致实际训练速度缓慢,失去了线性注意力的理论优势。

为了应对这些挑战,本文提出了以下核心贡献:

  1. 提出 Lightning Attention:这是首个能够在线性注意力中实现其理论计算优势的实现。其核心思想是采用“分而治之”的策略,将注意力计算分解为块内(intra-block)块间(inter-block)两部分。

    • 块内计算:采用传统的注意力计算方式(左乘积),即先计算查询(Q)和键(K)的点积。
    • 块间计算:利用线性注意力的“核技巧”(右乘积),即先计算键(K)和值(V)的乘积。
    • 通过这种方式,成功消除了因果线性注意力计算中对 cumsum 操作的依赖,从而解决了速度瓶颈。此外,结合了类似于FlashAttention的分块(tiling)技术,在整个前向和后向传播过程中优化GPU的IO效率,使其对硬件友好。
  2. 提出 TransNormerLLM (TNL) 架构:为了充分发挥 Lightning Attention 的性能,本文设计了一个专门为其定制的新型架构 TNL。该架构在先前的 TransNormer 基础上进行了多项改进:

    • 位置编码:采用带有指数衰减的线性化相对位置编码(LRPE-d),既保留了token间的全局交互能力,又避免了注意力稀释问题。
    • 门控机制:引入门控线性注意力(GLA)和简化的门控线性单元(SGLU),以平滑训练过程并提升模型性能。
    • 张量归一化:提出一种新的简单均方根归一化(SRMSNorm),在保持精度的同时加速模型。
    • 并行化:为TNL实现了高效的模型并行方案,使其能无缝部署在大型集群上。
  3. 全面的实验验证

    • 效率:实验证明,随着序列长度的增加,Lightning Attention 能够保持恒定的训练速度和较低的内存消耗,显著优于传统的PyTorch实现和FlashAttention-2(如图1所示)。
    • 精度:在不同模型尺寸(从44M到15B)和多种数据集(标准和自建)上的测试表明,TNL的性能不仅优于其他高效Transformer结构,而且与使用传统Transformer的SOTA大型语言模型(如LLaMA)不相上下(如图1和基准测试表所示)。

Figure 1. 训练速度和准确性比较。我们将TNL的训练速度和损失与最先进的Transformer模型(带FlashAttention-2的LLaMA)和高效的非Transformer模型(HGRN (Qin et al., 2023c) 和 TNN (Qin et al., 2023a))进行了比较。TNL实现了最低的训练损失,并且无论序列长度如何,都保持一致的训练速度。
Figure 1. 训练速度和准确性比较。我们将TNL的训练速度和损失与最先进的Transformer模型(带FlashAttention-2的LLaMA)和高效的非Transformer模型(HGRN (Qin et al., 2023c) 和 TNN (Qin et al., 2023a))进行了比较。TNL实现了最低的训练损失,并且无论序列长度如何,都保持一致的训练速度。

A3 背景知识/关键Observation/设计原则

2. 相关工作

高效语言建模的探索。为解决传统Transformer结构的高时间复杂度问题,研究人员正在探索新的高效模型架构。四种有前景的替代方案正在开发中,以取代自注意力模块用于长序列建模,包括线性Transformer、状态空间模型、长卷积和线性循环网络。

线性注意力。线性注意力将Softmax注意力分解为隐藏表示的内积,从而可以使用“核技巧”,即首先计算键和值的乘积,以避免二次方的 $n \times n$ 矩阵运算。不同的方法使用不同的隐藏表示。例如,【28,Transformers are rnns: Fast autoregressive transformers with linear attention,2020,ICML】使用 1+elu 作为激活函数,【43,cosformer: Rethinking softmax in attention,2022,ICLR】使用余弦函数来近似Softmax的性质,而【7,Rethinking attention with performers,2021,ICLR】、【56,Linear complexity randomized self-attention mechanism,2022,ICML】和【57,Efficient attention via control variates,2023,ICLR】则通过理论方法来近似Softmax。尽管其理论复杂度为 $O(nd^2)$,但在因果注意力中使用时,由于需要 cumsum 操作,线性注意力的实际计算效率会变低【24,Transformer quality in linear time,2022,arXiv】。此外,大多数线性注意力与传统Transformer相比仍存在一定的性能差距【28,Transformers are rnns: Fast autoregressive transformers with linear attention,2020,ICML】、【30,Neural architecture search on efficient transformers and beyond,2022,arXiv】。

状态空间模型。状态空间模型基于状态空间方程进行序列建模【20,Efficiently modeling long sequences with structured state spaces,2022,ICLR】,使用特殊的初始化【18,Hippo: Recurrent memory with optimal polynomial projections,2020,NeurIPS】、【21,On the parameterization and initialization of diagonal state space models,2022,NeurIPS】、对角化假设【22,Diagonal state spaces are as effective as structured state spaces,2022,NeurIPS】和混合技术【12,Hungry hungry hippos: Towards language modeling with state space models,2022,CoRR】来实现与Transformer相当的性能。由于状态空间方程的特性,推理可以以恒定的复杂度进行【20,Efficiently modeling long sequences with structured state spaces,2022,ICLR】,而训练速度与FlashAttention相比可能较慢。

长卷积。长卷积模型【44,Toeplitz neural network for sequence modeling,2023,ICLR】、【15,Simple hardware-efficient long convolutions for sequence modeling,2023,CoRR】利用等于输入序列长度的核大小,相比传统卷积,能够获得更宽的上下文。这些模型的训练涉及快速傅里叶变换(FFT)算法,将计算复杂度降低到 $O(n \log n)$。然而,长卷积模型在进行因果卷积推理时需要缓存所有历史计算,这使得它们在处理长序列方面不如RNN理想。

线性RNN。相比之下,线性RNN【35,Resurrecting recurrent neural networks for long sequences,2023,ICML】、【46,Hierarchically gated recurrent neural network for sequence modeling,22023,NeurIPS】在长序列建模中是更适合替代Transformer的模型。一个显著的例子是HGRN模型【46,Hierarchically gated recurrent neural network for sequence modeling,22023,NeurIPS】,这是一个基于线性RNN的大语言模型,其性能已显示出与同等规模的GPT模型相当的竞争力。

IO感知的注意力机制。FlashAttention系列【11,FlashAttention: Fast and memory-efficient exact attention with IO-awareness,2022,NeurIPS】、【10,Flashattention-2: Faster attention with better parallelism and work partitioning,2023,arXiv】专注于在GPU平台上高效实现标准注意力算子的系统级优化。这些方法采用分块(tiling)策略,以最小化GPU的高带宽内存(HBM)和片上SRAM之间的内存读写量。尽管这些方法优化了注意力计算中的IO通信,并且比以前的Softmax注意力实现更快,但它们的理论计算复杂度仍然是 $O(n^2d)$,这使得它们不适合长序列语言建模。

3.1. 预备知识

线性注意力的回顾。我们首先回顾线性注意力的公式,然后介绍我们提出的Lightning Attention。在TransNormer【42,The devil in linear transformer,2022,EMNLP】的NormAttention中,注意力计算不同于传统的Transformer结构【50,Attention is all you need,2017,NeurIPS】,它避开了昂贵的softmax和缩放操作。NormAttention机制可以表示如下:

公式1
公式1

其中 $Q, K, V \in \mathbb{R}^{n \times d}$ 分别是查询、键和值矩阵,n为序列长度,d为特征维度。该方程可以通过右矩阵乘法转换为其线性变体:
公式2
公式2

线性公式使得在训练期间能够以 $O(nd^2)$ 的复杂度进行高效的循环预测。此外,线性注意力保证了推理时无论序列长度如何,计算复杂度都为常数 $O(d^2)$。这是通过循环更新 $K^T V$ 实现的,无需重复计算整个注意力矩阵。相比之下,标准softmax注意力的推理复杂度为 $O(nd^2)$。

因果预测中的挑战。然而,在处理因果预测任务时,右乘积的有效性会受到影响,导致需要计算累积和(cumsum)【24,Transformer quality in linear time,2022,arXiv】。这个障碍阻碍了高效并行计算的潜力。在本节中,我们展示了通过在线性注意力计算中利用“分而治之”的概念,可以消除对 cumsum 的需求。为方便讨论,后续将忽略Norm操作。

两种因果场景计算方法。在因果场景下,有两种计算方法。一种是使用传统的注意力计算(左乘积),即首先计算 $QK^T$。完整的计算公式如下:

公式3
公式3

其中,如果 $t \geq s$,则 $M_{ts} = 1$,否则为0。完整的算法在算法1中详细说明。请注意,该算法是可并行的,但其时间复杂度为 $O(n^2d)$。另一种选择是先计算 $k_t v_t^T$(右乘积),它利用递归公式进行计算:
公式4
公式4

完整的算法在算法2中详细说明。该算法的时间复杂度为 $O(nd^2)$,但它对GPU不友好,使其比第一种方法慢。

算法1:线性注意力左乘积
算法1:线性注意力左乘积

算法2:线性注意力右乘积
算法2:线性注意力右乘积

A2 方法细节

3.2. 带分块(Tiling)的线性注意力

分块计算策略。我们使用分块技术来计算因果设置下的线性注意力。具体来说,我们首先按行将Q、K、V分成两个块:

公式5
公式5

然后,通过展开公式3,我们得到(注意 $kv_0 = 0$):
公式6
公式6

写成块形式,我们有:
公式7
公式7

分治计算思想。上述公式表明,前向因果线性注意力可以分为两部分:
- 块内计算:$[(Q_1 K_1^T) \odot M]V_1$(intra blocks)可以使用左乘积;
- 块间计算:$Q_1 KV_0$(inter blocks)可以使用右乘积。

迭代计算过程。值得注意的是,第二个块可以使用相同的思想进行计算,如下所示:

公式8
公式8

注意,要计算第二个块,我们必须使用 $KV_1 = kv_m$,这可以通过以下方式计算:
公式9
公式9

其中 $KV_0 = kv_0$。通过使用上述策略将矩阵划分为多个块,我们得到了Lightning Attention的前向传播算法。更详细的推导可以在附录C中找到。
算法3:Lightning Attention 前向传播
算法3:Lightning Attention 前向传播

反向传播的分块实现。对于反向传播,根据【28,Transformers are rnns: Fast autoregressive transformers with linear attention,2020,ICML】,我们可以将过程重写为:

公式10
公式10

因此,反向传播的计算与前向传播的公式4一致,同样可以使用分块技术得到Lightning Attention的反向传播算法。详细证明可以在附录C中找到。

3.3. 复杂度分析

定理 3.1。Lightning Attention的时间复杂度为 $O(nd^2 + nBd)$。

定理 3.1 的证明。对于前向传播,根据算法3,每个块内(intra)部分的时间复杂度为 $O(B^2d)$,每个块间(inter)部分的时间复杂度为 $O(Bd^2)$,更新KV的时间复杂度为 $O(Bd^2)$,所以每个循环中的时间复杂度为 $O(B^2d + Bd^2)$。由于我们循环 $T = n/B$ 次,总时间复杂度为 $O((B^2d + Bd^2)n/B) = O(nd^2 + nBd)$。因为反向传播的计算与前向传播相似,所以反向传播的时间复杂度也是 $O(nd^2 + nBd)$。 □

3.4. 精确的IO感知实现

IO优化实现流程。Lightning Attention在其整个计算过程中采用了上述的分块方法,并利用不同的方法来优化GPU内HBM和SRAM之间的内存带宽利用率。具体来说,在每次迭代t中,矩阵 $Q_t, K_t, V_t$ 被分割成块,然后传输到SRAM进行计算。块内和块间的操作是分开的,块内操作使用左乘积,块间操作使用右乘积。这种方法优化地利用了与右乘积相关的计算和内存效率,提高了整体执行速度。中间激活值KV在SRAM中被迭代保存和累积。随后,块内和块间的输出在SRAM内相加,结果写回HBM。Lightning Attention的结构如图2所示。Lightning Attention实现的复杂细节通过算法3(前向传播)和算法4(反向传播)进行了解释。

Figure 2. Lightning Attention的结构框架在其算法示意图中详细说明。在第t次迭代期间,矩阵Qt、Kt、Vt的分块从高带宽内存(HBM)传输到静态随机存取存储器(SRAM)。在SRAM内部,独立计算输出O_intra和O_inter,然后更新KV矩阵。随后,将O_intra和O_inter之和的最终输出Ot从SRAM写回HBM。
Figure 2. Lightning Attention的结构框架在其算法示意图中详细说明。在第t次迭代期间,矩阵Qt、Kt、Vt的分块从高带宽内存(HBM)传输到静态随机存取存储器(SRAM)。在SRAM内部,独立计算输出O_intra和O_inter,然后更新KV矩阵。随后,将O_intra和O_inter之和的最终输出Ot从SRAM写回HBM。

# 算法 4 Lightning Attention 反向传播
Input: Q, K, V, dO  R^{n×d}, block sizes B.
Divide X into T = n/B blocks X1, X2, ...XT of size B × d
each, where X  {Q, K, V}.
Divide dX into T = n/B blocks dX1, dX2, ...dXT of size
B × d each, where X  {Q, K, V, O}.
Initialize mask M  R^{B×B}, where M_ts = 1, if t  s, else 0.
Initialize KV = 0, dKV = 0  R^{d×d}.
for t = 1, . . . , T do
    Load K_t, V_t, O_t, dO_t  R^{B×d} from HBM to on-chip SRAM.
    On chip, compute dQ_intra = [(dO_t V_t^T)  M]K_t.
    On chip, compute dQ_inter = dO_t KV^T.
    On chip, compute KV = KV + K_t^T V_t.
    Write dQ_t = dQ_intra + dQ_inter to HBM as the t-th block of dQ.
end for
for t = T , . . . , 1 do
    Load Q_t, K_t, V_t, O_t, dO_t  R^{B×d} from HBM to on-chip SRAM.
    On chip, compute dK_intra = [(dO_t V_t^T)  M]^T Q_t.
    On chip, compute dK_inter = V_t dKV^T.
    On chip, compute dV_intra = [(Q_t K_t^T)  M]^T dO_t.
    On chip, compute dV_inter = K_t dKV.
    On chip, compute dKV = dKV + Q_t^T dO_t.
    Write dK_t = dK_intra + dK_inter, dV_t = dV_intra + dV_inter to HBM as the t-th block of dK, dV.
end for
Return dQ, dK, dV.

4. TransNormerLLM

4.1. 整体结构

TNL模型架构。我们的结构基于TransNormer【42,The devil in linear transformer,2022,EMNLP】的发现,但进行了定制化修改以平衡效率和性能。我们在图3中展示了整体结构。输入X通过两个连续的步骤进行更新:1). 它经过带有SimpleRMSNorm(SRMSNorm)归一化的门控线性注意力(Gated Linear Attention, GLA)。2). 它经过带有SRMSNorm归一化的简单门控线性单元(Simple Gated Linear Unit, SGLU)。我们对这两个模块都应用了Pre-norm。

Figure 3. TransNormerLLM (TNL) 的架构概览。每个Transformer块由用于token混合的门控线性注意力(GLA)和用于通道混合的简单门控线性单元(SGLU)组成。我们对两个模块都应用了Pre-norm。
Figure 3. TransNormerLLM (TNL) 的架构概览。每个Transformer块由用于token混合的门控线性注意力(GLA)和用于通道混合的简单门控线性单元(SGLU)组成。我们对两个模块都应用了Pre-norm。

4.2. 定制化修改

设计与动机。在本节中,我们概述了每个定制化修改背后的关键设计和灵感,包括位置编码、门控机制和张量归一化。

位置编码。在TransNormer中,低层使用DiagAttention来避免稀释问题。然而,这导致了token之间缺乏全局交互。在TNL中,我们利用带有指数衰减的LRPE【45,Linearized relative positional encoding,2023,TMLR】、【40,Train short, test long: Attention with linear biases enables input length extrapolation,2022,ICLR】、【44,Toeplitz neural network for sequence modeling,2023,ICLR】、【39,Rwkv: Reinventing rnns for the transformer era,2023,arXiv】来解决这个问题,在低层保留完整的注意力。我们的位置编码表达式如下:

公式11
公式11

我们称之为LRPE-d——带指数衰减的线性化相对位置编码。与原始LRPE类似,我们将 $θ$ 设置为可学习的。我们根据经验发现,相比于在每一层应用LRPE-d,只在第一层应用它并在其他层保持指数衰减,可以将训练速度提高约15-20%,而对性能的影响微乎其微。

与线性注意力的兼容性。请注意,这种位置编码与线性注意力完全兼容,因为它可以分别对s和t进行分解。第l层(假设总共有H个头和L层)中第h个头的 $λ$ 值由下式给出:

公式12
公式12

这里,$8h/H$ 对应第h个头的衰减率,而 $1-l/L$ 对应第l层的衰减率。项 $1-l/L$ 确保了低层的理论感受野(TRF)【47,Exploring transformer extrapolation,2024,AAAI】比高层小,这与TransNormer的动机一致。我们选择 $λ$ 为不可学习的,因为我们根据经验发现,当 $λ$ 可学习时,梯度会变得不稳定,导致NaN值。请注意,这种位置编码仍然与Lightning Attention兼容,具体算法在附录A B中有详细说明。

门控机制。门控可以增强模型的性能并平滑训练过程。在TNL中,我们采用了Flash【24,Transformer quality in linear time,2022,arXiv】的方法,在token混合中使用门控线性注意力(GLA):

公式13
公式13

我们选择 $\phi$ 为Swish【49,Searching for activation functions,2017,arXiv】激活函数,因为我们根据经验发现它的性能优于其他激活函数。

简化的门控线性单元。为了进一步加速模型,我们提出了简单GLU(SGLU),它从原始GLU结构中移除了激活函数,因为门控本身就可以引入非线性。因此,我们的通道混合变为:

公式14
公式14

我们根据经验发现,在GLU中不使用激活函数不会导致任何性能损失。

张量归一化。TransNormer【42,The devil in linear transformer,2022,EMNLP】中引入的原始NormAttention如下:

公式15
公式15

在TransNormerLLM中,我们将原始的RMSNorm替换为一个名为SimpleRMSNorm的新简单归一化函数,缩写为SRMSNorm:
公式16
公式16

我们根据经验发现,使用SRMSNorm不会导致任何性能损失。

A4 实验环境

  • 软件配置:
    • 模型实现:基于 Metaseq 框架【52,Opt: Open pre-trained transformer language models,2022,arXiv】和 Pytorch【37,Pytorch: An imperative style, high-performance deep learning library,2019,NeurIPS】。
    • 核心模块:Lightning Attention 使用 Triton【48,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】实现。
  • 硬件配置:
    • GPU:所有实验均在 A100 80G GPU集群上进行。
  • 数据集:
    • 小规模模型: 在 Wikitext-103 数据集上训练40M参数量的模型。
    • 大规模模型: 使用自建语料库进行预训练。原始数据超过700TB,经过清洗、去重和过滤后,得到约6TB、2万亿个token的高质量语料库。
    • 评测基准: 在多个公开基准上进行评测,包括:
      • 常识推理任务:BoolQ, PIQA, SIQA, HellaSwag, WinoGrande, ARC, OpenBookQA。
      • 多任务语言理解:MMLU, C-Eval (中文)。
      • 长序列任务:SCROLLS。
  • 模型架构:
    • 本文模型: TNL,训练了44M, 385M, 1B, 7B, 15B等多种参数规模。
    • 对比模型:
      • 高效结构:LLaMA-FA2, HGRN, TNN。
      • SOTA大模型:OPT, Pythia, BLOOM, GPT-Neo, Falcon, LLaMA, OpenLLAMA, Baichuan, ChatGLM, RWKV。

A4 实验结果

5.1. Lightning Attention 评估

  • 实验内容:对比了Lightning Attention与原始PyTorch实现(Vanilla)以及当前最先进的Softmax注意力实现FlashAttention-2(Flash2)的速度和内存使用情况。
  • 实验结果:如图4所示,Lightning Attention在处理时间和内存占用上均表现出显著优势。
    • 速度:在前向和后向传播中,Lightning Attention的处理时间呈线性增长,而Vanilla和Flash2则呈二次方增长
    • 内存:Vanilla实现很快耗尽内存。Lightning Attention的内存占用趋势与Flash2相似,但所需内存更少
  • 分析结论:Lightning Attention成功地将线性注意力的理论优势转化为实际的计算效率,实现了线性的时间和空间复杂度。

Figure 4. 速度和内存使用情况的比较分析:Vanilla代表PyTorch中的规范线性注意力(Qin et al., 2022a),Flash2代表FlashAttention-2。左侧两个子图:不同序列长度下前向和后向传播的运行时间(毫秒)。右侧两个子图:不同序列长度下前向和后向传播的内存利用率(GB)。
Figure 4. 速度和内存使用情况的比较分析:Vanilla代表PyTorch中的规范线性注意力(Qin et al., 2022a),Flash2代表FlashAttention-2。左侧两个子图:不同序列长度下前向和后向传播的运行时间(毫秒)。右侧两个子图:不同序列长度下前向和后向传播的内存利用率(GB)。

5.2. TNL 评估

  • 小模型性能评估 (表1):
    • 实验内容:在Wikitext-103数据集上,对比了约40M参数量的TNL与其他多种架构(注意力、MLP、RNN、FFT等)的性能。
    • 实验结果:TNL在测试集上取得了最低的困惑度(Perplexity)
  • 大模型训练性能 (图1):
    • 实验内容:将TNL扩展到1B和3B参数,与LLaMA-FA2、HGRN、TNN在相同的30B语料库上进行训练,并比较训练损失。
    • 实验结果:无论是在1B还是3B规模下,TNL都取得了最低的训练损失
  • 效率评估 (图1):
    • 实验内容:比较了TNL、LLaMA-FA2、HGRN和TNN在不同序列长度(从1K到128K)下的训练速度(TGS:每GPU每秒处理的token数)。
    • 实验结果:TNL的TGS保持稳定,而其他三个模型的TGS随着序列长度的增加而迅速下降。
    • 分析结论:Lightning Attention使TNL在处理超长序列方面具有显著优势。
  • 推理评估 (图5):
    • 实验内容:在A100 80G GPU上比较了多个7B大语言模型的推理吞吐量。
    • 实验结果:TNL的推理吞吐量比Transformer结构的模型高出最多11倍
  • 基准测试结果 (表2, 表3):
    • 实验内容:在常识推理、MMLU、C-Eval和SCROLLS等多个基准上测试了385M、1B、7B和15B的TNL模型,并与多个开源SOTA模型进行比较。
    • 实验结果:TNL在所有测试中都表现出极强的竞争力,性能与顶尖的Transformer模型相当或更优。

Figure 5. 推理吞吐量比较。我们在A100 80G GPU上测量了各种7B LLM模型的推理吞吐量。为模型选择的批量大小旨在优化GPU利用率而不超过内存限制。每个模型都使用512个token的输入提示进行测试,并且可以生成多达1024个新token。报告的吞吐量是20次尝试的平均值。
Figure 5. 推理吞吐量比较。我们在A100 80G GPU上测量了各种7B LLM模型的推理吞吐量。为模型选择的批量大小旨在优化GPU利用率而不超过内存限制。每个模型都使用512个token的输入提示进行测试,并且可以生成多达1024个新token。报告的吞吐量是20次尝试的平均值。

Table 1. 在Wikitext-103上的结果(TNN(Qin et al., 2023a)的设置)。↓表示越低越好。
Table 1. 在Wikitext-103上的结果(TNN(Qin et al., 2023a)的设置)。↓表示越低越好。

Table 2. 在常识推理和聚合基准上的性能比较。为了公平比较,我们报告了使用其发布模型复现的竞争方法的结果。官方结果以斜体表示。PS:参数大小(十亿)。T:token(十亿)。HS:HellaSwag。WG:WinoGrande。
Table 2. 在常识推理和聚合基准上的性能比较。为了公平比较,我们报告了使用其发布模型复现的竞争方法的结果。官方结果以斜体表示。PS:参数大小(十亿)。T:token(十亿)。HS:HellaSwag。WG:WinoGrande。

Table 3. 在SCROLLS上的性能比较 (Shaham et al., 2022):对2048预训练序列长度下高达10亿参数的模型的评估。PS:参数大小(十亿)。T:token(十亿)。
Table 3. 在SCROLLS上的性能比较 (Shaham et al., 2022):对2048预训练序列长度下高达10亿参数的模型的评估。PS:参数大小(十亿)。T:token(十亿)。

5.3. TNL 消融实验

  • 位置编码 (表4):LRPE-d效果最佳。但为了提升训练速度(约20%),最终采用了一种混合策略(仅在第一层使用LRPE-d),性能损失微小。
  • 衰减温度 (表5):在位置编码的指数衰减项中加入与层数相关的衰减温度 $1-l/L$ ,有效降低了模型的困惑度。
  • 门控机制 (表6):启用门控机制(GLA)后,模型损失从2.263降至2.248,证明了其有效性。
  • 归一化函数 (表7):SRMSNorm、RMSNorm和LayerNorm的性能差异很小。但SRMSNorm通过Triton优化后,在处理大维度数据时速度更快。
  • GLA激活函数 (表8):Swish和1+elu表现相似。但由于1+elu在7B模型中会导致NaN问题,最终选择了Swish。
  • GLU激活函数 (表9):从GLU中移除激活函数(形成SGLU)对结果影响甚微,因此为了简化模型而采用SGLU。

Table 4. 位置编码探索。LRPE-d 带来了最优结果。
Table 4. 位置编码探索。LRPE-d 带来了最优结果。

Table 5. 衰减温度的消融实验。带有衰减温度的结果被证明更优。
Table 5. 衰减温度的消融实验。带有衰减温度的结果被证明更优。

Table 6. 门控机制的消融实验。带门控的性能被证明更优。
Table 6. 门控机制的消融实验。带门控的性能被证明更优。

Table 7. 归一化函数探索。以下归一化函数之间的结果偏差很小。
Table 7. 归一化函数探索。以下归一化函数之间的结果偏差很小。

Table 8. GLA激活函数的消融实验。不同激活函数得到的结果几乎相同。
Table 8. GLA激活函数的消融实验。不同激活函数得到的结果几乎相同。

Table 9. GLU激活函数的消融实验。排除激活函数对结果没有负面影响。
Table 9. GLU激活函数的消融实验。排除激活函数对结果没有负面影响。

A5 结论

本文成功推出了 Lightning Attention,这是首个能完全释放线性注意力潜力的实现。其核心创新在于通过“分而治之”的策略,将注意力计算分解为块内(intra-blocks)和块间(inter-blocks)计算,并为它们分别应用不同的计算技术,从而消除了 cumsum 瓶颈。这使得模型在处理不同长度的序列时,能够在固定的内存占用下保持恒定的训练速度

此外,本文还提出了专为 Lightning Attention 定制的 TransNormerLLM (TNL) 架构。实验结果表明,TNL 在效率和准确性方面均优于现有的高效语言模型,并且其性能与采用传统Transformer架构的顶尖大型语言模型相比具有很强的竞争力。这项工作为开发更高效、可扩展性更强的大型语言模型开辟了新的道路。

A6 附录

A. 带衰减的线性注意力

LRPE-d位置编码的形式。TransNormerLLM使用LRPE-d位置编码,其格式如下:

公式17
公式17

根据【45,Linearized relative positional encoding,2023,TMLR】,Lrpe可以分解为q和k,因此我们考虑以下简化形式:
公式18
公式18

我们称之为带衰减的线性注意力,并证明它等价于递归形式:
公式19
公式19

B. 带衰减的Lightning Attention

算法扩展。我们将Lightning Attention扩展以适应带衰减的线性注意力。完整的算法可以在算法5和6中找到,正确性的证明在附录C中提供。

算法5:带衰减的Lightning Attention前向传播
算法5:带衰减的Lightning Attention前向传播

C. 证明

证明概述。这里我们直接讨论带衰减的线性注意力,因为原始的线性注意力是 $λ=1$ 的特例。

公式20
公式20

以递归形式,上述方程可以重写为:
公式21
公式21

公式22
公式22

公式23
公式23

其中
公式24
公式24

C.0.1. 前向传播

数学归纳法证明。在带衰减的线性注意力的前向传播中,第t个输出可以表示为。我们将使用数学归纳法来证明 $kv_t = \widetilde{kv_t}$。
基础情况 (n = 1):

公式25
公式25

归纳步骤。假设当 $n = m - 1$ 时命题成立,即 $kv_{m-1} = \widetilde{kv_{m-1}}$。那么,当 $n=m$ 时:
公式26
公式26

命题成立。因此,通过归纳法,该命题对所有 $n \geq 1$ 都成立。

分块形式推导。为了进行分块,我们以块形式写出方程。给定总序列长度n和块大小B,X被划分为 $T=n/B$ 个块 $\{X_1, X_2, \dots, X_T\}$,每个块大小为 $B \times d$,其中 $X \in \{Q, K, V, O\}$。
我们首先定义

公式27
公式27

给定 $KV_t$,第 $(t+1)$ 个块的输出,即 $tB+r$ 位置($1 \leq r \leq B$),为
公式28
公式28

以矩阵形式重写,我们得到
公式29
公式29

其中
公式30
公式30

第 $(t+1)$ 个块的KV可以写为
公式31
公式31

带衰减的Lightning Attention前向传播的完整表达式见算法5。

C.0.2. 反向传播

反向过程推导。对于反向传播,我们考虑逆过程。首先给定 $d\mathbf{o}_t$,我们有

公式32
公式32

以递归形式写出 $dkv_t$,我们得到
公式33
公式33

分块形式推导。为便于理解分块,我们以块形式考虑上述方程。给定总序列长度n和块大小B,X被划分为 $T = n/B$ 个块 $\{X_1, X_2, \dots, X_T\}$,每个块大小为 $B \times d$,其中 $X \in \{Q, K, V, O, dO\}$。
我们首先定义
公式34
公式34

那么对于第 $(t+1)$ 个块,即 $tB+r$, $0 \leq r < B$,我们有
公式35
公式35

以矩阵形式,我们有
公式36
公式36

由于 $dK_t$ 的递归从 $t+1$ 步进到 $t$,给定 $KV_{t+1}$,第 $t$ 个块的 $dK_t$,即在位置 $(t-1)B+r$ 处,
公式37
公式37

以矩阵形式,我们得到
公式38
公式38

考虑第 $t$ 个块的 $dV_t$,即在位置 $(t-1)B+r, 0 < r \leq B$ 处,我们有
公式39
公式39

以矩阵形式,我们得到
公式40
公式40

最后,$dKV_t$ 的递归关系为
公式41
公式41

算法6更详细地描述了带衰减的Lightning Attention的反向传播过程。

算法6:带衰减的Lightning Attention反向传播
算法6:带衰减的Lightning Attention反向传播

D. 语料库

数据收集与处理。我们从互联网上收集了超过700TB的公开文本数据。这些数据经过我们的数据预处理流程(如图6所示)处理后,留下了6TB的清洁语料库,大约有2万亿个token。我们对数据来源进行了分类,以提供更好的透明度和理解。这些类别的具体信息在表10中列出。

Figure 6. 数据预处理流程。收集的数据经过基于规则的过滤和去重过程,然后是我们自创的自清洁数据处理策略:基于模型的过滤、人工评估和评估模型。经过上述循环的多次迭代,我们获得了约2T tokens的高质量训练数据。
Figure 6. 数据预处理流程。收集的数据经过基于规则的过滤和去重过程,然后是我们自创的自清洁数据处理策略:基于模型的过滤、人工评估和评估模型。经过上述循环的多次迭代,我们获得了约2T tokens的高质量训练数据。

Table 10. 我们的语料库统计。对于每个类别,我们列出了在2万亿个token上训练时在该子集上执行的epoch数,以及token数和磁盘大小。我们还在右侧根据语言分布列出了表格。
Table 10. 我们的语料库统计。对于每个类别,我们列出了在2万亿个token上训练时在该子集上执行的epoch数,以及token数和磁盘大小。我们还在右侧根据语言分布列出了表格。

D.1. 数据预处理

预处理流程。我们的数据预处理流程包括三个步骤:1) 基于规则的过滤,2) 去重,和 3) 自清洁方案。在被添加到训练语料库之前,清洁的语料库需要经过人工评估。
- 基于规则的过滤
* 移除HTML标签和URL:使用正则表达式去除文本中的HTML标签和网址。
* 消除无用或异常字符串:去除如乱码等无价值的字符串。
* 标点符号去重:将连续的多个标点符号缩减为单个。
* 处理特殊字符:移除或替换不常见的特殊字符。
* 数字标准化:将各种格式的数字统一为标准格式。
* 保留Markdown/LaTeX格式:为保留结构化信息,特别保留这两种格式。
- 去重
* 方法:为避免过拟合并确保数据独特性,我们在文档或行级别使用MinHash和局部敏感哈希(LSH)算法进行高效去重。
- 自清洁方案
* 核心思想:利用模型本身来过滤数据。我们假设大部分预处理后的数据是高质量的,因此在一个包含所有预处理数据的集合上训练一个评估模型,该模型能够自动平滑数据流形分布,剔除低质量数据,同时保留大部分多样性。
* 迭代步骤
1. 评估模型:在预处理后的语料库上训练一个385M的模型作为数据质量过滤器。
2. 基于模型的数据过滤:使用评估模型计算每条数据的困惑度,只保留得分高于特定阈值的数据。
3. 人工评估:对过滤后的数据进行小部分抽样,进行人工质量评估。
* 循环优化:这些步骤循环进行,每次迭代都提高数据质量。

D.2. 分词

分词方法。我们使用字节对编码(BPE)算法对数据进行分词。值得注意的是,为了增强对中文内容的支持,我们将大量常见和不常见的汉字加入了词汇表。对于词汇表中不存在的词项,单词会被分解为其组成的UTF-8字符。这一策略确保了在模型训练期间对不同语言输入的全面覆盖和灵活性。

E. 分布式系统优化

系统优化策略。我们优化了系统以有效执行TNL的大规模预训练。我们采用完全分片数据并行(FSDP)【53,Pytorch fsdp: experiences on scaling fully sharded data parallel,2023,arXiv】、激活检查点【50,Megatron-lm: Training multibillion parameter language models using model parallelism,2019,arXiv】和自动混合精度(AMP)【33,Mixed precision training,2017,arXiv】技术来减少内存占用并加速计算。我们使用BFloat16【27,A study of bfloat16 for deep learning training,2019,arXiv】来增强训练稳定性。我们为Lightning Attention量身定制了模型并行。受Megatron-LM【50,Megatron-lm: Training multibillion parameter language models using model parallelism,2019,arXiv】模型并行的启发,该并行独立处理自注意力和MLP块,我们分别对SGLU和GLA应用模型并行。我们的模型并行策略的细节如下。

SGLU模型并行。回顾公式(12)中的SGLU结构:

公式42
公式42

SGLU的模型并行适应如下:
公式43
公式43

这将权重矩阵 $W_v$ 和 $W_u$ 沿其列进行拆分,并得到一个也沿其列拆分的输出矩阵。然后,拆分的输出 $[O_1, O_2]$ 与另一个沿其行拆分的矩阵相乘,如下:
公式44
公式44

与Megatron-LM中的模型并行类似,这整个过程将SGLU块内的三个通用矩阵乘法(GEMM)分布在多个GPU上,并且在前向和后向传播中分别只引入一个all-reduce集体通信操作。

GLA模型并行。回顾公式(11)中的GLA块,其模型并行版本是:

公式45
公式45

其中:
公式46
公式46

请注意,在我们的实现中,我们使用组合的QKVU投影来提高线性注意力的计算效率。获得的拆分输出矩阵 $[O_1, O_2]$ 再次与一个沿其列拆分的权重矩阵相乘,这与公式(40)类似。

F. 额外的TNL消融实验

Transformer vs TNL。我们对TNL和Transformer进行了一系列细致的比较测试,涵盖了不同的大小。这些模型的比较性能在表11中清晰地展示出来。在相同配置下,我们的TNL明显表现出比Transformer更优的性能。我们观察到,在385M的尺寸下,TNL比Transformer性能高出5%。更重要的是,当尺寸达到1B时,这种优势变得更加明显,TNL比Transformer有9%的优势。

Table 11. Transformer vs TNL. 在相同配置下,TNL在385M和1B尺寸下分别比Transformer性能好5%和9%。
Table 11. Transformer vs TNL. 在相同配置下,TNL在385M和1B尺寸下分别比Transformer性能好5%和9%。

TransNormer vs TNL。我们比较了原始的TransNormer和改进后的TNL,结果如表12所示。TNL分别表现出2%和1%的提升。

Table 12. TransNormer vs TNL. TNL表现优于TransNormer。
Table 12. TransNormer vs TNL. TNL表现优于TransNormer。

归一化函数速度。我们使用Triton增强了SRMSNorm,使得在处理更大维度时处理速度显著提升,如图7所示,性能优于传统的PyTorch实现。

Figure 7. SRMSNorm实现的性能评估。上图展示了在固定特征维度为3072时,不同序列长度的前向传播(左部分)和后向传播(右部分)的运行时间比较。下图展示了在固定序列长度为4096时,不同特征维度的运行时间比较。
Figure 7. SRMSNorm实现的性能评估。上图展示了在固定特征维度为3072时,不同序列长度的前向传播(左部分)和后向传播(右部分)的运行时间比较。下图展示了在固定序列长度为4096时,不同特征维度的运行时间比较。