Spark Transformer: Reactivating Sparsity in FFN and Attention
Spark Transformer: Reactivating Sparsity in FFN and Attention
作者/机构: Chong You, Kan Wu, Zhipeng Jia, Lin Chen, Srinadh Bhojanapalli, Jiaxian Guo, Utku Evci, Jan Wassenberg, Praneeth Netrapalli, Jeremiah J. Willcock, Suvinay Subramanian, Felix Chern, Alek Andreev, Shreya Pathak, Felix Yu, Prateek Jain, David E. Culler, Henry M. Levy, and Sanjiv Kumar (Google, xAI, Anthropic)
A1 主要贡献
本文旨在解决现代 Transformer 模型中激活稀疏性缺失的问题。随着模型从 ReLU 激活函数转向门控非 ReLU 激活函数(如 GELU),在训练好的 Transformer 中观察到的“懒惰神经元”现象【索引62,The lazy neuron phenomenon: On emergence of activation sparsity in transformers+2022+ICLR】消失了,这使得利用激活稀疏性提升模型效率变得困难。
核心问题: 我们能否在不损害模型质量的前提下,在最新的 Transformer 模型中重新引入高水平的激活稀疏性?
研究目标: 本文旨在提出一种新的 Transformer 架构,该架构能在前馈网络(FFN)和注意力机制中都实现高水平的激活稀疏性,同时保持模型质量、参数数量,并沿用标准训练流程。
创新点与主要贡献:
本文引入了 Spark Transformer 架构,它由 Spark FFN 和 Spark Attention 两个核心组件构成。
-
统一的稀疏化框架和低成本预测器:
- Spark Transformer 将 FFN 和注意力机制都解释为键值查找表,并为此设计了一个统一的稀疏化框架。
- 它通过重新分配 FFN 参数和注意力键嵌入的一部分维度,构建了一个低成本的预测器来识别应被激活的条目。这一设计避免了引入额外的参数,解决了现有方法中增加训练复杂性和参数量的挑战(挑战 #3)。
-
高效的 Statistical Top-k 算子:
- 为了在不显著拖慢训练速度的情况下引入稀疏性,本文提出了一种名为 Statistical Top-k 的硬件加速器友好、线性时间复杂度的近似算法。
- 该算法通过拟合高斯分布来估计一个阈值,从而选出最重要的条目,避免了传统 top-k 算子所需的高成本排序操作,解决了训练速度慢的问题(挑战 #2)。
- 该算子几乎处处可微,有利于模型优化。
-
高质量与高效率的权衡:
- 通过使用 Gemma-2 的训练方案进行完整的预训练,Spark Transformer 在保持模型质量的同时实现了显著的稀疏性。实验证明,该方法相比于切换回 ReLU、应用 top-k 掩码等现有方法,在 FLOPs 减少和模型质量之间取得了更优的权衡(挑战 #1)。
- 具体而言,只有 8% 的 FFN 神经元被激活,每个 token 最多关注 256 个其他 token,这使得 FLOPs 减少了 2.5 倍。
- 这种稀疏性带来了实际的性能提升,在 CPU 上的解码速度提升高达 1.79 倍,在 GPU 上提升高达 1.40 倍。
(a) 每个 token 的 FLOPs 与质量对比 (完整训练的 1/6)
(b) 下游任务评估
(c) CPU 推理耗时
图 1 | Spark Transformer 通过在 FFN 和注意力机制中实现高水平的激活稀疏性来提高推理效率,同时几乎不影响模型质量。(a) 与先前工作在 8k 序列长度下每个 token 的相对 FLOPs(y 轴)与相对训练损失(x 轴)的比较。我们使用标准的 Gemma-2 作为基线,它没有激活稀疏性。仅在 FFN 层内采用激活稀疏性的方法。我们的 Spark FFN 相比于 ReLU、ReLU2 和 Topk(分别指将标准 Gemma-2 的激活函数切换为 ReLU、ReLU2 以及 Topk 和 GELU 的组合)实现了最有利的权衡。将 Spark FFN(激活参数为 8%)与 Spark Attention(最多关注 256 个 token)相结合,我们的 Spark Transformer 实现了与 Gemma-2 相当的性能,同时将 FLOPs 减少到 40%。(b) 在标准下游任务上的评估证实了 Spark Transformer 几乎不影响质量。(c) Prefill / decode 的耗时展示了由 FLOPs 减少带来的 1.86 倍/1.64 倍加速。结果是在 4 核 CPU 上对 4096 个 token 的提示进行测试得到的。对于 prefill,提示被分块为 64 个 token 的批次,遵循 gemma.cpp 的默认设置。
图 2 | Spark FFN 和 Spark Attention 的架构。(左) 标准 FFN (即公式 (1)) 和标准 Attention (即公式 (6)) 的统一图示。在 FFN 的情况下,$q \in \mathbb{R}^{d_{\text{model}}}$ 是输入,$K$ 和 $V$ 分别是第一层和第二层的权重,$\sigma()$ 是 GELU。在 Attention 的情况下,$q \in \mathbb{R}^{d_{\text{attn}}}$ 是查询,$K$ 和 $V$ 分别是键和值矩阵,$\sigma()$ 是 softmax。(右) Spark FFN (即公式 (2)) 和 Spark Attention (即公式 (7)) 的统一图示。在 Spark FFN 的情况下,$\sigma_1()$ 是 GELU,$\sigma_2()$ 是恒等函数。在 Spark Attention 的情况下,$\sigma_1()$ 是 softmax,$\sigma_2()$ 是 softplus。在这两种情况下,都应用了 Statistical-Topk (即公式 (10)) 来引入稀疏性,这使得与 $K_2$ 和 $V$ 的稀疏矩阵乘法能够减少 FLOPs 数量。
A2 方法细节
2. Spark Transformer
本节描述 Spark Transformer 的两个组件:Spark FFN 和 Spark Attention。
2.1. Spark FFN
标准 FFN 架构。在标准的 Transformer 中,FFN 是一个双层多层感知机,它将输入 token $q \in \mathbb{R}^{d_{\text{model}}}$ 映射到输出。
$$\text{Standard-FFN}(\boldsymbol{q}; K, V) \stackrel{\text{def}}{=} V \cdot \sigma(K^{\top} \boldsymbol{q}) \in \mathbb{R}^{d_{\text{model}}}.$$在上述公式中,$\{K, V\} \subseteq \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$ 是可训练的模型参数,$\sigma()$ 是一个非线性激活函数。为简化符号,我们忽略了对层索引的依赖。
标准 FFN 的计算成本。公式 (1) 中的每个矩阵乘法有 $2d_{\text{model}} \cdot d_{\text{ff}}$ FLOPs,因此总计算成本为 $4d_{\text{model}} \cdot d_{\text{ff}}$。先前的工作表明,当 $\sigma()$ 是 ReLU 时,激活图 $\sigma(K^\top q)$ 在模型训练后非常稀疏。这种稀疏性可以很自然地用于减少其与第二层权重矩阵 $V$ 相乘的计算成本【索引62,The lazy neuron phenomenon: On emergence of activation sparsity in transformers+2022+ICLR】,从而将 FFN 的总 FLOPs 降至 $2d_{\text{model}} \cdot (d_{\text{ff}} + k)$,其中 $k \ll d_{\text{ff}}$ 是激活中的非零条目数。值得注意的是,这种稀疏性无法减少与 $K$ 相关的计算成本,这部分构成了 FFN 总 FLOPs 的一半。
Spark FFN 架构。为了同时减少 FFN 第一层的 FLOPs,我们引入了 Spark FFN,其定义如下:
$$\text{Spark-FFN}(\boldsymbol{q}; K_1, K_2, V, k, r) \stackrel{\text{def}}{=} V \cdot \left( \sigma(\text{Top}_k(K_1^\top \cdot \boldsymbol{q}[:r])) \odot (K_2^\top \cdot \boldsymbol{q}[r:]) \right).$$在上述公式中,$K_1 \subseteq \mathbb{R}^{r \times d_{\text{ff}}}$,$K_2 \subseteq \mathbb{R}^{(d_{\text{model}}-r) \times d_{\text{ff}}}$ 和 $V \subseteq \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}}$ 是可训练参数,激活函数 $\sigma()$ 遵循 Gemma 的设置,采用 GELU【索引37,Gaussian error linear units (gelus)+2016+arXiv】。Top$_k$ 的引入是为了获得稀疏性,其中 $k$ 是一个指定稀疏级别的超参数。具体来说,Top$_k$ 保留激活向量中最大的 $k$ 个值,同时将其他值设为 0。本文中,我们使用第 3 节介绍的高效 statistical top-k,它避免了对激活值进行排序。
输入拆分与低秩预测器。输入 $q$ 被拆分为 $q[:r]$ 和 $q[r:]$,分别包含前 $r$ 个维度和剩余的维度,其中 $r$ 是一个超参数。引入这种设计的目的是让 $K_1^\top q[:r]$ 这一项充当非零条目位置的低秩预测器,这使我们能够在计算 $K_2^\top q[r:]$ 和与 $V$ 的乘法时获得效率优势。下面将详细讨论这一点。
Spark-FFN 的 FLOPs 计算(无稀疏性)。若不利用稀疏性直接实现 Spark-FFN,其 FLOPs 数量与标准 FFN(公式 (1))相同,即:
其中三个项分别来自与 $K_1$、$K_2$ 和 $V$ 的乘法。
Spark-FFN 的 FLOPs 计算(利用稀疏性)。在 Spark-FFN 中,可以首先计算 $K_1^\top q[:r]$ 作为低秩预测器。其输出经过 Top$_k$(选择大约 $k$ 个最重要的条目)和激活函数 $\sigma()$ 后,我们得到一个稀疏输出。重要的是,获得稀疏输出后,无需执行公式 (2) 中另外两个矩阵乘法(即 $K_2^\top q[r:]$ 和与 $V$ 的乘法)的完整计算。相反,可以执行稀疏矩阵乘法,从而大幅减少 FLOPs:
$$2r \cdot d_{\text{ff}} + 2(d_{\text{model}} - r) \cdot k + 2d_{\text{model}} \cdot k = 2(d_{\text{ff}} - k) \cdot r + 4d_{\text{model}} \cdot k,$$这是一个关于 $r$ 的增函数。换言之,$r$ 控制着计算成本。我们在附录 C.4 节中提供了消融研究,表明当 $r \approx d_{\text{model}}/2$ 时,模型质量最佳。在这种情况下,Spark FFN 的总 FLOPs 约为 $d_{\text{model}} \cdot d_{\text{ff}} + 3 \cdot d_{\text{model}} \cdot k$,当 $k$ 很小时,这比公式 (3) 减少了 4 倍。
与门控激活的关系。包括 Gemma 在内的许多最新 Transformer 模型,都使用标准 FFN(公式 (1))的一个变体,其中激活函数被替换为门控激活:
$$\text{Gated-FFN}(\boldsymbol{q}; \boldsymbol{K}_1, \boldsymbol{K}_2, \boldsymbol{V}) = \boldsymbol{V} \cdot (\sigma(\boldsymbol{K}_1^\top \boldsymbol{q}) \odot (\boldsymbol{K}_2^\top \boldsymbol{q}) ) .$$在上述公式中,$\{K_1, K_2, V\} \subseteq \mathbb{R}^{d_{\text{model}} \times d'_{\text{ff}}}$。注意,为了在质量研究中与公式 (1) 中的 FFN 保持参数数量相同,$d'$ 通常取 $2/3 \cdot d_{\text{ff}}$【索引90,Glu variants improve transformer+2020+arXiv】。
Spark FFN 与门控 FFN 的对比。我们的 Spark FFN(公式 (2))在结构上与门控 FFN 有些相似,两者都在第一层有两个线性映射,在第二层有一个。不同之处在于:1) Spark FFN 增加了一个 Top$_k$ 以获得稀疏性;2) Spark FFN 第一层的输入是通过拆分输入向量的维度得到的。
2.2. Spark Attention
标准多头注意力机制。在标准的多头注意力层中,每个头 $h$ 将输入 $q \in \mathbb{R}^{d_{\text{model}}}$ 映射为一个查询、一个键和一个值向量,维度均为 $d_{\text{attn}}$,即 $q^{(h)} = W_Q^{(h)} q \in \mathbb{R}^{d_{\text{attn}}}$,$k^{(h)} = W_K^{(h)} q \in \mathbb{R}^{d_{\text{attn}}}$,$v^{(h)} = W_V^{(h)} q \in \mathbb{R}^{d_{\text{attn}}}$。这里,$\{W_Q^{(h)}, W_K^{(h)}, W_V^{(h)}\} \subseteq \mathbb{R}^{d_{\text{attn}} \times d_{\text{model}}}$ 是可训练的权重。
标准注意力计算。将上下文 $q$ 中的所有键向量和值向量收集成 $K^{(h)} = [k_1^{(h)}, \dots, k_{n_{\text{ctx}}}^{(h)}] \in \mathbb{R}^{d_{\text{attn}} \times n_{\text{ctx}}}$ 和 $V^{(h)} = [v_1^{(h)}, \dots, v_{n_{\text{ctx}}}^{(h)}] \in \mathbb{R}^{d_{\text{attn}} \times n_{\text{ctx}}}$,注意力计算如下:
Standard-Attention(q; K, V) def= V · softmax(K⊤q)
其中为简化符号,我们省略了对 $h$ 的依赖。与公式 (6) 相关的计算成本为每个头 $4d_{\text{attn}} \cdot n_{\text{ctx}}$。最后,所有头的输出被连接起来,并通过一个线性映射投影回 $d_{\text{model}}$ 维度。
Spark Attention 架构。注意到公式 (6) 的形式与 FFN(公式 (1))相同,只是非线性函数的选择不同。因此,我们采用与获得 Spark FFN 类似的策略,提出了 Spark Attention:
$$\text{Spark-Attention}(\boldsymbol{q}; K, V, k, r) \stackrel{\text{def}}{=} V \cdot \left( \sigma_1 \left( \text{Top}_k^{(-\infty)} (K_1^{\top} \boldsymbol{q}[:r]) \right) \odot \sigma_2 (K_2^{\top} \boldsymbol{q}[r:]) \right)$$在上述公式中,$K_1 \in \mathbb{R}^{r \times n_{\text{ctx}}}$ 和 $K_2 \in \mathbb{R}^{(d_{\text{attn}}-r) \times n_{\text{ctx}}}$ 分别包含来自 $K$ 的前 $r$ 行和剩余的行,而 $\sigma_1$ 和 $\sigma_2$ 是非线性函数。根据经验,我们发现取 $\sigma_1 = \text{softmax}$ 和 $\sigma_2 = \text{softplus}$ 能得到很好的结果。最后,Top$(-\infty)_k$ 指的是一个算子,它保留输入中最大的 $k$ 个值,同时将其余值设置为 $-\infty$。该算子可以使用 statistical top-k 来实现,我们将在第 3 节中解释。
Spark Attention 的 FLOPs 计算。若直接实现,公式 (7) 中的 FLOPs 数量为 $4d_{\text{attn}} \cdot n_{\text{ctx}}$,与公式 (6) 的 FLOPs 相同。然而,通过注意到 softmax 的输出预计是稀疏的,大约有 $k$ 个非零条目,与 $K_2^\top q[r:]$ 和与 $V$ 的乘法相关的计算成本可以大幅降低。特别地,如果我们取 $r = d_{\text{attn}}$,则每个 token 的 FLOPs 变为 $d_{\text{model}} n_{\text{ctx}} + 3 d_{\text{model}} \min{k_{\text{attn}}, n_{\text{ctx}}}$,当 $d_{\text{attn}} \ll n_{\text{ctx}}$ 时,这几乎是 4 倍的减少。
3. Statistical Top-k
本节介绍 Statistical-Top$_k$,一种用于实现 Spark FFN(即公式 (2))和 Spark Attention(即公式 (7))中 Top$_k$ 算子的高效算法。
软阈值算子定义。回顾一下,对于任意向量 $x \in \mathbb{R}^d$ 和标量阈值 $\theta \in \mathbb{R}$,软阈值算子定义为:
$$\text{Soft-Threshold}(\boldsymbol{x}, \theta) \stackrel{\text{def}}{=} \max\{\boldsymbol{x} - \theta \cdot \mathbf{1}, \mathbf{0}\} \in \mathbb{R}^d$$其中 1 和 0 分别是所有条目为 1 和 0 的 $d$ 维向量。软阈值算子将 $x$ 的每个条目向左平移 $\theta$,然后将结果在零处进行阈值处理。
Statistical-Topk 定义。我们将 Statistical-Top$_k$ 定义为从 $\mathbb{R}^d$到$\mathbb{R}^d$的如下映射:
$$\text{Statistical-Top}_k(\boldsymbol{x}) \stackrel{\text{def}}{=} \text{Soft-Threshold}(\boldsymbol{x}, \theta(\boldsymbol{x}, k)),$$其中
$$\theta(x, k) \stackrel{\text { def }}{=} \operatorname{mean}(x)+\operatorname{std}(x) \cdot Q(1-\frac{k}{d}).$$Statistical-Topk 工作原理。在上述公式中,我们定义 mean$(x) \stackrel{\text{def}}{=} \frac{1}{d} \sum_{i=1}^d x_i$ 和 std$(x) \stackrel{\text{def}}{=} \sqrt{\frac{1}{d-1} \sum_{i=1}^d (x_i - \text{mean}(x))^2}$,它们分别计算输入 $x$ 的样本均值和标准差。$Q(\cdot)$ 是标准高斯分布的分位数函数(即累积分布函数的逆)。在 Spark Transformer 中,公式 (10) 被用作公式 (2) 中的 Top$_k$ 算子。对于公式 (7) 中的 Top$(-\infty)_k$ 算子,使用了公式 (10) 的一个轻微变体,其中低于阈值 $\theta(x, k)$ 的条目被设置为 $-\infty$ 而不是 0。Statistical-Top$_k$ 首先计算一个阈值 $\theta(x, k)$,使得大约有 $k$ 个 $x$ 的条目超过它,然后应用带有这个阈值的软阈值算子于 $x$ 以获得稀疏输出。我们将在接下来的两个小节中讨论这两个组件。
3.1. 阈值估计
阈值估计的原理。公式 (10) 中的阈值 $\theta(x,k)$ 的设计是为了,如果 $x$ 的条目是从高斯分布中抽取的,那么 $d$ 个条目中大约有 $k$ 个会超过这个阈值。为了理解这一点,让 $\mu$ 和 $\sigma$ 分别表示底层高斯分布的均值和标准差。其分位数函数由 $\mu + \sigma \cdot Q(p)$ 给出,其中 $p \in (0, 1)$。因此,由于分位数函数的性质,我们期望 $x$ 中大约有 $d \cdot p$ 个条目超过 $\mu + \sigma \cdot Q(1-p)$。在实践中,由于 $\mu$ 和 $\sigma$ 是未知的,它们分别被样本均值 mean($x$) 和样本标准差 std($x$) 替代。
理论保证。以下定理将这一论证形式化。
定理 3.1。设 $x \in \mathbb{R}^d$ 是一个其条目独立同分布于 $N(\mu, \sigma^2)$ 的向量。对于任何 $1 \le k \le d-1$,设 $\theta(x,k)$ 是在公式 (10) 中定义的标量。取任意 $\delta \in (0, 1)$ 并假设 $d \ge \max{2, \log \frac{6}{\delta}}$。以至少 $1-\delta$ 的概率,大于 $\theta(x,k)$ 的 $x$ 的条目数,即 card$({i \in [d] | x_i > \theta(x,k)})$,满足
$$\frac{|\mathbf{card}\left(\left\{i \in[d] \mid x_i>\theta(\boldsymbol{x}, k)\right\}\right)-k|}{d} \leq 4 \sqrt{\frac{\log \frac{6}{\delta}}{d}}\left(1+\sqrt{-2 \log \min \left\{\frac{k}{d}, 1-\frac{k}{d}\right\}}\right).$$误差界分析。定理 3.1 提供了 $k$ 与超过阈值的 $x$ 的真实条目数之间的相对误差界。该界在 $k=1$ 或 $k=d-1$ 时最大化。因此,最坏情况下的界为 $O(\sqrt{\frac{\log d \cdot \log \frac{1}{\delta}}{d}})$,随着 $d$ 的增加而消失。值得注意的是,当 $k = \Theta(d)$ 时,误差界变为 $O(\sqrt{\frac{\log \frac{1}{\delta}}{d}})$,表现出更快的收敛速度。
计算成本。阈值 $\theta(x, k)$ 的计算非常高效,类似于 LayerNorm 层中使用的操作,仅需 $2d$ FLOPs 即可计算样本的均值和标准差。这与朴素的基于排序的方法形成鲜明对比,后者的复杂度为 $O(d \log d)$。
高斯分位数函数的实现。虽然高斯分位数函数 $Q(\cdot)$ 没有封闭形式的解,但在标准软件包如 SciPy【索引99,SciPy 1.0: Fundamental Algorithms for Scientific Computing in Python+2020+Nature Methods】中有高精度的分段近似算法,其复杂度为常数,可以直接应用于我们的需求。
3.2. 稀疏化
硬阈值处理的局限性。给定阈值 $\theta(x,k)$,一个直接的方法是获得一个稀疏向量,即把 $x$ 中低于阈值的所有条目设为零,保留其余的值。这个算子,有时被称为硬阈值处理【索引10,Iterative thresholding for sparse approximations+2008+Journal of Fourier analysis and Applications】,存在不连续性的问题,这可能妨碍其在基于梯度下降的训练中的适用性。
软阈值处理的优势。为了解决这个问题,Statistical-Top$_k$ 采用了公式 (9) 中定义的软阈值算子【索引8,A fast iterative shrinkage-thresholding algorithm for linear inverse problems+2009+SIAM journal on imaging sciences】。这个算子首先将 $x$ 的所有条目减去阈值 $\theta(x, k)$,然后将所有低于 0 的条目设为 0。软阈值处理具有连续且几乎处处可微(除了当 $x$ 的条目与 $\theta(x, k)$ 相等时)的优点。
平滑化以实现完全可微性。为了实现完全可微性,可以使用像 Huber 损失【索引39,Robust estimation of a location parameter+1992+Breakthroughs in statistics: Methodology and distribution】这样的平滑函数,它在输入 $x$ 上逐元素定义为:
$$\begin{aligned} \operatorname{Huber}(x; \delta) \stackrel{\text{def}}{=} \begin{cases}\frac{1}{2} x^2 & \text{for } |x| < \delta, \\ \delta \cdot (|x| - \frac{1}{2} \delta) & \text{otherwise.}\end{cases} \end{aligned}$$可微性证明。映射 $x \mapsto \text{Huber}(\text{Statistical-Top}_k(x); \delta)/\delta$ 的连续可微性在下面得到证实:
定理 3.2。对于任何 $\delta > 0$,函数 $\mathbb{R}^d \to \mathbb{R}^d$ 定义为 $\text{Huber}(\text{Statistical-Top}_k(x); \delta) / \delta$ 是连续可微的。
实际应用中的选择。注意到当 $\delta \to 0$ 时,公式 (14) 收敛到 $\text{Statistical-Top}_k(x)$,因为 $\text{Huber}(x; \delta)/\delta \to |x|$ 并且 $\text{Statistical-Top}_k(x)$ 总是非负的。然而,在实践中,我们发现使用非零的 $\delta$ 并不能改善模型质量,因此为简单起见,我们设置 $\delta = 0$。
软阈值处理的变分形式。最后,软阈值处理有一个变分形式(例如,参见【索引79,Proximal algorithms+2014+Foundations and trends® in Optimization】):
$$\text{Soft-Threshold}(x, \theta) = \arg\min_{\boldsymbol{z} \geq \mathbf{0}} \theta \|\boldsymbol{z}\|_1 + \frac{1}{2} \|\boldsymbol{x} - \boldsymbol{z}\|_2^2.$$该公式旨在寻找一个向量 $z$,它既能最小化与输入 $x$ 的平方 $\ell_2$ 距离,又能最小化其 $\ell_1$ 范数,其中阈值 $\theta$ 平衡了这两项。鉴于 $\ell_1$ 范数的稀疏促进性质,软阈值处理有效地找到了输入 $x$ 的稀疏近似。这种变分形式也将 Statistical-Top$_k$ 与文献中的其他 top-k 算法联系起来;见第 3.3 节。
3.3. Statistical Top-k 与相关 Top-k 算子的比较
与其他 top-k 算法的联系。公式 (15) 的变分形式揭示了 Statistical-Top$_k$ 与文献中其他 top-k 算法的联系。具体来说,Lei 等人 (2023)【索引58,Conditional adapters: Parameter-efficient transfer learning with fast inference+2023+NeurIPS】将 soft top-k 定义为
$$\arg\min \ -\theta \cdot H(\boldsymbol{z})-\langle\boldsymbol{z}, \boldsymbol{x}\rangle, \text { s.t. } \boldsymbol{z}^{\top} \mathbf{1}=k, \mathbf{0} \leq \boldsymbol{z} \leq \mathbf{1}$$其中 $H(z)$ 是熵函数。另一项工作 (Lou et al., 2024)【索引68,Sparser is faster and less is more: Efficient sparse attention for long-range transformers+2024+arXiv】定义了 SparseK 算子
$$\arg \min_{\boldsymbol{z}} -H^G(\boldsymbol{z}) - \langle \boldsymbol{z}, \boldsymbol{x} \rangle, \text{ s.t. } \boldsymbol{z}^{\top} \mathbf{1} = k, \mathbf{0} \le \boldsymbol{z} \le \mathbf{1},$$其中 $H^G(z)$ 是广义基尼熵。
Statistical-Topk 的优势。形式为公式 (15) 的 Statistical-Top$_k$,以及公式 (16) 和公式 (17),都可以解释为在受稀疏化正则化约束的情况下,寻找一个接近输入的输出。它们的主要区别在于稀疏正则化的选择。也就是说,soft top-k 和 SparseK 分别使用熵和基尼熵,而 Statistical-Top$_k$ 使用 $\ell_1$ 范数。$\ell_1$ 的选择使得 Statistical-Top$_k$ 更优越,因为它有由软阈值处理提供的封闭形式解,仅需 $O(d)$ FLOPs。相比之下,soft top-k 和 SparseK 都没有封闭形式解,需要一个迭代算法,其 FLOPs 数量取决于迭代次数。此外,无法保证 soft top-k 和 SparseK 能获得(大约)$k$ 个非零条目作为输出。
与分布式训练中类似思想的比较。最后,我们提到类似于 statistical top-k 的思想已被用于分布式训练问题【索引70,An efficient statistical-based gradient compression technique for distributed training systems+2021+MLSys】【索引92,Understanding top-k sparsification in distributed deep learning+2019+arXiv】(Lin et al., 2018)。然而,我们是第一个将其引入、改编并验证其在激活稀疏性方面有效性的。更多讨论见 D.2 节。
A4 实验环境
模型架构:本文实验基于 Gemma-2 2B 模型,这是一个拥有 20 亿参数的 decoder-only Transformer。实验中,将标准的 FFN 和 Attention 替换为 Spark FFN 和 Spark Attention。
- 模型维度:$d_{\text{model}} = 2304$。
- FFN 配置:原始 Gemma-2 使用门控 FFN(公式 (5)),$d'{\text{ff}} = 9216$。Spark FFN(公式 (2))使用 $d/2$。}} = 13824$ 以保持参数数量不变。稀疏度超参数 $k=1106$(约 8% 稀疏度),低秩预测器维度 $r=1024 \approx d_{\text{model}
- Attention 配置:原始 Gemma-2 交替使用 8192 token 的全局注意力和 4096 token 窗口的局部注意力,头维度 $d_{\text{attn}}=256$。Spark Attention(公式 (7))替换了这两种注意力,局部注意力窗口保持 4096。稀疏度超参数 $k=256$(每个 token 最多关注 256 个 token),低秩预测器维度 $r=128 = d_{\text{attn}}/2$。
- 位置编码:与 Gemma-2 一样使用 Rotary Position Embedding (RoPE)【索引96,Roformer: Enhanced transformer with rotary position embedding+2024+Neurocomputing】,并将其应用于 Spark Attention 中的 $q[:r]$, $q[r:]$, $K_1$ 列向量和 $K_2$ 列向量。
数据集:模型在 2 万亿个主要为英文文本的 token 上进行预训练,与 Gemma-2 2B 模型使用相同的训练流程和数据。
硬件配置:
* CPU:4 核和 16 核 CPU VM。
* GPU:NVIDIA T4 GPU。
* 训练硬件:未明确说明,但 JAX 的使用暗示了在 TPU 上进行训练。
软件配置:
* 代码实现:
* 训练:使用 JAX 实现。
* CPU 推理:基于官方 C++ 推理引擎 gemma.cpp【索引33,Google/gemma.cpp: lightweight, standalone c++ inference engine for google’s gemma models+2025+Github】进行修改,以支持稀疏矩阵乘法,并利用 SIMD 操作和软件预取。
* GPU 推理:基于 llama.cpp【索引31,llama.cpp+2025+Github】进行修改,使用定制的 CUDA 核。
A4 实验结果
质量评估
- 下游任务表现:在 Gemma-2 论文中使用的一系列基准测试上,Spark Transformer 的质量与 Gemma-2 相当,证实了其在大幅降低 FLOPs 的同时保持了近乎无损的质量(如图 1b 所示)。
- FLOPs-质量权衡:为了与其他强制稀疏化方法对比,本文将 Gemma-2 2B 的变体(激活函数分别替换为 ReLU、ReLU2 和 Top-k+GELU)预训练了标准流程的 1/6。如图 1a 所示,这些方法要么导致较大的质量损失(ReLU 和 Topk),要么 FLOPs 减少不充分(ReLU2)。相比之下,Spark FFN 在质量损失更小的情况下实现了更多的 FLOPs 减少。当结合 Spark FFN 和 Spark Attention 后,Spark Transformer 能够在 FLOPs 大幅降低的情况下,实现与 Gemma-2 几乎无差别的质量。
稀疏性验证
- 训练过程中的稀疏度:本文通过实验验证了 statistical top-k 的有效性。如图 3 所示,在训练初期,FFN 中的非零激活比例接近 8%,与超参数设置($k/d_{\text{ff}} = 8\%$)非常吻合。这符合预期,因为模型参数随机初始化时,激活图的条目近似服从高斯分布。即使在训练后期(480k 步),该假设仍然近似成立,statistical top-k 能够稳定地产生接近 8% 的稀疏度。
- Attention 稀疏度:如图 3b 所示,在整个训练过程中,每个 token 关注的 token 数量始终低于超参数设置的 256。平均值远小于 256 是因为许多 token 的上下文长度本身就小于 256。
- 评估阶段的稀疏度:在评估阶段也观察到了类似的稀疏度水平(见附录 C.2)。
图 3 | 在选定的训练步骤中,Spark FFN 和 Spark Attention 在 26 个层中的中间激活稀疏性。对于 FFN,我们报告了在 $d_{\text{ff}} = 13824$ 个条目中非零条目的百分比。对于 Attention,我们报告了非零条目的数量(即被关注的 token 数)。我们的超参数选择是在 Spark FFN 中有 8% 的非零值,在 Spark Attention 中最多有 256 个非零值。
推理效率
本文评估了 Spark Transformer 在 CPU 和 GPU 上的效率优势。通过修改 gemma.cpp 和 llama.cpp 以支持稀疏矩阵乘法,不仅减少了计算 FLOPs,还降低了内存带宽需求(如图 4 所示)。
-
CPU 性能:
- 在 16 核 CPU 上,针对不同长度的 prompt,解码速度提升了 1.35 倍至 1.79 倍。对于短 prompt,Spark FFN 的优化效果最显著;对于长 prompt,Spark Attention 的优化效果最显著(如图 5a、5b 所示)。
- 在 prefill 和 decode 阶段,Spark Transformer 均表现出显著的效率提升。即使在 prefill 阶段将 prompt 分块处理(块大小为 64),仍能保持强大的性能,分别实现了 1.86 倍(prefill)和 1.64 倍(decode)的 wall-time 加速(如图 1c 所示)。
-
GPU 性能:
- 在 NVIDIA T4 GPU 上,解码速度提升了 1.25 倍至 1.40 倍,表现出与 CPU 类似的趋势(如图 5c 所示)。
(a) 向量掩码矩阵乘法 (b) 稀疏向量-矩阵乘法
红色:跳过的内存访问/计算
图 4 | 使用稀疏激活实现矩阵乘法的图示。(a) 向量掩码矩阵乘法接收一个密集向量 $q[r:]$、一个密集矩阵 $K_2^\top$ 以及来自 statistical top-k on $K_1^\top q[:r]$ 的掩码,以计算 $y := (K_2^\top q[r:]) \odot \text{mask}$。它跳过了与被掩码列相关的内存加载和计算。(b) 稀疏向量-矩阵乘法接收一个稀疏激活向量 $y$ 来计算密集矩阵 $V$ 中行的加权和。它跳过了加载和计算对应于 $y$ 中 0 的行。为了优化性能,我们使用 tiling 实现稀疏向量-矩阵乘法,这有助于最小化跨 CPU 核心的同步。
图 5 | Spark Transformer 在各种硬件平台上因激活稀疏性带来的解码加速。我们报告了 Spark Transformer 在没有针对稀疏激活进行硬件优化、仅针对 Spark FFN 中的稀疏性进行硬件优化以及同时针对 Spark FFN 和 Spark Attention 中的稀疏性进行硬件优化时的解码速度。所有实验均使用解码批大小为 1。
训练效率
- Top-k 算子开销对比:如图 6 所示,与 JAX 中标准的近似 top-k 算子
jax.lax.approx_max_k相比,statistical top-k 的训练开销极小。即使在召回率仅为 50% 的情况下,JAX 的 top-k 算子也会导致超过 10 倍的训练减速,这使得使用它进行训练变得不切实际。相比之下,statistical top-k 引入的减速微乎其微,证明了其高效性。
图 6 | 使用我们的 statistical top-k 与标准 top-k(即 jax.lax.approx_max_k)相对于不使用任何 top-k 的训练减速比较。
A7 补充细节
5. 讨论与相关工作
总结与展望。本文介绍了 Spark Transformer 架构,旨在减少 Transformer 模型中 FFN 和 Attention 的 FLOPs。由于这两个组件在处理长上下文的大型 Transformer 中占主导地位,解码单个 token 的总 FLOPs 也大幅降低(见表 C.1),从而在合适的硬件上带来显著的计算效率提升。这种通过选择性激活部分模型参数和限制注意力上下文来实现稀疏性的原则,与神经科学中大脑稀疏活动模式是其高效性关键因素的研究【索引5,An energy budget for signaling in the grey matter of the brain+2001+Journal of Cerebral Blood Flow & Metabolism】【索引7,Experimental evidence for sparse firing in the neocortex+2012+Trends in neurosciences】【索引55,Sparsity of population activity in the hippocampus is task-invariant across the trisynaptic circuit and dorsoventral axis+2024a+Behavioural Brain Research】不谋而合。尽管当前硬件限制了稀疏激活的全部潜力,但 Spark Transformer 在 CPU 和 GPU 上的实际壁钟时间减少,以及相关技术在包括 TPU 在内的平台上的先前证据【索引65,Deja vu: Contextual sparsity for efficient llms at inference time+2023+ICML】【索引94,Powerinfer: Fast large language model serving with a consumergrade gpu+2023+arXiv】【索引106,Hire: High recall approximate top-k estimation for efficient llm inference+2024+arXiv】,凸显了其巨大潜力。我们希望这项工作能为研究更适合稀疏计算的替代硬件开辟道路,从而规避“硬件彩票”【索引38,The hardware lottery+2021+Communications of the ACM】问题,并在未来实现更大的效率提升。
与近期文献的联系。接下来,我们将我们的工作与近期文献联系起来。
5.1. 走向最小专家混合模型?
与 MoE 的关系。专家混合模型 (MoE) 是一类著名的稀疏激活模型,它将 FFN 内的神经元分组,并激活选定组中的所有神经元【索引59,Gshard: Scaling giant models with conditional computation and automatic sharding+2020+ICLR】【索引91,Outrageously large neural networks: The sparsely-gated mixture-of-experts layer+2017+arXiv】。这种结构化分组使其与现代训练加速器(如 GPU 和 TPU)兼容,从而能够训练超大规模模型。然而,可靠地训练 MoE 仍然具有挑战性,需要复杂的算法设计和特殊的硬件支持【索引29,Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity+2022+JMLR】。更关键的是,稀疏性的结构化特性限制了模型的灵活性和表达能力,与同等参数数量的密集模型相比,其质量有所妥协。
迈向“最小专家混合”。为了缓解这一问题,近期的工作趋向于采用包含大量更小专家的架构,并同时激活其中的多个专家【索引18,Deepseekmoe: Towards ultimate expert specialization in mixture-of-experts language models+2024+arXiv】【索引36,Mixture of a million experts+2024+arXiv】。Spark Transformer 可以被看作是将这一趋势推向极致,其中 FFN 中的每个神经元本身就是一个专家,其低秩预测器向量作为专家路由器。这催生了可被描述为“最小专家混合” (Mixture of Minimum Experts) 的概念。
基于自然稀疏性的专家设计。最后,自然出现的非结构化激活稀疏性的发现,激发了另一种设计专家的视角,即通过神经元的自然稀疏性对其进行分组【索引16,Approximating two-layer feedforward networks for efficient transformers+2023+EMNLP】【索引25,Towards structured sparsity in transformers for efficient inference+2023+Workshop on Efficient Systems for Foundation Models@ ICML2023】【索引82,Emergent mixture-of-experts: Can dense pre-trained transformers benefit from emergent modular structures?+2023+arXiv】【索引97,Exploiting activation sparsity with dense to dynamic-k mixture-of-experts conversion+2024+Workshop on Efficient Systems for Foundation Models II@ ICML2024】【索引108,Moefication: Transformer feed-forward layers are mixtures of experts+2022+ACL】【索引112,Learn to be efficient: Build structured sparsity in large language models+2024+arXiv】。
5.2. 重新激活激活稀疏性?
激活稀疏性技术。稀疏激活是提高大模型效率的常用方法,多年来已发展出许多低成本激活预测技术,如低秩分解【索引20,Low-rank approximations for conditional feedforward computation in deep neural networks+2013+arXiv】、量化【索引11,Seernet: Predicting convolutional neural network feature-map sparsity through low-bit quantization+2019+CVPR】、乘积键【索引51,Large memory layers with product keys+2019+NeurIPS】、哈希【索引12,Slide: In defense of smart algorithms over hardware acceleration for large-scale deep learning systems+2020+MLSys】等。随着现代 Transformer 模型的普及,这些技术成为降低其高计算成本的自然选择【索引40,Sparse is enough in scaling transformers+2021+NeurIPS】【索引65,Deja vu: Contextual sparsity for efficient llms at inference time+2023+ICML】【索引94,Powerinfer: Fast large language model serving with a consumergrade gpu+2023+arXiv】【索引107,Lookupffn: making transformers compute-lite for cpu inference+2023+ICML】。特别是,FFN 中激活自然稀疏的发现【索引62,The lazy neuron phenomenon: On emergence of activation sparsity in transformers+2022+ICLR】【索引69,Sparsing law: Towards large language models with greater activation sparsity+2024b+arXiv】【索引108,Moefication: Transformer feed-forward layers are mixtures of experts+2022+ACL】带来了很多激动人心的进展,因为它可以在不牺牲质量的情况下通过激活稀疏性获得效率。
本文工作的定位。我们的工作属于这一方向的最新进展,旨在将这些好处带给没有自然稀疏性的最新一代大语言模型。早期的尝试【索引72,Relu strikes back: Exploiting activation sparsity in large language models+2023+arXiv】【索引81,Theoretical explanation of activation sparsity through flat minima and adversarial robustness+2023+arXiv】【索引110,Relu2 wins: Discovering efficient activation functions for sparse llms+2024+arXiv】试图通过切换回 ReLU 变体来恢复稀疏性,但这通常会导致质量损失。通过更仔细的调优,质量差距可能在很大程度上被弥补,但激活变得不那么稀疏(例如,LLAMA 7B 中有 25% 的非零值【索引93,Prosparse: Introducing and enhancing intrinsic activation sparsity within large language models+2024a+arXiv】)。Top-k 近来已成为获得稀疏性的更流行选择【索引95,Turbo sparse: Achieving llm sota performance with minimal activated parameters+2024b+arXiv】,并能在提供强稀疏性的同时保持中性质量,但仅在选定层中【索引106,Hire: High recall approximate top-k estimation for efficient llm inference+2024+arXiv】。此外,这类方法需要微调来引入稀疏性并获得预测器。在不进行微调的情况下,Lee 等人 (2024b)【索引56,Cats: Contextually-aware thresholding for sparsity in large language models+2024b+arXiv】;Liu 等人 (2024a)【索引64,Training-free activation sparsity in large language models+2024a+arXiv】;Zhang 等人 (2025)【索引111,R-sparse: Rank-aware activation sparsity for efficient llm inference+2025+ICLR】在保持中性质量的情况下最多获得了 50% 的非零值。与这些工作相比,我们的工作不仅在所有 FFN 层的激活中获得了 8% 的非零值,还获得了一个预测器,所有这些都通过单阶段训练完成。我们在 D.1 节中提供了与这些方法的比较摘要。
激活稀疏性的其他作用。我们注意到激活稀疏性的用处超出了效率范畴。例如,理论研究表明其对模型泛化性和可学习性有益【索引6,Learning neural networks with sparse activations+2024+COLT】【索引73,Sparsity-aware generalization theory for deep neural networks+2023+COLT】。此外,被激活的神经元可能与语义概念相关联,这为了解 Transformer 模型的工作机制和操控其输出提供了可能【索引17,Self-conditioning pre-trained language models+2022+ICML】【索引68a,Pace: Parsimonious concept engineering for large language models+2024a+arXiv】。
稀疏注意力。稀疏注意力泛指通过关注上下文中选定的 token 子集来降低计算成本的方法【索引21,Attention is naturally sparse with gaussian distributed input+2024+arXiv】【索引42,Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention+2024+NeurIPS】【索引85,Sparse modular activation for efficient sequence modeling+2023+NeurIPS】。稀疏注意力的工作包括使用手工制作的注意力模式【索引2,Colt5: Faster long-range transformers with conditional computation+2023+EMNLP】【索引9,Longformer: The long-document transformer+2020+arXiv】【索引15,Generating long sequences with sparse transformers+2019+arXiv】【索引24,Longnet: Scaling transformers to 1,000,000,000 tokens+2023+arXiv】,其特点是简单,以及学习的注意力模式【索引49,Reformer: The efficient transformer+2020+ICLR】【索引74,Faster causal attention over large sequences through sparse flash attention+2023+arXiv】【索引87,Efficient content-based sparse attention with routing transformers+2021+TACL】,其特点是更好的建模能力。然而,学习注意力模式通常涉及学习哈希表或 k-means 中心等,这显著增加了建模的复杂性。与我们的 Spark Attention 密切相关的是 top-k 注意力【索引35,Memory-efficient transformers via top-k attention+2021+arXiv】,它仅通过 top-k 阈值处理就获得了数据自适应的注意力。我们的工作在 top-k 注意力的基础上进行了改进,引入了一个低成本预测器,从而能够从稀疏性中获得更大的计算收益。最后,KV 剪枝方法在解码过程中永久丢弃选定的 token【索引66,Scissorhands: Exploiting the persistence of importance hypothesis for llm kv cache compression at test time+2024b+NeurIPS】【索引109,H2o: Heavy-hitter oracle for efficient generative inference of large language models+2023+NeurIPS】,并且无法实现像稀疏注意力方法那样高的压缩比。
5.3. 协同以实现更高的推理效率?
与其他优化技术的结合。这项工作也为未来高效推理大语言模型的研究开辟了有前景的途径,特别是在将 Spark Transformer 的架构效率与其他领先的优化技术相结合方面。我们在这里简要讨论两个关键的协同作用。
与推测解码的协同作用。Spark Transformer 与推测解码高度互补。作为目标模型,其更快的推理速度直接加速了主要的验证瓶颈。作为草稿模型,其近乎中性的质量和高速度使其成为生成高质量草稿的理想选择,可能导致更高的 token 接受率和更大的整体加速。完整的讨论见 D.3 节。
与量化的协同作用。我们还假设与量化有很强的协同作用。其好处预计是乘法性的,因为 Spark Transformer 减少了操作数量,而量化减少了其成本。更重要的是,与保留高幅度异常值的标准剪枝不同,我们的 statistical top-k 算子使用软阈值处理。这会缩小激活分布的动态范围,从而可能降低对激活量化的敏感性。该机制的详细分析见附录 D.4 节。
A5 结论
本文提出的 Spark Transformer 架构,通过在 FFN 和 Attention 机制中引入高水平的激活稀疏性,显著降低了 Transformer 模型的计算成本(FLOPs)。其核心创新在于:
1. 一个统一的稀疏化框架,通过重用部分参数构建低成本预测器,避免了额外参数和复杂的训练流程。
2. 一个高效的 Statistical Top-k 算子,它以线性时间复杂度近似 top-k 操作,避免了传统方法带来的巨大训练开销,同时保持了良好的可微性。
实验结果表明,Spark Transformer 在使用 Gemma-2 2B 模型进行预训练时,能够在仅激活 8% FFN 神经元和每个 token 最多关注 256 个 token 的高稀疏度下,达到与原始模型几乎无差别的质量。这种稀疏性转化为显著的推理速度提升,在 CPU 和 GPU 上分别实现了高达 1.79 倍和 1.40 倍的解码加速。这项工作为开发更高效、更易于部署的大型语言模型提供了新的途径,并有望与推测解码、量化等技术协同,进一步提升推理效率。
A6 附录
A. 证明
A.1. 定理 3.1 的证明
证明概要。本证明首先为经验均值 $\bar{x}$ 和经验标准差 $s$ 建立集中界,证明它们分别接近真实均值 $\mu$ 和真实标准差 $\sigma$。
1. 使用 Laurent-Massart 界【索引52,Adaptive estimation of a quadratic functional by model selection+2000+Annals of statistics】来约束 $s$ 与 $\sigma$ 的偏差。
2. 使用标准正态分布的 Mill 不等式来约束 $\bar{x}$ 与 $\mu$ 的偏差。
3. 结合这两个界,推导出阈值 $\theta(x,k)$ 与其理论期望值 $\mu + \sigma Q(1 - k/d)$ 之间的偏差界。
4. 然后,利用 Dvoretzky-Kiefer-Wolfowitz 不等式【索引28,Asymptotic minimax character of the sample distribution function and of the classical multinomial estimator+1956+The Annals of Mathematical Statistics】【索引71,The tight constant in the dvoretzky-kiefer-wolfowitz inequality+1990+The annals of Probability】,该不等式约束了经验累积分布函数 (ECDF) 与真实累积分布函数 (CDF) 之间的最大偏差。
5. 最后,通过联合界将 ECDF 的偏差和阈值的偏差结合起来,推导出超过阈值 $\theta(x,k)$ 的元素数量与期望值 $k$ 之间的最终相对误差界。
A.2. 定理 3.2 的证明
证明概要。Huber statistical top-k 函数可以写成函数复合的形式:$\text{Huber}(\text{Soft-Threshold}(x, \theta(x,k)))/\delta$。要证明其连续可微,只需证明其各组成部分是连续可微的。函数 $\theta(x,k)$ 是样本均值和标准差的线性组合,而这两者都是关于 $x$ 的 $C^1$ 函数,因此 $\theta(x,k)$ 也是 $C^1$ 的。接着,证明 $\text{Huber}(\text{Soft-Threshold}(x, \theta))$ 是关于 $(x, \theta)$ 的 $C^1$ 函数。通过分段分析 Huber 函数的定义,可以很容易地验证其本身及其梯度都是连续的。由于整个表达式是 $C^1$ 函数的复合,因此最终函数是连续可微的。
$$\begin{aligned} \text{Huber}(\text{Soft-Threshold}(x, \theta)) = \begin{cases} \delta x - \delta \theta - \frac{1}{2}\delta, & \text{if } x > \theta + \delta; \\ \frac{1}{2}(x - \theta)^2, & \text{if } \theta \leq x \leq \theta + \delta; \\ 0, & \text{if } x < \theta. \end{cases} \end{aligned}$$ $$\begin{aligned} \frac{\partial \operatorname{Huber}(\operatorname{Soft-Threshold}(x, \theta))}{\partial(x, \theta)}=\left\{\begin{array}{ll} (\delta,-\delta), & \text { if } x>\theta+\delta ; \\ (x-\theta, \theta-x) & \text { if } \theta \leq x \leq \theta+\delta ; \\ (0,0), & \text { if } x<\theta, \end{array}\right. \end{aligned}$$B. 稀疏矩阵乘法的实现细节
实现平台。我们在 gemma.cpp (CPU) 和 llama.cpp (GPU) 中实现了稀疏矩阵乘法。以下主要关注解码时批处理大小为 1 的情况。
两种稀疏乘法类型。在批大小为 1 的情况下,Spark FFN 和 Spark Attention 都利用了两种稀疏向量-矩阵乘法:
1. 向量掩码矩阵乘法 (Vector-masked matrix multiplication):如图 4a 所示,一个向量 $q$ 与矩阵 $K$ 的未被掩码的列相乘。掩码列的输出为零。在 Spark FFN 中,这用于计算 $K_2^\top q[r:]$。掩码向量由 $\text{Statistical-Top}_k(K_1^\top q[:r])$ 的输出生成。该实现会跳过从 DRAM (CPU) 或 HBM (GPU) 加载被掩码的矩阵列以及相关的计算。
2. 稀疏向量-矩阵乘法 (Sparse vector-matrix multiplication):如图 4b 所示,一个包含许多零的向量与一个密集矩阵相乘。在 Spark FFN 中,这用于将中间激活向量与矩阵 $V$ 相乘。实现时,矩阵 $V$ 按行存储,计算会跳过加载和处理与稀疏向量中零元素对应的行。
性能优化。
* CPU: 利用 SIMD 指令(与原始 Gemma 实现相同),并使用软件 CPU 预取 (__builtin_prefetch) 来重叠 DRAM 加载和计算。对于稀疏向量-矩阵乘法,还使用了 tiling 来最小化跨 CPU 核心的同步。
* GPU: 使用定制的 CUDA 核。
扩展到多批次和 Prefill。对于大于 1 的批处理大小或 prefill 阶段,我们继续使用每个向量各自的掩码来跳过计算。对于内存加载,则使用批内所有向量掩码的并集来创建一个统一的掩码。这意味着批次越大,内存加载的节省可能越少(除非 top-k 位置有显著重叠)。尽管如此,Spark Transformer 仍然能通过跳过基于单个掩码的计算来持续减少 FLOPs。
C. 额外的实验结果和细节
C.1. Statistical Top-k 输入的分布
验证高斯假设。Statistical Top-k 的一个基本假设是其输入向量(Spark FFN 中 GELU 前的激活和 Spark Attention 中 softmax 前的激活)的元素服从独立同分布的高斯分布。图 C.4 (FFN) 和图 C.5 (Attention) 的经验评估表明,这些激活向量的分布与高斯分布非常接近,从而证明了使用 Statistical Top-k 的合理性。图中还比较了根据拟合的高斯分布估计的截止值与在经验分布上获得目标稀疏度所需的真实截止值,两者非常接近。
图 C.4 | Spark FFN 中 statistical top-k 输入激活的条目分布(Spark Attention 的结果见图 C.5)。两行对应于输入在两个位置 0 和 1000 的激活,各列对应于 26 层预训练 Spark Transformer 在四个不同深度层 {0, 8, 16, 24} 的激活。输入是来自 https://paulgraham.com/articles.html 的第一篇文章的前 1000 个 token,前面加上了 BOS token。我们比较了经验分布(Empirical)与均值和标准差(std)为输入样本均值和 std 的高斯分布(Fitted)。我们看到高斯分布能很好地近似经验分布。我们还比较了从高斯分布估计的截止值,即公式 (10) 中使用的 $\theta(x,k)$,其中 $k/d = 5\%$(Cutoff for fitted),与在经验分布上获得 8% 非零值的截止值(Cutoff for empirical)。可以看出这两个截止值很接近。
图 C.5 | Spark Attention 中 statistical top-k 输入激活的条目分布(Spark FFN 的结果见图 C.4)。两行对应于两个不同注意力头的激活,各列对应于 26 层预训练 Spark Transformer 在四个不同深度层 {0, 8, 16, 24} 的激活。模型输入是来自 https://paulgraham.com/articles.html 的第一篇文章的前 1000 个 token,前面加上了 BOS token,我们检查了最后一个 token 的激活(即第 1001 个 token 的查询嵌入与所有 1001 个键嵌入的内积)。我们比较了经验分布(Empirical)与均值和标准差(std)为输入样本均值和 std 的高斯分布(Fitted)。我们看到高斯分布能很好地近似经验分布。我们还比较了从高斯分布估计的截止值,即公式 (10) 中使用的 $\theta(x,k)$,其中 $k=256$(Cutoff for fitted),与在经验分布上获得前 256 个条目的截止值(Cutoff for empirical)。可以看出这两个截止值很接近。
C.2. 评估期间的稀疏度水平
测试时稀疏度。作为对图 3(训练时稀疏度)的补充,图 C.1 展示了评估期间的稀疏度水平,证实了 Statistical Top-k 在测试时也能产生与训练时相同水平的稀疏度。图中展示了在一些任意选择的 token 上的结果。对于 Attention,特别选择了位置 512、1024 和 2048 的 token,这些都超过了我们为 Spark Attention 选择的 $k=256$。
图 C.1 | Spark FFN 和 Spark Attention 在评估期间的中间激活稀疏性(训练期间的结果见图 3)。对于 FFN,我们使用一个简单的提示“test”,并报告生成第 5、10 和 15 个 token 时的非零条目百分比。对于 Attention,我们报告在 prefill 期间第 512、1024 和 2048 个 token 处的非零条目数。
C.3. 批处理分析
不同批次大小下的性能。图 C.2 对比了 Spark Gemma 2 和 Gemma 2 在不同块大小下的 prefill 吞吐量(tokens/sec)。实验在 16 核 CPU VM 上使用 4096-token 的提示进行。
* Gemma-2: 随着批次/块大小增加,吞吐量显著提升,直到批大小为 8 时达到计算瓶颈。这是因为批处理通过在 CPU 缓存中重用权重来减少内存访问。
* Spark Transformer: 在批大小为 1 时加速最明显。从 1 增加到 2 时,由于高稀疏性导致 top-k 位置重叠少,吞吐量变化不大。当批大小超过 4 后,它也开始受益于权重重用。由于 FLOPs 需求降低,它直到批大小约 64 时才达到计算瓶颈,远晚于 Gemma-2。
* 结论: Spark Transformer 在两种场景下增益最显著:批大小为 1(桌面或移动设备解码的常见设置),以及批大小足够大以至于 FLOPs 成为主要瓶颈的场景。
图 C.2 | Spark Transformer vs. Gemma-2 在不同块大小下的 Prefill Token/秒。我们在 16 核 CPU VM 上使用长度为 4096 token 的提示。
C.4. Spark FFN 中 $r$ 和 $k$ 的影响
超参数消融研究。Spark FFN 有两个超参数:$r$ 控制低成本预测器的秩和 FLOPs,而 $k$ 控制激活稀疏度和 FLOPs。图 C.3 展示了对这两个超参数的消融研究(训练前 5% 的步骤)。
* $r$ 的影响: 如图 C.3a 所示,最佳选择是 $r=1024$,接近 $d_{\text{model}}/2=1152$。
* $k$ 的影响: 如图 C.3b 所示,模型质量对稀疏度在 [5%, 10%] 范围内的选择不敏感,但如果更稀疏(例如 3% 非零),则会出现质量损失。
图 C.3 | Spark FFN 中超参数 $r$ 和 $k$ 对训练损失的影响。应用了 $\sigma = 200$ 的高斯滤波器来平滑损失曲线。模型使用标准 Gemma-2 训练迭代次数的 1/20 进行训练。
C.5. 额外的消融研究
各组件效果分析。图 C.6 展示了 Gemma-2 和 Spark Transformer 在前 80k 训练步的训练损失曲线。Spark Transformer 的损失略高于 Gemma-2,但这种微小差异并未导致评估质量上的显著差异(如图 1b 所示)。
* Spark FFN vs Spark Attention: 实验显示,仅将标准 Attention 替换为 Spark Attention (曲线 "Gemma-2 + Spark Attention") 会比 Gemma-2 带来微小的质量提升。这表明 Spark FFN 的引入是导致模型质量轻微下降的原因。但总体而言,Spark FFN 和 Spark Attention 对质量的影响都不显著。
* 稀疏强制 vs 低成本预测器: 为了解这两个组件的独立作用,实验将 Statistical Top-k 直接应用于标准的 Gated FFN(如公式 (C.2))和标准 Attention(如公式 (C.4)),而不使用低成本预测器架构。结果(曲线 "Gemma-2 + Top-k")显示,训练损失显著增大,且与 Gemma-2 的差距随训练增加而扩大。这表明,低成本预测器不仅是为了降低预测非零项的成本,它还有助于弥补因引入 Statistical Top-k 而带来的质量差距。换言之,在 FFN 和 Attention 中加入低秩预测器使得 Transformer 更能适应激活稀疏化而无质量损失。
图 C.6 | 在前 80k 训练步(总共 500k 步)的训练损失消融研究。
C.6. 计算每 Token 的 FLOPs
FLOPs 节省分析。表 C.1 对比了标准 Transformer 和 Spark Transformer 各主要组件的 FLOPs。在标准 Transformer 中,计算成本主要由 FFN 和 Attention 点积决定。Spark Transformer 通过引入稀疏参数 $k_{\text{ff}}$ 和 $k_{\text{attn}}$ 来减少 FLOPs。对于 Gemma-2B,设置 $k_{\text{ff}} = 8\% \times d_{\text{ff}}$ 和 $k_{\text{attn}} = 256$,可以使 FFN 的 FLOPs 减少 3.2 倍,Attention 点积减少 4 倍,总体(假设 $n_{\text{ctx}}=8k$)减少 2.5 倍。
表 C.1 | 每 Token 的 FLOPs 比较:Spark Transformer vs. 标准 Transformer。在一个模型维度为 $d_{\text{model}}$ 的标准 Transformer 中,我们假设多头注意力的头维度之和等于 $d_{\text{model}}$,FFN 使用非门控激活且宽度为 $d_{\text{ff}}$。这里,$n_{\text{ctx}}$ 代表目标 token 的上下文长度。计算成本主要由 FFN(假设 $d_{\text{ff}} \gg d_{\text{model}}$,这很典型)和注意力点积(假设上下文很长)决定。Spark Transformers 引入稀疏参数 $k_{\text{ff}}$ 和 $k_{\text{attn}}$ 来减少 FLOPs。设置 $k_{\text{ff}} = 8\% \times d_{\text{ff}}$ 和 $k_{\text{attn}} = 256$ 可使 FFN 的 FLOPs 减少 3.2 倍,注意力点积减少 4 倍,总体(假设 $n_{\text{ctx}} = 8k$)对 Gemma-2B 减少 2.5 倍。
D. 补充讨论
D.1. 与 FFN 激活稀疏性相关工作的比较
性能对比。表 D.1 总结了近期在最新 LLM 中实现 FFN 激活稀疏性的工作。可以看出,Spark Transformer 实现了 -72% 的 FLOPs 减少,超过了所有其他方法,而质量损失仅为 -0.9%,低于大多数方法,仅次于 HiRE。HiRE 在质量损失上与我们相当,但 FLOPs 减少较少。
方法分类比较。
* 预训练方法: 从头开始预训练模型时就强制稀疏性,如 ReLUification【索引72,Relu strikes back: Exploiting activation sparsity in large language models+2023+arXiv】。Spark Transformer 属于此类。
* 微调方法: 在预训练好的模型上通过额外训练步骤引入稀疏性,如 ProSparse【索引93,Prosparse: Introducing and enhancing intrinsic activation sparsity within large language models+2024a+arXiv】和 HiRE【索引106,Hire: High recall approximate top-k estimation for efficient llm inference+2024+arXiv】。
* “零样本”方法: 无需额外训练,直接在预训练模型上强制稀疏性,如 CATS【索引56,Cats: Contextually-aware thresholding for sparsity in large language models+2024b+arXiv】。
预训练方法的合理性。虽然 Spark Transformer 需要从头预训练,成本高于微调方法,但对于长期服务而言,其更高的稀疏度和 FLOPs 减少带来的推理效率优势,可以摊销额外的预训练成本。
表 D.1 | 与强制 LLM 中 FFN 激活稀疏性的相关工作比较。Spark Transformer 在质量损失最小之一的情况下,实现了最大的 FLOPs 减少。
D.2. 关于 Statistical Top-k 的补充讨论
相较于现有工作的创新性。
1. 首次应用于激活稀疏性: 我们是第一个将 Statistical Top-k 用于 Transformer 激活稀疏性的工作,这可能促进该领域的发展。
2. 与 Transformer 的非平凡协同: 我们识别出应在非线性函数(FFN)和 softmax(Attention)之前的激活上应用 Statistical Top-k,因为这些激活在随机初始化时理论上服从高斯分布,并凭经验验证了其在训练后仍然可靠。我们从硬阈值扩展到软阈值,改善了优化环境。
3. 理论证明与概念联系: 我们首次为 Statistical Top-k 的正确性提供了理论证明(定理 3.1),并揭示了它与文献中其他 top-k 算子的概念联系(第 3.3 节)。
处理激活分片的情况。当训练大型 Spark Transformer,需要对激活向量进行分片时,可以采用两种方法应用 Statistical Top-k:
* 全局 Statistical Top-k: 每个设备计算本地的均值和方差,然后通过通信聚合得到全局的均值和方差。之后在每个设备上独立执行后续步骤。这种方法输出与不分片时完全相同,且额外的计算和通信成本都很低。
* 局部 Statistical Top-k: 每个设备仅在自己的分片上应用 Statistical Top-k'(其中 $k' = k/D$,$D$ 是设备数)。这种方法是次优的,但没有任何额外开销。
在设备数远小于激活向量长度的情况下,全局方法是自然的选择。
D.3. 与推测解码的协同作用
作为目标模型。推测解码的主要瓶颈是并行验证步骤。Spark Transformer 作为目标模型,可以直接降低这一瓶颈的延迟。尽管并行验证多个 token 会激活比单个 token 更多的神经元,但由于草稿规模小和“聚合稀疏性”(连续 token 激活的神经元重叠度高)现象,激活的并集仍然是稀疏的,从而有效降低了计算成本。
作为草稿模型。一个好的草稿模型需要快速且准确。Spark Transformer 以其接近无损的质量和低推理成本,成为理想的草稿模型。使用 Spark Transformer 作为草稿模型,有望比使用标准蒸馏模型获得更高的 token 接受率,从而最大化推测解码的效率。
D.4. 与量化技术的协同作用
正交且协同的优化。Spark Transformer 和量化是正交且高度协同的优化技术。Spark Transformer 减少操作总数和内存访问,而量化降低单个操作的成本,其综合效益预计是乘法性的。
促进激活量化。我们假设 Spark Transformer 可能比密集模型更能抵抗激活量化带来的精度损失。这源于 Statistical Top-k 的设计:
* 动态范围压缩: 其使用的软阈值操作(输出 $\max(x - \theta(x,k), 0)$)会从所有激活值中减去一个阈值,从而压缩激活张量的动态范围。这有助于缓解 LLM 激活量化中由大幅度“异常值”导致的问题。
* 零点稳定性: 大量(92%)未被选中的神经元被精确地设置为 0,这在任何量化格式中都可以无损表示,避免了“近零”噪声。
对于权重分片,我们预计其敏感性与密集的 Gemma-2 相似。总之,Spark Transformer 不仅与量化兼容,还可能主动促进量化,为进一步加速提供了稳健的路径。
💬 评论讨论
欢迎在这里分享您的想法和见解!