Tensor Product Attention Is All You Need

作者/机构:Yifan Zhang (清华大学/普林斯顿大学), Yifeng Liu (UCLA), Huizhuo Yuan (UCLA), Zhen Qin, Yang Yuan (清华大学/上海期智研究院), Quanquan Gu (UCLA), Andrew Chi-Chih Yao (清华大学/上海期智研究院)

A1 主要贡献

随着大语言模型(LLM)处理上下文长度的增加,推理过程中的键值(KV)缓存成为严重的内存瓶颈。本文提出了张量积注意力(Tensor Product Attention, TPA),这是一种新型注意力机制,其核心创新点和贡献如下:

  1. 提出 TPA 机制:通过上下文相关的张量分解(Tensor Decompositions)来压缩查询(Q)、键(K)和值(V)的激活值。与静态权重分解(如 LoRA)不同,TPA 动态分解激活值,构建低秩且具上下文信息的表示。
  2. KV Cache 显著缩减:TPA 在推理阶段大幅减少了 KV Cache 的大小,通常比标准多头注意力(MHA)降低一个数量级,同时在预训练验证损失(困惑度)和下游任务表现上持平或更优。
  3. 高度兼容性:TPA 原生支持旋转位置嵌入(RoPE)及任何位置编码,也支持输出门控和 KV 移位。它可以作为 LLaMA、Qwen 和 Gemma 等现代 LLM 架构中 MHA 层的掉入式(Drop-in)替换方案。
  4. 提出 T6 模型架构:基于 TPA 构建了 Tensor ProducT ATTenTion Transformer (T6) 序列建模架构。
  5. 开发 FlashTPA Decoding 算法:这是一种针对 TPA 优化的自回归推理算法,利用因子化组件直接进行爱因斯坦求和(einsum)计算,避免显式还原全量张量,在长序列推理速度上超越了优化的 MHA、MQA、GQA 和 MLA。

Fig 1 Tensor ProducT ATTenTion Transformer (T6) 中的张量积注意力 (TPA)
Fig 1 展示了在 T6 每一层中,输入隐藏状态 $[x_t]$ 如何通过线性层产生 Q、K、V 的潜在因子矩阵,应用 RoPE 后通过张量积形成各头张量,最后进行缩放点积注意力计算。


A3 背景知识/关键观察/设计原则

符号系统与张量定义。本文使用粗体大写字母表示矩阵,粗体小写表示向量。令 $[d_{model}]$ 为嵌入维度,$[h]$ 为头数,$[d_h]$ 为每头维度。$[x_t \in \mathbb{R}^d]$ 表示第 $[t]$ 个 token 的输入。$[Q, K, V \in \mathbb{R}^{T \times h \times d_h}]$ 分别表示查询、键、值张量。向量 $[a, b]$ 的张量积(外积)定义为 $[a \otimes b = C]$,其中 $[C_{ij} = a_i b_j]$。矩阵 $[C]$ 的向量化记为 $[vec(C)]$。

缩放点积注意力与多头注意力 (MHA)。缩放点积注意力通过比较 $[Q]$ 和 $[K]$ 来确定关注重点,并对 $[V]$ 进行加权组合:

$$\operatorname{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=\operatorname{Softmax}\left(\frac{\mathbf{QK}^{\top}}{\sqrt{d_{k}}}\right) \mathbf{V},$$


MHA(【60, Attention is all you need+2017+NIPS+https://arxiv.org/abs/1706.03762】)通过将模型表示划分为多个头来扩展此机制。对于每 个 token 嵌入 $[x_t]$,计算每个头如下:

$$\mathbf{Q}_{t,i} = (\boldsymbol{W}_i^Q)^\top \mathbf{x}_t \in \mathbb{R}^{d_h}, \mathbf{K}_{t,i} = (\boldsymbol{W}_i^K)^\top \mathbf{x}_t \in \mathbb{R}^{d_h}, \mathbf{V}_{t,i} = (\boldsymbol{W}_i^V)^\top \mathbf{x}_t \in \mathbb{R}^{d_h},$$
$$\mathbf{head}_i = \text{Attention}(\mathbf{Q}_i, \mathbf{K}_i, \mathbf{V}_i),$$
其中 $[W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d_{model} \times d_h}]$ 是可学习的投影矩阵。最后连接所有头并映射回原始维度:
$$\operatorname{MHA}(\mathbf{X})=\operatorname{Concat}\left(\mathbf{head}_{1}, \ldots, \mathbf{head}_{h}\right) \boldsymbol{W}^{O} .$$


A2 方法细节

查询、键和值的张量因子化。TPA 不再通过单一线性映射形成头张量,而是将每个 $[Q_t, K_t, V_t]$ 因子化为秩分别为 $[R_q, R_k, R_v]$ 的(上下文相关)张量积之和。具体而言,对于每个 token $[t]$:

$$\mathbf{Q}_{t}=\frac{1}{R_{Q}} \sum_{r=1}^{R_{Q}} \mathbf{a}_{r}^{Q}\left(\mathbf{x}_{t}\right) \otimes \mathbf{b}_{r}^{Q}\left(\mathbf{x}_{t}\right), \quad \mathbf{K}_{t}=\frac{1}{R_{K}} \sum_{r=1}^{R_{K}} \mathbf{a}_{r}^{K}\left(\mathbf{x}_{t}\right) \otimes \mathbf{b}_{r}^{K}\left(\mathbf{x}_{t}\right),$$

$$\mathbf{V}_{t}=\frac{1}{R_{V}} \sum_{r=1}^{R_{V}} \mathbf{a}_{r}^{V}\left(\mathbf{x}_{t}\right) \otimes \mathbf{b}_{r}^{V}\left(\mathbf{x}_{t}\right),$$
其中 $[a_r^Q(x_t), a_r^K(x_t), a_r^V(x_t) \in \mathbb{R}^h]$,$[b_r^Q(x_t), b_r^K(x_t), b_r^V(x_t) \in \mathbb{R}^{d_h}]$。这意味着查询、键、值在头维度(Head dimension)和特征维度(Feature dimension)上被解耦分解。

潜在因子映射与秩索引合并。张量积中的每个因子都取决于 token 的隐藏状态 $[x_t]$。以查询为例:

$$\mathbf{a}_r^Q(\mathbf{x}_t) = \boldsymbol{W}_r^{a^Q} \mathbf{x}_t \in \mathbb{R}^h, \quad \mathbf{b}_r^Q(\mathbf{x}_t) = \boldsymbol{W}_r^{b^Q} \mathbf{x}_t \in \mathbb{R}^{d_h},$$


其中 $[W_r^{aQ} \in \mathbb{R}^{h \times d_{model}}]$ 和 $[W_r^{bQ} \in \mathbb{R}^{d_h \times d_{model}}]$ 是可学习的权重。为了实现效率,秩索引常被合并到单一输出维度:

$$\mathbf{a}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{a^{Q}} \mathbf{x}_{t} \in \mathbb{R}^{R_{q} \cdot h}, \mathbf{b}^{Q}(\mathbf{x}_{t})=\boldsymbol{W}^{b^{Q}} \mathbf{x}_{t} \in \mathbb{R}^{R_{q} \cdot d_{h}}$$
随后重塑为 $[A_Q(x_t) \in \mathbb{R}^{R_Q \times h}]$ 和 $[B_Q(x_t) \in \mathbb{R}^{R_Q \times d_h}]$。则第 $[t]$ 个 token 的查询张量可表示为:
$$\mathbf{Q}_{t}=\frac{1}{R_{Q}} \mathbf{A}_{Q}\left(\mathbf{x}_{t}\right)^{\top} \mathbf{B}_{Q}\left(\mathbf{x}_{t}\right) \in \mathbb{R}^{h \times d_{h}}.$$
该操作等同于 $[Q_t = \frac{1}{R_Q} \sum_{r=1}^{R_Q} a_r^Q(x_t) (b_r^Q(x_t))^\top]$。

多头注意力计算与参数初始化。一旦完成因子化,多头注意力将按照标准 Transformer 进行。对于每个头 $[i]$:

$$\mathbf{head}_i = \text{Softmax} \left( \frac{1}{\sqrt{d_h}} \mathbf{Q}_i (\mathbf{K}_i)^\top \right) \mathbf{V}_i,$$


连接所有头后,通过输出权重 $[W^O]$ 投影回原始嵌入维度:

$$\text{TPA}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\mathbf{head}_1, \dots, \mathbf{head}_h) \boldsymbol{W}^O.$$
因子权重矩阵采用 Xavier 初始化(【15, Understanding the difficulty of training deep feedforward neural networks+2010+AISTATS】)。

RoPE 兼容性与加速优化。在标准流程中,RoPE 应用于重构后的张量。TPA 提供了一种优化策略:直接将 RoPE 集成到因子化中。
1. 因子预旋转:可以预先旋转 token 维度的因子,例如:

$$\widetilde{\mathbf{B}}_{K}(\mathbf{x}_{t}) \leftarrow \mathrm{RoPE}_{t}(\mathbf{B}_{K}(\mathbf{x}_{t})),$$


2. 预旋转表示:得到预旋转后的键表示:

$$\widetilde{\mathbf{K}}_t = \frac{1}{R_K} \sum_{r=1}^{R_K} \mathbf{a}_{(r)}^K(\mathbf{x}_t) \otimes \mathrm{RoPE}_t(\mathbf{b}_{(r)}^K(\mathbf{x}_t)) = \frac{1}{R_K} \mathbf{A}_K(\mathbf{x}_t)^\top \mathrm{RoPE}_t(\mathbf{B}_K(\mathbf{x}_t)).$$
这样每个键 $[K_t]$ 在缓存前就已经完成了旋转,推理阶段无需再显式计算旋转,从而加速解码。Theorem 3.1 证明了 TPA 因子化与 RoPE 是数学兼容的,保持了相对平移属性(详见附录 D.1 证明)。

KV 缓存与内存减少分析。标准注意力需要缓存 $[K_t, V_t \in \mathbb{R}^{h \times d_h}]$,总开销为 $[2 T h d_h]$。
- TPA 因子化 KV 缓存:TPA 仅存储因子组件 $[A_K(x_t), \tilde{B}_K(x_t), A_V(x_t), B_V(x_t)]$。
- 开销计算:每个 token 的存储成本为 $[(R_K + R_V)(h + d_h)]$。
- 压缩率:与标准缓存相比,比例为 $[\frac{(R_K + R_V)(h + d_h)}{2 h d_h}]$。对于典型的 $[d_h = 64]$ 或 $[128]$,当秩 $[R_K, R_V]$ 设为 1 或 2 时,KV Cache 可获得量级上的缩减。Table 1 详细对比了各机制的参数量与 KV 缓存大小。

Table 1 不同注意力机制对比
Table 1 不同注意力机制对比

MHA, MQA, GQA 作为非上下文 TPA 的表达。作者指出 MHA、MQA 和 GQA 实际上是 TPA 的特殊非上下文变体,即通过对 TPA 因子施加约束:
- MHA:等同于秩 $[R=h]$ 且头维度因子为固定正交的标准基向量 $[e_i]$ 的 TPA。
- MQA:等同于键值秩 $[R_K, R_V = 1]$ 且头维度因子为全 1 向量 $[1_h]$ 的 TPA。
- GQA:等同于秩为组数 $[G]$ 的 TPA,头维度因子为掩码向量。

T6 模型架构细节。T6 采用了 LLaMA(【58, Llama: Open and efficient foundation language models+2023+arXiv】)的整体架构,但将自注意力块替换为 TPA 块。FFN 部分采用 SwiGLU 层(【47, Glu variants improve transformer+2020+arXiv】):

$$\begin{aligned} \begin{aligned} \mathbf{x} & \leftarrow \mathbf{x}+\operatorname{TPA}(\operatorname{RMSNorm}(\mathbf{x})), \\ \mathbf{x} & \leftarrow \mathbf{x}+\operatorname{SwiGLU-FFN}(\operatorname{RMSNorm}(\mathbf{x})). \end{aligned} \end{aligned}$$


在每个子层前应用层归一化(如 RMSNorm)。


A5 FlashTPA 解码算法

算法核心逻辑。为了高效生成 token,FlashTPA Decoding 利用因子化表示,通过一系列爱因斯坦求和(einsum)计算注意力输出,完全避免了完整 $Q, K, V$ 张量的实体化。
- 数据流过程:如图 2 所示。先计算头共享的特征空间点积 $[P]$,将其与头特定的 $[A_Q, A_K]$ 混合获得 logits,应用 Softmax,最后通过 $[A_V, B_V]$ 聚合值。
- 内存与带宽优势:随着序列长度 $[M]$ 增长,KV Cache 的因子化存储大幅降低了内存带宽压力。
- 实现细节:提供了基于 Triton 的内核实现(Algorithm 3),该实现融合了 Softmax 和值聚合步骤,具有良好的数值稳定性(在线 log-sum-exp)。

Fig 2 FlashTPA Decoding 数据流图
Fig 2 展示了单查询 (N=1) 与 M 个缓存项交互的过程,矩形代表张量,圆圈代表 einsum 缩并或元素级乘法。


A4 实验环境

  • 数据集:使用 FineWeb-Edu 100B 数据集进行预训练(1000 亿 token 训练,1 亿 token 验证)。
  • 模型规模:训练了四种规模:Small (124M)、Medium (353M)、Large (773M) 和 XL (1.5B)。所有模型均在 50B token 上训练 1 个 epoch。
  • 基准对比:对比了标准 LLaMA (MHA) 及其变体 MQA、GQA、MLA。为公平起见,调整头数以确保各注意力机制的参数量一致。
  • 软件配置:基于 nanoGPT 代码库实现。使用 AdamW 优化器,学习率采用余弦退火策略,设置了 2000 步预热。
  • 硬件配置
    • Small: 4x A100 GPU。
    • Medium/Large/XL: 8x A100 GPU。
    • 具体的 batch size 和梯度累积步数见附录 Table 9。

A4 实验结果

语言建模性能

  • 收敛速度与损失:在所有规模(Medium, Large, XL)中,TPA(红线)和 TPA-KVonly(粉线)的收敛速度均优于或持平于基准。如图 4 所示,TPA 在训练各阶段的验证损失均明显低于 MHA 和 MLA。
  • 验证困惑度:TPA-based 方法在预训练结束时达到了最低的困惑度(附录 Fig 9)。
    Fig 4 验证损失曲线对比

下游任务评测
- 在 ARC, BoolQ, HellaSwag, MMLU 等主流基准上测试了 0-shot 和 2-shot 表现。
- Medium (353M):TPA 平均准确率为 51.41%,高于 MHA (50.11%) 和 MLA (50.13%)。
- Large (773M):TPA-KVonly 获得最高平均得分 53.52%。
- XL (1.5B):TPA-KVonly 继续保持领先(55.03%)。
- 详细结果见 Table 2 和 Table 3。

Table 2 Medium 模型下游任务结果
Table 2 Medium 模型下游任务结果

FlashTPA 解码速度
- 可扩展性:在长序列生成上表现卓越。随着序列长度增加,FlashTPA 的解码时间增长速度慢于 MQA/GQA。
- 性能对比:在序列长度超过 $[2^{14} \sim 2^{15}]$ 时,基于 Triton 的 FlashTPA 速度超越了 MLA,并显著快于 MHA(Fig 5)。
Fig 5 不同 batch size 下的解码时间对比


A5 结论

本文提出的 Tensor Product Attention (TPA) 通过上下文相关的低秩分解,有效地解决了 LLM 推理中的 KV Cache 膨胀问题。实验证明 TPA 在大幅压缩存储的同时,不仅没有损失模型性能,反而提升了语言建模的质量。其与 RoPE 的天然兼容性以及高效的 FlashTPA 解码算法,使其成为长文本建模极具竞争力的方案。
未来工作:作者指出目前的实现基于 Triton,未来开发高度优化的 CUDA 内核有望进一步挖掘其性能。此外,TPA 在其他模态(如图像、多模态)的泛化能力也值得深入研究。


A6 附录/补充细节

高阶张量积注意力 (Higher-Order TPA)。附录 C 探讨了将特征向量 $[b_r(x_t)]$ 进一步分解为 $[b_r(x_t) \otimes c_r(x_t)]$ 的可能性(3 阶分解)。这可以增强模型的表达能力,通过额外的上下文调制因子来调整特征。作者证明了高阶 TPA 同样兼容 RoPE(Theorem C.1),并在 Small 模型上初步验证了其可行性。

非上下文因子变体
- Non-contextual B:固定 token 维度因子,仅让头维度因子 $[a(x_t)]$ 随输入变化。实验显示其表现略逊于标准 TPA。
- TPA-KVonly:仅对 K 和 V 进行因子化,Q 使用标准投影。这是一种在压缩率和复杂度之间的平衡方案。
- 共享 B 因子:让 K 和 V 共享 token 维度基向量 $[b_r]$,可进一步减少参数量和缓存开销。

复杂度分析 (Appendix A)。详细对比了显式还原(Naive)与特殊路径(Specialized)的 FLOPs。推导得出当秩 $[R]$ 较小时,TPA 的计算复杂度优势随头数 $[H]$ 或特征维度 $[D]$ 的增加而扩大。在超大规模模型(如 1~2T MoE)中,TPA 在投影成本和注意力成本上均具有显著优势。


汇总方法细节中的引用

编号 参考文献名 发表年份 会议/期刊 引用段落及原文描述
[60] Attention is all you need 2017 NIPS 背景章节:“Scaled dot-product attention [60] determines how to focus on different parts...”
[52] Roformer: Enhanced transformer... 2024 Neurocomputing 方法章节:“A key advantage of TPA is its native compatibility with rotary positional embeddings (RoPE) [52]...”
[15] Understanding the difficulty of training... 2010 AISTATS 方法章节:“We use Xavier initialization [15] for the factor weight matrices...”
[58] Llama: Open and efficient... 2023 arXiv 架构章节:“T6 utilizes the overall architecture of LLaMA [58] while changing the self-attention block...”
[47] Glu variants improve transformer 2020 arXiv 架构章节:“The feed-forward network (FFN) adopts a SwiGLU layer, as in [47, 58].”
[32] Deepseek-v2: A strong, economical... 2024 arXiv 引言/背景:“Multi-Head Latent Attention (MLA) in Deepseek-V2 [32] caches compressed KV representations...”
[57] Triton: An intermediate language... 2019 ACM 解码算法:“Our current FlashTPA implementation utilizes Triton [57].”
[24] NanoGPT 2022 GitHub 实验章节:“All experiments... are implemented based on the nanoGPT codebase [24].”