Virtual Width Networks
Virtual Width Networks
作者/机构: 1ByteDance Seed (贡献者见文末)
A1 主要贡献
核心问题: 现有研究表明,根据缩放定律【15, Training compute-optimal large language models, 2022, arXiv】, 【20, Scaling laws for neural language models, 2020, arXiv】,增加模型宽度(隐藏维度)可以显著提升模型性能,因为它允许在每个向量中编码更丰富的信息。然而,简单地增加隐藏维度会导致参数和计算量呈二次方增长,这在资源受限的情况下构成了重大挑战。混合专家(MoE)架构虽然扩展了模型容量,但其主干网络的隐藏维度仍然是固定的,这成为了模型表征能力的瓶颈。本文旨在解决一个核心问题:我们能否在获得更宽表示所带来的好处的同时,避免传统宽度缩放所导致的二次方成本爆炸?
研究目标: 本文的目标是提出一个通用框架,该框架能够扩展词元嵌入的宽度,同时保持Transformer主干网络的隐藏维度不变。通过解耦表示宽度和主干网络宽度,研究旨在以接近恒定的计算成本来扩展嵌入空间,从而提升模型的训练效率和性能。此外,研究还旨在探索虚拟宽度缩放作为一种新的大模型效率提升维度,并寻找其与模型性能之间的缩放关系。
创新点:
本文的主要贡献如下:
* 虚拟宽度网络 (Virtual Width Networks, VWN): 提出了一种名为VWN的新框架,它通过广义超连接(Generalized Hyper-Connections, GHC)将嵌入宽度与主干网络宽度解耦,实现了以最小的额外计算成本进行r倍的虚拟宽度扩展。通过系统的扩展实验,本文发现并揭示了虚拟宽度因子r与损失之间存在一个对数-线性(log-linear)的缩放定律,为虚拟宽度如何影响模型性能提供了新的见解。
-
广义超连接 (Generalized Hyper-Connections, GHC): 本文形式化了GHC,它是一个统一的公式,涵盖了之前的变体(如超连接和分数连接),并提供了在虚拟隐藏状态和主干网络隐藏状态之间进行灵活路由和混合的机制。
-
与多词元预测 (Multi-Token Prediction, MTP) 的协同作用: 本文证明了VWN与MTP能够产生协同效应,在下游任务的准确性上取得了一致的改进。直观上,MTP更密集的监督信号能够充分利用扩展的虚拟空间,而VWN提供的额外表示自由度则改善了模型的短程组合建模能力。
图 1 使用虚拟宽度网络(VWN)在一个33亿激活参数的MoE模型上进行的大规模实验结果。我们比较了基线MoE-A3.3B和配置了虚拟宽度因子r=8的MoE-A3.3B-VWNx8。左图和中图:针对下一词元和下两词元预测的训练损失与已见词元数的关系。VWN分别使用比基线少2.5倍和3.5倍的词元达到了相同的损失。右图:在一系列开源基准测试上的平均准确率(见表3),分数使用内部定义的任务权重进行聚合。在这种加权方案下,1个点的差异对应着显著的性能差距。
图 2 标准Transformer与虚拟宽度网络(VWN)的对比。(a) 标准Transformer的嵌入和主干网络使用相同的宽度。(b) 简单的宽度缩放会按比例扩展两者,导致参数和计算量二次方增长。(c) VWN将嵌入宽度与主干网络宽度解耦。通过广义超连接,超宽嵌入(例如1.5倍)可以与标准宽度的主干网络相结合,以最小的计算开销增加表示能力。
A3 背景知识
2 相关工作
模型容量的扩展。Transformer模型通过增加模型宽度、深度和数据规模展现了强大的性能提升【15, Training compute-optimal large language models, 2022, arXiv】, 【20, Scaling laws for neural language models, 2020, arXiv】。然而,增加隐藏维度通常会导致参数和计算量呈二次方增长,在资源受限的环境中构成挑战。为将模型容量与计算量解耦,研究者们提出了几种方法。例如,混合专家(MoE)模型【9, Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity, 2022, Journal of Machine Learning Research】, 【22, Gshard: Scaling giant models with conditional computation and automatic sharding, 2020, arXiv】, 【27, The sparsely-gated mixture-of-experts layer, 2017, Outrageously large neural networks】通过条件性地激活子网络来有效地扩展模型规模。我们的方法在增加有效容量的同时,避免了通常与加宽隐藏维度相关的二次方计算成本。这是通过将嵌入宽度与主干网络的隐藏大小解耦来实现的。
超连接与分数连接。超连接(Hyper-Connections, HC)【36, Hyper-connections, 2024, arXiv】和AltUp【2, Alternating updates for efficient transformers, 2023, Advances in Neural Information Processing Systems】通过跨层的低成本组合链接来扩展隐藏维度,从而增强模型的表达能力。然而,在大的隐藏空间中,HC通常未能充分利用扩展的表示,因为每个扩展部分仅由少数标量权重更新,难以完全发掘额外的容量。分数连接(Frac-Connections, FC)【37, Frac-connections: Fractional extension of hyper-connections, 2025, arXiv】采取了相反的策略:它们不是扩大隐藏大小,而是将现有的隐藏维度划分为多个更小的段,从而在不增加模型宽度的情况下实现类似HC的连接性。我们提出的广义超连接(GHC)整合了两者的优点——既扩展了隐藏维度,又将其进一步细分为结构化的子状态。这种设计为容量使用提供了细粒度的控制,并能更有效地利用扩展的表示空间。此外,它还引入了额外的灵活性:隐藏维度的扩展比例不必是整数倍,这种分数扩展已被经验性地验证为有效(见5.1节)。
图 3 虚拟宽度网络(VWN)概览。(a) 标准Transformer在输入嵌入、每层的中间隐藏向量和最终层输出之间保持一致的宽度。(b) VWN通过超宽嵌入来扩展嵌入维度,同时使用轻量级的广义超连接(GHC)来维持层维度。这些维度通过小的矩阵Al和Bl(l代表层号)进行灵活交互。(c) 我们启用了多词元监督(多词元预测),从而允许更丰富的词元表示。
嵌入扩展。近期的研究强调了词汇表扩展在大型语言模型中的重要性【29, Scaling laws with vocabulary: Larger models deserve larger vocabularies, 2024, Advances in Neural Information Processing Systems】,表明扩大输入词汇表——特别是通过分层的n-gram词元嵌入——能够系统性地提高模型的表达能力和训练效率,且计算开销可忽略不计【17, Over-tokenized transformer: Vocabulary is generally worth scaling, 2025, arXiv】。“超词元化Transformer”框架【17, Over-tokenized transformer: Vocabulary is generally worth scaling, 2025, arXiv】引入了超编码(Over-Encoding, OE)来利用多n-gram词元化扩展输入表示,并引入了超解码(Over-Decoding, OD)通过多词元预测目标来增强输出监督。值得注意的是,多词元预测(Multi-Token Prediction, MTP)【10, Better & faster large language models via multi-token prediction, 2024, Forty-first International Conference on Machine Learning】被认为是OD在实际训练中的一种有效实现。
A2 方法细节
3.1 重新思考模型宽度
标准Transformer的宽度与计算成本。在一个具有L层和模型宽度D的标准Transformer模型中,初始词元表示$h_0 \in \mathbb{R}^D$是通过嵌入查找获得的。这个表示随后经过多个Transformer层处理,每个层由一个注意力块和一个前馈网络(FFN)块组成。具体来说,在第l层,中间隐藏向量$h_l \in \mathbb{R}^D$由$h_{l-1}$计算得出。最后一层输出词元表示$h_L \in \mathbb{R}^D$,然后通过一个线性头将其投影到维度为|V|的词汇空间。Transformer的计算复杂度为$O(D^2)$,这意味着扩展模型宽度D会导致计算成本呈二次方增长。
解耦嵌入维度与隐藏层维度。然而,嵌入查找操作只占总计算成本的一小部分。利用这一洞察,我们将嵌入维度与隐藏层维度解耦,使得嵌入维度可以被显著扩展,同时保持中间层计算的原始隐藏维度。因此,这种方法在显著增强词元嵌入表示能力的同时,几乎保持了原有的计算成本。
3.2 超宽嵌入
提出超宽嵌入技术。为了增加嵌入维度,我们提出了超宽嵌入(Over-Width Embedding)技术。给定一个固定的隐藏大小D,我们将输入的嵌入维度扩大到一个更宽的维度D',从而在不显著增加计算开销的情况下获得更丰富的词元嵌入。
形式化定义。形式上,令$h^l \in \mathbb{R}^D$表示第l层的隐藏状态。我们将这个隐藏向量均匀地划分为m个不相交的段:
$\mathbf{h}^{l} = (\mathbf{h}^{l \mathsf{T}}_{1} \ \mathbf{h}^{l \mathsf{T}}_{2} \ \ldots \ \mathbf{h}^{l \mathsf{T}}_{m})^{\mathsf{T}}, \text{ where } \mathbf{h}^{l}_{k} \in \mathbb{R}^{D/m}, k=1,2,\ldots,m.$
定义扩展嵌入向量。接下来,我们定义一个扩展的嵌入向量$e \in \mathbb{R}^{D'}$,其中$D' = \frac{n}{m} D$,且n > m为整数:
$\mathbf{e} = (\mathbf{e}_1^\mathrm{T} \ \mathbf{e}_2^\mathrm{T} \ \ldots \ \mathbf{e}_n^\mathrm{T})^\mathrm{T}$, where $\mathbf{e}_k \in \mathbb{R}^{D'/n}$ with $D' = \frac{n}{m}D.$
应用宽嵌入。最后,在输入层,我们设置$h'^0 = e$,从而利用更宽的词元嵌入。
可选的线性投影。当扩展比率n很大时,可以选择使用单个线性投影将原始的1倍嵌入映射到更宽的维度:
$\mathbf{E}_{\text{wide}} = \mathbf{W}_{\text{expand}} \mathbf{E}_{\text{base}};$
这类似于对一个非常宽的嵌入表应用低秩分解。此外,可以采用输入增强策略【17, Over-tokenized transformer: Vocabulary is generally worth scaling, 2025, arXiv】,通过在每个输入中注入比单个孤立词元嵌入更多的信息来进一步丰富加宽的表示。
Unembedding(逆嵌入)。对于逆嵌入过程,模型需要在进入逆嵌入层之前,将最后的超宽隐藏状态映射回原始的隐藏宽度D。我们引入一个规约算子$W_{reduce} \in \mathbb{R}^{D \times D'}$,它执行从超宽维度D'到原始宽度D的线性投影:
$\mathbf{h}_{\text{reduce}}^L = \mathbf{W}_{\text{reduce}} \mathbf{h}'^L.$
训练稳定性。为了稳定训练,在规约算子之前应用了归一化,如图3(c)所示。当扩展比率r=n/m很大时,超宽维度D'可能会变得非常大(例如,将4096维的隐藏大小扩展8倍会产生32K维的表示)。我们不直接对所有D'维度进行归一化,而是采用分组归一化(Group Normalization)【31, Group normalization, 2018, Proceedings of the European conference on computer vision (ECCV)】,其中组大小等于原始隐藏大小D。
3.3 广义超连接
提出广义超连接 (GHC)。我们提出了广义超连接(GHC),这是一种新颖的方法,可以有效地利用更宽的词元嵌入,同时在中间层计算中保持原始的隐藏维度。具体来说,在每一层l,GHC引入一个轻量级的转换矩阵$\mathcal{GHC}^l$,该矩阵编码了原始隐藏表示的段与扩展词元嵌入之间的加权关系。形式上,该矩阵定义如下:
GHC的输出。考虑第l个网络层$\mathcal{T}^l$,它在Transformer中集成了自注意力层或前馈网络。GHC的输出表示为$H'^l = \text{Reshape}(h'^l, (n, D'/n))$,代表超宽隐藏状态(Over-Width Hidden States),其公式可以表示为:
动态广义超连接 (DGHC)。为了进一步增强前向过程的适应性,我们引入了GHC方法的动态扩展,称为动态GHC(DGHC),其中转换矩阵根据输入表示H'进行自适应条件化:
DGHC的实践。在实践中,我们采用了Zhu等人【36, Hyper-connections, 2024, arXiv】, 【37, Frac-connections: Fractional extension of hyper-connections, 2025, arXiv】提出的混合策略,该策略集成了静态和动态参数,并做了一些微调以更好地适应我们的VWN框架。动态参数通过一个轻量级的线性投影网络生成。为保证训练稳定性,输入特征首先被归一化。随后,应用一个线性变换和一个tanh激活函数。输出然后被一个小的、可学习的矩阵缩放,并与相应的静态矩阵结合:
$\overline{\mathrm{H}^{\prime}}=\operatorname{norm}\left(\mathrm{H}^{\prime}\right),$
$\mathcal{A}(\mathbf{H}') = \mathbf{S}_\alpha \circ \tanh \left(\frac{\overline{\mathbf{H}'\mathbf{W}_\alpha}}{\tau}\right) + \mathbf{A}.$
其中 $\tau = \sqrt{D/m}$,$S_{\beta} \in \mathbb{R}^{m \times n}$ 和 $S_{\alpha} \in \mathbb{R}^{n \times (m+n)}$ 是可学习的缩放矩阵,初始化为1(分别与B和A形状相同)。令 $d_b := D'/n = D/m$ 表示每个块的宽度,并将 $H'$ 视为一个 $n \times d_b$ 的矩阵。投影权重 $W_{\beta} \in \mathbb{R}^{d_b \times m}$ 和 $W_{\alpha} \in \mathbb{R}^{d_b \times (m+n)}$ 是生成动态系数的可学习参数。根据这些形状,$H'W_{\beta} \in \mathbb{R}^{n \times m}$ 且 $H'W_{\alpha} \in \mathbb{R}^{n \times (m+n)}$;方程(9)中的转置使其变为 $m \times n$ 以匹配 $B \in \mathbb{R}^{m \times n}$,而方程(10)已经与 $A \in \mathbb{R}^{n \times (m+n)}$ 对齐。
初始化和实现。方程(9)和(10)中的动态参数$W_\beta$和$W_\alpha$初始化为0,而静态矩阵则按如下方式初始化。值得注意的是,我们没有对初始化进行任何专门的调优,因此仍有提升学习效率的空间。
静态矩阵B的初始化。静态矩阵 $B \in \mathbb{R}^{m \times n}$ 以循环模式初始化:
静态矩阵A的初始化。静态矩阵 $A \in \mathbb{R}^{n \times n}$ 初始化为块矩阵:
权重衰减策略。静态组件B和A不使用权重衰减,而动态组件则使用。实现细节可在附录B中找到,算法见算法1。
算法1 虚拟宽度网络(VWN)前向传播
需要: 超宽词元嵌入 $e \in \mathbb{R}^{D'}$
需要: 分数率 m,扩展宽度 n,主干网络维度 D
需要: 网络层 $\{T^1, \dots, T^L\}$ 和路由矩阵 $\{A^l, B^l\}_{l=1}^L$
需要: 压缩矩阵 $R \in \mathbb{R}^{n \times m}$
确保: 最终输出 y
- 初始化:
- $H'_0 \leftarrow \text{Reshape}(e, (n, D'/n))^\intercal \in \mathbb{R}^{D'/n \times n}$
- for $l = 1$ to $L$ do
- $X^l \leftarrow A^l H'_{l-1}$
- $z^l \leftarrow T^l(\text{Reshape}(X^l, (D, )))$ ▷ Transformer块中FFN或注意力层的输入
- $Z^l \leftarrow \text{Reshape}(z^l, (m, D/m))^\intercal$
- $H'_l \leftarrow B^{l\intercal} Z^l + \hat{A}^{l\intercal} H'_{l-1}$
- end for
- $h^L \leftarrow \text{Linear}(\text{GroupNorm}(H'_L))$
- $y \leftarrow \text{Unembedding}(\text{Norm}(h^L))$
- return y
3.4 多词元预测
利用MTP增强输出层。对于输出层,先前的研究【17, Over-tokenized transformer: Vocabulary is generally worth scaling, 2025, arXiv】已经证明多词元预测(MTP)可以作为k-gram解码的近似。基于这一见解,我们利用MTP来提供细粒度的监督信号,通过在主干模型之上增加额外的VWN层来构建一个增强的预测头。具体来说,我们遵循DeepSeek-AI等人【7, Deepseek-v3 technical report, 2025, https://arxiv.org/abs/2412.19437】的做法,将下一词元的嵌入与前一词元最后一层的嵌入连接起来,并应用一个线性投影来生成logits,如图3(c)的上半部分所示 。
解决MTP在VWN中的成本问题。然而,像DeepSeek-AI等人【7, Deepseek-v3 technical report, 2025, https://arxiv.org/abs/2412.19437】那样采用单个密集的线性层来混合隐藏状态和嵌入(即一个2D→D的投影)在VWN下会变得极其昂贵,因为宽度被扩展了r倍。一个简单的密集混合层将扩展为2rD→rD;对于r=8的情况,参数和FLOPs都会大幅增长,难以承受。为了解决这个问题,我们使用一个块级线性层进行混合。我们将rD维的向量划分 为n = r × m个大小为D/m的段,并对每个段应用相同的、形状为(2D/m)→(D/m)的小型线性层。换句话说,我们在每个段内局部地融合隐藏状态和嵌入特征,并在所有块之间共享这个线性投影器。这保留了更宽的VWN表示的好处,同时使混合成本与r=1的情况相当。
3.5 成本分析
计算成本。VWN的理论计算开销相对较低。我们关注主要的计算成本。归一化操作(例如RMSNorm)每个词元需要4nmD FLOPs。计算动态参数A和B每个词元需要2(2m+n)nmD FLOPs。宽度连接的成本为2(m+n)nD FLOPs,深度连接需要2nD FLOPs。在m=2和n=3的适度设置下,归一化、动态参数计算和宽度连接步骤总共需要42D FLOPs,而深度连接需要6D FLOPs。这些计算成本对于基于GPU的训练/推理系统来说是次要的,特别是考虑到激活内存访问带来的I/O开销,这成为VWN的瓶颈。为了最小化I/O,归一化、动态参数计算和宽度连接操作被融合成一个单一的GPU核。此外,宽度连接可以与Transformer层中后续的层归一化融合。当m较小时,由于超宽隐藏状态的存在,VWN大约增加了n-1倍的层归一化和残差加法的成本。在这种设置下,这个开销可以忽略不计,但对于较大的m,其影响会随配置而变化。
内存成本。在模型训练期间,必须存储中间激活以进行反向传播。VWN引入了额外的内存开销用于保存VWN的输入激活。然而,这可以通过廉价的重计算来缓解。在一个典型的训练框架如Megatron-LM中,采用选择性激活重计算【21, Reducing activation recomputation in large transformer models, 2023, Proceedings of the 6th MLSys Conference】,一个普通Transformer层中的每个词元需要34D字节的激活存储。VWN主要增加了保存A和B输入的成本,需要2 × 2 × (nm + 1)D字节,假设每个数字用2字节表示(16位浮点数),并且每个Transformer层有两个宽度和深度连接。虽然注意力和FFN的输入通常会为权重梯度计算而保存,但它们可以从宽度连接中高效地重计算出来。通过保存在宽度连接中A的输入和在深度连接中B的输入,后续的宽度连接输入可以以低成本重计算。用一个因子η表示被保存的宽度连接输入的比例,VWN对一个Transformer层的额外激活内存消耗为4ηnmD字节。在m=2,n=3和η=0.5(保存注意力的宽度连接输入,并为FFN重计算它)的适度设置下,增加的内存消耗为3D字节,约占普通Transformer层内存足迹的8.8%。在模型推理期间,额外的内存开销仅来自于额外的参数,与其他内存消耗相比可以忽略不计。
4. 连接性视角
VWN作为深度轴上的注意力机制。我们从连接性的角度重新诠释虚拟宽度网络(VWN),将其视为沿深度轴的注意力机制。将层堆栈看作一个“深度序列”,其中每个层索引就像一个词元位置,而隐藏状态则充当“垂直KV缓存”。在这种视角下,常见的连接模式可以映射到对先前层的类似注意力的窗口:(1) 一个没有残差连接的普通前馈堆栈对应于大小为1的滑动窗口(每层只处理当前输入并忘记前一个);(2) 残差连接【11, Deep residual learning for image recognition, 2016, Proceedings of the IEEE conference on computer vision and pattern recognition】实现了一个大小为2的窗口(当前输入加上紧邻的前一个);(3) 密集连接【16, Densely connected convolutional networks, 2017, Proceedings of the IEEE conference on computer vision and pattern recognition】, 【23, Denseformer: A dense transformer framework for person re-identification, 2023, IET Computer Vision】, 【32, Muddformer: Breaking residual bottlenecks in transformers via multiway dynamic dense connections, 2025, arXiv】将窗口大小扩展到包含所有先前的层,允许每层重用所有先前的表示。带有广义超连接(GHC)的VWN则介于两者之间:它实现了一种学习到的、固定成本的、类似线性注意力的深度机制,从而扩展了可访问的深度上下文。
GHC的展开形式。形式上,令第l层的加宽状态为一个槽矩阵 $H'_l \in \mathbb{R}^{(D/m) \times n}$,它有n个大小为D/m的槽,并令 $r := n/m$ 为以D为单位测量的宽度扩展。GHC的递推关系在方程(6)中明确给出了主干映射:$H'_l = B^{l\intercal} \mathcal{T}^l(\mathring{A}^{l\intercal} H'_{l-1}) + \hat{A}^{l\intercal} H'_{l-1}$,其中 $\hat{A}^{l\intercal}$ 传输/衰减存储在槽中的信息(一个学习到的携带/遗忘算子),而 $B^l$ 将当前层的主干摘要写入选定的槽中。将方程(6)显式展开得到:
按照惯例,空乘积等于单位矩阵。方程(13)显示,$H'_l$ 线性地聚合了来自早期层的、经过主干网络转换的特征,这些特征通过“携带算子”$\hat{A}$传播,并在每一步通过B写入——这捕捉了对一个压缩的深度缓存进行线性注意力的精髓。
选择m的考量。用于存储深度信息的内存预算——以D为单位衡量——是 $r=n/m$。GHC将此预算分配在每层保真度和记忆的层数之间:
* 当m=1时,模型以完整的D维保真度存储多达r层(层数较少,每层带宽较高)。
* 当m>1时,模型存储多达n=rm层,每层被压缩到D/m维度(层数较多,每层带宽较低)。
因此,m控制每层的压缩率,n控制名义上的深度窗口,而r固定了总内存预算。然后,学习到的、依赖于输入的路由通过衰减而不是硬截断,提供了一个超越名义窗口的软扩展。
m的选择直觉。直观上,更大的m以较低的每层保真度为代价,扩展了有效记忆的层数。对于更宽的模型,增加的表示能力提供了足够的带宽来容纳更大的m。同样,更深的网络也从更大的m中受益,因为使每层能够访问更长范围、更浅层的信息可以缓解优化困难并改善梯度流。
硬深度窗口与软深度窗口。
* 硬路由。如果 $\hat{A}^l$ 和 $B^l$ 近似于置换/二元门,则更新行为类似于一个固定大小的深度滑动窗口。当m=1时,有r=n个维度为D的槽,因此模型可以全保真度地保留最近的r层。当m>1时,有n=rm个大小为D/m的槽;每层的D维状态被压缩到D/m并写入一个槽中,从而形成一个大小为n的压缩形式的硬窗口。
* 软路由。使用实值的、可能依赖于输入的 $\hat{A}^l$ 和 $B^l$(动态GHC)时,信息在步骤间被部分保留和混合。当 $\hat{A}^l$ 的谱半径小于1时,方程(13)意味着来自前序层的贡献呈指数衰减。有效的深度感受野可以超过名义上的硬窗口(m=1时>r,m>1时>n),尽管信息会逐渐被衰减和混合。
一个具体配置示例。考虑(m, n)=(8, 64),因此r=8。模型维护n=64个宽度为D/8的槽。在硬路由下,当前层可以利用最近的64层,每层以原始维度的1/8表示。在软路由下,早于64层的贡献可能会以衰减的方式持续存在,有效地扩大了“深度感受野”。
关于注意力类比的范围。我们与注意力的类比主要借用了沿深度的KV缓存视角。这并不意味着层间连接是通过标准自注意力中的相似性得分或成对相关性来构建的。GHC使用学习到的(静态或输入条件的)路由矩阵以固定成本在各层之间携带、压缩和写入信息,而不是计算点积得分或对层索引进行softmax。
A4 实验
实验环境
- 模型: 实验主要在内部开发的混合专家(MoE)模型上进行,涵盖多种规模,包括MoE 0.4B/4B、MoE 2.5B/25B、MoE-A0.8B(8亿激活参数)和MoE-A3.3B(33亿激活参数)。
-
数据集: 所有模型均在内部大规模数据集上进行预训练。下游性能评估使用了两个基准集合:
- Collection A: 用于1.5倍虚拟宽度实验,包含ARC_Challenge, BBH, DROP, WinoGrande, Hellaswag, MMLU, MMLU-Pro, C-Eval, TriviaQA, Ape210K, GSM8K, MATH, MBPP, HumanEval, AGIEval, GPQA等基准(见表2)。
- Collection B: 用于大规模虚拟宽度实验,包含MMLU, MMLU-Pro, C-Eval, AGIEval, BBH, DROP, KOR-Bench-Easy, MATH, MBPP+, HumanEval, McEval, TriviaQA, Chinese SimpleQA等基准(见表3)。
-
硬件配置: 未明确说明具体的GPU型号和数量,但提及了基于GPU的训练系统,暗示使用了大规模GPU集群。
- 软件配置: 训练框架提及了Megatron-LM。VWN的实现细节在附录中以PyTorch风格的伪代码给出。
实验结果
5.1 VWN 1.5倍
我们联合评估了VWN和多词元预测(MTP)在1.5倍虚拟宽度(配置为(m, n) = (2, 3))下的效果。
-
在0.4B/4B MoE模型上的表现 (图4):
- 训练损失: 与基线相比,单独使用VWN能持续降低下一词元预测(NTP)损失。单独使用MTP会轻微增加NTP损失。VWN+MTP组合在增强版变体中实现了最低的最终损失,但与基线相比仍有0.016的差距。
- 下游性能 (Collection A): MTP单独使用与基线相当,而VWN+MTP在整个训练过程中均取得了最高的平均准确率增益。
图 4 VWN和MTP在0.4B/4B MoE模型上的性能。左图:训练损失与已见词元数(十亿)的关系。VWN降低了下一词元预测损失,而MTP轻微损害了NTP损失;VWN和MTP的组合(VWN+MTP)在增强变体中取得了最低的最终损失,但在包含MTP时与基线指标仍有0.016的小差距。右图:下游平均准确率(%)与词元数的关系。VWN和MTP都提高了下游准确率,它们的组合在整个训练过程中带来了最大的增益。
-
在2.5B/25B MoE模型上的表现 (图5):
- 训练损失: VWN降低了下一词元损失。在此规模下,在VWN之上增加MTP并未损害优化过程。VWN和VWN+MTP都达到了相似的低最终损失,均比基线低约0.015。
- 下游性能 (Collection A): 两种变体都优于基线,其中VWN+MTP在整个训练过程中始终产生最高的平均准确率。
图 5 VWN和MTP在2.5B/25B MoE模型上的性能。左图:训练损失与已见词元数(十亿)的关系。VWN相对于基线减少了下一词元预测损失,在此规模上,在VWN之上添加MTP并未损害损失,VWN+MTP达到了最低的最终损失,训练结束时与基线有0.015的差距。右图:下游平均准确率(%)与词元数的关系。VWN和VWN+MTP都优于基线,并且VWN+MTP在整个训练过程中提供了最高的准确率。
5.2 大虚拟宽度
我们研究了在更强的内部基线上进行虚拟宽度扩展的效果,所有模型默认包含MTP头。
-
分数率m的消融研究 (图6): 在MoE-A0.8B模型上,针对不同的虚拟宽度因子r(2, 4, 8)进行了m值的消融实验。结果显示,当r=2时,将m从2增加到4有适度改善。当r=4或r=8时,改变m值带来的差异很小,表明在此模型规模下,当m约等于4之后,进一步细化隐藏分区带来的好处基本饱和。
图 6 在MoE-A0.8B上针对不同虚拟宽度因子r的分数率m的消融实验。每个面板绘制了VWN×2(左)、VWN×4(中)和VWN×8(右)的下一词元训练损失与已见词元数(十亿)的关系。当r=2时,将m从2增加到4产生了适度但可见的改进。当r=4或r=8时,在测试值之间改变m只会导致微小的差异,这表明在这个模型规模上,当m≈4之后,更精细的隐藏分区效果基本饱和。 -
虚拟宽度因子的缩放定律 (图7, 图8, 表1): 在MoE-A0.8B上,固定m=8,改变虚拟宽度因子r(2, 4, 8)。
- 性能增益: 随着r的增大,VWN带来了持续且单调的收益。在500B词元时,与基线相比,VWN×8在下一词元损失上减少了0.035,在下两词元损失上减少了0.058,在Collection B上的下游准确率提高了4.16个点。
- 缩放关系: 观察到的损失减少与虚拟宽度因子r之间呈现对数-线性关系(图8)。拟合函数为 $y = -0.0069 \cdot \log_2(x) + 1.6212$ ($R^2=0.9986$),表明虚拟宽度每翻一倍,损失大约减少0.0069。
图 7 VWN在MoE-A0.8B上固定分数率m=8时的词元效率。我们通过设置r∈{2, 4, 8}和n=r·m={16, 32, 64}来改变虚拟宽度因子。左/中:下一词元和下两词元预测的训练损失与已见词元数的关系。右:在Collection B上的平均准确率与词元数的关系。VWN持续提高样本效率;在500B词元时,VWN×8相比非VWN基线,在不增加主干网络宽度的情况下,利用超宽嵌入和GHC,实现了Δ=0.035(下一词元损失)、Δ=0.058(下两词元损失)的降低,以及+4.16个点的准确率增益(Collection B,表3)。
图 8 虚拟宽度因子r与损失之间关系的缩放定律分析。观测数据(红点)用对数-线性函数 y = -0.0069 · log₂(x) + 1.6212 拟合,决定系数 R² = 0.9986。
表 1 VWN在MoE-A0.8B上固定m=8时的扩展行为。所有数字表示在500B词元训练后相对于非VWN基线的改进。每个Δ值表示与基线相比的损失减少量,准确率增益在集合B上测量。
-
在大规模模型上的VWN表现 (图1): 在3.3B激活参数的MoE模型(MoE-A3.3B)上使用8倍虚拟宽度扩展((m, n) = (8, 64))。
- 训练加速: VWN显著加速了优化。达到基线相同的下一词元损失,VWN只需2.5倍少的词元;达到相同的下两词元损失,只需3.5倍少的词元。
- 优势放大: VWN的优势随着训练的进行而放大。下一词元损失差距从初期的0.025扩大到3T词元时的0.032。
- 下游性能: 在Collection B上的峰值平均准确率比基线高出2.16个点,证实性能差距持续存在并随训练扩大。
A5 结论
本文介绍了虚拟宽度网络(VWN),它是一种将表示宽度与二次方计算成本解耦的实用机制。实验表明,即使是1.5倍的适度扩展也能带来持续的改进。当扩展到8倍虚拟宽度时,优化速度显著加快:与基线宽度相比,下一词元预测损失的收敛速度快2倍以上,多词元预测损失快3倍以上。
更重要的是,VWN的性能表现出清晰的缩放行为。我们观察到虚拟宽度因子r与损失减少之间存在近似的对数-线性关系,r每翻一倍,平均损失下降约0.0069。尽管增益幅度适中,但这表明虚拟宽度可以作为补充现有深度、宽度和数据缩放定律的一个新的、可预测的模型效率扩展维度。
VWN能够干净地集成到标准的Transformer堆栈和训练流程中,为研究容量/计算权衡以及如何通过受控的宽度扩展有效提高质量提供了一个具体的参考点。然而,将这些算法增益转化为生产效率取决于系统现实。尽管VWN在质量与计算的权衡上很有前景,但它面临实际限制:随着隐藏宽度的增长,通信和内存访问开销变得不可忽视,且当代硬件对非常宽的激活和跨设备路由并不友好。目前,对极宽配置的工程支持仍然有限,这限制了其可部署性。在实践中,1.5倍至4倍的虚拟宽度扩展在当今的技术栈上更为可行,而更大的扩展可能需要软件、内存布局和互连策略的协同设计才能完全发挥其潜力。
A6 附录
A MoE-A0.8B模型的详细下游结果
VWN×8在下游基准上的表现。如图9所示,绘制了跨基准的词元效率曲线,VWN×8(r=8, n=64)在所有任务上都统一地使学习曲线左移,表明样本效率更高。在知识和推理密集型基准(如DROP, MATH)上改进最大,这表明扩展的超宽嵌入在不增加核心计算的情况下改善了组合推理和信息检索能力。HumanEval的增益较小,这与其有限的测试集规模一致。这些优势在训练后期依然存在,未观察到性能回归,表明VWN被持续利用而不是早期饱和。值得注意的是,VWN在具有较长上下文的任务(如DROP和TriviaQA)上取得了特别强的增益,在这些任务中,对长距离依赖和多句证据聚合的建模最能从扩大的嵌入空间中受益。总体而言,VWN将其在词元级别的效率增益一致地转移到多样化的下游领域,在不增加主干网络宽度的情况下增强了泛化能力。
图 9 VWN在MoE-A0.8B模型上跨下游基准的性能。我们比较了非VWN基线与VWN×8(r = 8; n = r · m = 64)。VWN×8在整个训练过程中始终优于基线;在500B词元时,它在DROP上准确率提升+8.92点,HumanEval上+2.44点,MATH上+4.20点,MMLU上+3.95点,MMLU-Pro上+5.25点,TriviaQA上+7.45点。
B 广义超连接的实现
算法2 PyTorch风格的广义超连接伪代码
# h: 隐藏向量 (BxLxD)
class GHyperConnection(nn.Module):
def __init__(self, dim, m, n_in=3, n_out=2):
super().__init__()
self.m, self.n_in, self.n_out = m, n_in, n_out
self.factor = 1.0 / math.sqrt(dim // self.m)
# 初始化静态 beta: 循环模式
static_beta_tensor = torch.zeros(self.m, n_in)
for j in range(n_in):
static_beta_tensor[j % self.m, j] = 1.0
self.static_beta = nn.Parameter(static_beta_tensor.T.contiguous())
# 初始化静态 alpha: 块矩阵
init_alpha = torch.cat([torch.eye(self.m), torch.eye(self.m), torch.zeros((self.m, self.n_in - self.m))], dim=1)
if self.n_in > self.m:
part2 = torch.cat([torch.zeros((self.n_in - self.m, self.m * 2)), torch.eye(self.n_in - self.m)], dim=1)
init_alpha = torch.cat([init_alpha, part2], dim=0)
self.static_alpha = nn.Parameter(init_alpha.contiguous())
# # 动态参数
self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim // self.m, self.m + self.n_in)))
self.dynamic_alpha_scale = nn.Parameter(torch.ones_like(self.static_alpha))
self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim // self.m, self.m)))
self.dynamic_beta_scale = nn.Parameter(torch.ones_like(self.static_beta))
self.layer_norm = RMSNorm(hidden_size=dim // self.m)
def _base_width_connection(self, h, dynamic_fn, dynamic_scale, static_scale):
h_shape = h.shape
N, NMM = static_scale.shape
M = (NMM - N) // 2
h_reshape = h.reshape((h_shape[:-1].numel(),) + (N, h_shape[-1] // N))
norm_h = self.layer_norm(h_reshape)
alpha_beta = (safe_tanh(norm_h @ dynamic_fn.T.to(dtype=norm_h.dtype) * self.factor) * dynamic_scale[None, ...] + static_scale[None, ...])
alpha, beta = torch.split(alpha_beta, (M + N, M), dim=-1)
mix_h = (h_reshape.transpose(1, 2) @ alpha.to(dtype=h_reshape.dtype)).transpose(1, 2)
return mix_h.reshape(h_shape[:-1] + mix_h.shape[1:]), beta
def width_connection(self, h):
dynamic_fn = torch.concat([self.dynamic_alpha_fn.T, self.dynamic_beta_fn.T], dim=0)
dynamic_scale = torch.concat([self.dynamic_alpha_scale, self.dynamic_beta_scale], dim=-1).contiguous()
static_scale = torch.concat([self.static_alpha, self.static_beta], dim=-1)
return self._base_width_connection(h, dynamic_fn.to(dtype=h.dtype), dynamic_scale.to(dtype=h.dtype), static_scale.to(dtype=h.dtype))
def depth_connection(self, mix_h, h_o, beta):
h_o_shape = h_o.shape
h_o = h_o.reshape(h_o_shape[:-1] + (self.m, h_o_shape[-1] // self.m))
h_i = beta.view(h_o.shape[:2] + beta.shape[1:]).to(dtype=h_o.dtype) @ h_o
h = h_i + mix_h[..., self.m:, :]
h_shape = h.shape
return h.reshape(h_shape[:-2] + (h_shape[-2] * h_shape[-1],)).contiguous()
算法3 PyTorch风格的带广义超连接的Transformer伪代码
# h: 隐藏向量 (BxLxD)
# atten_ghyper_connection, ffn_ghyper_connection: ghyper-connection 模块
# attn_norm, ffn_norm: 归一化模块
# # 注意力块
mix_h, beta = atten_ghyper_connection.width_connection(h)
mix_h_shape = mix_h.shape
h = mix_h[...,:self.rate,:].reshape(mix_h_shape[:-2] + (mix_h_shape[-2] // 2 * mix_h_shape[-1], ))
h = attn_norm(h)
h = self_attention(h)
h = atten_ghyper_connection.depth_connection(mix_h, dropout(h), beta)
# # FFN 块
mix_h, beta = ffn_ghyper_connection.width_connection(h)
mix_h_shape = mix_h.shape
h = mix_h[...,:self.rate,:].reshape(mix_h_shape[:-2] + (mix_h_shape[-2] // 2 * mix_h_shape[-1], ))
h = ffn_norm(h)
h = ffn(h)
h = ffn_ghyper_connection.depth_connection(mix_h, dropout(h), beta)
表 2 下游基准集合 A.
表 3 下游基准集合 B.
方法细节中的引用汇总
-
[7] Deepseek-v3 technical report, 2025, https://arxiv.org/abs/2412.19437
- 引用段落: 3.4 Multi-token Prediction
- 引用描述: 论文在描述MTP的实现时,引用了这篇文献的做法。原文描述为:“Specifically, following DeepSeek-AI et al. [7], we concatenate the embedding of the next token with the last-layer embedding of the preceding token, applying a linear projection to generate logits...” (具体来说,我们遵循DeepSeek-AI等人的做法[7],将下一个词元的嵌入与前一个词元最后一层的嵌入拼接起来,然后应用一个线性投影来生成logits...)。这表明本文的MTP预测头设计借鉴了DeepSeek-V3的技术。
-
[17] Over-tokenized transformer: Vocabulary is generally worth scaling, 2025, arXiv
- 引用段落: 3.2 Over-Width Embedding, 3.4 Multi-token Prediction
- 引用描述: 在3.2节,引用该文献以支持“输入增强策略”可以进一步丰富加宽的表示。在3.4节,引用该文献以说明MTP是k-gram解码的一种近似,为本文采用MTP提供了理论依据。原文描述为:“previous research [17] has demonstrated that Multi-Token Prediction (MTP) serves as an approximation of k-gram decoding.” (先前的研究[17]已经证明,多词元预测(MTP)可以作为k-gram解码的一种近似。)
-
[21] Reducing activation recomputation in large transformer models, 2023, Proceedings of the 6th MLSys Conference
- 引用段落: 3.5 Cost Analysis
- 引用描述: 在分析内存成本时,引用该文献来说明在典型的训练框架(如Megatron-LM)中,通过选择性激活重计算可以优化激活的存储。原文描述为:“In a typical training framework like Megatron-LM, each token in a vanilla transformer layer requires 34D bytes for activation storage, employing selective activation recomputation [21].” (在一个典型的训练框架如Megatron-LM中,采用选择性激活重计算[21],一个普通Transformer层中的每个词元需要34D字节的激活存储。)
-
[31] Group normalization, 2018, Proceedings of the European conference on computer vision (ECCV)
- 引用段落: 3.2 Over-Width Embedding
- 引用描述: 在描述如何稳定训练时,引用该文献来说明当虚拟宽度非常大时,采用分组归一化(Group Normalization)来代替对整个宽向量的直接归一化。原文描述为:“Instead of directly normalizing across all D′ dimensions, we adopt Group Normalization [31], where the group size equals the original hidden size D.” (我们不直接对所有D'维度进行归一化,而是采用分组归一化[31],其中组大小等于原始隐藏大小D。)
-
[36] Hyper-connections, 2024, arXiv 和 [37] Frac-connections: Fractional extension of hyper-connections, 2025, arXiv
- 引用段落: 3.3 Generalized Hyper-Connections
- 引用描述: 在介绍动态GHC(DGHC)的具体实现时,引用了这两篇文献,表明本文采用了它们提出的静态和动态参数相结合的混合策略。原文描述为:“In practice, we adopt the hybrid strategy from Zhu et al. [36, 37], which integrates both static and dynamic parameters, while making slight adjustments to better fit our VWN framework.” (在实践中,我们采用了Zhu等人[36, 37]的混合策略,该策略集成了静态和动态参数,并做了一些微调以更好地适应我们的VWN框架。) 这说明GHC的设计是在Hyper-Connections和Frac-Connections工作基础上的演进和调整。
💬 评论讨论
欢迎在这里分享您的想法和见解!