Fast Attention Mechanisms: A Tale of Parallelism

作者/机构: Jingwen Liu, Hantao Yu, Clayton Sanford, Alexandr Andoni, Daniel Hsu
* 哥伦比亚大学计算机科学系,美国纽约
* 谷歌研究院,美国旧金山

A1 主要贡献

核心问题:尽管Transformer模型因其强大的并行处理能力,在理论上能够模拟大规模并行计算(MPC)算法,但它们面临着一个关键瓶颈:注意力机制的二次方时间复杂度。这一问题严重限制了模型处理长序列的可扩展性。为解决此问题,学术界提出了多种计算效率更高的替代方案,如低秩近似和高效的最近邻搜索,但这些新方法是否能保留标准注意力机制的表征优势尚不明确。此外,先前研究建立的Transformer与MPC模型之间的理论关系较为粗糙,特别是在模拟Transformer时,MPC算法需要多达 N² 台机器,这暗示了Transformer可能比MPC更强大,两者之间存在显著差距。

研究目标与创新点:本文旨在提出一种新的高效注意力机制,它不仅具有亚二次方的时间复杂度,还能更紧密地捕捉MPC模型的计算能力,并保留标准注意力的关键表征优势。

本文的主要贡献如下:

  1. 提出近似最近邻注意力(ANNA)机制:本文引入了一种名为ANNA的高效注意力机制,其核心是利用近似最近邻搜索技术(如局部敏感哈希LSH)来替代标准的全局注意力计算,从而实现亚二次方的时间复杂度。

  2. 建立ANNA-Transformer与MPC模型的紧密等价关系(定理1.1):本文证明了ANNA-Transformer与MPC模型之间存在比标准Transformer更强的等价关系,显著缩小了先前研究中存在的理论差距。具体来说:

    • 任何R轮的MPC算法都可以被一个O(R)层、宽度为O(N^(ε+δ))的ANNA-Transformer模拟。
    • 任何L层的ANNA-Transformer都可以被一个O(L)轮的MPC算法模拟,而该MPC算法仅需N^(1-δ+O(1/c²))台机器,这是一个远小于N²的亚二次方(甚至接近线性)数量。这一发现回答了引言中提出的问题,即是否存在一种注意力机制能更紧密地匹配MPC的计算能力。
  3. 统一低秩与最近邻高效注意力模型(定理1.2):本文证明了ANNA-Transformer的表达能力足以模拟基于低秩近似的Transformer。具体地,任何L层的低秩注意力Transformer都可以被一个O(L)层的ANNA-Transformer(具有可比宽度)模拟。这为理解和统一两类主流的高效注意力方法提供了一个理论框架,表明ANNA至少与低秩方法同样强大。

  4. 在关键推理任务上验证ANNA的有效性:本文通过理论构建和实验验证了ANNA-Transformer在两个关键推理任务上的能力:

    • Match2:证明了单层ANNA-Transformer即可解决此任务(定理5.2)。
    • k-hop归纳头(k-hop induction heads):证明了O(log k)深度的ANNA-Transformer即可解决此任务(定理5.6),这几乎达到了标准Transformer所能实现的最佳效率,并优于需要Ω(k)深度的低秩注意力和循环网络。
    • 实验结果:通过将训练好的softmax注意力模型蒸馏到ANNA-Transformer中,实验表明ANNA-Transformer能够以极低的错误率解决Match2和1-hop归纳头任务,证明了其理论上的表征能力在实践中是可行的。

A3 背景知识

2.1 标准注意力与Transformer

标准注意力头的定义。首先定义(标准)注意力机制和Transformer。一个(标准)注意力头AttnQ,K,V由查询(query)、键(key)、值(value)的嵌入函数 Q, K, V : R^d → R^m 指定。对于输入 X ∈ R^(N×d),它的计算如下:

$$\operatorname{Attn}_{Q, K, V}(X):=\operatorname{softmax}\left(Q(X) K(X)^{\top}\right) V(X) \in \mathbb{R}^{N \times m}$$

其中Q、K、V和softmax是逐行应用的。我们称N为上下文长度,m为嵌入维度。Q(X)、K(X)、V(X)的行分别是查询、键和值。

注意力头单行输出的计算方式。如果 $q_i = Q(X)_i$, $k_j = K(X)_j$, 且 $v_j = V(X)_j$,那么 AttnQ,K,V(X) 的第i行计算如下:

$$\operatorname{Attn}_{Q, K, V}(X)_i = \sum_{j=1}^N w_{i, j} v_j, \quad where \quad w_{i, j} = \frac{\exp(\langle q_i, k_j \rangle)}{\sum_{j'=1}^N \exp(\langle q_i, k_{j'} \rangle)}.$$

多头注意力层的定义。一个H头的注意力层 f : R^(N×d) → R^(N×d) 由H个注意力头 $(Attn_{Q_h,K_h,V_h}){h=1}^H$ 和H个 m × d 的矩阵 $(W_h)(X)W_h$。}^H$ 组成;它的计算方式为 $f(X) := \sum_{h=1}^H Attn_{Q_h,K_h,V_h

Transformer的定义。一个L层的Transformer T由L个注意力层 $f_1, ..., f_L : R^{N×d} → R^{N×d}$ 和一个输出函数 $ψ : R^d → R^d$ 指定。给定输入 X ∈ R^(N×d),定义如下计算过程:

$$X^{(0)}:=X \text { and } X^{(\ell)}:=f_{\ell}(X^{(\ell-1)}) \text { for } \ell=1, \ldots, L.$$

T在输入X上的输出是 $ψ(X^{(L)})$,其中ψ是逐行应用的。

Transformer的计算瓶颈。在本文中,我们将上下文长度N(输入token的数量)视为主要的扩展参数。这反映了现代长上下文大语言模型的范式,其中上下文长度可以超过10^6(GeminiTeam et al., 2023【GeminiTeam, et al. Gemini: a family of highly capable multimodal models. arXiv preprint arXiv:2312.11805, 2023.】),从而支持书籍长度的文本输入。因此,我们通常希望尺寸参数m、H和L相对于N是亚线性的(理想情况下L是常数)。因此,注意力的O(N²)运行时复杂度仍然是Transformer架构的主要瓶颈。

Transformer作为并行计算模型的视角。遵循(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)的做法,我们允许逐元素操作($Q_h, K_h, V_h, ψ$)是任意函数(仅受位精度限制;参见(Sanford et al., 2024c, 附录A.1))。因此,我们可以将Transformer视为一个计算模型,它在任意的逐token计算和token之间的通信之间交替进行。这启发了与大规模并行计算(Massively Parallel Computation)的联系,接下来将对此进行定义。

2.2 大规模并行计算

MPC模型概述。大规模并行计算(MPC)框架(Im et al., 2023【Im, Sungjin, et al. "Massively parallel computation: Algorithms and applications." Foundations and Trends® in Optimization 5.4 (2023): 340-417.】)通过一个分布式计算系统来模拟对大型输入的计算,该系统在本地计算轮次和受限的全局通信轮次之间交替进行。

MPC协议的正式定义。我们正式陈述MPC协议(具有亚线性内存)的定义如下。对于常数 $γ, ε > 0$,一个R轮的(γ, ε)-MPC协议规定了在 P = Θ(N^(1+γ-ε)) 台机器上对N个词(其中一个词是 p = Θ(log N) 位,由 Z_2p 的一个元素表示)的输入进行如下计算,每台机器的本地内存为 s = O(N^ε) 个词:

MPC协议的计算步骤

  1. 初始化:输入任意分布在前 $\lceil N/s \rceil$ 台机器上。
  2. 本地计算与消息准备:在每一轮中,每台机器根据其本地内存的任意函数,准备要发送给其他机器的消息。任何一台机器准备的消息总大小最多为s个词。
  3. 通信:在每一轮结束时,消息被放置在目标接收方的本地内存中。协议确保任何一台机器接收到的消息总大小最多为s个词。
  4. 输出:在第R轮之后,输出存储在前 $\lceil N/s \rceil$ 台机器的内存中。

MPC协议的功能定义。如果对于任何 $X \in Z_{2^p}^N$,协议 $\pi(X) = f(X)$,其中 $\pi(X)$ 是协议 $\pi$ 在输入X下的输出,我们就说一个MPC协议 $\pi$ 计算一个函数 $f : Z_{2^p}^N \rightarrow Z_{2^p}^N$。

MPC的复杂度衡量。本文考虑的主要复杂度度量是轮数R。许多经典算法问题的轮数复杂度已经得到了很好的理解。例如,对于图连通性有简单的MPC协议(O(log N)轮),对于排序也有(O(1)轮)(Ghaffari, 2019【Ghaffari, Mohsen. "Massively parallel algorithms." Lecture notes by Davin Choo, Computer Science, ETH Zurich, June 2019.】)。

A2 方法细节

3 近似最近邻注意力

本节介绍近似最近邻注意力(ANNA),这是一种受近似最近邻(ANN)搜索问题启发的注意力机制。我们首先概述近似最近邻搜索问题,并介绍局部敏感哈希(LSH),这是ANN的一种核心技术(3.1节)。然后,我们正式定义ANNA,并提供一个基于LSH的亚二次方时间算法来计算ANNA,并附有理论保证(3.2节)。

3.1 近似最近邻与局部敏感哈希

ANN搜索问题定义。我们首先定义近似最近邻(ANN)搜索问题。给定一个包含N个点的数据集D,这些点位于度量空间Y中,以及参数c, r > 0,目标是构建一个数据结构,对于任何距离D不超过r的查询点q ∈ Y,该数据结构能返回D中距离q不超过cr的任何一个点(Kushilevitz et al., 1998【Kushilevitz, Eyal, Rafail Ostrovsky, and Yuval Rabani. "Efficient search for approximate nearest neighbor in high dimensional spaces." In STOC, 1998.】; Indyk and Motwani, 1998【Indyk, Piotr, and Rajeev Motwani. "Approximate nearest neighbors: towards removing the curse of dimensionality." In Proceedings of the Thirtieth Annual ACM Symposium on Theory of Computing, STOC ’98, 1998.】)。

LSH作为ANN的解决方案。在现代机器学习背景下,我们希望为高维度量空间中的ANN搜索开发具有亚线性查询时间的快速算法。局部敏感哈希(LSH)是一个众所周知的工具,它能实现这样的运行时间并提供可证明的近似保证。为简单起见,我们假设度量空间是m维欧几里得空间。

局部敏感哈希(LSH)定义。固定参数 r > 0,近似因子 c > 1 和一个集合 U。如果对于任意 x, y ∈ R^m,一个哈希函数族 H(其中 h : R^m → U)满足以下条件,则称其为 (r, cr, p1, p2)-敏感的:
* 如果 $∥x - y∥ \le r$,那么 $Pr_{h \in H}[h(x) = h(y)] \ge p_1$。
* 如果 $∥x - y∥ > cr$,那么 $Pr_{h \in H}[h(x) = h(y)] \le p_2$。
该族H被称为LSH族,其质量为 $ρ = \frac{\log(1/p_1)}{\log(1/p_2)}$ (Indyk and Motwani, 1998【Indyk, Piotr, and Rajeev Motwani. "Approximate nearest neighbors: towards removing the curse of dimensionality." In Proceedings of the Thirtieth Annual ACM Symposium on Theory of Computing, STOC ’98, 1998.】)。

LSH算法的性能。一个典型的基于LSH的算法可以解决ANN搜索问题,其空间复杂度为O(N^(1+ρ)),查询时间为O(N^ρ),并且ρ可以小到1/c²(Andoni and Indyk, 2006【Andoni, Alexandr, and Piotr Indyk. "Near-optimal hashing algorithms for approximate nearest neighbor in high dimensions." In 2006 47th Annual IEEE Symposium on Foundations of Computer Science (FOCS’06), 2006.】)。

3.2 基于近似最近邻注意力的Transformer

ANNA模型族的概念。我们首先定义一个模型族,其中只有查询和键足够“邻近”的token才会相互关注。ANNA注意力单元将注意力查询向量视为近似最近邻搜索的查询,将键向量视为数据点。ANNA根据近似最近邻阈值来检索和加权值向量。以下定义形式化了这一模型族。

ANNA的正式定义。一个近似最近邻注意力(ANNA)机制ANNA_Q,K,V,其查询、键、值嵌入函数为 Q, K, V : R^d → R^m,并带有(非负)参数 r, c, ℓ, η,是一个(可能随机的)机制,它对输入 X ∈ R^(N×d) 执行以下计算:

$$\mathrm{ANNA}_{Q, K, V}(X)_i:=\sum_{j=1}^N w_{i, j} v_j \quad \text { for all } i \in[N],$$

对于一些非负权重 $w_{i,j} \ge 0$ 且 $\sum_j w_{i,j} = 1$,这些权重满足以下条件。

ANNA权重的约束条件。以至少 $1 - \eta$ 的概率,对于所有 $i \in [N]$,权重 $w_{i,j}$ 满足:
* $w_{i,j} > 0 \implies k_j \in N(q_i, cr)$
* $k_j \in N(q_i, r) \implies w_{i,j} \ge \frac{1}{(|N(q_i, cr)|-1)\ell+1}$

其中 $q_i := Q(X)i$, $k_i := K(X)_i$, $v_i := V(X)_i$,并且 $N(q, t) := {k \in {k_j}^N : |q - k| \le t}$。

ANNA层与ANNA-Transformer的定义。我们以与定义(标准)注意力层和Transformer(定理2.1和2.2)完全类似的方式来定义ANNA层和ANNA-Transformer。

ANNA参数的解释。参数r和c的语义与ANN搜索中相同。参数ℓ捕捉了“注意力权重”在查询的r-近邻之外的键上的分布程度。失败概率η允许随机化,这在像LSH这样的ANN搜索算法中是典型的。

ANNA的LSH实现。上述定义代表了所有ANNA单元必须满足的一组约束,而不是一个特定的算法实现。因此,多种模型都满足该定义,包括具有有界查询和键向量的softmax注意力,以及当且仅当 $q_i = k_j$ 时 $w_{i,j} > 0$ 的精确匹配注意力(EMA)。并非所有这类模型都有计算上高效的实现。为了识别亚二次方的ANNA模型,我们提出了一个基于LSH的ANNA实现,该实现为特定参数r, c, ℓ, η计算满足条件的权重向量 $w_{i,j}$。

LSH实现细节(算法1)。我们定义一个哈希函数族 $G = {g : p \in R^d \rightarrow (h_1(p), h_2(p), \dots, h_z(p)) \in U^z | h_i \in H, \forall i \in [z]}$,并从G中独立均匀随机地采样ℓ个哈希函数 $g_1, \dots, g_l$,从而得到ℓ个哈希表。每个哈希码对应哈希表中的一个哈希桶,每个哈希桶维护落入该桶的v的总和和k的计数。我们通过将所有的键、值对存储在哈希表中来预处理它们。对于每个 $(k_i, v_i), i \in [N]$,计算 $k_i$ 的哈希码 $g_1(k_i), g_2(k_i), \dots, g_ℓ(k_i)$,并分别更新对应于 $g_1(k_i), g_2(k_i), \dots, g_ℓ(k_i)$ 的桶的总和与计数。对于每个查询 $q_i, i \in [N]$,从 $g_1(q_i), g_2(q_i), \dots, g_ℓ(q_i)$ 中搜索并检索所有的值和计数。然后通过将ℓ个桶中的所有值相加,再除以计数的总和来计算平均值。具体细节见算法1。

输入: 输入 X ∈ R^(N×d)
输出: 每个查询的ANNA输出。
1: 对于所有 i ∈ [N],令 qi = Q(X)i, ki = K(X)i, vi = V(X)i。
2: for u = 1 to ℓ do ▷ 预处理阶段
3: 从H中独立同分布地采样z个哈希函数 hu,1, hu,2, . . . , hu,z。
4: 创建一个由哈希码索引的空哈希表 Tu。
5: for 每个键值对 (kj, vj) do
6: 计算哈希码 gu(kj) = (hu,1(kj), hu,2(kj), . . . , hu,z(kj))。
7: if Tu[gu(kj)] 为空, then Tu[gu(kj)] := (vj, 1) else Tu[gu(kj)] += (vj, 1)。
8: 初始化一个字典 attn ← {(q1, 0), (q2, 0), . . . , (qN, 0)}
9: for 每个查询 qi do ▷ 查询阶段
10: vsum ← 0; count ← 0
11: for u = 1 to ℓ do
12: 计算哈希码 gu(qi) = (hu,1(qi), . . . , hu,z(qi))
13: if Tu[gu(qi)] = (v, a) 不为空, then vsum += v and count += a
14: attn[qi] ← vsum/count
15: return attn

LSH算法的理论保证。固定 $c > 3$,LSH族 H 是 (r, cr, p1, p2)-敏感的,其质量 $ρ < 1/3$,设 $ℓ = Θ(N^{3ρ} \log N)$,以及 $z = Θ(\log_{1/p2} N)$。那么算法1(使用H, ℓ, z)实现了一个ANNA机制,其参数为 r, c, ℓ 且 $η = O(1/N^{1-3ρ})$。

算法1的性能分析。定理3.4的完整证明见附录A。算法1的总运行时间为 $O(mN^{1+3ρ} \log_{1/p2} N)$,假设从LSH族中采样和评估一个哈希函数需要O(m)时间,并且算法1的数值输入以 $p = Θ(\log N)$ 位的精度指定。使用的总空间为 $O(mN^{1+3ρ})$ 位。

内存优化。算法1的内存复杂度可以通过只存储一个哈希表,并让每个条目跟踪每个查询的值,进一步提高到 $O(\tilde{m}N)$ 位,同时保持相同的时间复杂度。详细实现见附录A中的算法2。

权重与距离的关系。权重 $w_{i,j}$ 取决于 $q_i$ 和 $k_j$ 之间的哈希碰撞次数,通常是距离 $\Delta := |q_i - k_j|$ 的函数。例如,如果我们使用(Andoni et al., 2015【Andoni, Alexandr, et al. "Practical and optimal lsh for angular distance." In Proceedings of the 29th International Conference on Neural Information Processing Systems, NIPS’15, 2015.】)中的随机超平面LSH,那么 $w_{i,j} \propto \exp(-\Delta^2 \log(m)/(4 - \Delta^2))$。

后续约定。在本文的其余部分,我们基于ANNA的构造指的是通过算法1及其适当选择的r和任意大的c来实现的高效实现。

4 高效Transformer与MPC

我们证明了ANNA-transformer与MPC在亚线性本地内存和亚二次方机器数量的机制下存在一个精确的等价关系(4.1节和4.2节)。我们还表明,ANNA涵盖了基于低秩的亚二次方注意力替代方案(4.3节)。

4.1 ANNA-Transformer可以模拟MPC

ANNA模拟MPC的定理。以下定理表明,任何具有亚线性本地内存的R轮MPC协议都可以由一个具有O(R)层、亚线性数量的头和嵌入维度的ANNA-transformer来模拟。定理4.1的完整证明在附录B中。定理4.1 (ANNA模拟MPC). 固定常数 $0 < ε < ε' < 1$。对于任何确定性的R轮 $(ε, ε)$-MPC协议 $\pi$,存在一个ANNA-transformer T,其层数 L = O(R),每层头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明思路。实际上,我们证明了ANNA-transformer的一个特例,即近似因子 $c \rightarrow \infty$ 且 r = 0 的情况,就已经足以模拟MPC。在这种情况下,对每个查询的ANNA等价于只寻找与查询完全匹配的键;我们称之为精确匹配注意力(Exact-Match Attention, EMA),并在附录B中正式定义。在我们的模拟中,我们将每个输入token视为一个本地机器,所有的本地计算都由逐元素函数Q, K, V处理。证明的核心在于使用EMA处理机器间的消息传递。根据(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】)的命题24,我们可以假设每台机器只向最多 $α = O(N^δ)$ 台机器发送消息,其中 $δ < ε$。我们为每台机器分配一个唯一的 positional encoding 或标识符,这个编码作为在每台机器中检索消息的唯一键。核心思想是为每台机器创建一个查询,为每个目标机器创建一个键,相关联的值是协议中发送到目标机器的消息的嵌入。由于每台机器最多可以向其他机器发送α条消息,我们创建α个EMA头,每个头负责所有N台机器的一条出站消息。每台机器通过在每个头中设置一个查询来检索发送给它们的消息。由于消息会被平均在一起,我们使用(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)中引理3.2的相同嵌入机制,以允许在逐元素操作中进行错误校正。□

模拟的意义。这为我们提供了一个从MPC到ANNA-transformer的亚二次方时间归约:即,通信过程可以在近线性时间内实现,而对于标准注意力来说是二次方的。此外,这将ANNA-transformer与现有的MPC层次结构联系起来(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】):任何可由R轮、O(N^ε)内存的MPC协议解决的问题,都可以由O(R)层、mH = O(N^(ε+δ))的ANNA-transformer解决,对于某个 δ > 0。例如,根据(Hajiaghayi et al., 2019【Hajiaghayi, MohammadTaghi, et al. "Mapreduce meets fine-grained complexity: Mapreduce algorithms for apsp, matrix multiplication, 3-sum, and beyond." arXiv preprint arXiv:1905.01748, 2019.】)的定理3.1,O(1)层的ANNA-transformer可以用 mH = O(N^(1/2+δ)) 解决3-SUM问题。

4.2 MPC可以模拟ANNA-Transformer

MPC模拟ANNA的定理。以下定理(在附录C中证明)表明,任何L层的ANNA-transformer(如算法1实现)都可以由一个O(L)轮的MPC协议模拟。由于算法1是随机的,它使用一个随机种子从LSH族中采样哈希函数。该模拟假设可以访问ANNA-transformer中所有层所需的随机种子。定理4.2 (MPC模拟ANNA). 固定常数 $0 < ε < ε' < 1$。对于任何L层的ANNA-transformer T(如算法1实现),其 mH = O(N^ε),存在一个 O(L/(ε' - ε)) 轮的MPC协议π,其本地内存 s = O(N^ε'),机器数量 P = O(N^(1+ε-ε'+3/c²)),使得对于所有输入 $input \in Z_{2^p}^N$,$\pi(input) = T(input)$。

结果分析。观察到,在模拟ANNA-transformer时使用的机器数量可以是强亚二次方的(当c很大时,实际上接近线性)。相比之下,来自(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)的标准transformer的模拟需要N²台机器。如前所述(1.2节),这表明ANNA-transformer比标准transformer更精确地刻画了高效的MPC协议。

对下界的启示。另一方面,根据定理4.2,MPC的轮数复杂度下界直接意味着ANNA-transformer的深度下界。这一论证在(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)中被用来为标准transformer在图连通性和k-hop归纳头等问题上建立(有条件的)深度下界;这些下界同样适用于ANNA-transformer。

4.3 ANNA-Transformer可以模拟低秩Transformer

ANNA与低秩注意力的比较。如第1节所述,有许多关于高效注意力替代方案的提议。在本节中,我们关注基于注意力矩阵低秩近似的亚二次方替代方案。具体来说,我们提出以下问题:哪些问题对ANNA来说本质上是容易的,但对低秩近似注意力来说是困难的,反之亦然?

低秩注意力的定义。低秩注意力由两个特征映射 Q', K' : R^d → R^r 指定,其中 r ≪ N(意图是用 Q'(X)K'(X)^T 来近似 softmax(Q(X)K(X)^T))。对于输入 X ∈ R^(N×d),它通过首先计算 K'(X)^TV(X) ∈ R^(r×m),然后左乘 Q'(X) 来计算 Q'(X)K'(X)^TV(X)。

ANNA与低秩注意力的能力分离。注意,(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)给出了任何低秩注意力在k-hop问题上的一个下界。稍后在第5节,我们给出了一个解决k-hop问题的O(log k)深度ANNA-transformer的构造,这直接提供了一种对ANNA容易但对低秩注意力困难的问题类型。然而,是否存在任何对低秩注意力容易但对ANNA困难的问题?

ANNA模拟低秩注意力的定理。以下定理通过证明任何L层的低秩注意力transformer都可以由一个O(L)层的ANNA-transformer模拟来回答这个问题。因此,在时间和参数高效的机制下(亚线性的秩和嵌入维度),低秩注意力transformer并不比ANNA-transformer更强大。定理4.4 (ANNA模拟低秩注意力). 对于常数 $0 < ε < ε' < 1$,任何基于低秩注意力的transformer,其深度为L,秩为r,嵌入维度为m,且 rm = O(N^ε),都可以由一个ANNA-transformer模拟,其深度为O(L),头数为 H = O(N^((ε'-ε)/4)),嵌入维度为 m = O(N^ε')。

证明思路。我们通过首先使用O(L/(ε' - ε))轮的MPC来模拟L层的低秩transformer,然后定理4.1通过MPC为我们提供了用ANNA-transformer模拟L层低秩transformer的方法来证明定理4.4。完整的证明在附录D中给出。

与其他基于最近邻搜索的高效注意力机制的比较。Reformer(Kitaev et al., 2020【Kitaev, Nikita, Lukasz Kaiser, and Anselm Levskaya. "Reformer: The efficient transformer." In International Conference on Learning Representations, 2020.】)是另一种基于LSH的高效注意力。在Reformer中,输入token按其(标量)哈希值排序。然后,这个排序后的列表被分成大小相等的块,每个块只包含O(1)个token。标准注意力在每个块内应用。我们证明,Reformer的表达能力必须来自排序操作:没有排序,将注意力限制在每个常数大小的块内会阻止Reformer甚至用O(1)层计算像“平均值”这样的基本函数(无论嵌入维度如何);细节在附录E中给出。

与混合方法的比较。KDEformer(Zandieh et al., 2023【Zandieh, Amir, et al. "KDEformer: Accelerating transformers via kernel density estimation." In Proceedings of the 40th International Conference on Machine Learning, 2023.】)和HyperAttention(Han et al., 2024【Han, Insu, et al. "Hyperattention: Long-context attention in near-linear time." In The Twelfth International Conference on Learning Representations, 2024.】)将softmax注意力矩阵近似为稀疏矩阵和低秩矩阵的和。它们使用LSH技术来寻找稀疏元素(即注意力矩阵中的重元素),并对剩余部分使用低秩注意力。定理4.4表明,这个低秩部分不会显著增加其表征能力。

5 ANNA-Transformer用于推理任务

在本节中,我们研究ANNA-transformer在两个具体的推理任务上的表现:Match2(Sanford et al., 2023【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Representational strengths and limitations of transformers." In Thirty-seventh Conference on Neural Information Processing Systems, 2023.】)和k-hop(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)。这些任务是评估transformer推理能力的基准,它们在表征强度方面区分了不同的神经架构。

5.1 ANNA-Transformer解决Match2任务

Match2任务介绍。Match2任务(Sanford et al., 2023【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Representational strengths and limitations of transformers." In Thirty-seventh Conference on Neural Information Processing Systems, 2023.】)衡量模型将成对元素相互关联的能力。我们表明,单个ANNA机制可以解决Match2。

Match2任务定义。给定一个输入序列 $X = (x_1, \dots, x_N) \in [M]^N$,其中 $M \le poly(N)$,Match2(X)的第i个输出是 $1{\exists j, x_i + x_j = 0 \mod M}$,对于所有 $i \in [N]$。

ANNA解决Match2的定理。对于任何 $N, M = N^{O(1)}$,存在一个ANNA-transformer T,它有一个层,一个注意力头,嵌入维度为1,使得对于所有 $X \in [M]^N$,$T(X) = Match2(X)$。

5.2 ANNA-Transformer解决k-hop任务

归纳头任务介绍。归纳头(又名关联回忆)任务(Elhage et al., 2021【Elhage, Nelson, et al. "A mathematical framework for transformer circuits." Transformer Circuits Thread, 2021.】)是一个通过完成最近的二元组来预测下一个token的推理任务。它被认为是LLM涌现“上下文学习”能力的一个重要机制。

归纳头任务定义。设Σ为一个有限字母表,w ∈ Σ^N。对于每个 i ∈ [N],定义

$$\sigma(w, i)=\max \{\{0\} \cup\{j \in \mathbb{N}: j \leq i, w_{j-1}=w_{i}\}\}.$$

归纳头任务是为每个 1 ≤ i ≤ N,计算 $w_{σ(w,i)}$ 的值。

归纳头任务示例。例如,设 Σ = {a, b, c} 且 w = aabcbabca。那么 $w_{σ(w,9)} = b$,因为第9个token是a,在位置9之前a的最后一次出现(在位置6)之后是b。

ANNA解决归纳头任务的定理。以下定理表明,我们的ANNA-transformer可以使用常数数量的层以及亚线性的嵌入维度和头数来解决归纳头问题。定理5.4:固定常数 $0 < ε < ε' < 1$。存在一个ANNA-transformer T,其层数 L = O(1),每层头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有 w ∈ Σ^N 和所有 i ∈ [N], $T(w)i = w$。

证明思路。我们通过为归纳头构建一个常数轮的MPC算法来证明定理5.4(在附录F.2中),然后应用定理4.1将其转换为一个ANNA-transformer。

k-hop任务介绍。归纳头任务被(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)推广为一个k步的变体,称为“k-hop”。

k-hop任务定义。设Σ为一个有限字母表,w ∈ Σ^N。设 $σ^k(w, i)$ 表示前述定义中 $σ(w, \cdot)$ 的k次复合。k-hop归纳头任务是为每个 1 ≤ i ≤ N 计算 $w_{σ^k(w,i)}$。

k-hop任务示例。使用相同的例子,其中 Σ = {a, b, c},w = aabcbabca 且 k = 2,我们有 $w_{σ(σ(w,9))} = a$,因为在位置7之前b的最后一次出现之后是a。

ANNA解决k-hop任务的扩展。正如在(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)中所做的,我们使用函数复合为k-hop构建了一个O(log k)轮的MPC算法,从而为此任务上的ANNA-transformer产生了对数深度的扩展。

ANNA解决k-hop任务的定理定理5.6:固定常数 $0 < ε < ε' < 1$,任何 k ∈ N 和字母表 Σ 且 |Σ| = O(N)。存在一个ANNA-transformer T,其层数 L = O(log k),每层头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有 w ∈ Σ^N 和所有 i ∈ [N], $T(w)i = w$。

与先前工作的关系。完整的构造在附录F.2中给出。我们注意到,尽管先前的工作已经给出了k-hop(Bietti et al., 2023【Bietti, Alberto, et al. "Birth of a transformer: A memory viewpoint." Advances in Neural Information Processing Systems 36 (2023): 1560-1588.】; Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)和Match2(Sanford et al., 2023【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Representational strengths and limitations of transformers." In Thirty-seventh Conference on Neural Information Processing Systems, 2023.】)的transformer构造,但鉴于架构的差异,这些结果并不直接意味着ANNA-transformer的构造。

ANNA在k-hop任务上的优势。先前的工作(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)表明,多层循环网络和低秩亚二次方注意力(Choromanski et al., 2021【Choromanski, Krzysztof Marcin, et al. "Rethinking attention with performers." In International Conference on Learning Representations, 2021.】; Kacham et al., 2024【Kacham, Praneeth, Vahab Mirrokni, and Peilin Zhong. "PolySketchFormer: Fast transformers via sketching polynomial kernels." In Proceedings of the 41st International Conference on Machine Learning, 2024.】)除非深度为Ω(k)或其内存大小/嵌入维度为Ω(N/k^6),否则无法解决k-hop问题。相比之下,ANNA-transformer同时实现了O(log k)的深度和亚线性(相对于N)的宽度。在这个意义上,k-hop任务将ANNA-transformer与其他这些高效的神经架构区分开来。

A4 实验环境与结果

实验环境

任务与数据集:
* Match2: 数据集按照(Kozachinskiy et al., 2025【Kozachinskiy, Alexander, et al. "Strassen attention: Unlocking compositional abilities in transformers based on a new lower bound method." 2025.】)的方式生成,上下文长度 N=32,模数 M=37。数据集包含10000个样本,采用平衡批次生成策略以确保标签分布均匀。
* 归纳头 (Induction Heads): 使用(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)的数据生成算法,设置跳数 k=1,上下文长度 N=100,字母表大小 |Σ|=4。采用在线训练方式。

模型架构:
* Match2: 单层ANNA-Transformer,包含1个注意力头,嵌入维度 m=64,以及一个宽度为 4m、使用GeLU激活函数的MLP。
* 归纳头: 两层ANNA-Transformer,每层1个注意力头,嵌入维度 m=128,以及一个宽度为 4m、使用GeLU激活函数的MLP。

硬件配置:
* GPU: NVIDIA Titan RTX 和 NVIDIA Titan Xp 各一块。

软件与训练配置:
* 训练策略: 由于基于LSH的ANNA(算法1)不可微,本文采用一种蒸馏策略:首先训练一个可微的softmax注意力模型作为代理,然后将学习到的参数用于ANNA-Transformer进行推理。该softmax模型对查询和键进行单位范数归一化,并使用可调的温度参数 β
* LSH类型: 采用(Andoni et al., 2015【Andoni, Alexandr, et al. "Practical and optimal lsh for angular distance." In Proceedings of the 29th International Conference on Neural Information Processing Systems, NIPS’15, 2015.】)提出的角距离LSH。
* 优化器与损失函数: 使用Adam优化器和交叉熵损失函数。
* 学习率: 0.01。
* 训练步数: Match2任务为20000步,归纳头任务为400000步。

实验结果

Match2任务实验:
* 实验内容: 在Match2数据集上评估单层ANNA-Transformer的性能。
* 实验结果: 在测试集上,当使用 ℓ=8 个哈希表和每个表 z=1 个哈希函数时,ANNA-Transformer能够达到零错误率。
* 分析结论: 实验证明,一个简单的ANNA-Transformer结构能够完美解决Match2任务,验证了其理论上的表达能力。图1a展示了不同哈希表数量()和哈希函数数量(z)下的错误率变化。

归纳头任务实验:
* 实验内容: 在1-hop归纳头任务上评估两层ANNA-Transformer的性能。
* 实验结果: 模型取得了显著优于随机猜测(错误率0.75)的性能。当第一层使用 ℓ=32 个哈希表、第二层使用 ℓ=4 个哈希表,并且每层 z=2 个哈希函数时,错误率约为0.2。通过增加哈希表的数量,错误率可进一步降低至0.1以下。
* 分析结论: ANNA-Transformer能够有效地学习并解决归纳头任务。图1b展示了在两层中不同哈希表数量组合下的最低错误率,表明增加哈希表的数量可以提高模型性能。

图1: 所有误差均为10次运行的平均值。(a) Match2上的错误率:x轴表示哈希表的数量ℓ,不同颜色对应每个哈希表的不同哈希函数数量z。(b) 归纳头上的错误率:行对应第一层中的哈希表数量,列对应第二层中的哈希表数量。报告的错误率是在z∈{1, 2, 3, 4}的选择中达到的最佳值。
图1: 所有误差均为10次运行的平均值。(a) Match2上的错误率:x轴表示哈希表的数量ℓ,不同颜色对应每个哈希表的不同哈希函数数量z。(b) 归纳头上的错误率:行对应第一层中的哈希表数量,列对应第二层中的哈希表数量。报告的错误率是在z∈{1, 2, 3, 4}的选择中达到的最佳值。

A5 结论

工作总结。本文提出了一类更高效的神经架构——ANNA-Transformer。研究证明,该架构不仅保留了由MPC框架所刻画的标准Transformer的表征能力,而且与MPC模型建立了更紧密的等价关系。此外,本文还证明了常数层的ANNA-Transformer可以模拟常数层的低秩Transformer,并能在接近最优的深度下解决诸如Match2和k-hop等推理任务。

未来工作展望。本文为未来的研究留下了几个有趣的方向。
1. 开发更优的训练方法:虽然本文所用的临时训练方法作为概念验证是有效的,但开发一种能够直接优化ANNA-Transformer(或其可微变体)性能的原则性训练方法是十分必要的,而不是依赖于代理模型。
2. 扩展到大规模真实世界基准:本文的实证验证仅限于小型合成数据集。将这些实验扩展到大规模、真实世界的基准测试是重要的下一步。

A6 附录

A 定理3.4的证明

定理3.4的重述。此处重述定理3.4,该定理为算法1提供了理论保证。定理A.1 (LSH算法对ANNA的保证;定理3.4). 固定 $c > 3$,LSH族 H 是 (r, cr, p1, p2)-敏感的,其质量 $ρ < 1/3$,设 $ℓ = Θ(N^{3ρ} \log N)$,以及 $z = Θ(\log_{1/p2} N)$。那么算法1(使用H, ℓ, z)实现了一个ANNA机制,其参数为 r, c, ℓ 且 $η = O(1/N^{1-3ρ})$。

证明。我们的算法适用于具有大近似因子的情况,即 c > 3。由于我们只关心距离查询点cr范围内的最近邻,我们需要限制距离大于cr的两点落入同一桶的概率。考虑族G,使得如果 $∥x - y∥ > cr$,则 $Pr_{g \in G}[g(x) = g(y)] \le 0.1/N^3$。那么对于每个桶,碰撞(即x, y落入同一桶且 $∥x - y∥ > cr$)的期望次数小于 $N \cdot Pr_{g \in G}[g(x) = g(y)] \le 0.1/N^2$。因此,通过马尔可夫不等式,对于每个桶,以大于 $1 - 0.1/N^2$ 的概率,桶内没有碰撞。然后,通过对所有非空桶(最多有N个)应用联合界,以大于 $1 - 0.1/N$ 的概率,一个哈希表中没有碰撞。根据(Indyk and Motwani, 1998【Indyk, Piotr, and Rajeev Motwani. "Approximate nearest neighbors: towards removing the curse of dimensionality." In Proceedings of the Thirtieth Annual ACM Symposium on Theory of Computing, STOC ’98, 1998.】),$z = O(\log_{1/p2} N)$,即每个哈希函数 $g \in G$ 由从LSH族H中采样的 $O(\log_{1/p2} N)$ 个哈希函数组成,这足以在 $∥x - y∥ > cr$ 时实现 $Pr_{g \in G}[g(x) = g(y)] \le 0.1/N^3$。另一方面,由于碰撞概率非常小,成功概率(当 $∥x - y∥ \le r$ 时),即 $p = Pr_{g \in G}[g(x) = g(y)] = N^{-3ρ}$(回想 $ρ = \frac{\log 1/p1}{\log 1/p2}$),也相当小。我们需要大量的哈希表。那么对于每个 $q_i$,它的r-最近邻k($k \in N(q_i, r)$)在所有ℓ个表中都与 $q_i$ 落入不同桶的概率上限为 $(1 - p)^ℓ$。通过对所有可能的最近邻和所有 $q_i$ 应用联合界,失败概率的界为 $N^2(1 - p)^ℓ$。假设我们希望失败概率小于某个 $δ > 0$,那么我们希望 $N^2(1 - p)^ℓ \le δ$。对两边取对数,并对足够小的x使用log(1 - x)的泰勒展开,我们发现 $ℓ = O(N^{3ρ}(\log N + \log 1/δ))$ 对于成功概率 $1 - δ$ 来说是足够的。因此,通过对所有ℓ个哈希表应用联合界,以 $1 - 0.1N^{1-3ρ}$ 的概率,所有哈希表中都没有碰撞,这意味着如果 $∥k_j - q_i∥ > cr$,则 $w_{i,j} = 0$。通过设置 $δ = 0.1N^{1-3ρ}$,我们得到 $ℓ = O(N^{3ρ} \log N)$。因此,总失败概率η的界为 $δ + 0.1N^{1-3ρ}$,即 $O(1/N^{1-3ρ})$。如果 $∥k_j - q_i∥ \le r$,根据上述保证,我们知道 $k_j$ 在ℓ个哈希桶中至少与 $q_i$ 碰撞一次。这意味着 $w_{i,j} \ge 1/count$,其中count是 $q_i$ 检索到的ℓ个哈希桶中所有碰撞的总数。在最坏的情况下,除了 $k_j$ 只碰撞一次外,所有 $k \in N(q_i)$ 都在所有ℓ个哈希表中与 $q_i$ 碰撞。因此,$count \le (|N(q_i)| - 1) \cdot ℓ$,这给了我们 $w_{i,j} \ge \frac{1}{(|N(q_i)|-1) \cdot ℓ+1}$。□

运行时间和内存使用。可以看出,对于每个查询,我们需要评估 $O(N^{3ρ} \log_{1/p2} N)$ 个哈希函数并计算m维向量的和,所以总运行时间是 $O(mN^{1+3ρ} \log_{1/p2} N)$。在预处理期间,我们需要存储 $N^{3ρ}$ 个哈希表和值的总和,每个表最多有N个桶,所以总内存是 $O(mN^{1+3ρ} \log N)$ 位。实际上,使用的空间可以进一步提高到 $O(mN \log N)$ 位。可以只存储1个大小为 $O(mN \log N)$ 的哈希表,而不是维护ℓ个哈希表,每个条目负责跟踪每个查询的值。对于每一轮哈希(总共ℓ轮),使用哈希函数对所有查询进行哈希,并为它们创建空桶。然后,对每个键进行哈希,如果该键哈希到现有的查询桶中,则将其值相加(并增加计数)。处理完键之后,每个查询从其对应的桶中累积值和计数。我们在算法2中给出了内存高效的实现。

算法2:使用LSH族H、ℓ个哈希表和z个哈希函数/表的线性内存ANNA实现
算法2:使用LSH族H、ℓ个哈希表和z个哈希函数/表的线性内存ANNA实现

B ANNA-Transformer可以模拟MPC

基于精确匹配注意力的模拟。我们使用ANNA的一个特例来模拟MPC,我们称之为精确匹配注意力(EMA)。在EMA中,我们要求键与查询完全相同,才会在注意力矩阵中被考虑。我们证明这个特例已经足以模拟MPC。

EMA定义。设 $X \in R^{N×d}$ 为输入嵌入,Q, K, V : $R^{N×d} \rightarrow R^{N×d}$ 为查询/键/值嵌入函数。对于任何查询q,令 $N(q) = {k_j \in K : k_j = q}$。对于每个查询 $q_i$,精确匹配注意力计算:

$$\begin{aligned} \mathrm{EMA}_{K, V}\left(q_{i}\right)=\left\{\begin{array}{cl} \frac{1}{\left|\mathcal{N}\left(q_{i}\right)\right|} \sum_{j \in \mathcal{N}\left(q_{i}\right)} v_{j} & \text { if } \mathcal{N}\left(q_{i}\right) \neq \emptyset \\ \mathbf{0} & \text { otherwise. } \end{array}\right. \end{aligned}$$

EMA层和EMA-transformer的定义类似。要看到EMA是ANNA的一个特例,只需在ANNA中设置r = 0, $c \rightarrow \infty$ 和 $w_{i,j} = \frac{1}{|N(q)|}$,它就变得与EMA完全相同。EMA也允许一个近线性的时间算法:首先对所有键进行排序(使用字典序),时间复杂度为O(dN log N),空间复杂度为O(dN);在查询时,对每个查询执行二分搜索,时间复杂度为O(d log N)。

大嵌入维度下的初步模拟。我们首先给出一个直接模拟R轮(ε, ε)-MPC的模拟,使用L = R + 1层但嵌入维度较大,以展示证明的核心思想。定理B.2 (EMA模拟MPC). 对于常数 $0 < ε < 1$,任何具有N台机器且每台机器本地内存 s = O(N^ε) 的确定性R轮MPC协议π,存在一个EMA-transformer T,其深度 L = R + 1,头数 H = O(N^ε),嵌入维度 m = O(N^(5ε) log N),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明。对于任何将输入映射到输出的R轮MPC协议π,我们为本地计算阶段和消息传输阶段定义中间步骤。我们将本地计算前的所有机器输入表示为 MachineIn1, MachineIn2, ..., MachineInR,并将确定性本地计算($Local_i^r$)后的信息表示为 MachineOut1, ..., MachineOutR,其中 $MachineOut_i^r = Local_i^r(MachineIn_i^r)$。在通信(消息传输)阶段,我们需要将消息从 MachineOut^r 路由到正确的机器,即 MachineIn^(r+1)。在我们的模拟中,EMA-transformer的每个输入token扮演MPC协议中一台机器的角色。我们通过架构中的逐元素函数Q(·), K(·), V(·)来模拟本地计算函数($Local_i^r$)。因此,模拟过程可以分为3个部分:
1. 初始化:输入到EMA-transformer的数据分布在N个token中,我们需要将其转移到前 $\lceil N/s \rceil$ 个token/机器以匹配MachineIn1。
2. 路由(消息传输):在每轮r的本地计算之后,我们需要将消息从MachineOut^r传达到MachineIn^(r+1)。
3. 最终输出:MPC输出分布在前 $\lceil N/s \rceil$ 个token/机器中,我们需要将它们分发回N个token。
以下3个引理为这3个部分构建了要素。

路由引理。我们首先证明MPC的消息传输部分可以由EMA-transformer模拟。回想一下,在r轮本地计算之后,每台机器i都有一组要发送给其他机器的消息,表示为 $MachineOut_i^r = {(Msg_{i \rightarrow dest}, dest) : dest \in sent_i}$,其中 $sent_i$ 是机器i将发送消息的机器索引集合,而 $Msg_{i \rightarrow dest}$ 是机器i发送给机器dest的消息。在消息通信阶段之后,每台机器i都有一组从其他机器接收到的消息,表示为 $MachineIn_i^{r+1} = {(Msg, Src) : (Msg, i) \in MachineOut_{Src}^r}$。由于每台机器只能发送/接收s个词,我们有 $\sum_{dest \in sent_i} |Msg| \le s$ 和 $\sum_{(Msg,i) \in MachineOut_{Src}^r} |Msg| \le s$ 对于所有机器i。我们称这个过程为MPC的路由过程。以下引理表明,MPC的每个路由轮都可以由一层EMA-transformer模拟。引理B.3 (路由). 对于任何具有q台机器且每台机器本地内存为s的R轮MPC协议π和任何r ∈ [R-1],存在一个EMA-transformer路由router,其头数 H = O(s),Q和K的嵌入维度 m = O(log q),V的嵌入维度 m = O(s^5 log q),它接收输入 X = MachineIn^r 并产生输出 router(X) = MachineIn^(r+1)。

证明。遵循(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)中的假设,我们将本地计算编码到transformer的逐元素操作Q(·), K(·), V(·)中。证明的主要部分将集中于使用EMA将MachineOut^r路由到MachineIn^(r+1)。我们为每台机器i分配一个唯一的 positional encoding 或标识符,表示为 $p_i$。这可以用O(log q)位来完成。该编码作为检索每台机器中消息的唯一键。核心思想是为每台机器i创建一个查询,为每个 $dest \in sent_i$ 创建一个键,相关的值是协议中发送给dest的消息 $Msg_{i \rightarrow dest}$。由于每台机器最多可以向其他机器发送s条消息,我们创建s个EMA头,每个头负责所有q台机器的一条消息。每台机器通过在每个头中设置一个查询来检索发送给它们的消息。因为每个查询只能关注完全匹配的键,所以每个不同的出站消息必须由不同的注意力头传递,但多个入站消息可能由同一个注意力头接收。具体来说,设 $Q^h, K^h, V^h$ 为每个头h ∈ [s]在机器本地计算后的查询、键、值嵌入。对所有h,设置 $q_i^h = p_i$,所以

$$\begin{aligned} Q^1=Q^2=\cdots=Q^s=\begin{pmatrix}p_1^\top\\p_2^\top\\\vdots\\p_q^\top\end{pmatrix}. \end{aligned}$$

设 $k_i^h = p_{dest_i^h}$ 其中 $dest_i^h \in sent_i = {dest_{i1}, dest_{i2}, \dots, dest_{is}}$,$dest_{ij}$ 是机器i发送的第j个词消息的目标机器索引。键矩阵构造如下:

$$\begin{aligned} K^{1}=\begin{pmatrix}p_{\text{dest}_{1}^{1}}^{\top} \\ p_{\text{dest}_{1}^{2}}^{\top} \\ \vdots \\ p_{\text{dest}_{1}^{\text{q}}}^{\top}\end{pmatrix}, \quad K^{2}=\begin{pmatrix}p_{\text{dest}_{2}^{1}}^{\top} \\ p_{\text{dest}_{2}^{2}}^{\top} \\ \dots \\ p_{\text{dest}_{2}^{\text{q}}}^{\top}\end{pmatrix}, \quad \dots, \quad K^{\text{s}}=\begin{pmatrix}p_{\text{dest}_{\text{s}}^{1}}^{\top} \\ p_{\text{dest}_{\text{s}}^{2}}^{\top} \\ \vdots \\ p_{\text{dest}_{\text{s}}^{\text{q}}}^{\top}\end{pmatrix}. \end{aligned}$$

设 $v_i^h$ 是 $(Msg_{i \rightarrow dest_i}, dest_i^h, i)$ 的某种嵌入,表示为 $v_i^h = emb_i^h(Msg_{i \rightarrow dest_i}, dest_i^h, i)$,对于某个稍后定义的 $emb_i^h$,且

$$\begin{aligned} V^1 = \begin{pmatrix} \text{emb}_1^1(\text{Msg}_{\text{dest}_1^1}^1, \text{dest}_1^1, 1) \\ \text{emb}_2^1(\text{Msg}_{\text{dest}_1^2}^2, \text{dest}_1^2, 2) \\ \vdots \\ \text{emb}_q^1(\text{Msg}_{\text{dest}_1^q}^q, \text{dest}_1^q, q) \end{pmatrix}, \dots, V^s = \begin{pmatrix} \text{emb}_1^s(\text{Msg}_{\text{dest}_s^1}^1, \text{dest}_s^1, 1) \\ \text{emb}_2^s(\text{Msg}_{\text{dest}_s^2}^2, \text{dest}_s^2, 2) \\ \vdots \\ \text{emb}_q^s(\text{Msg}_{\text{dest}_s^q}^q, \text{dest}_s^q, q) \end{pmatrix} \end{aligned}$$

通过这样构造Q, K, V,在我们的EMA中,每个查询将检索其键与查询完全匹配的消息的平均值。然而,以这种方式设置值矩阵,当有多个 $k_i^h \in K^h$ 等于同一个查询时,可能会损坏消息。为了解决这个问题,我们可以应用(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)中引理3.2的相同基于多哈希的编码,该编码将每条消息编码在由稀疏二进制矩阵生成的多个固定位置,并有一个额外的“有效性位”指示消息是否被损坏。我们在此重述其引理3.2的改编版本。引理B.4 ((Sanford et al., 2024c)的引理3.2;稀疏平均中的消息编码). 对于任何消息大小 $\Delta \in N$,消息计数界限 $α \in N$,存在一个编码函数 $ϕ$,它接收上面定义的 $(Msg_{i \rightarrow dest_i^h}, dest_i^h, i)$,其中其大小对所有 i ∈ [q] 和 h ∈ [α] 都以 $\Delta$ 为界,并将其编码为 $emb_i^h(Msg_{i \rightarrow dest_i^h}, dest_i^h, i) \in R^m$,其中 m = O(α^4Δ log q),以及一个解码器函数 $φ$,它接收上面定义的Q, K, V的EMA输出,并将其解码为 $(Msg_{i \rightarrow dest_i^h}, dest_i^h, i)$。设 $rcvd_i = {src_{i1}, src_{i2}, \dots, src_{is}}$,其中 $src_{ij}$ 是机器i接收消息的第j个源机器索引。因为 $|sent_i| \le s$ 和 $|rcvd_i| \le s$,在EMA的每个头中,对于每个查询,最多有s个值被检索和平均。因此,这里我们可以直接应用引理B.4,其中 α = Δ = s,这给了我们一个嵌入维度界限 m = O(s^5 log q)。□

初始化引理。然后我们证明,用一层EMA-transformer,我们可以正确地初始化MPC的设置,即将输入 Input = (input1, input2, ..., inputn) 转换为MachineIn1,即MPC计算第一轮之前的输入,也就是输入均匀分布在前 $\lceil n/s \rceil$ 台机器上。引理B.5 (初始化). 对于任何具有q台机器且每台机器本地内存为s的R轮MPC协议π和n个词的输入,存在一个EMA-transformer init,其头数 H = 1,Q, K的嵌入维度 m = O(log q),V的嵌入维度 m = O(s),它接收输入并输出 init(input) = MachineIn1。

证明。输入应分布在 $1 \le i \le \lceil n/s \rceil$ 的每台机器上,其中 $MachineIn_i^1 = {(input_{idx}, idx) : idx \in s(i - 1) + 1, \dots, \min{n, si}}$。令 $q_{in} = \lceil n/s \rceil$ 为存储初始输入的机器数。由于给予init的输入是n个token(这里我们将每个token视为一台机器),我们需要重新安排内存,使输入分布在前 $q_{in}$ 个token上。和之前一样,我们使用 positional encoding $p_i$ 作为每台机器的唯一标识符。我们为每个输入token创建一个键值对,键对应于 $input_{idx}$ 去往的机器的标识符,值为 $(input_{idx}, idx)$。同时,为每台机器 $i \in [q_{in}]$ 创建一个查询。对于每台机器 $i \in [q_{in}]$,定义查询嵌入 $q_i = p_i$,

$$\begin{aligned} Q = \begin{pmatrix} p_{1}^{\top} \\ p_{2}^{\top} \\ \vdots \\ p_{q_{in}}^{\top} \end{pmatrix} \end{aligned}$$

对于每个token $input_{idx}, idx \in [n]$,令 $dest_{idx} = \lceil idx/s \rceil$ 为存储该token的机器,定义键嵌入 $k_{idx} = p_{dest_{idx}}$,

$$\begin{aligned} K = \begin{pmatrix} p_{\text{dest}_1}^{\top} \\ p_{\text{dest}_2}^{\top} \\ \vdots \\ p_{\text{dest}_{\text{n}}}^{\top} \end{pmatrix} \end{aligned}$$

令 $i' = idx \mod s$。对于每个token $input_{idx}, idx \in [n]$,定义值嵌入 $v_{idx} \in R^{2s}$ 为在第 $2i' - 1, 2i'$ 个条目中为 $(input_{idx}, idx)$,在所有其他条目中为0,

$$\begin{aligned} V = \begin{pmatrix} \text{input}_1 & 1 & 0 & 0 & 0 & \dots & 0 & 0 \\ 0 & 0 & \text{input}_2 & 2 & 0 & \dots & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & \dots & \text{input}_s & s \\ \text{input}_{s+1} & s+1 & 0 & 0 & 0 & \dots & 0 & 0 \\ \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots & \vdots \end{pmatrix} \end{aligned}$$

通过这样设置值矩阵,我们可以避免消息被损坏。

最终输出引理。最后,我们证明用额外的一层EMA-transformer,我们可以将最后一轮的MachineInR映射到MPC协议的输出,其中输出存储在前 $\lceil n/s \rceil$ 台机器中。引理B.6 (最终输出). 对于任何具有q台机器且每台机器本地内存为s的R轮MPC协议π和n个词的输入,存在一个EMA-transformer out,其头数 H = 1,Q, K的嵌入维度 m = O(log q),V的嵌入维度 m = O(s),它接收MachineInR并输出 out(MachineInR):n = π(input) = output。

证明。首先,逐元素操作可以从MachineInR计算出MachineOutR。输出分布在 $1 \le i \le \lceil n/s \rceil = q_{out}$ 的每台机器上,机器i的内存为 $output_i = {(output_{idx}, idx) : idx \in s(i - 1) + 1, \dots, \min{n, si}}$。然后,我们只需要从所有 $q_{out}$ 台机器中检索输出token,并将它们分发回n个token。这一步是init的逆操作。我们为所有 $idx \in [n]$ 的每个token $output_{idx}$ 创建一个查询。令 $src_{idx} = \lceil idx/s \rceil$ 为token $output_{idx}$ 所在的机器。对于每个token $output_{idx}, idx \in [n]$,定义查询嵌入 $q_{idx} = p_{src_{idx}}$,

$$\begin{aligned} Q = \begin{pmatrix} p_{\mathrm{src}_1}^\top \\ p_{\mathrm{src}_2}^\top \\ \vdots \\ p_{\mathrm{src}_{\mathrm{n}}}^\top \end{pmatrix} \end{aligned}$$

对于每台机器 $i \in [q_{out}]$,创建一个键 $k_i = p_i$,

$$\begin{aligned} K = \begin{pmatrix} p_{1}^{\mathsf{T}} \\ p_{2}^{\mathsf{T}} \\ \vdots \\ p_{q_{out}}^{\mathsf{T}} \end{pmatrix} \end{aligned}$$

与每个键i关联的值是存储在每台机器i中的内存MsgOut。定义值嵌入 $v_i = MsgOut_i$,

$$\begin{aligned} V = \begin{pmatrix} \mathtt{MsgOut}_1 \\ \mathtt{MsgOut}_2 \\ \vdots \\ \mathtt{MsgOut}_{\mathtt{q_{out}}} \end{pmatrix} \end{aligned}$$

通过选择一个合适的逐元素函数 $φ$,out(MachineInR)i,1 = outputi。该定理通过堆叠这三个引理的元素得出。每个引理都为我们提供了最终EMA-transformer T的一层,其嵌入维度 m = O(N^(5ε) log N):

$$T = \text{out} \circ \text{route}_{\text{R}-1} \circ \cdots \circ \text{route}_{1} \circ \text{init}$$

对通用MPC的扩展。上述模拟同样适用于(γ, ε)-MPC,只需在输入中填充max(0, O(N^(1+γ-ε)) - N)个空的思维链token。

关于头数的备注。标准transformer可以用相同的嵌入维度但只有一个注意力头来模拟MPC(Sanford et al., 2024c,a)。这里的EMA需要O(N^ε)个头,我们把如何改进头数留给未来的工作。

ANNA模拟MPC的推论。由于EMA是ANNA在r = 0和 $c \rightarrow \infty$ 时的特例,用ANNA-transformer模拟MPC自然地从定理B.2得出。推论B.9 (ANNA模拟MPC). 对于常数 $0 < ε < 1$,任何具有N台机器且每台机器本地内存 s = O(N^ε) 的确定性R轮MPC协议π,存在一个ANNA-transformer T,其深度 L = R + 1,头数 H = O(N^ε),嵌入维度 m = O(N^(5ε) log N),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

改进嵌入维度的模拟。定理B.2只在亚线性本地内存机制下 s = O(N^(1/5-δ))(对于任何δ > 0)时给出了一个MPC模拟。然而,许多MPC协议算法需要 s = Ω(N^(1/2)) 的本地内存,例如用于3-SUM的MPC算法(Hajiaghayi et al., 2019【Hajiaghayi, MohammadTaghi, et al. "Mapreduce meets fine-grained complexity: Mapreduce algorithms for apsp, matrix multiplication, 3-sum, and beyond." arXiv preprint arXiv:1905.01748, 2019.】)和图算法(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】)。上述使用EMA-transformer的模拟并未产生亚线性的嵌入维度。(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】)进一步给出了一个用亚线性嵌入维度的transformer模拟具有亚线性本地内存的MPC,方法是用O(1)层transformer模拟一轮MPC协议,而不仅仅是一层。他们的改进也适用于此。定理B.10 (用改进的嵌入维度模拟MPC的EMA). 对于常数 $0 < ε < ε' < 1$,任何具有N台机器且每台机器本地内存 s = O(N^ε) 的确定性R-轮MPC协议π,存在一个深度 L = O(R) 的EMA-transformer,其头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明。证明依赖于通过一个受限版本的MPC协议来模拟任何MPC协议(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】),该协议限制了每台机器可以发送消息的机器数量。然后,使用修改版的定理B.2来模拟这个受限版本的MPC。

受限MPC的定义。对于常数 $γ, ε, ρ > 0$,一个(γ, ε, ρ)-MPC协议是一个(γ, ε)-MPC协议,但有一个额外的约束:在每一轮中,每台机器只能向/从 k = O(n^ρ) 台机器发送/接收消息,而它能发送和接收的消息总大小仍然是 s = O(N^ε)。我们称k为通信容量。

受限MPC模拟通用MPC。(Sanford et al., 2024a【Sanford, Clayton, et al. "Understanding transformer reasoning capabilities via graph algorithms." Advances in Neural Information Processing Systems 37 (2024a): 78320-78370.】)给出了一个用O(R)轮的(γ, ε, ρ)-MPC协议模拟R轮(γ, ε)-MPC协议的构造。我们在此重述他们的命题。引理B.12 ((Sanford et al., 2024a)的命题24;(γ, ε, ρ)-MPC模拟(γ, ε)-MPC). 对于常数 $γ, ε > 0$ 和 $ρ \in (0, ε/2)$,如果函数f可以由一个R轮的(γ, ε)-MPC计算,它也可以由一个 $O(\frac{R(1+γ)}{2ρ^2})$ 轮的(γ, ε, ρ)-MPC协议计算。

用EMA模拟受限MPC。因此,我们只需要用我们的EMA-transformer来模拟(γ, ε, ρ)-MPC协议。模拟遵循与定理B.2相同的配方,我们有初始化、消息传递和最终输出阶段。由于(γ, ε, ρ)-MPC的初始化和输出遵循与(γ, ε)-MPC相同的规则,我们只需要修改模拟的消息传递部分,这对应于路由引理B.3。引理B.13 (EMA模拟(γ, ε, ρ)-MPC). 对于常数 $0 < ρ < ε < 1$,任何具有N台机器且每台机器本地内存 s = O(N^ε) 和通信容量 k = O(N^ρ) 的确定性R轮MPC协议π,存在一个深度 L = R + 1 的EMA-transformer,其头数 H = O(N^ρ),嵌入维度 m = O(N^(ε+4ρ) log N),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明。对于初始化和最终输出部分,我们只使用引理B.5和引理B.6中构造的相同的initout。在路由部分(引理B.3),因为 $|sent_i| \le k$, $|rcvd_i| \le k$,我们只需要k个头,在EMA的每个头中,最多有k个键匹配每个查询,因此对于单个查询,最多有k个值被平均。因此,我们可以应用引理B.4,其中 α = k,Δ = s,得到嵌入维度 m = O(N^(ε+4ρ) log N)。这为我们提供了一个新的route'_r,其头数和嵌入维度在每轮r都减少了。同样地,我们堆叠单层EMA-transformer的3个构建块,得到一个(R + 1)层的EMA transformer。

$$T = \text{out} \circ \text{route}'_{\text{R}-1} \circ \cdots \circ \text{route}'_1 \circ \text{init}$$

这就完成了引理的构造。设 $ρ = \min(ε/2, (ε' - ε)/4)$。在这种设置下,γ = ε。通过引理B.12,我们可以用一个 $R' = O(\frac{R(1+ε)}{2\min(ε^2, (ε'-ε)^2)})$ 轮的(γ, ε, ρ)-MPC来模拟R轮的(γ, ε)-MPC。然后,通过引理B.13,我们可以用一个 $R' + 1$ 层的EMA transformer来模拟这个 $R'$ 轮的(γ, ε, ρ)-MPC,其头数为O(N^ρ),嵌入维度为 O(N^(ε+4ρ) log N) = O(N^ε')。□

ANNA模拟MPC的改进结果。再次,ANNA-transformer的改进模拟结果来自定理B.10。推论B.14 (ANNA用改进的嵌入维度模拟MPC). 对于常数 $0 < ε < ε' < 1$,任何具有N台机器且每台机器本地内存 s = O(N^ε) 的确定性R轮MPC协议π,存在一个深度 L = O(R) 的ANNA-transformer,其头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

C MPC可以模拟ANNA-Transformer

MPC模拟EMA(热身)。作为热身,我们首先用MPC模拟EMA-transformer,然后将其推广到模拟ANNA-transformer。定理C.1 (MPC模拟EMA). 固定常数 $0 < ε < ε' < 1$。对于任何L层EMA-transformer T,其 mH = O(N^ε),存在一个 $O(\frac{L}{ε'-ε})$ 轮的MPC协议π,其本地内存 s = O(N^ε'),机器数量 P = O(N^(1+ε-ε')),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明。我们首先展示如何使用MPC模拟一层EMA-transformer。在宏观层面,对于每个token $x_i$,我们有一个token机器i,负责计算 $x_i$ 的键、查询和值嵌入以及对 $x_i$ 的其他逐元素计算。证明的核心是为每个查询搜索完全匹配的键,并将与匹配键相关联的平均值发送给token机器。为了做到这一点,我们按键定义的顺序对所有的键值对 $(k_i, v_i)$ 进行排序。我们将排序后的键值对分成多个桶,使得每个桶包含相同的键。对于每个桶,我们有一个“元信息”机器来存储包含桶中键的机器的索引。然后,我们在每个桶内计算平均值,将平均值存储到“元信息”机器中,并将该值传播给所有与该键匹配的查询。首先,让X表示查询和键的空间,我们定义一个比较器<来排序。不失一般性,我们将其定义为字典序比较器。基于此比较器,我们定义[N]的一个查询排名排列 $σ = (σ_1, σ_2, \dots, σ_N)$ 和一个键排名排列 $σ' = (σ'_1, σ'_2, \dots, σ'_N)$,使得

$$\begin{aligned} \begin{aligned} q_{\sigma_{1}} &< q_{\sigma_{2}} < \cdots < q_{\sigma_{N}} \\ \text{and } k_{\sigma_{1}^{\prime}} &< k_{\sigma_{2}^{\prime}} < \cdots < k_{\sigma_{N}^{\prime}} \end{aligned} \end{aligned}$$

对于“元信息”机器,我们使用一个均匀哈希函数 h : X → [N] 将查询和键映射到它们对应的“元信息”机器。回想一下,对于一个均匀哈希函数h,对于任何 a, b ∈ X 且 a ≠ b,$P(h(a) = h(b)) = 1/N$。因此,

$$\begin{aligned} \begin{aligned} & \mathbb{P}\left(\exists i \text { such that the size of bucket } h\left(q_i\right) \geq s\right) \\ & \leq \mathbb{P}(\exists s \text { different elements fall into one bucket }) \\ & \leq\left(\begin{array}{c} N \\ s \end{array}\right) \frac{1}{N^s} \leq \frac{1}{s !}=\frac{1}{N^{\varepsilon^{\prime}} !} \end{aligned} \end{aligned}$$

以高概率,每个“元信息”机器最多负责s个键或查询。我们将机器分为不同类型,并在此总结每种类型机器的角色:
* 对于 i ∈ [N],机器i是 $x_i$ 的token机器。该机器执行token i的所有逐元素计算。具体来说,它计算查询、键、值嵌入 $q_i, k_i, v_i$ 以及注意力层之后的逐元素操作。
* 对于 $i \in [\lceil mN/s \rceil]$,机器 (i, Q) 是一个用于排序查询和存储排序后查询列表第i块的数据结构机器。换句话说,设 $n_q = \lfloor s/m \rfloor$ 是每台机器可以存储的查询数,在排序结束时,机器 (i, Q) 存储 ${q_{σ((i-1)·n_q+1)}, \dots, q_{σ(i·n_q)}}$。
* 对于 $i \in \lceil 2mN/s \rceil$,机器 (i, KV) 是一个用于排序键值对列表的数据结构机器。换句话说,设 $n_k = \lfloor s/2m \rfloor$ 是每台机器可以存储的键值对数,在排序结束时,机器 (i, KV) 存储 ${(k_{σ'((i-1)·n_q+1)}, v_{σ'((i-1)·n_q+1)}), \dots, (k_{σ'(i·n_q)}, v_{σ'(i·n_q)})}$。
* 对于 i ∈ [N],机器 (i, hq) 是哈希值为i的查询的“元信息”机器。设 $h_i^q = {q_j | j \in [N], h(q_j) = i}$。该机器存储 $q \in h_i^q$ 在排序列表中的位置信息。具体来说,对于所有 $q \in h_i^q$,该机器存储起始机器索引,即(start, Q)其中 start = arg minj{q ∈ 机器(j, Q)},和结束机器索引,即(end, Q)其中 end = arg maxj{q ∈ 机器(j, Q)}。
* 对于 i ∈ [N],机器 (i, hk) 是哈希值为i的键的“元信息”机器。设 $h_i^k = {k_j | j \in [N], h(k_j) = i}$。该机器存储 $k \in h_i^k$ 在排序列表中的位置信息。具体来说,对于所有 $k \in h_i^q$,该机器存储起始机器索引,即(start, KV)其中 start = arg minj{k ∈ 机器(j, KV)},和结束机器索引,即(end, Q)其中 end = arg maxj{k ∈ 机器(j, KV)}。
* 消息传播所需的辅助机器。

MPC协议流程。我们继续讨论计算单层单头EM-attention transformer输出的MPC协议。在第一轮中(与(Sanford et al., 2024c)的token分散阶段相同),将每个token $x_i$ 路由到其对应的token机器i。在第二轮中,每个token机器i计算查询、键、值嵌入 $q_i = Q(x_i), k_i = K(x_i), v_i = V(x_i)$ 并将 $(q_i, i)$ 发送到排序查询数据结构机器 $(\lceil mi/s \rceil, Q)$,将 $(k_i, v_i, i)$ 发送到排序键数据结构机器 $(\lceil 2mi/s \rceil, KV)$。然后,排序查询数据结构机器(所有i的(i, Q))对查询进行排序。MPC中的排序已得到很好的研究,这可以在常数轮内完成(Goodrich et al., 2011【Goodrich, Michael T., Nodari Sitchinava, and Qin Zhang. "Sorting, searching, and simulation in the mapreduce framework." In Proceedings of the 22nd International Conference on Algorithms and Computation, ISAAC’11, 2011.】)。引理C.2 (MPC排序). 存在一个本地内存 s = O(N^ε') 的MPC协议,可以在 $O(\frac{1}{ε'-ε})$ 轮内用 O(N^(1+ε-ε')) 台机器对N个项进行排序,每个项的大小为 O(N^ε),ε < ε'。

信息聚合。每个查询数据结构机器(包含 $q_i$ 的i)将其位置信息发送到其“元信息”机器 $(h(q_i), h_q)$。思想是构建一个s/m-ary树结构来聚合信息,每个查询数据结构机器是该树的一个叶节点。回想一下,每台机器 (i, Q) 存储查询集 $S = \{q_{σ((i-1)·n_q+1)}, \dots, q_{σ(i·n_q)}\}$。如果S包含某个特定查询向量 $q_l$ 的开始和结束,则 (i, Q) 发送一条消息 $(q_l, (i, Q))$ 到机器 $(h(q_l), h_q)$。机器 (i, Q) 还将其第一个和最后一个查询发送到树中的父机器,即发送消息 $(q_{σ((i-1)·n_q+1)}, (i, Q), \text{first})$ 和 $(q_{σ(i·n_q)}, (i, Q), \text{last})$。父节点收集完所有来自叶节点的消息后,它会做与其子节点相同的事情:如果它包含某个查询q的开始和结束,它会将查询的位置信息(存储它的第一台和最后一台机器)发送到其对应的“元信息”机器 $(h(q), h_q)$,并将第一个和最后一个查询及其位置信息发送到其父机器。这是递归完成的,由于总共有 $\lceil mN/s \rceil$ 个查询数据结构机器,这个s/m-ary树的深度为 $O(\log_{s/m} mN/s) = O(\frac{1}{ε'-ε})$,这意味着 $O(\frac{1}{ε'-ε})$ 轮和 O(mN/s) 台机器就足够了。

处理键值对。我们对 (k, v) 对做同样的事情。排序数据结构机器 (i, KV) 根据k的顺序对 (k, v) 对进行排序。和之前一样,我们构建一个s/(2m)-ary树将位置信息发送到每个键的“元信息”机器。与查询的不同之处在于,我们合并具有相同键的值。对于s/(2m)-ary树中的每台机器,它计算其包含的每个键相关联的平均值,并将平均值发送到相应的“元信息”机器。特别是,对于每个 $k_i$,其“元信息”机器 $(h(k_i), h_k)$ 包含信息 $(k_i, \bar{v})$,其中 $\bar{v}$ 是满足 $k_j = k_i$ 的 $v_j$ 的平均值。

信息交换与回传。接下来,查询和键的“元信息”机器需要交换信息以检索每个查询对应的值。每个 (i, hk) 将其拥有的 $(k, \bar{v})$ 对发送到机器 (i, hq)。然后,每个 (i, hq) 机器匹配q和k,并将关联的值 $\bar{v}$ 发送给q。注意,这一步可以通过反向传播为将q的位置信息发送到 $(h(q), h_q)$ 而构建的s/m-ary树来完成。换句话说,我们可以只反转这棵树中的消息发送方向。因此,查询数据结构机器中的每个查询都会收到它检索到的值,从每个查询数据结构机器 (i, Q),我们可以将每个查询检索到的值发送到其对应的token机器,这是第二轮的逆操作。

总结。总共需要的轮数是 $O(\frac{1}{ε'-ε})$,需要的机器数是 O(mN/s) = O(N^(1+ε-ε'))。为了使其适用于H个头,我们可以创建H个副本并让它们并行运行。由于 mH = O(N^ε),轮数和机器数的界仍然成立。通过为L层中的每一层创建这个MPC模拟,我们按层的顺序将它们堆叠起来,从而产生对L层EMA-transformer的完整模拟。□

MPC模拟ANNA。接下来,我们推广上述算法,以模拟可通过算法1计算的ANN注意力。由于算法1是一个随机算法,我们假设MPC协议共享ANNA-transformer所有层所需的所有随机种子。定理C.3 (MPC模拟ANNA). 固定常数 $0 < ε < ε' < 1$。对于任何L层ANNA-transformer T(如算法1实现),其 mH = O(N^ε),存在一个 O(L/(ε' - ε)) 轮的MPC协议π,其本地内存 s = O(N^ε'),机器数量 P = O(N^(1+ε-ε'+3/c²)),使得对于所有输入 $input \in Z_{2^p}^N$,$T(input) = \pi(input)$。

证明。模拟ANNA-transformer的宏观思想与模拟EMA非常相似。我们有与之前相同类型的机器。最大的区别是,我们现在有ℓ个哈希表,每个哈希轮次一个,而不是一个用于查询和键的哈希表,我们根据查询和键的哈希值对它们进行排序。同样,我们首先概述我们将使用的不同类型的机器。
* 对于 i ∈ [N],机器i是 $x_i$ 的token机器。该机器执行token i的所有逐元素计算。具体来说,它计算查询、键、值嵌入 $q_i, k_i, v_i$ 以及注意力层之后的逐元素操作。
* 对于 $i \in [\lceil mN/s \rceil], t \in [ℓ]$,机器 (i, Q, $h_t$) 是一个数据结构机器,用于第t个哈希表的排序查询和查询排序列表的第i块,排序顺序基于算法1中的 $g_t(q)$。
* 对于 $i \in \lceil 2mN/s \rceil$,机器 (i, KV, $h_t$) 是一个数据结构机器,用于第t个哈希表的键值对排序列表和键值对排序列表的第i块,排序顺序基于 $g_t(k)$。
* 对于 i ∈ [N], t ∈ [ℓ],机器 $(g_t(q_i), h_q, t)$ 是第t个哈希值为 $g_t(q_i)$ 的查询的“元信息”机器。设 $h_t^q(i) = \{q_j | j \in [N], g_t(q_j) = g_t(q_i)\}$。该机器存储 $q \in h_t^q$ 在第t个哈希表中的位置信息。
* 对于 i ∈ [N], t ∈ [ℓ],机器 $(g_t(k_i), h_k, t)$ 是第t个哈希值为 $g_t(k_i)$ 的键的“元信息”机器。设 $h_t^k(i) = \{k_j | j \in [N], g_t(k_j) = g_t(k_i)\}$。该机器存储 $k \in h_t^k$ 在第t个哈希表中的位置信息。
* 消息传播所需的辅助机器。

MPC协议流程(ANNA)。和以前一样,我们仍然使用每个token机器来计算嵌入 $q_i, k_i, v_i \in R^m$。然后,每个token机器需要将 $(q_i, i)$ 和 $(k_i, v_i, i)$ 发送到数据结构机器,即机器 $(\lceil mi/s \rceil, Q, h_t)$ 和机器 $(\lceil 2mi/s \rceil, KV, h_t)$,对于所有 t ∈ ℓ。因为 $ℓ = N^{3ρ}$,我们使用s/m-ary树将查询和键传播到相应的数据结构机器。这需要 $O(\frac{1}{ε'-ε})$ 轮和 O(N^(1+3ρ+ε-ε')) 台机器。然后,对于每个查询哈希表 t ∈ [ℓ],数据结构机器根据查询的哈希值对查询进行排序。与定理C.1相同,我们使用s/m-ary树将每个哈希桶的位置信息发送到其对应的“元信息”机器。对于每个键值对哈希表 t ∈ [ℓ],数据结构机器根据键的哈希值对键值对进行排序。之后,使用s/(2m)-ary树将信息传播到相应的“元信息”机器。与EMA模拟的不同之处在于,这个s/(2m)-ary树中的每台机器维护具有相同哈希值的键的值的总和,而不是平均值,并且还维护键的计数。这些可以在 $O(\frac{1}{ε'-ε})$ 轮和 O(N^(1+3ρ+ε-ε')) 台机器内完成。接下来,键的“元信息”机器将值的总和和计数发送到相应的查询“元信息”机器,即机器 $(g_t(k_i), h_k, t)$ 发送到机器 $(g_t(q_i), h_q, t)$。每个查询“元信息”机器然后沿着s/m-ary树,将值的总和和计数广播到哈希表中的查询。最后,哈希表中的每个查询需要将信息传播回其原始的token机器。由于每个token机器将从 $ℓ = N^{3ρ}$ 台机器接收消息,我们再次反转将查询发送到每个数据结构机器的s/m-ary树。在聚合期间,s/m-ary树中的每台机器仍然维护它接收到的值的总和和计数的总和。接收到值的总和和计数后,每个token机器i然后计算 ANNA($q_i$) = 值的总和除以计数。上述模拟在 $O(\frac{1}{ε'-ε})$ 轮内模拟了一层ANNA-transformer,使用了 O(N^(1+3ρ+ε-ε')) 台机器,其中 $ρ = 1/c^2$。因此,通过堆叠L层的模拟,总共需要 $O(\frac{L}{ε'-ε})$ 轮。为了扩展到H个头,我们只需要为H个并行副本实例化上述模拟,并且由于 mH = O(ε),总轮数和机器数保持不变。□

D ANN/EM Attention可以通过MPC模拟低秩Attention

模拟思路。我们通过首先给出一个计算低秩注意力的MPC算法,然后将其转换为ANNA-transformer来用ANN注意力模拟低秩注意力。定理D.1 (ANNA/EMA模拟低秩Attention). 对于常数 $0 < ε < ε' < 1$,任何基于低秩注意力的transformer,其深度L,秩r,嵌入维度m,且 rm = O(N^ε),都可以由一个EMA/ANNA-transformer模拟,其深度为 $O(\frac{L}{ε'-ε})$,头数为 H = O(N^ε'),嵌入维度为 m = O(N^(5ε') log N)。

证明。我们通过首先证明任何单层的低秩注意力可以由常数轮的MPC模拟来证明该定理。引理D.2 (MPC模拟低秩Attention). 对于常数 $0 < ε < ε' < 1$,任何单层低秩注意力,其秩r,嵌入维度m,且 rm = O(N^ε),都可以由一个 $O(\frac{1}{ε'-ε})$ 轮的MPC协议模拟,其本地内存 s = O(N^ε'),机器数量为 O(N)。

证明。假设 rm = O(N^ε) 且MPC的本地内存 s = O(N^ε'),其中 ε < ε'。与我们在MPC模拟EMA中所做的相同,对于每个token $x_i, i \in [N]$,我们有一个token机器i来计算 $x_i$ 的嵌入,但我们需要在核空间中计算它,即 $q'_i = Q'(x_i), k'_i = K'(x_i)$ 和 $v_i = V(x_i)$。为了计算 K'(X)^TV(X),回想一下

$$K'(X)^{\top}V(X) = \sum_{i=1}^{N} k_i v_i^{\top}$$

我们只需要计算N个大小为 r × m 的矩阵的和。每个token机器i计算矩阵 $k_i v_i^T$,我们构建一个 $\lfloor \frac{s}{rm} \rfloor$-ary树的机器来计算总和。树的叶子是所有的token机器,每个节点负责计算 $\lfloor \frac{s}{rm} \rfloor$ 个矩阵的和。我们从之前的模拟中知道,树的深度是 $O(\frac{1}{ε'-ε})$。在我们得到矩阵 M = K'(X)^TV(X) ∈ R^(r×m) 之后,为了计算 Q(X)K'(X)^TV(X),我们只需要将矩阵M传播到所有的token机器。每个token机器i计算 $q_i^T M$。通过反转计算和树中消息传播的方向,我们可以在 $O(\frac{1}{ε'-ε})$ 轮内将M传播到所有的token机器。因此,我们总共可以用 $O(\frac{1}{ε'-ε})$ 轮来模拟核注意力。□
对于L层的低秩注意力transformer,我们使用引理D.2为每一层构建MPC,并再次使用每个token机器的本地计算来模拟逐元素计算。我们将L个MPC堆叠在一起,总共有 $O(\frac{L}{ε'-ε})$ 轮。该定理通过应用定理B.2和推论B.9得出。□

改进嵌入维度的模拟。由于证明的核心是通过MPC模拟低秩Attention,我们也可以应用定理B.10和推论B.14,它们用更好的嵌入维度模拟MPC,从而为模拟低秩注意力transformer获得改进的嵌入维度。推论D.3 (ANNA/EMA用改进的嵌入维度模拟低秩Attention). 对于常数 $0 < ε < ε' < 1$,任何基于低秩注意力的transformer,其深度L,秩r,嵌入维度m,且 rm = O(N^ε),都可以由一个EMA/ANNA-transformer模拟,其深度为 $O(\frac{L}{(ε'-ε)·\min(ε^2, (ε'-ε)^2)})$,头数为 H = O(N^((ε'-ε)/4)),嵌入维度为 m = O(N^ε')。

E 关于Reformer的讨论

Reformer的定义。我们在此将Reformer正式定义为一个计算模型。定义E.1 (Reformer attention). 给定查询、键、值嵌入 Q(X), K(X), V(X) ∈ R^(N×m),使得 $q_i := k_i = Q(X)[i, :] = K(X)[i, :], v_i = V(X)[i, :]$,Reformer attention按以下步骤进行:
1. 对 {$q_1, \dots, q_N$} 应用一个哈希函数 h : R^m → U;
2. 按 $h(q_i)$ 对所有 $q_i$(以及因此的 $k_i$)进行排序,并将所有 $q_i$ 分成大小为 B ≤ O(1) 的块,令 $h'(q_i)$ 为 $q_i$ 所在块的标签(每个块中的查询可以有不同的哈希值);
3. 对于每个 $q_i$,只关注与它在同一个块中的 $k_j$。

因此,$q_i$ 的输出嵌入为

$$\sum_{j:h'(k_{j})=h'(q_{i})} \frac{\exp(\langle q_{i}, k_{j}\rangle)}{\sum_{j':h'(k_{j'})=h'(q_{i})} \exp(\langle q_{i}, k_{j'}\rangle)} \cdot v_{j}.$$

固定通信模式的Reformer。我们将 $f_ℓ : [N] \rightarrow [N]^B$ 定义为指定每个查询在第ℓ层应与哪些键计算内积的函数。根据Reformer的约束,我们有 $\forall i \in [N]$:
1. $f_ℓ(i) = \{a_1, a_2, \dots, a_B\} \in [N]^B$ 是一个集合(无重复)。
2. $i \in f_ℓ(i)$。
3. 对于任何 $j \in f_ℓ(i)$,$f_ℓ(j) = f_ℓ(i)$。
在第ℓ层对每个查询 $q_i$ 的注意力计算中,Reformer计算

$$\sum_{j \in f_{\ell}(i)} \frac{\exp \left(\left\langle q_{i}, k_{j}\right\rangle\right)}{\sum_{j^{\prime} \in f_{\ell}(i)} \exp \left(\left\langle q_{i}, k_{j^{\prime}}\right\rangle\right)} \cdot v_{j}$$

我们首先研究一个受限版本的Reformer,该版本预先固定了通信模式,即对于所有 $ℓ \in [L]$,$f_ℓ$ 是与输入无关的,并证明它不能计算所有输入token的和。

SUM任务定义。给定输入 X = ($x_1, x_2, \dots, x_N$),$x_i \in [M]$,且 M = N^(O(1)),SUM任务定义为 $SUM(X) = \sum_{i=1}^N x_i$。我们说一个Reformer T计算SUM,如果对于所有X,$T(X)_N = SUM(X)$。这里 $T(X)_N$ 是T在输入X下的第N个输出。可以选择任何位置作为最终输出位置,这里我们不失一般性地选择最后一个token,以遵循自回归生成模型的惯例。

Reformer的局限性命题E.3. 固定 L = O(1) 和 {$f_ℓ$}. 任何具有L层且每层注意力模式由 {$f_ℓ$} 指定的Reformer T都不能计算SUM(X):存在一个X,使得 $|T(X)_N - SUM(X)| \ge \epsilon$,对于任何 $0 < \epsilon < M/2$。

证明。我们用 {$ϕ_ℓ$} 表示每层的逐元素计算。令 $T^ℓ(X)_i$ 表示T在经过ℓ层计算后的第i个输出。我们通过归纳法证明该命题。
归纳假设: $T^ℓ(X)_i$ 是最多 B^ℓ 个不同 $x_i \in X$ 的函数。
基础情况: ℓ = 1

$$\begin{aligned} \begin{aligned} T^{1}(X)_{i} & =\sum_{j \in f_{1}(i)} \frac{\exp \left(\left\langle q_{i}, k_{j}\right\rangle\right)}{\sum_{j^{\prime} \in f_{1}(i)} \exp \left(\left\langle q_{i}, k_{j^{\prime}}\right\rangle\right)} \cdot v_{j} \\ & =\sum_{j \in f_{1}(i)} \frac{\exp \left(\left\langle Q\left(x_{i}\right), K\left(x_{j}\right)\right\rangle\right)}{\sum_{j^{\prime} \in f_{1}(i)} \exp \left(\left\langle Q\left(x_{i}\right), K\left(x_{j^{\prime}}\right)\right\rangle\right)} \cdot V\left(x_{j}\right) \\ & =\phi_{1}\left(x_{a_{1}}, x_{a_{2}}, \ldots, x_{a_{B}}\right) \text { where } f_{1}(i)=\left\{a_{1}, \ldots, a_{B}\right\} \end{aligned} \end{aligned}$$

它最多是X中B个 $x_i$ 的函数。
归纳步骤: 考虑

$$\begin{aligned} \begin{aligned} T^{\ell+1}(X)_{i} & =\sum_{j \in f_{\ell+1}(i)} \frac{\exp \left(\left\langle Q\left(T^{\ell}(X)_{i}\right), K\left(T^{\ell}(X)_{j}\right)\right\rangle\right)}{\displaystyle \sum_{j^{\prime} \in f_{\ell+1}(i)} \exp \left(\left\langle Q\left(T^{\ell}(X)_{i}\right), K\left(T^{\ell}(X)_{j^{\prime}}\right)\right\rangle\right)} \cdot V\left(T^{\ell}(X)_{j}\right) \\ & =\phi_{\ell+1}\left(T^{\ell}(X)_{a_{1}}, \ldots, T^{\ell}(X)_{a_{B}}\right) \text { where } f_{1}(i)=\left\{a_{1}, \ldots, a_{B}\right\} \end{aligned} \end{aligned}$$

由于每个 $T^ℓ(X)_a$ 最多是来自X的 B^ℓ 个变量的函数,所以 $T^{ℓ+1}(X)_i$ 最多是来自X的 $B \cdot B^ℓ = B^{ℓ+1}$ 个变量的函数。因此,如果 $T^L(X)_i$ 是所有 {$x_1, \dots, x_N$} 的函数,我们需要 $B^L \ge N$,因此 L = Ω($\log_B N$)。在 B = O(1) 和 L = O(1) 的情况下,$T^L(X)_i$ 是 $B^L \ll N$ 个变量的函数。不失一般性,考虑 $T^L(X)_N$ 是 {$x_1, \dots, x_{B^L}$} 的函数。那么,$x_{B^L+1}$ 可以是[M]中的任何数字,使得 $T^L(X)_N$ 远离SUM(X)。□

结论。因此,如果Reformer有任何能力,它必须来自排序部分,因为排序算法可以访问所有token输入的信息。尽管常数层的Reformer不能计算SUM,但可以轻易证明,一层ANNA-transformer可以通过设置 $v_i = N x_i$ 和 $k_1 = k_2 = \dots = k_N = q_N$ 对于所有 $i \in [N]$ 来计算SUM,从而检索所有的 $v_i$ 并将它们平均。

F ANNA-Transformer解决k-hop和Match2问题

F.1 ANNA/EMA-Transformer解决Match2问题

定理F.1。对于任何 $N, M = N^{O(1)}$,存在一个EMA-transformer T,它有一个层,一个注意力头,嵌入维度为1,使得对于所有 $X \in [M]^N$,$T(X) = Match2(X)$。

证明。给定输入 $X \in [0, M]^{N×1}$。令 $Q(X) = ϕ(X)Q, K(X) = ϕ(X)K, V(X) = ϕ(X)V$,其中Q, K, V是 $R^{2×1}$ 中的矩阵。定义 $ϕ$ 为 $ϕ(x) = (x, 1)$,并定义

$$\begin{aligned} Q = \begin{pmatrix} 1 \\ 0 \end{pmatrix}, K = \begin{pmatrix} -1 \\ M \end{pmatrix}, V = \begin{pmatrix} 0 \\ 1 \end{pmatrix} \end{aligned}$$

使得

$$\begin{aligned} \phi(X) Q=\begin{pmatrix}x_1 \\ x_2 \\ \vdots \\ x_N\end{pmatrix}, \phi(X) K=\begin{pmatrix}M-x_1 \\ M-x_2 \\ \vdots \\ M-x_N\end{pmatrix}, \phi(X) V=\begin{pmatrix}1 \\ 1 \\ \vdots \\ 1\end{pmatrix}. \end{aligned}$$

结果,对于每个 $1 \le i \le N$,如果存在 $1 \le j \le N$ 使得 $x_i + x_j = M$,那么

$$ ((\phi(X)Q)(\phi(X)K)^{\top})[i, j] = \frac{1}{|\{j \in [N] : x_i + x_j = M\}|} $$

否则,对于所有 $1 \le j \le N$, $((ϕ(X)Q)(ϕ(X)K)^T)[i, j] = 0$。最后,我们可以计算出,如果 $|\{j \in [N] : x_i + x_j = M\}| \neq 0$,那么

$$\text{EMA}(\phi(X)Q, \phi(X)K, \phi(X)V)[i] = \frac{1}{|\{j \in [N] : x_i + x_j = M\}|} \cdot |\{j \in [N] : x_i + x_j = M\}| = 1,$$

如果 $|\{j \in [N] : x_i + x_j = M\}| = 0$,那么

$$\text{EMA}(\phi(X)Q, \phi(X)K, \phi(X)V)[i] = 0.$$

这对于ANNA-transformer也给出了相同的结果。

推论F.2。对于任何 $N, M = N^{O(1)}$,存在一个ANNA-transformer T,它有一个层,一个注意力头,嵌入维度为1,使得对于所有 $X \in [M]^N$,$T(X) = Match2(X)$。

F.2 ANNA-Transformer解决k-hop问题

ANNA解决1-hop。我们首先证明ANNA-transformer可以解决归纳头(1-hop)。引理F.3。固定常数 $0 < ε < ε' < 1$,且 $|Σ| \le N$。存在一个ANNA-transformer T,其层数 $L = O(\frac{1}{ε·(ε'-ε)^2})$,每层头数 $H = O(N^{((ε'-ε)/4)})$,嵌入维度 $m = O(N^{ε'})$,使得对于所有 $w \in Σ^N$ 和所有 $i \in [N]$,$T(w)_i = w_{σ(w,i)}$ 如果 $σ(w, i) \neq 0$;$T(w)_i = \perp$ 如果 $σ(w, i) = 0$。

证明。我们通过设计一个具有本地内存 s = O(N^ε) 和 N/s 台机器的常数轮MPC算法来解决1-hop来证明该引理。由于 $|Σ| \le N$,每个token可以用 O(log N) 位(O(1)个词)来嵌入。表示输入 $w^N = (x_1, x_2, \dots, x_N)$。MPC算法工作如下:
1. 对于每个 $x_i$,检索下一个token $x_{i+1}$,机器上的每个token存储为 $(x_i, i, x_{i+1}, i + 1)$ 的嵌入。
2. 为对象 $(x_i, i, x_{i+1}, i + 1)$ 定义一个比较器<。对于两个元组 $(x_i, i, x_{i+1}, i + 1)$ 和 $(x_j, j, x_{j+1}, j + 1)$,如果 $x_i \neq x_j$,那么 $x_i < x_j \Rightarrow (x_i, i, x_{i+1}, i + 1) < (x_j, j, x_{j+1}, j + 1)$;如果 $x_i = x_j$,那么 $i < j \Rightarrow (x_i, i, x_{i+1}, i + 1) < (x_j, j, x_{j+1}, j + 1)$。按比较器<对 $(x_i, i, x_{i+1}, i + 1)$ 进行排序。
3. 排序列表中的每个token $(x_i, i, x_{i+1}, i + 1)$ 检索其在排序列表中的前一个token,表示为 $(x_j, j, x_{j+1}, j + 1)$。更新token的嵌入:如果 $x_j = x_i$,则token $x_i$ 的嵌入变为 $(i, x_{j+1}, j + 1)$,即 $(i, w_{σ(w,i)}, σ(w, i))$;如果 $w_j \neq w_i$,则token $x_i$ 的嵌入变为 $(i, \perp, 0)$。
4. 将每个 $(i, w_{σ(w,i)}, σ(w, i))$ 发送到正确的输出机器 $\lceil i/s \rceil$ 并为token i输出 $w_{σ(w,i)}$。
对于步骤1,每台机器只需要向其邻居机器发送消息:机器i向机器i-1发送消息,这只需要1轮。在步骤2中,每个元组只有O(log N)位,所以根据定理C.2,排序需要 $O(\frac{1}{ε})$ 轮。在步骤3中,同样每台机器只需要向其邻居机器发送消息:机器i向机器i+1发送消息,这只需要1轮。在步骤4中,排序列表的每台机器将其存储的最多s个元组发送到正确的输出机器,这需要1轮。因此,MPC算法总共有 $O(\frac{1}{ε})$ 轮。然后,我们将这个MPC算法转换为一个ANNA-transformer。根据定理B.14,这给了我们一个ANNA-transformer,其头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),层数 $L = O(\frac{1}{ε·(ε'-ε)^2})$。□

ANNA解决k-hop。现在我们证明ANNA-transformer可以用O(log k)层解决k-hop。定理F.4。固定常数 $0 < ε < ε' < 1, |Σ| \le N$ 和任何 $k \in N$。存在一个ANNA-transformer T,其层数 $L = O(\frac{1}{ε·(ε'-ε)^2} + \frac{\log k}{(ε'-ε)^2})$,每层头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),使得对于所有 $w \in Σ^N$ 和所有 $i \in [N]$,$T(w)_i = w_{σ^k(w,i)}$。

证明。我们通过构建一个具有 s = O(N^ε) 本地内存和 O(N/s) 台机器的 O(log k) 轮MPC来证明该定理。我们通过归纳法证明该MPC算法可以计算k-hop。设 $k = \sum_{j=0}^{\lfloor\log k\rfloor} k_j 2^j$ 和 $k_{:ℓ} = \sum_{j=0}^{ℓ-1} k_j 2^j$,其中 $k_j \in \{0, 1\}$。
归纳假设: 经过 $O(\frac{1}{ε}) + 2ℓ$ 轮MPC计算后,每个token i的token嵌入编码了该元组的信息

$$(i, w_{\sigma^{2^{\ell}}(w, i)}, \sigma^{2^{\ell}}(w, i), w_{\sigma^{k: \ell}(w, i)}, \sigma^{k: \ell}(w, i))$$

基础情况: ℓ = 0, k = 1由定理F.3蕴含。在步骤3之后,我们有 $(i, w_{σ(w,i)}, σ(w, i))$。现在考虑k = ℓ+1。对于每个i,包含 $(i, w_{σ^{2^ℓ}(w,i)}, σ^{2^ℓ}(w, i), w_{σ^{k_{:ℓ}}(w,i)}, σ^{k_{:ℓ}}(w, i))$ 的机器(机器 $\lceil i/s \rceil$)将消息 $(i, σ^{2^ℓ}(w, i))$ 发送到包含 $σ^{2^ℓ}(w, i)$ 作为元组第一项的机器,即机器 $\lceil σ^{2^ℓ}(w,i)/s \rceil$。机器 $\lceil σ^{2^ℓ}(w,i)/s \rceil$ 然后将以下元组发送到机器 $\lceil i/s \rceil$:

$$\begin{aligned} \begin{aligned} & (\sigma^{2^\ell}(w, i), w_{\sigma^{2^\ell}(w, \sigma^{2^\ell}(w, i))}, \sigma^{2^\ell}(w, \sigma^{2^\ell}(w, i)), w_{\sigma^{k:\ell}(w, \sigma^{2^\ell}(w, i))}, \sigma^{k:\ell}(w, \sigma^{2^\ell}(w, i))) \\ & = (\sigma^{2^\ell}(w, i), w_{\sigma^{2^{\ell+1}}(w, i)}, \sigma^{2^{\ell+1}}(w, i), w_{\sigma^{k:\ell+2^\ell}(w, i)}, \sigma^{k:\ell+2^\ell}(w, i)) \\ & := (\sigma^{2^\ell}(w, i), t_1, t_2, t_3, t_4) \end{aligned} \end{aligned}$$

由于每台机器最多有s个元组,且函数 $σ(w, i)$ 除 $\perp$ 外是一对一的,所以每台机器发送和接收的消息数以s为界。机器 $\lceil i/s \rceil$ 接收到上述消息后,它会更新token i的元组:
1. 如果 $k_ℓ = 0$,token i更新为:$(i, t_1, t_2, w_{σ^{k_{:ℓ}}(w,i)}, σ^{k_{:ℓ}}(w, i))$
2. 如果 $k_ℓ = 1$,token i更新为:$(i, t_1, t_2, t_3, t_4)$
根据定义,token i的嵌入现在是:

$$(i, w_{\sigma^{2^{\ell+1}}(w, i)}, \sigma^{2^{\ell+1}}(w, i), w_{\sigma^{k:\ell+1}(w, i)}, \sigma^{k:\ell+1}(w, i))$$

上述归纳步骤只需要2轮MPC。因此,总轮数是 $O(\frac{1}{ε}) + 2(ℓ + 1)$。当 $ℓ = \lfloor\log k\rfloor + 1$ 时,该算法计算k-hop的输出。同样,我们可以将这个MPC算法转换为一个ANNA-transformer。根据定理B.14,这给了我们一个ANNA-transformer,其头数 H = O(N^((ε'-ε)/4)),嵌入维度 m = O(N^ε'),层数 $L = O(\frac{1}{ε·(ε'-ε)^2} + \frac{\log k}{(ε'-ε)^2})$。□

G 实验细节

实验设置。以下是实验设置的详细信息。所有实验都在2个GPU上启动:NIVIDIA Titan RTX和NVIDIA Titan Xp。我们训练了一个修改版的注意力矩阵,然后使用(Andoni et al., 2015【Andoni, Alexandr, et al. "Practical and optimal lsh for angular distance." In Proceedings of the 29th International Conference on Neural Information Processing Systems, NIPS’15, 2015.】)中的角距离LSH族实现的ANNA从训练好的模型中进行蒸馏。我们的softmax注意力将Q(X)和K(X)中的所有查询和键归一化为单位范数,并计算 $softmax(β \cdot Q(X)K(X)^T)V(X)$,其中β > 0是一个超参数。

G.1 Match2实验

数据集生成。受(Kozachinskiy et al., 2025【Kozachinskiy, Alexander, et al. "Strassen attention: Unlocking compositional abilities in transformers based on a new lower bound method." 2025.】)为Match3任务(Match2的三元组版本)生成数据的方式启发,我们使用相同的算法为Match2生成数据,但在计算标签时改为成对关系。每个样本是一个元组(X, Y),其中 $X = (x_1, x_2, \dots, x_N)$,每个 $x_i$ 是从 {1, 2, ..., 36} 中采样的整数;$Y = (y_1, y_2, \dots, y_N)$,每个 $y_i = 1\{\exists j, x_i + x_j = 0 \mod 37\}$。序列长度N设置为32。在采样数据时,我们确保每个批次通过保持Y中1的分布相同来达到平衡:每个批次有4个箱,每个箱对应于Y中1的百分比范围 [0, 25%), [25%, 50%), [50%, 75%), [75%, 100%];每个箱的大小是批次大小的1/4。详情见算法3。

算法3 Match2数据集生成
算法3 Match2数据集生成

训练细节。我们分别用 $β \in \{0.1, 1, 10\}$ 训练了3个模型,使用Adam优化器对交叉熵损失进行优化,学习率为0.01。每个模型有一层,一个注意力头,嵌入维度 m = 64,以及一个宽度为4m、激活函数为GeLU的MLP。数据集大小、批次大小、训练步数分别为10000、32、20000。我们在所有3个训练好的模型上应用ANNA,哈希表数 $ℓ \in \{1, 2, \dots, 16\}$,每个表的哈希函数数 $z \in \{1, 2, \dots, 6\}$,$β = 0.1$ 时的性能最好(错误可以为0)。由于ANNA的实现是随机的,对于每个(ℓ, z)的组合,我们运行10次并报告10次运行的平均误差。当 $β = 0.1$ 时的性能图见图1a。在这种设置下,$ℓ \ge 8, z = 1$ 可以在包含256个测试样本的测试集上实现0错误。

G.2 归纳头实验

数据集生成。我们使用(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)的数据生成算法。每个样本(X, Y)的形式为 X = (k, X'),其中对于训练样本k ∈ {0, 1},对于测试样本k = 1,X' ∈ Σ^(N-1),Y = (0, Y'),其中Y'是X'的k-hop标签。这里的序列长度N = 100,字母表大小 |Σ| = 4。

训练细节。我们分别用 $β \in \{0.1, 1, 10\}$ 训练了3个模型,使用Adam优化器对交叉熵损失进行优化,学习率为0.01。每个模型有2层,每层有一个注意力头,嵌入维度 m = 128,以及一个宽度为4m、激活函数为GeLU的MLP。我们使用在线训练:在每个训练步骤,采样全新的数据进行训练。批次大小和训练步数分别为32、400000。

评估细节。我们在所有3个训练好的模型上应用ANNA。对于第一层,哈希表数ℓ从 {32, 40, 48, ..., 96} 中选择,z从 {1, 2, 3, 4} 中选择。对于第二层,哈希表数ℓ从 {4, 8, 12, ..., 32} 中选择,z从 {1, 2, 3, 4} 中选择。在评估测试误差时,我们计算所有token上的误差。注意这与(Sanford et al., 2024c【Sanford, Clayton, Daniel Hsu, and Matus Telgarsky. "Transformers, parallel computation, and logarithmic depth." In Forty-First International Conference on Machine Learning, 2024c.】)不同,后者只计算其归纳头存在的token上的误差,以避免在k较大且有大量空输出时高估性能。在我们的设置中,k = 1,没有很多空输出,模型学习何时输出空token是很重要的。我们发现 $β = 1$ 的性能最好,所以我们在图1b中报告了它。同样,误差是每种组合10次运行的平均值,并取z上的最小值。可以看出,第一层32个哈希表和第二层4个哈希表已经给出了非常不平凡的性能:在100个样本上,每个样本有100个token预测,错误率为0.2,而随机猜测会给出0.75的错误率。随着第一层哈希表的增多,错误率可以降到0.1以下。