Tensor Product Attention Is All You Need
Tensor Product Attention Is All You Need
作者/机构:Yifan Zhang (清华大学/普林斯顿大学), Yifeng Liu (UCLA), Huizhuo Yuan (UCLA), Zhen Qin, Yang Yuan (清华大学/上海期智研究院), Quanquan Gu (UCLA), Andrew Chi-Chih Yao (清华大学/上海期智研究院)
A1 主要贡献
随着大语言模型(LLM)处理上下文长度的增加,推理过程中的键值(KV)缓存成为严重的内存瓶颈。本文提出了张量积注意力(Tensor Product Attention, TPA),这是一种新型注意力机制,其核心创新点和贡献如下:
- 提出 TPA 机制:通过上下文相关的张量分解(Tensor Decompositions)来压缩查询(Q)、键(K)和值(V)的激活值。与静态权重分解(如 LoRA)不同,TPA 动态分解激活值,构建低秩且具上下文信息的表示。
- KV Cache 显著缩减:TPA 在推理阶段大幅减少了 KV Cache 的大小,通常比标准多头注意力(MHA)降低一个数量级,同时在预训练验证损失(困惑度)和下游任务表现上持平或更优。
- 高度兼容性:TPA 原生支持旋转位置嵌入(RoPE)及任何位置编码,也支持输出门控和 KV 移位。它可以作为 LLaMA、Qwen 和 Gemma 等现代 LLM 架构中 MHA 层的掉入式(Drop-in)替换方案。
- 提出 T6 模型架构:基于 TPA 构建了 Tensor ProducT ATTenTion Transformer (T6) 序列建模架构。
- 开发 FlashTPA Decoding 算法:这是一种针对 TPA 优化的自回归推理算法,利用因子化组件直接进行爱因斯坦求和(einsum)计算,避免显式还原全量张量,在长序列推理速度上超越了优化的 MHA、MQA、GQA 和 MLA。
Fig 1 展示了在 T6 每一层中,输入隐藏状态 $[x_t]$ 如何通过线性层产生 Q、K、V 的潜在因子矩阵,应用 RoPE 后通过张量积形成各头张量,最后进行缩放点积注意力计算。
A3 背景知识/关键观察/设计原则
符号系统与张量定义。本文使用粗体大写字母表示矩阵,粗体小写表示向量。令 $[d_{model}]$ 为嵌入维度,$[h]$ 为头数,$[d_h]$ 为每头维度。$[x_t \in \mathbb{R}^d]$ 表示第 $[t]$ 个 token 的输入。$[Q, K, V \in \mathbb{R}^{T \times h \times d_h}]$ 分别表示查询、键、值张量。向量 $[a, b]$ 的张量积(外积)定义为 $[a \otimes b = C]$,其中 $[C_{ij} = a_i b_j]$。矩阵 $[C]$ 的向量化记为 $[vec(C)]$。
缩放点积注意力与多头注意力 (MHA)。缩放点积注意力通过比较 $[Q]$ 和 $[K]$ 来确定关注重点,并对 $[V]$ 进行加权组合:
MHA(【60, Attention is all you need+2017+NIPS+https://arxiv.org/abs/1706.03762】)通过将模型表示划分为多个头来扩展此机制。对于每 个 token 嵌入 $[x_t]$,计算每个头如下:
A2 方法细节
查询、键和值的张量因子化。TPA 不再通过单一线性映射形成头张量,而是将每个 $[Q_t, K_t, V_t]$ 因子化为秩分别为 $[R_q, R_k, R_v]$ 的(上下文相关)张量积之和。具体而言,对于每个 token $[t]$:
潜在因子映射与秩索引合并。张量积中的每个因子都取决于 token 的隐藏状态 $[x_t]$。以查询为例:
其中 $[W_r^{aQ} \in \mathbb{R}^{h \times d_{model}}]$ 和 $[W_r^{bQ} \in \mathbb{R}^{d_h \times d_{model}}]$ 是可学习的权重。为了实现效率,秩索引常被合并到单一输出维度:
多头注意力计算与参数初始化。一旦完成因子化,多头注意力将按照标准 Transformer 进行。对于每个头 $[i]$:
连接所有头后,通过输出权重 $[W^O]$ 投影回原始嵌入维度:
RoPE 兼容性与加速优化。在标准流程中,RoPE 应用于重构后的张量。TPA 提供了一种优化策略:直接将 RoPE 集成到因子化中。
1. 因子预旋转:可以预先旋转 token 维度的因子,例如:
2. 预旋转表示:得到预旋转后的键表示:
KV 缓存与内存减少分析。标准注意力需要缓存 $[K_t, V_t \in \mathbb{R}^{h \times d_h}]$,总开销为 $[2 T h d_h]$。
- TPA 因子化 KV 缓存:TPA 仅存储因子组件 $[A_K(x_t), \tilde{B}_K(x_t), A_V(x_t), B_V(x_t)]$。
- 开销计算:每个 token 的存储成本为 $[(R_K + R_V)(h + d_h)]$。
- 压缩率:与标准缓存相比,比例为 $[\frac{(R_K + R_V)(h + d_h)}{2 h d_h}]$。对于典型的 $[d_h = 64]$ 或 $[128]$,当秩 $[R_K, R_V]$ 设为 1 或 2 时,KV Cache 可获得量级上的缩减。Table 1 详细对比了各机制的参数量与 KV 缓存大小。
MHA, MQA, GQA 作为非上下文 TPA 的表达。作者指出 MHA、MQA 和 GQA 实际上是 TPA 的特殊非上下文变体,即通过对 TPA 因子施加约束:
- MHA:等同于秩 $[R=h]$ 且头维度因子为固定正交的标准基向量 $[e_i]$ 的 TPA。
- MQA:等同于键值秩 $[R_K, R_V = 1]$ 且头维度因子为全 1 向量 $[1_h]$ 的 TPA。
- GQA:等同于秩为组数 $[G]$ 的 TPA,头维度因子为掩码向量。
T6 模型架构细节。T6 采用了 LLaMA(【58, Llama: Open and efficient foundation language models+2023+arXiv】)的整体架构,但将自注意力块替换为 TPA 块。FFN 部分采用 SwiGLU 层(【47, Glu variants improve transformer+2020+arXiv】):
在每个子层前应用层归一化(如 RMSNorm)。
A5 FlashTPA 解码算法
算法核心逻辑。为了高效生成 token,FlashTPA Decoding 利用因子化表示,通过一系列爱因斯坦求和(einsum)计算注意力输出,完全避免了完整 $Q, K, V$ 张量的实体化。
- 数据流过程:如图 2 所示。先计算头共享的特征空间点积 $[P]$,将其与头特定的 $[A_Q, A_K]$ 混合获得 logits,应用 Softmax,最后通过 $[A_V, B_V]$ 聚合值。
- 内存与带宽优势:随着序列长度 $[M]$ 增长,KV Cache 的因子化存储大幅降低了内存带宽压力。
- 实现细节:提供了基于 Triton 的内核实现(Algorithm 3),该实现融合了 Softmax 和值聚合步骤,具有良好的数值稳定性(在线 log-sum-exp)。
Fig 2 展示了单查询 (N=1) 与 M 个缓存项交互的过程,矩形代表张量,圆圈代表 einsum 缩并或元素级乘法。
A4 实验环境
- 数据集:使用 FineWeb-Edu 100B 数据集进行预训练(1000 亿 token 训练,1 亿 token 验证)。
- 模型规模:训练了四种规模:Small (124M)、Medium (353M)、Large (773M) 和 XL (1.5B)。所有模型均在 50B token 上训练 1 个 epoch。
- 基准对比:对比了标准 LLaMA (MHA) 及其变体 MQA、GQA、MLA。为公平起见,调整头数以确保各注意力机制的参数量一致。
- 软件配置:基于 nanoGPT 代码库实现。使用 AdamW 优化器,学习率采用余弦退火策略,设置了 2000 步预热。
- 硬件配置:
- Small: 4x A100 GPU。
- Medium/Large/XL: 8x A100 GPU。
- 具体的 batch size 和梯度累积步数见附录 Table 9。
A4 实验结果
语言建模性能。
- 收敛速度与损失:在所有规模(Medium, Large, XL)中,TPA(红线)和 TPA-KVonly(粉线)的收敛速度均优于或持平于基准。如图 4 所示,TPA 在训练各阶段的验证损失均明显低于 MHA 和 MLA。
- 验证困惑度:TPA-based 方法在预训练结束时达到了最低的困惑度(附录 Fig 9)。
下游任务评测。
- 在 ARC, BoolQ, HellaSwag, MMLU 等主流基准上测试了 0-shot 和 2-shot 表现。
- Medium (353M):TPA 平均准确率为 51.41%,高于 MHA (50.11%) 和 MLA (50.13%)。
- Large (773M):TPA-KVonly 获得最高平均得分 53.52%。
- XL (1.5B):TPA-KVonly 继续保持领先(55.03%)。
- 详细结果见 Table 2 和 Table 3。
FlashTPA 解码速度。
- 可扩展性:在长序列生成上表现卓越。随着序列长度增加,FlashTPA 的解码时间增长速度慢于 MQA/GQA。
- 性能对比:在序列长度超过 $[2^{14} \sim 2^{15}]$ 时,基于 Triton 的 FlashTPA 速度超越了 MLA,并显著快于 MHA(Fig 5)。
A5 结论
本文提出的 Tensor Product Attention (TPA) 通过上下文相关的低秩分解,有效地解决了 LLM 推理中的 KV Cache 膨胀问题。实验证明 TPA 在大幅压缩存储的同时,不仅没有损失模型性能,反而提升了语言建模的质量。其与 RoPE 的天然兼容性以及高效的 FlashTPA 解码算法,使其成为长文本建模极具竞争力的方案。
未来工作:作者指出目前的实现基于 Triton,未来开发高度优化的 CUDA 内核有望进一步挖掘其性能。此外,TPA 在其他模态(如图像、多模态)的泛化能力也值得深入研究。
A6 附录/补充细节
高阶张量积注意力 (Higher-Order TPA)。附录 C 探讨了将特征向量 $[b_r(x_t)]$ 进一步分解为 $[b_r(x_t) \otimes c_r(x_t)]$ 的可能性(3 阶分解)。这可以增强模型的表达能力,通过额外的上下文调制因子来调整特征。作者证明了高阶 TPA 同样兼容 RoPE(Theorem C.1),并在 Small 模型上初步验证了其可行性。
非上下文因子变体。
- Non-contextual B:固定 token 维度因子,仅让头维度因子 $[a(x_t)]$ 随输入变化。实验显示其表现略逊于标准 TPA。
- TPA-KVonly:仅对 K 和 V 进行因子化,Q 使用标准投影。这是一种在压缩率和复杂度之间的平衡方案。
- 共享 B 因子:让 K 和 V 共享 token 维度基向量 $[b_r]$,可进一步减少参数量和缓存开销。
复杂度分析 (Appendix A)。详细对比了显式还原(Naive)与特殊路径(Specialized)的 FLOPs。推导得出当秩 $[R]$ 较小时,TPA 的计算复杂度优势随头数 $[H]$ 或特征维度 $[D]$ 的增加而扩大。在超大规模模型(如 1~2T MoE)中,TPA 在投影成本和注意力成本上均具有显著优势。
汇总方法细节中的引用
| 编号 | 参考文献名 | 发表年份 | 会议/期刊 | 引用段落及原文描述 |
|---|---|---|---|---|
| [60] | Attention is all you need | 2017 | NIPS | 背景章节:“Scaled dot-product attention [60] determines how to focus on different parts...” |
| [52] | Roformer: Enhanced transformer... | 2024 | Neurocomputing | 方法章节:“A key advantage of TPA is its native compatibility with rotary positional embeddings (RoPE) [52]...” |
| [15] | Understanding the difficulty of training... | 2010 | AISTATS | 方法章节:“We use Xavier initialization [15] for the factor weight matrices...” |
| [58] | Llama: Open and efficient... | 2023 | arXiv | 架构章节:“T6 utilizes the overall architecture of LLaMA [58] while changing the self-attention block...” |
| [47] | Glu variants improve transformer | 2020 | arXiv | 架构章节:“The feed-forward network (FFN) adopts a SwiGLU layer, as in [47, 58].” |
| [32] | Deepseek-v2: A strong, economical... | 2024 | arXiv | 引言/背景:“Multi-Head Latent Attention (MLA) in Deepseek-V2 [32] caches compressed KV representations...” |
| [57] | Triton: An intermediate language... | 2019 | ACM | 解码算法:“Our current FlashTPA implementation utilizes Triton [57].” |
| [24] | NanoGPT | 2022 | GitHub | 实验章节:“All experiments... are implemented based on the nanoGPT codebase [24].” |
💬 评论讨论
欢迎在这里分享您的想法和见解!