Mixtures of Subspaces for Bandwidth Efficient Context Parallel Training
Mixtures of Subspaces for Bandwidth Efficient Context Parallel Training
作者/机构: Sameera Ramasinghe, Ajanthan Thalaiyasingam, Hadi Mohaghegh Dolatabadi, Gil Avraham, Violetta Shevchenko, Yan Zuo, Chamin Hewa Koneputugodage, Alexander Long, Pluralis Research
主要贡献
本文针对在低带宽网络下进行去中心化上下文并行(Context Parallelism, CP)训练的通信瓶颈问题,提出了一种高效的激活压缩方法。随着大型语言模型(LLM)的上下文窗口不断扩展,CP 训练变得至关重要,但其在每个层、每个步骤中都需要在设备间广播大量的 attention 激活(keys 和 values),这对去中心化环境中的低带宽连接构成了严峻挑战。
核心问题:上下文并行训练需要在所有设备之间进行 all-to-all 通信,这在去中心化训练场景(如普通互联网连接)中会因带宽限制而变得极其昂贵,严重拖慢训练速度。现有方法都假设存在数据中心级别的高速互连。
研究目标:开发一种通信高效的上下文并行训练方法,使其能够在低带宽的去中心化环境中运行,同时保持与高速、集中式环境相当的收敛性能和模型质量。
创新点与核心洞察:
1. 利用激活的内在低秩结构:本文的核心洞察是,Transformer 模型中的 attention 激活(queries, keys, values)本质上存在于一个低维流形上,即具有显著的低秩结构。
2. 动态子空间混合:基于上述洞察,本文提出通过高效的重参数化方法,将激活输出动态地约束到学习到的低维子空间混合中。具体来说,通过将 attention 权重分解为低秩形式 $W = BUU^T$,并联合优化 B 和子空间基 U,从而实现对激活的压缩。
3. 高效的优化方案:为了避免在 Stiefel 流形上直接优化 U 所带来的高昂计算成本(如 SVD/QR 分解),本文提出了一种重参数化方案,使用一个固定的正交基和一个可学习的旋转矩阵 $R(\theta)$ 来表示 U。这种方法在保留优化理论保证(如收敛性)的同时,显著降低了计算开销。
4. 数据自适应压缩:为了提高表达能力,模型为每个输入序列的块(chunk)预测一个唯一的旋转参数 $\theta$,从而生成数据依赖的动态子空间。这通过一个轻量级的线性预测头实现,几乎不增加额外开销。
5. 与标准架构兼容:该方法引入的额外组件(旋转预测头和投影矩阵)可以在训练完成后被移除,使模型恢复为标准的 Transformer 架构。这确保了与现有推理基础设施和部署框架的完全兼容性。
主要成果:本文提出的方法实现了超过 95% 的通信压缩率,且训练开销可忽略不计,模型收敛性没有损失。实验证明,该方法能够将在 300Mbps 带宽网络下训练的十亿参数级别模型的上下文长度扩展到超过 100K tokens,其墙钟收敛速度与在 100Gbps 高速网络下训练的集中式模型相当。
背景知识与关键发现
2.1 上下文并行训练与通信瓶颈
上下文并行训练(CP)的机制。我们首先简要阐述CP训练,并建议读者阅读【11,The llama 3 herd of models,2024,arXiv】以获得更详尽的介绍。Transformer的注意力机制要求每个query与所有的key-value对进行交互,导致计算复杂度随上下文窗口呈二次方增长。这对于长序列来说尤其难以承受,因此并行化策略至关重要。在上下文并行设置中,输入序列 $X \in R^{n \times d}$(其中n是上下文长度,d是模型维度)会沿着上下文维度被划分到m个设备上:
$$X = [X_1^\top \dots X_m^\top]^\top, \quad X_i \in \mathbb{R}^{n_i \times d}, \quad \sum_{i=1}^m n_i = n.$$每个设备 i 为每个注意力头计算局部的 queries、keys 和 values。为清晰起见,我们省略了头的索引;除非另有说明,所有量都被理解为每个注意力头的量:
$$Q_{i}=X_{i} W_{q} \in \mathbb{R}^{n_{i} \times d}, \quad K_{i}=X_{i} W_{k} \in \mathbb{R}^{n_{i} \times d}, \quad V_{i}=X_{i} W_{v} \in \mathbb{R}^{n_{i} \times d} .$$通信瓶颈。在本地计算注意力需要全局访问 keys 和 values:
$$K_{\mathrm{g}}=\left[K_{1}^{\top} \dots K_{m}^{\top}\right]^{\top} \in \mathbb{R}^{n \times d}, \quad V_{\mathrm{g}}=\left[V_{1}^{\top} \dots V_{m}^{\top}\right]^{\top} \in \mathbb{R}^{n \times d},$$通常,CP 会执行某种形式的 all-gather 操作,即每个设备广播其本地的 $K_i, V_i$ 以形成全局的 $K_g, V_g$,这会给每个设备带来 $O(nd)$ 的通信成本,其中 $d \ll n$。最近提出的 Ring Attention 【29,Ring attention with blockwise transformers for near-infinite context,2024,ICLR】在环形拓扑中对此通信进行流水线处理,在每个阶段增量地交换本地的 key-value 块并计算部分注意力。然而,上述方法从根本上依赖于昂贵的大型 K、V 矩阵的通信。
2.2 注意力输出展现出低秩结构
核心观察。我们的压缩方案受到一个观察的启发:预训练 Transformer 的注意力输出位于一个低维流形上。为了支持这一观点,我们分析了公开可用的大规模预训练 LLM 的检查点,并检查了它们的注意力激活。图 1 展示了 LLAMA 70B 的一个示例。
稳定秩分析。具体来说,我们测量了每个注意力层中 query (Q)、key (K) 和 value (V) 激活的稳定秩。一个矩阵 $A \in R^{n \times d}$ 的稳定秩定义为:$srank(A) = \frac{\|A\|_F^2}{\|A\|_2^2}$,其中 $\|A\|_F$ 表示 Frobenius 范数,$\|A\|_2$ 表示谱范数。与传统的矩阵秩(对小扰动和数值噪声高度敏感)不同,稳定秩提供了一个稳健、连续的有效维度度量。这使得它特别适合于表征学习到的神经表示,其中许多奇异值通常很小但由于噪声或过参数化而非零。
低秩现象。如图 1 所示,注意力激活的稳定秩在所有层中都保持较低水平。有趣的是,Q 和 K 的秩通常略低于 V,这表明它们的压缩程度更高。这一观察是我们方法的基础,即利用低秩分解进行高效压缩。附录 B 中提供了其他架构中此现象的进一步证据。接下来,我们将这个想法形式化。
方法
基于低秩结构的压缩。现在我们介绍我们提出的高效上下文并行 Transformer 训练方法。首先,我们形式化地说明了在注意力激活中凭经验观察到的低秩结构如何实现有效压缩。在 §2.2 中,我们看到大型预训练 Transformer 的 Q、K、V 激活表现出显著的低秩结构(图 1)。这一发现意味着,在设备之间仅传输这些激活的低维分量是可行的,从而在实践中实现近乎无损的压缩。形式上,让一个正交矩阵 $U \in R^{d \times r}$(其中 $r \ll d$)的列张成激活的主要子空间。我们不是通信完整的局部激活矩阵 $Z = X^{(i)}W \in R^{n_i \times d}$(其中 $Z \in \{K, V\}$ 表示 key/value 激活, $W \in \{W_k, W_v\}$,$X$ 是注意力层的输入),而是只传输其压缩表示:$Z_{comp} = X^{(i)}W U \in R^{n_i \times r}$。原始激活随后可以在接收节点上重建为:$Z \approx Z_{comp}U^\top$。这种压缩方法保留了由 U 张成的子空间内的所有信息,并且当激活完全位于该子空间时是无损的。等价地,这个投影可以被整合到注意力权重中,并被解释为将它们分解成一个低秩表示:$W = B(U U^\top)$,$B \in R^{d \times d}$。
固定子空间的次优性。上述公式隐含地假设,先验选择的 U 对于每个层、每个块以及每个优化阶段都具有足够的表达能力。很容易看出这个假设在何处会失效。即使存在一个最优的低秩注意力权重矩阵,将权重限制为 $W = BU U^\top$ 的形式会将搜索范围限制在 U 的列空间内。如果这个空间不包含真正的最优解,模型可能会收敛到一个次优解。简而言之,固定 U 可能会阻止模型达到最佳性能。
3.1 在乘积流形上的联合优化
联合优化策略。为了解决固定 U 的局限性,我们提出联合优化分解 $W = B U U^\top$。具体来说,我们同时在乘积流形 $M = R^{d \times d} \times St(n, r)$ 上学习子空间表示 U 和矩阵 B。这里,$B \in R^{d \times d}$ 在标准的欧几里得空间中进行优化,而 U 位于 Stiefel 流形 $St(n, r)$ 上,其更新可以通过黎曼梯度下降自然地执行。以下结果表明,这种联合优化在梯度下降下能实现线性收敛。
收敛性保证。设 $\Phi(W, \vartheta)$ 是一个平滑损失函数,并考虑注意力权重的分解 $W = B U U^\top$,其中 $B \in R^{d \times d}$,$U \in St(d, r)$,$\vartheta \in R^p$ 表示所有其他参数。在乘积流形 $M := R^{d \times d} \times St(d, r) \times R^p$ 上使用黎曼梯度下降最小化重参数化目标 $\Phi(\hat{B}, U, \vartheta) = \Phi(BUU^\top, \vartheta)$,在温和的假设下,可以得到 Q-线性(几何)收敛到一个一阶驻点。正式结果和证明请参见附录中的引理 1。请注意,由于 $\|U\| = 1$,分解后的目标函数仍然是 Lipschitz 平滑的,收敛结果自然地源于欧几里得和黎曼流形上的标准梯度下降理论。我们在附录 A 中提供了一个完整的证明,明确处理了乘积流形结构,并假设了 Polyak–Łojasiewicz (PL) 条件。
3.2 通过 U 的重参数化降低计算成本
重参数化的动机与方法。通过黎曼梯度下降直接在 Stiefel 流形 $St(n, r)$ 上优化 U 提供了强大的理论保证,但成本高昂:每次欧几里得更新后,U 必须被重新正交化(即在流形上的标准“收缩”操作),这通常通过 QR 或 SVD 分解来恢复 $U^\top U = I_r$。为了缓解这个问题,我们使用一个固定的正交基 $\bar{U}$ 和一个可学习的旋转矩阵 $R(\theta) \in O(d)$ 对 U 进行高效的重参数化:
$$U(\theta) = R(\theta) \overline{U},$$其中 $O(d)$ 表示由所有 $d \times d$ 正交矩阵构成的正交流形。如果映射 $\theta \to R(\theta)$ 具有足够的表达能力,旋转矩阵 $R(\theta)$ 可以完全参数化 $O(d)$,从而在保持流形表示能力的同时,显著减少计算开销。
几何结构和驻点的保持。将 U 重参数化为 $U(\theta) = R(\theta)\bar{U}$ 将正交约束转移到了一个无约束的欧几里得变量 $\theta$ 上,从而消除了昂贵的 QR/SVD 步骤,并允许我们在 $\theta$ 空间中运行普通的 SGD/Adam。一个自然的担忧是,这种变量变换可能会扭曲损失函数的景观并阻碍优化;然而,我们证明情况并非如此。链式法则表明 $\nabla_\theta\Phi(\hat{B}, \theta, \vartheta) = D_\theta U(\theta)^\top \text{grad}_U \Phi(B, U(\theta), \vartheta)$,因此 $\nabla_\theta\Phi$ 正是原始黎曼梯度的拉回(pull-back)。因此,一阶临界点保持不变。以下陈述将此结果形式化。
驻点的等价性。在重参数化 $U = R(\theta)\bar{U}$ 下,最小化 $\hat{\Phi}$ 与最小化 $\Phi$ 拥有完全相同的局部最小值和严格鞍点。正式结果和证明请参见附录中的定理 1。因此,我们可以通过旋转有效地隐式表示和优化投影子空间,而不会损害解的质量或最优性。
3.3 降低通信成本
紧凑的旋转表示。重参数化 $U(\theta) = R(\theta)\bar{U}$ 允许我们在每个节点上本地缓存固定的正交框架 $\bar{U}$,并仅传输参数 $\theta$。然而,为了完全参数化正交群 $O(d)$ 中的旋转,通常需要 $\frac{1}{2}d(d-1)$ 个参数,即 $\theta \in R^{d(d-1)/2}$。我们接下来将展示,在实践中对所有可能的旋转进行密集搜索是不必要的。具体来说,我们可以通过控制 $\theta$ 的维度来在搜索空间和通信效率之间取得权衡。
低维李子群。为了获得更紧凑的表示以降低通信成本,我们选择一小组固定的斜对称矩阵 $\{A_1, \dots, A_k\} \subset \mathfrak{o}(d)$,$A_i^T = -A_i$(其中 $\mathfrak{o}(d)$ 表示正交群的李代数),并定义相应的 k 维李子群 【16, Lie Groups, Lie Algebras, and Representations: An Elementary Introduction, 2015, Springer】, 【12, The geometry of algorithms with orthogonality constraints, 1998, SIAM Journal on Matrix Analysis and Applications】:
$$\mathcal{R}_k = \left\{ R(\theta) = \exp \left( \textstyle \sum_{l=1}^k \theta(l) A_i \right) \mid \theta \in \mathbb{R}^k \right\},$$其中 $\theta(l)$ 是 $\theta$ 的第 l 个元素。因为指数映射是 $\theta = 0$ 附近的一个局部微分同胚,对于足够小的 $\|\theta\|$,集合 $U(\theta) = R(\theta)\bar{U}$ 构成了 $St(d, r)$ 的一个 k 维子流形。因此,选择 $k \ll \frac{1}{2}d(d-1)$ 可以在通信成本和表示灵活性之间提供一个有利的权衡。重要的是,我们之前关于不存在伪局部最小值的结论仍然有效,只要一个最优框架 $U^*$ 位于可达流形 $\{R\bar{U} : R \in \mathcal{R}_k\}$ 之内(或足够接近),因为映射 $\theta \to U(\theta)$ 在这个流形上仍然是局部满射的。
3.4 通过逐块旋转实现动态子空间混合
权衡与先验知识。§3.3 指出,旋转维度 k 控制着表示灵活性和通信效率之间的权衡。通过一个精心选择的先验 U,使用较小的 k 变得可行,将优化限制在 U 的一个局部邻域内。
生成先验子空间。我们通过一个短暂的、不压缩的预热阶段来生成这个先验。在此阶段,模型使用较短的上下文长度进行少量迭代(< 500次)训练,以避免通信瓶颈。此阶段后,每个节点计算其局部注意力权重的 top-r 主成分,并将它们存储为固定的子空间基 $\bar{U} \in St(d, r)$。先前关于权重-子空间稳定性的工作(例如,【13, Rank diminishing in deep neural networks, 2022, NeurIPS】, 【18, SP3: Enhancing structured pruning via pca projection, 2024, ACL (Findings)】)的经验证据表明,主导激活子空间在训练早期就会出现,这支持了我们的策略。
逐样本自适应。对所有输入使用单一的全局旋转可能会对异构数据产生欠拟合。为了在不增加 k 的情况下保持表达能力,我们引入了一个轻量级机制来为每个序列块预测一个唯一的旋转参数 $\theta$。回想一下,在上下文并行训练中,每个节点 i 处理来自输入序列的不同块 $X_i \in R^{n_i \times d}$。对于一个注意力输出块 $Z_i = W X_i$,我们采用一个小的线性预测头:$\psi: R^d \to R^k$,$\theta = \psi(Z_{\text{avg},i})$,其中 $Z_{\text{avg},i}$ 是该块的平均注意力输出,从而生成特定于块的旋转参数。给定一组在每个节点上本地缓存的预共享斜对称生成元 $\{A_l\}_{l=1}^k \subset \mathfrak{o}(d)$,我们构造旋转矩阵为:$R(\theta_i) = \exp(\sum_{l=1}^k \theta_i(l)A_l) \in \mathcal{R}_k \subset O(d)$。在本地,keys 和 values 被压缩为 $Z_{\text{comp},i} = Z_i R(\theta_i) \bar{U} \in R^{n \times r}$。然后广播 $(Z_{\text{comp},i}, \theta_i)$,接收节点将 keys/values 重建为:$Z_i \approx Z_{(i, \text{comp})} \bar{U}^\top R(\theta_i)^\top$。注意,峰值内存由注意力计算主导,其规模为 $O(n^2)$,这使得线性头的开销 $O(dk)$ 可以忽略不计——我们也在经验上证明了这一观察。整个过程总结在算法 1 中。
带宽成本。在我们的方法中,每个节点传输 $nr$ 个浮点数(激活)在 $\tilde{Z}$ 中,以及 $\theta$ 中的 k 个额外标量。通常,我们有 $k \ll nr \ll nd$,确保了较低的通信开销。值得注意的是,我们发现即使使用 $k=1$——即一个定义了一个平面的单一旋转角度——也足以保持训练稳定性和输入自适应的灵活性,实现了与固定全局旋转相当的带宽效率。在实现中,我们设置 $S \sim N(0, 1)^{d \times d}$ 为固定的,并定义斜对称生成元 $A := \frac{\theta}{\|S - S^\top\|}(S - S^\top)$,$A^\top = -A$,$\theta \ge 0$。对于 $\theta \in R$,我们设置旋转 $R(\theta) = \exp(\theta A) \in O(d)$,因此 A 固定了旋转平面,而 $\theta$ 设置其大小。
二阶近似。由于 A 是斜对称的,其谱范数满足 $\|A\|_2 = \theta$。对于足够小的 $|\theta| \le \epsilon \ll 1$,旋转矩阵 $R(\theta)$ 允许一个二阶泰勒近似:
$$R(\theta) \approx I+\theta A+\frac{1}{2} \theta^2 A^2.$$这种近似提供了两个关键优势。1) 计算成本:其计算复杂度为 $O(d^2)$,而精确的矩阵指数计算(例如,通过 Padé 或 Schur 分解)的复杂度为 $O(d^3)$。2) 近单位矩阵偏置:它引入了一个有益的近单位矩阵偏置,当 $\theta$ 很小且中心在零附近时(通过裁剪强制实现),它有效地充当了单位矩阵 I 的近似无偏估计量。在这种情况下,高阶项在期望中消失,得到 $E[R(\theta)] \approx I$。这个属性允许旋转保持在初始预热子空间 U 附近,从而促进受控的局部自适应而不会发生显著漂移。通过固定 A 并使用标量 $\theta$,我们实现了 $O(nr)$ 的通信复杂度,远低于朴素的 $O(nd)$。
权重同步。注意力权重仍然需要在设备之间同步,但它们的演化速度远慢于激活 【6, Synchronize only the immature parameters: Communication-efficient federated learning by freezing parameters adaptively, 2023, IEEE Transactions on Parallel and Distributed Systems】, 【5, Using fast weights to attend to the recent past, 2016, NeurIPS】。因此,我们仅每 c 步平均一次相应的权重;在所有实验中,我们使用 c = 200,这带来的通信开销可以忽略不计。
3.5 移除投影组件
动机。我们的方法通过两个非标准组件增强了 Transformer 架构:(i)一个预测旋转参数 $\theta$ 的小型线性旋转头,以及(ii)用于压缩激活的低秩投影矩阵 U。尽管这些组件在训练期间的开销很小,但对于某些下游应用,可能需要与现成的 Transformer 模型保持严格的 API 兼容性。
权重坍缩机制。随着训练的进行,与我们的辅助投影头相关的可学习权重会坍缩到它们所引导的数据相关子空间上。因此,当模型接近收敛时,我们可以完全丢弃这些头,恢复到一个普通的 Transformer,而不会损失在训练期间积累的预测增益。以下结果形式化了这种坍缩机制。
数据依赖投影下“空闲”注意力方向的界限。设样本投影矩阵为 $P(x) = U(x)U(x)^\top$。选择任何其他投影到任意子空间的投影矩阵 Q。定义平均重叠度 $p_Q := E_x[\|P(x)Q\|^2] \in [0, 1]$。使用权重衰减 $\lambda > 0$ 运行随机梯度下降。那么,位于 Q-子空间内的注意力权重满足 $\lim_{t\to\infty} \|W(t)Q\|_F \le \frac{p_Q L}{\lambda}$,其中 L 是损失函数的 Lipschitz 界。因此,如果数据几乎从不激发这些方向($p_Q \ll 1$),相应的权重就会缩减。也就是说,空闲方向被免费剪枝了。正式的定理和证明,请参见附录中的定理 3。
恢复标准架构。一旦权重坍缩到它们的数据对齐子空间上,旋转头及其基矩阵 U 就变得冗余了。因此,我们可以分离这些组件,并进行一个短暂的、低学习率的微调过程来打磨剩余的参数。如图 4 所示,损失曲线在这一转变过程中保持平滑,表明没有引入优化冲击。在推理时,该模型现在与标准 Transformer 无法区分:它不增加任何额外参数,不需要自定义核函数,并且与现有的部署流水线完全兼容。
相关工作
去中心化训练。去中心化学习摒弃了中心协调器,而是依赖于一个由自主设备组成的集体,这些设备通过网状网络合作训练大规模模型。这些设备通常是异构且地理分散的,面临着延迟和带宽不均匀的连接。关于收敛性和鲁棒性的基础理论已由【24, Can decentralized algorithms outperform centralized algorithms? a case study for decentralized parallel stochastic gradient descent, 2017, NeurIPS】, 【22, Decentralized stochastic optimization and gossip algorithms with compressed communication, 2019, ICML】, 【21, A unified theory of decentralized sgd with changing topology and local updates, 2020, ICML】建立,而互补的系统工作则展示了在真实集群上的实际可行性【41, Towards crowdsourced training of large neural networks using decentralized mixture-of-experts, 2020, NeurIPS】, 【7, Distributed deep learning in open collaborations, 2021, NeurIPS】。然而,大多数现有技术仅限于 DDP(分布式数据并行)设置【24, 22, 21, 7】,将模型大小限制在单个节点的总内存内。请注意,这是一个相对研究充分的领域,与我们探索的未被探索的去中心化上下文并行设置是正交的。DDP 领域一个值得注意的工作是 Power Gossip【48, Powergossip: Practical low-rank communication compression in decentralized deep learning, 2020, arXiv】,它用邻近副本之间以 gossip 方式交换信息取代了同步的 all-to-all 通信。其关键洞见是,当每个副本通过本地 SGD 独立训练时,成对的权重差异在一个低秩子空间中演化,使它们能够在 gossip 期间有效地压缩权重差异。另一个有趣的 DDP 方法是 Photon【42, Photon: Federated llm pre-training, 2024, arXiv】,其通信节省主要源于不频繁的梯度交换,而非任何显式的压缩方案。这种跳过同步的方法在上下文并行流水线中是不可行的,因为在每个前向和后向传播中,激活必须在节点之间传输。尽管如此,这些 DDP 风格的技术与我们的方法是正交的,并且可以在混合设置中与我们的方法结合使用。
调度导向的方法。像 SWARM parallelism【39, Swarm parallelism: Training large models can be surprisingly communication-efficient, 2023, ICML】和 Tasklets【53, Decentralized training of foundation models in heterogeneous environments, 2022, NeurIPS】这样的调度导向方法缓解了掉队者效应和网络随机性,但它们仍然继承了去中心化设置固有的通信开销。相比之下,我们引入了第一个专为 CP 定制的通信压缩策略,消除了一个关键的带宽瓶颈,这个瓶颈迄今为止阻碍了去中心化模型在更长上下文窗口上的扩展。
上下文并行注意力。对于单设备长序列处理,稀疏近似(如 BigBird【54, Big bird: Transformers for longer sequences, 2020, NeurIPS】)将注意力复杂度降低到 $O(n)$,而 IO 感知的精确核函数(如 FlashAttention【33, Faster causal attention over large sequences through sparse flash attention, 2023, arXiv】)通过分块和片上缓存最大化了硬件吞吐量。最近的系统研究将序列维度本身并行化【25, Torchtitan: One-stop pytorch native solution for production ready llm pre-training, 2024, arXiv】, 【37, Deepspeed: System optimizations enable training deep learning models with over 100 billion parameters, 2020, KDD】, 【15, The llama 3 herd of models, 2024, arXiv】。Blockwise Parallel Transformers【28, Blockwise parallel transformers for large context models, 2023, NeurIPS】通过重叠计算和 ring-all-reduce,在 32K token 的序列上实现了近线性的加速,而 RingAttention【29, Ring attention with blockwise transformers for near-infinite context, 2024, ICLR】通过流水线化的块交换将这一思想扩展到几乎无限的上下文。然而,这些方法仍然广播完整的 key/value 张量。我们的方法则传输一个紧凑的低秩表示和一个轻量级的旋转参数,在保留精确注意力语义的同时减少了带宽,从而补充了现有的上下文并行框架。
实验环境
我们在三个大规模语料库上评估仅解码器模型——FineWeb (FW)【30, Pointer sentinel mixture models, 2016】,C4【35, Exploring the limits of transfer learning with a unified text-to-text transformer, 2019, arXiv】和 BookCorpus (BC)【56, Aligning books and movies: Towards story-like visual explanations by watching movies and reading books, 2015, ICCV】。对于每个数据集,我们保留了10%的训练集用于验证。所有模型骨干都遵循 LLAMA 3【11, The llama 3 herd of models, 2024, arXiv】;具体的模型规格在相应章节中给出。我们使用基础学习率 = $3 \times 10^{-4}$,并采用线性预热和衰减,应用权重衰减 = 0.01。除了最后一个块外,每个 transformer 层的 K 和 V 投影分别被压缩了 98% 和 95%(总体压缩率 96.5%),这是通过相对于 d 适当选择 r 来实现的。我们对所有模型使用 GPT2 tokenizer。
实验结果
带宽效率在去中心化设置中的表现
我们在两种网络设置下训练一个 8 层、8 亿参数的模型(嵌入大小 = 2048,注意力头 = 8):一个集中的 100Gbps 网络和一个去中心化的 300Mbps 互联网级链路。我们使用上下文并行(CP)在八个 A100 GPU 上处理 132K token 的序列长度,这些 GPU 分别以相应的带宽连接。图 2 显示,在 300Mbps 链路上使用普通 CP 比在集中的 100Gbps 网络上慢了 20 倍以上。通过我们的压缩技术,同样的 300Mbps 设置的收敛速度几乎与集中式基线一样快。
验证结果。表 2 报告了训练后模型的测试性能。为此,我们遵循 Chinchilla 缩放定律【17, Training compute-optimal large language models】将每个模型训练到其计算最优点。具体来说,对于我们 8 亿参数的模型,我们使用了 1:20 的模型-token 比例,并在每个数据集上训练了 160 亿个 token。值得注意的是,我们的压缩去中心化模型在相同的训练迭代次数下,其困惑度(perplexity)与集中式模型相当,甚至略优,同时在普通链路上比原生(未压缩)CP 提供了显著更高的吞吐量。在低带宽链路上将未压缩模型训练至完成在计算上是不可行的(估计超过 150 天),因此我们仅报告其在此设置下的吞吐量(TPS)。
消融实验
设计消融。我们在八个 A100 GPU 上对 8 亿参数模型进行消融实验,上下文长度为 132K(见表 1)。使用学习到的旋转的模型性能优于使用固定或随机投影的模型。二阶指数近似不影响性能,证实了其充分性。省略主方向的预热初始化($\bar{U}$)会明显降低结果,突显了此先验的重要性。
扩展性。我们的基于压缩的 CP 方案具有良好的扩展性,可以与其他并行训练策略无缝融合。我们将模型扩展到 32 层(30 亿参数),在 32 个 A100 上同时启用流水线并行和 CP(图 3),并实现了显著的吞吐量增益。
重参数化。我们方法的一个关键步骤是重参数化 U,这绕过了昂贵的黎曼操作(QR/SVD 回调)。如表 4 所示,这种重参数化显著提高了吞吐量(TPS)。更多关于架构选择的消融实验见附录 C。
预热步数。为了衡量预热步数的影响,我们进行了一项消融研究,改变了预热时长并评估了在 FineWeb 数据集上的困惑度。结果如表 3 所示。如图所示,即使预热步数减少到 300 步,模型也取得了相当的性能,表明没有显著的性能下降。在实践中,我们默认使用 500 步来提供一个安全稳定的基线。这项研究进一步强调了我们预热策略的轻量级和鲁棒性,特别是与现代 LLM 预训练中常用的更复杂的调度机制相比。请注意,困惑度的差异很小且稳定,表明在 300 步预热后性能是稳定的。
移除投影和旋转头
如 §3.5 所述,从业者可能希望在预训练后恢复到标准的 Transformer 架构,以兼容下游框架。我们凭经验验证了我们的理论预测,即注意力权重会逐渐与投影子空间对齐,从而允许在训练末期安全地移除投影层和旋转头。图 4 显示,在训练后期移除这些组件可以保持收敛,而过早移除会中断训练。
与基线方法的比较
由于目前没有针对 CP 压缩的基线方法,我们构建了两个:(i)稀疏化——一个 Top-10% 方案(90% 压缩),只传输 K、V 块中幅度最大的条目,灵感来自常见的 DDP 压缩方法;(ii)量化——在传输前对 K、V 激活进行 4-bit 量化(75% 压缩),遵循激活压缩的标准实践。如图 5(左)所示,即使我们使用了更激进的 96.5% 压缩率,我们的方法也全面优于这些基线(132k 上下文窗口)。
为求全面,我们还与长上下文模型 BigBird【54, Big bird: Transformers for longer sequences, 2020, NeurIPS】和 CosFormer【34, cosformer: Rethinking softmax in attention, 2022, arXiv】进行了比较,它们并非为 CP 设计,在 A100 上最多只能处理 32K token。为了公平比较,我们将我们的压缩应用于跨四个 GPU 的 CP,每个 GPU 处理 8K token。如图 5(右)所示,两个基线的收敛性都明显差于我们的方法。所有实验都是在 8 亿参数模型上进行的。
结论
我们提出了第一种压缩方法,使得在具有低带宽互连的去中心化环境中能够对语言模型进行上下文并行训练。我们的方法支持在通过互联网级链路(例如 300Mbps)连接的独立 GPU 上,使用超过 100K token 的上下文长度进行训练,同时其墙钟收敛速度与具有高速(100Gbps)连接的集中式系统相当。此外,我们的方法通过允许在训练后移除投影层,保持了与标准 Transformer 架构的兼容性,从而便于在下游框架中无缝部署。我们对我们方法的关键特性进行了理论分析,并通过广泛的经验评估验证了其有效性。
补充细节
局限性。我们的压缩方法在上下文并行训练中实现了近乎无损的收敛,但仍有几个悬而未决的问题。首先,除了简单的子空间旋转之外,其他的重参数化方法可能会带来进一步的准确性或效率提升。其次,该方法即使在搜索空间被大幅缩减(通过非常低维的 $\theta$)的情况下仍能找到好的最小值的惊人能力,缺乏一个严谨的解释;其与最近关于隐式正则化和彩票假设(lottery-ticket-style)现象的研究之间的联系值得更深入的探讨。尽管存在这些不足,这项工作为上下文并行压缩建立了第一个基准,我们希望它能激发更深入的理论和实证探索。
💬 评论讨论
欢迎在这里分享您的想法和见解!