Linear Attention

韩广云,NVIDIA GPU 加速计算专家团队 高级工程师 | AI Open Day/2025-11-07

目录

议程 (Agenda)

  • 线性注意力 (Linear Attention)

    • 介绍 (Introduction)
    • 实现策略 (Implementation Strategy)
  • 扩展至 Delta Rule (Extension to Delta Rule)

    • 介绍 (Introduction)
    • 实现策略 (Implementation Strategy)

线性注意力 (Linear Attention)

当前状态 (Current Status)

  • 研究活跃

    • 提出了许多线性注意力机制:Linear, Mamba2, Lighting, Delta Rule 等。
  • 工业级大语言模型正在涌现

    • MiniMax-Text-01,于 2025-01 发布,456B
    • Qwen3-Next,于 2025-09 发布,80B
    • 未来的 Qwen3.5¹
  • 关注于核函数级别的优化及其挑战

    • 推理预填充(Inference prefill)的 Hopper 核函数
    • 目标是工业级的 LLM
    • 完全融合的线性注意力(Fully fused linear attention)
    • 扩展至 Delta Rule
    • 开放性问题

基础 (Basics)

  • 最基本的形式
    Page 5, 公式1

    • 其中 M 是一个下三角矩阵,所有元素为1。
    • 可以很容易地以递归形式计算。
  • 可以重写为块形式 (block-wise form)
    Page 5, 公式2

    • 结果是从之前的块中累积而来,而计算只涉及当前块。
  • 在 MiniMax 的工作中¹

    • 第一部分称为 O<sub>intra</sub>
    • 第二部分称为 O<sub>inter</sub>
  • 展开为块级形式 (Unroll for chunk-wise form)
    Page 5, 公式3

    • 这使得能够高效地映射到硬件上。
    • 使用 TensorCore 进行矩阵乘法 (MatMul)。
    • 扩展到 lighting attention。
    • 额外的元素级前/后处理。

并行性 - 回归 Flash Attention (Parallelism - Back to Flash Attention)

  • 调度 (Scheduling)

    • 采用 Flash attention v2 的调度方式:
      (seqlen_q / TileSizeQ) x (seqlen_k / TileSizeK) x num_heads x num_seqs
    • 其中 (seqlen_q / TileSizeQ)num_heads x num_seqs 部分是并行的 (parallelize),而 (seqlen_k / TileSizeK) 部分是串行的 (sequential)。
  • 外层循环处理 Q tiles

    • 由 CTA (Cooperative Thread Arrays) 并行化。
    • 每个 CTA 产生一个 O tile(非持久性)。
  • 内层循环处理 K 或 V tiles

    • 在单个 CTA 内顺序处理。
Page 6, Flash Attention 调度示意图
Page 6, Flash Attention 调度示意图

并行性 - 完全融合的线性注意力 (Parallelism - Fully Fused Linear Attention)

  • 调度 (Scheduling)

    • 调度方式为:(seqlen / TileSize) x num_heads x num_seqs
    • 其中 (seqlen / TileSize) 部分是串行的 (sequential),而 num_heads x num_seqs 部分是并行的 (parallelize)。
  • 只有一个串行循环

    • 计算复杂度更低。
    • 但也降低了并行化的机会。
  • 每个 CTA 产生

    • 多个输出 tiles。
    • 以及一个最终的输出状态。
Page 7, 完全融合线性注意力调度示意图
Page 7, 完全融合线性注意力调度示意图

算法分析 (Algorithm Analysis)

Page 8, 算法分析的公式和表格
Page 8, 算法分析的公式和表格
  • 假设 B 是 Tile Size, N 是序列长度

    • O<sub>inter</sub> 的计算量: 2B²d x (N/B - 1)
    • O<sub>intra</sub> 的计算量: (2B²d + 2B d²) x N/B
  • 总近似计算量

    • O(Nd²)O(N B d)
    • 关于序列长度是线性复杂度,这是好的
  • 分析结论

    • 假设 B=64, d=128
    • 计算每个 tile 所需的平均 FLOP: O<sub>inter</sub>: 2097152, O<sub>intra</sub>: 3145728
    • 每个输出 tile 传输 64KiB 数据
    • 算法的算术强度 (Arithmetic Intensity, AI) 为 80
    • 这是内存受限 (Memory Bounded) 的。
      • 假设有足够的并行性。
      • 应该将所有步骤融合成一个 kernel。

Hopper GPU 上的完全融合线性注意力 (Fully Fused Linear Attention On Hopper GPUs)

Page 9, Hopper GPU 上的流水线示意图
Page 9, Hopper GPU 上的流水线示意图

上图展示了一个理想的指令顺序,旨在实现高效的流水线操作,包括数据加载(Load Q, K, V)、数学计算(Math WG1, WG2)和结果存储(Store O)。图中展示了 Acquire/Release 各种 pipe(Q, K, V, O pipe)以协调数据流,最终计算出完整的 tile。

Hopper GPU 上的完全融合线性注意力 - 实现细节

Page 10, Hopper GPU 实现细节图
Page 10, Hopper GPU 实现细节图
  • 中间状态分布在数学工作组寄存器 (Math WG Registers) 中

    • 因为需要频繁的读/写操作。
    • 由于长范围累积,需要使用 FP32 格式。
    • 进一步减少了 SMEM 的使用。
  • 最大 TileSize 仅为 64

    • 如果是 128,双缓冲的 BF16 Q,K,V 需要 192KiB SMEM。
    • 但需要更多的 SMEM 用于中间结果 (P=QK)。
    • SMEM 不足。
  • WGMMA 的排序由最小化每个缓冲区的生命周期决定

    • O SMEM 用于 TMA (Tensor Memory Accelerator) 存储。

Hopper GPU 上的完全融合线性注意力 - 映射到硬件的细节

Page 11, WGMMA 配置和硬件映射细节
Page 11, WGMMA 配置和硬件映射细节
  • WGMMA 配置

    • 上图表格展示了不同计算步骤(P=QK, O<sub>intra</sub>=SQ, O<sub>inter</sub>=VP, S+=VK)的操作数、M/N/K 维度以及状态所处的空间(寄存器 R 或共享内存 S)。
  • 硬件映射要点

    • 转置(Transposition)的注解被省略了。
    • 仅适用于 MMA 配置。
    • S(状态)存在于寄存器中。

      • S 是操作数 A,用于计算 SQ。
      • 只有操作数 A 可以从输出 D 转换而来。
    • 所有设计都是将 S 放入寄存器的结果。

基准测试 (Benchmarks)

Page 12, 性能基准测试结果
Page 12, 性能基准测试结果
  • 基准测试环境

    • H800 NVL¹
    • GPU Max Clock rate: 1785 MHz (1.78 GHz)
    • 通过 nvidia-smi -lgc 1000,1000 锁定 GPU 时钟。
    • Memory Clock: 2619 Mhz.
  • 通过 CUDA Events 测量

    • 存在方差。
  • 完全融合的线性注意力 (Fully fused LA) 可以更快

    • 仍在改进中。
  • vllm 的速度较慢
    -因为它在 kernel 外部使用一个循环来处理可变长度的输入。

并行性问题 (Parallelism Problem)

  • 基于调度

    • (seqlen / TileSize) x num_heads x num_seqs
    • 串行部分 (seqlen / TileSize),并行部分 num_heads x num_seqs
  • 可以实现的并行度远小于 FMHA (Flash Multi-Head Attention)

    • Q 不再对并行性有贡献。
    • 因此,计算复杂度较低。
  • 可能导致硬件利用率低,为什么?

    • 在线服务要求低 TFST (Time to First Token)。
    • 小的 num_seqs,也称为 batch size
    • 线性注意力最初是为长序列设计的。
    • 但不能轻易地并行化(分割)成块。

并行性问题 - 上下文并行 (Parallelism Problem - Context Parallelism)

  • Linear 和 Lighting Attention 可以进行上下文并行化。
  • 以四个计算单元为例:
    Page 14, 上下文并行方案

    • 如上图所示,可以将计算 O<sub>intra</sub> 和 K<sup>T</sup>V 的任务分配到不同节点(NODE0-NODE3)。
    • 节点之间通过 Send/Recv 操作交换中间结果(K<sup>T</sup>V)。
    • 每个节点计算自己的 O<sub>inter</sub> 部分并累加之前节点的结果。
  • 尚未实现。

扩展至 Delta Rule (Extension to Delta Rule)

基础 (Basics)

Page 15, Delta Rule 公式
Page 15, Delta Rule 公式
  • 仅关注分块形式 (chunk-wise form)

    • 基本上是计算一个新的 V,记为 V<sup>new</sup>。
    • 在 Linear Attention 中替换 V

      • 用新计算出的 V<sup>new</sup> 进行计算。
    • 元素级处理(Elementwise processing)是不同的。

  • 计算流程

    • 首先计算 V<sup>new</sup>。
    • 之后的部分与普通的线性注意力相同,只是使用了 V<sup>new</sup>,并且有不同的元素级处理。

全融合Delta法则 (Fully Fused Delta Rule)

在Hopper GPU上的实现

  • 新特性

    • KK和其逆矩阵的计算。

      • 存在长延迟。
    • 计算NewV。

      • 用NewV替代V。
      • 否则,就像线性注意力机制一样。
    • 但是现在我们有很多 exp2flog2f 的计算。

下图展示了在Hopper GPU上实现全融合Delta法则的时间线图。该图解了不同工作组(Math WG1, Math WG2)中各种计算任务(如加载Q/K/V,计算T=KK,O1=QS,V-SK等)的并行与依赖关系。

Page 16: 全融合Delta法则在Hopper GPU上的执行流程图
Page 16: 全融合Delta法则在Hopper GPU上的执行流程图

针对下三角矩阵的矩阵求逆

  • 直接前向代换很简单

    • 但是非常昂贵!

      • O(n³) 级别的指令数。
      • 不良的内存访问模式。
    • 在对角线上使用小块矩阵进行计算。

  • 分块求逆来解决问题

    • 使用以下公式递归计算完全逆矩阵:
      Page 17: 分块矩阵求逆公式
    • 内存高效。

      • 使用LDSM或STSM。
    • 计算高效。

      • 使用HMMA。

下图展示了分块求逆的过程,其中绿色部分代表原始矩阵,浅绿色部分代表求逆后的矩阵。

Page 17: 下三角矩阵的分块求逆示意图
Page 17: 下三角矩阵的分块求逆示意图

3阶段流水线 (3-Stages Pipelining)

  • 将T和P的计算移至独立的流水线

    • 为T和P使用双缓冲。
    • 共享MUFU。
  • 之前

    • LdSt WG + Math WG
  • 当前

    • LdSt WG + T&P WG + State Related WG
  • 提升了性能

    • 在实践中并不非常显著。

下图展示了3阶段流水线的执行流程,对比了不同工作组(Aux Math WG1, Math WG2, and WG3)的任务调度。

Page 18: 全融合Delta法则的3阶段流水线示意图
Page 18: 全融合Delta法则的3阶段流水线示意图

基准测试 (Benchmarks)

  • 基准测试环境

    • H800 NVL¹
    • GPU最高时钟频率:1785 MHz (1.78 GHz)

      • 锁定至1Ghz via nvidia-smi -lgc 1000,1000
    • 内存时钟频率:2619 Mhz.

  • 通过CUDA Events测量

    • 预计存在方差。
  • 全融合Delta法则可以快很多

    • 仍有改进空间?

下表展示了在固定序列长度和固定批大小两种情况下的基准测试结果,对比了fla和我们(Ours)的实现。

Page 19: 基准测试结果表格


<font size="1">1. For technical discussion and reference only, perf. may vary based on different product portfolio.</font>
<font size="1">2. Flash attention performance is tested with version v2.5.3 commit 49b3c3b</font>
<font size="1">3. Our kernel is still in development</font>

开放性问题 (Open Question)

  • 主循环步骤中的长延迟

    • 3个以上的WGMMA

      • KK, SK 和 (V-SK)T
    • 矩阵求逆

      • 多轮的加载(LD)和存储(ST)
      • 更多的HMMA
    • 由于门控delta法则,存在过多的逐元素处理。

  • 精度问题

    • 主要来自矩阵求逆。
    • 也许TF32可以解决这个问题。
      • 甚至更长的延迟。
  • 并行性问题

    • 没有更多的上下文并行性。