Hyper-Connections

  • 文章标题: HYPER-CONNECTIONS
  • 作者/机构: Defa Zhu, Hongzhi Huang, Zihao Huang, Yutao Zeng, Yunyao Mao, Banggu Wu, Qiyang Min, Xun Zhou (Seed-Foundation-Model Team, ByteDance)

A1 主要贡献

深度学习在各个领域取得了巨大成功,其中残差连接(【14, Deep residual learning for image recognition, He et al., 2016, aaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 2016.】)在现代神经网络架构(包括Transformer和CNN)中发挥了重要作用。残差连接有助于缓解梯度消失问题,从而能够有效训练非常深的网络。然而,残差连接并非万能,其固有的局限性仍待解决。

核心问题: 残差连接的两种主要变体,即前置归一化(Pre-Norm)和后置归一化(Post-Norm),各自在梯度消失和表示坍塌之间做出了不同的权衡。Pre-Norm有效地解决了梯度消失问题(【3, Learning long-term dependencies with gradient descent is difficult, Bengio et al., 1994, Yoshua Bengio, Patrice Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE transactions on neural networks, 5(2), 1994.】; 【12, Understanding the difficulty of training deep feedforward neural networks, Glorot & Bengio, 2010, Xavier Glorot and Yoshua Bengio. Understanding the difficulty of training deep feedforward neural networks. In Proceedings of the thirteenth international conference on artificial intelligence and statistics. JMLR Workshop and Conference Proceedings, 2010.】),但可能导致深层表示的坍塌(【19, Understanding the difficulty of training transformers, Liu et al., 2020, Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han. Understanding the difficulty of training transformers. arXiv preprint arXiv:2004.08249, 2020.】),即深层网络的隐藏特征变得高度相似,削弱了增加层数带来的效果。相反,Post-Norm减轻了表示坍塌问题,但又重新引入了梯度消失的风险。这两种问题就像跷跷板的两端,而这两种变体只是在其中进行权衡。关键问题在于,残差连接(包括Pre-Norm和Post-Norm)预先定义了层内输出与输入之间的连接强度。

研究目标与创新点: 针对残差连接的局限性,本文提出了一个重要问题:神经网络能否自主学习最优的连接强度以提升性能?为此,我们提出了超连接(Hyper-connections, HC),这是一种新方法,在计算和参数增加可忽略不计的情况下,显著提升了模型性能。HC的核心思想是提出可学习的深度连接(depth-connections)宽度连接(width-connections),如图2(b)所示。与图2(a)中的残差连接相比,这些连接能够灵活地在垂直方向上整合不同深度的特征。深度连接可以看作是广义的残差连接,为每层输入和输出之间的连接分配权重。为了使网络能够同时建模不同的深度连接,我们将网络输入扩展为n个副本,每个副本都有自己的深度连接(图2(b))。这种设计允许多个隐藏向量保留与前序层连接的多种模式。此外,我们在n个隐藏向量之间建立了宽度连接,允许同一层内的隐藏向量进行信息交换(图2(b))。作者认为,n>1个隐藏状态是必要的,因为当n=1时,跷跷板效应依然存在,实验也表明性能没有提升。而当n>1时,超连接不仅可以学习调整残差的强度,还可以重新排列层(顺序或并行)。为了进一步增强灵活性,我们引入了动态超连接(Dynamic Hyper-connections, DHC),使网络能够根据输入调整连接权重。值得注意的是,尽管HC看似将网络宽度增加了n倍,但额外的参数和计算成本几乎可以忽略不计。


图 2: 扩展率为n=2的超连接(HC)。(a) 残差连接。(b) 超连接:β1, β2, α0,0, α0,1, α1,0, α1,1, α2,1, 和 α2,2 是可学习的标量或由网络预测的标量,取决于具体的HC版本。这些连接实现了横向信息交换和跨深度的特征垂直整合。带有HC的Transformer如图17所示。它们可以解耦为深度连接和宽度连接。(c) 深度连接在层输出和隐藏向量h1之间执行加权求和。(d) 宽度连接允许隐藏向量h1和h2之间进行信息交换。

主要贡献与实验结果: 本文的研究主要集中在大型语言模型(LLM)的预训练上,同时也扩展到了视觉生成和分类任务。以Pre-Norm为基线,我们证明了超连接的显著优势,包括1B和7B的密集模型以及7B的MoE模型。对于OLMoE模型(【23, OLMoE: Open Mixture-of-Experts Language Models, Muennighoff et al., 2024, Niklas Muennighoff, Luca Soldaini, Dirk Groeneveld, Kyle Lo, Jacob Morrison, Sewon Min, Weijia Shi, Pete Walsh, Oyvind Tafjord, Nathan Lambert, Yuling Gu, Shane Arora, Akshita Bhagia, Dustin Schwenk, David Wadden, Alexander Wettig, Binyuan Hui, Tim Dettmers, Douwe Kiela, Ali Farhadi, Noah A. Smith, Pang Wei Koh, Amanpreet Singh, and Hannaneh Hajishirzi. Olmoe: Open mixture-of-experts language models, 2024. URL https://arxiv.org/abs/2409.02060.】),DHC的优势尤为突出,如图1所示。使用DHC的模型收敛速度快1.8倍,并且在训练5000亿个token后,在ARC-Challenge上的得分比基线高出6分。可视化分析(图3)显示,基线模型倾向于表示坍塌,相邻层特征相似度高。而采用HC的模型,相邻层特征相似度显著降低,且相似度范围更广,表明HC增强了每一层的效果。这些证据表明了超连接原理的普适性,并有望应用于更多AI挑战中 。


图 1: 基线模型OLMoE-1B-7B和带有超连接的模型OLMoE-1B-7B-DHC×4的性能对比。(1)和(2)分别显示了训练损失(0.99 EMA平滑)和C4-en验证损失。我们的方法收敛速度比基线快1.8倍,并在5000亿tokens时保持显著优势。(3)和(4)显示了在HellaSwag和ARC-Challenge上的准确率曲线,证明了OLMoE-1B-7B-DHC×4模型的卓越性能。


图 3: OLMo-1B模型(【13, Olmo: Accelerating the science of language models, Groeneveld et al., 2024, Dirk Groeneveld, Iz Beltagy, Pete Walsh, Akshita Bhagia, Rodney Kinney, Oyvind Tafjord, Ananya Harsh Jha, Hamish Ivison, Ian Magnusson, Yizhong Wang, et al. Olmo: Accelerating the science of language models. arXiv preprint arXiv:2402.00838, 2024.】)中当前层和前一层输入的余弦相似度。曲线表示相似度的中位数,阴影区域表示第5和第95百分位数之间的范围。红色曲线显示使用Pre-Norm的模型,蓝色曲线显示使用超连接的模型。

A2 方法细节

2.1 静态超连接

超隐藏矩阵的定义与初始化。我们将第k层的输入隐藏向量记为 $h_{k-1} \in R^d$(或 $h_{k-1} \in R^{d \times 1}$),网络的初始输入为 $h_0$。首先,我们将 $h_0 \in R^d$ 复制n次,形成初始的超隐藏矩阵 $H_0 = [h_0 \ h_0 \ ... \ h_0] \in R^{n \times d}$。这里的n是扩展率。对于第k层,其输入是来自前一层的超隐藏矩阵 $H_{k-1} = [h_{k-1}^1 \ h_{k-1}^2 \ ... \ h_{k-1}^n]^\intercal \in R^{n \times d}$。最后,我们将最后的超隐藏矩阵按行求和,得到所需的隐藏向量,然后通过一个最终的投影器(在Transformer中是归一化层和unembedding层)来产生网络的最终输出。为了简化后续分析中的符号,我们省略了层的索引,简单地将超隐藏矩阵表示为 $H = (h_1 \ h_2 \ ... \ h_n)$。

超连接矩阵的结构。超连接(HC)可以用一个矩阵HC来表示,其中每个元素定义了连接权重。该矩阵的结构如下:

$$\mathcal{HC} = \begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{B} \\\mathbf{A}_{\mathbf{m}} & \mathbf{A}_{\mathbf{r}} \end{pmatrix} = \begin{pmatrix} 0 & \beta_1 & \beta_2 & \dots & \beta_n \\\alpha_{1,0} & \alpha_{1,1} & \alpha_{1,2} & \dots & \alpha_{1,n} \\\alpha_{2,0} & \alpha_{2,1} & \alpha_{2,2} & \dots & \alpha_{2,n} \\\vdots & \vdots & \vdots & \ddots & \vdots \\\alpha_{n,0} & \alpha_{n,1} & \alpha_{n,2} & \dots & \alpha_{n,n} \end{pmatrix} \in \mathbb{R}^{(n+1) \times (n+1)}.$$

超连接的输出公式。考虑一个网络层T,它在Transformer中集成了自注意力层和前馈网络。HC的输出,记为 $\hat{H}$,可以简单地表示如下:

$$\hat{\mathbf{H}}=\mathcal{H} \mathcal{C}(\mathcal{T}, \mathbf{H})=\mathbf{B}^{\top} \mathcal{T}\left(\mathbf{H}^{\top} \mathbf{A}_{\mathbf{m}}\right)^{\top}+\mathbf{A}_{\mathbf{r}}^{\top} \mathbf{H} .$$

输入加权求和。我们使用权重 $A_m$ 对输入 $H = (h_1 \ h_2 \ ... \ h_n)$ 进行加权求和,以获得当前层T的输入 $h_0$,其计算公式为:

$$\mathbf{h}_0^\intercal = \mathbf{A_m}^\intercal \mathbf{H},$$

隐藏矩阵的映射。而 $A_r$ 则用于连接H并将其映射到一个超隐藏矩阵 $H'$,如下所示:

$$\mathbf{H}^{\prime}=\mathbf{A}_{\mathbf{r}}^{\top} \mathbf{H} .$$

最终输出计算。随后,输出由以下公式给出:

$$\hat{\mathbf{H}}=\mathbf{B}^{\intercal}(\mathcal{T} \mathbf{h}_{0})^{\intercal}+\mathbf{H}^{\prime} .$$

深度连接的解耦。深度连接可以解耦为以下矩阵,如图2(a)所示:

$$\begin{aligned} \mathcal{DC} = \begin{pmatrix} \mathbf{B} \\ \text{diag}(\mathbf{A_r}) \end{pmatrix} = \begin{pmatrix} \beta_1 & \beta_2 & \cdots & \beta_n \\ \alpha_{1,1} & \alpha_{2,2} & \cdots & \alpha_{n,n} \end{pmatrix} \in \mathbb{R}^{2 \times n}, \end{aligned}$$

其中,第一行B表示当前层T输出的权重,最后一行diag($A_r$)表示输入的权重。我们用diag($A_r$)表示$A_r$对角线元素的扁平化向量。

宽度连接的定义。宽度连接矩阵可以定义如下,如图2(b)所示:

$$\mathcal{WC} = (\mathbf{A_m} \quad \mathbf{A_r}) \in \mathbb{R}^{n \times (n+1)}.$$

算法伪代码。采用超连接的算法在算法1中呈现。

2.2 动态超连接

动态超连接矩阵表示。HC的条目可以动态地依赖于输入H,动态超连接(DHC)的矩阵表示定义如下:

$$\begin{aligned} \mathcal{HC}(\mathbf{H})=\begin{pmatrix}\mathbf{0}_{1\times 1} & \mathcal{B}(\mathbf{H}) \\ \mathcal{A}_m(\mathbf{H}) & \mathcal{A}_r(\mathbf{H})\end{pmatrix} \end{aligned}$$

DHC输出公式。类似地,给定一个层T和输入H,我们得到DHC的输出如下:

$$\hat{\mathbf{H}}=\mathcal{H} \mathcal{C}(\mathbf{H})(\mathcal{T}, \mathbf{H}) .$$

动态参数的计算。在实践中,我们结合动态和静态矩阵来实现DHC。动态参数通过线性变换获得。为了稳定训练过程,我们在线性变换前引入归一化,并在其后应用tanh激活函数,再通过一个小的初始可学习因子进行缩放。以下方程详细说明了这些动态参数的计算方式:

$$\overline{\mathbf{H}} = \operatorname{norm}(\mathbf{H})$$ $$\mathcal{B}(\mathbf{H})=s_{\beta} \circ \tanh \left(\overline{\mathbf{H}} \mathbf{W}_{\beta}\right)^{\top}+\mathbf{B} \in \mathbb{R}^{1 \times n}$$ $$\mathcal{A}_m(\mathbf{H})=s_\alpha \circ \tanh (\overline{\mathbf{H}} \mathbf{W}_m)+\mathbf{A}_m \in \mathbb{R}^{n \times 1}$$ $$\mathcal{A}_r(\mathbf{H}) = s_\alpha \circ \tanh(\overline{\mathbf{H}}\mathbf{W}_r) + \mathbf{A}_r \in \mathbb{R}^{n \times n}$$

DHC性能与实现。我们在§4中的实验表明,动态超连接在语言建模任务中优于静态超连接。超连接的静态和动态变体的PyTorch实现在算法2和3中有详细说明。

2.3 初始化

与Pre-Norm等价的初始化策略。为了使超连接的初始化等同于Pre-Norm残差连接,我们采用了以下初始化策略。公式11、12和13中的动态参数 $W_{\beta}$, $W_m$, 和 $W_r$ 被初始化为0,而静态矩阵则按如下方式初始化:

$$\begin{aligned} \begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{B}^{k} \\ \mathbf{A}_{\mathbf{m}}{ }^{k} & \mathbf{A}_{\mathbf{r}}{ }^{k} \end{pmatrix}=\begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{1}_{1 \times n} \\ \mathbf{e}_{k \bmod n} & \mathbf{e}_{n \times n} \end{pmatrix}, \end{aligned}$$

其中k是层的索引,mod表示取模运算。

A3 背景知识/关键Observation/设计原则

本节阐述了超连接背后的原理。我们探讨了残差连接的变体,即Pre-Norm和Post-Norm,如何被视为非可训练的超连接,并引入了顺序-并行二元性的概念,展示了超连接如何动态优化层排列以增强网络性能。通过展开视图对超连接的可视化分析在§4.5中讨论。

3.1 残差连接作为非可训练的超连接

Pre-Norm和Post-Norm作为非可训练超连接。Pre-Norm和Post-Norm残差连接可以表示为扩展率n=1的以下超连接矩阵:

$$\begin{aligned} \mathcal{HC}_{PreNorm} = \begin{pmatrix} 0 & 1 \\ 1 & 1 \end{pmatrix}, \qquad (15) \qquad \mathcal{HC}_{PostNorm} = \begin{pmatrix} 0 & \frac{1}{\sqrt{\sigma_{i}^{2}+\sigma_{o}^{2}+2\sigma_{io}}} \\ 1 & \frac{1}{\sqrt{\sigma_{i}^{2}+\sigma_{o}^{2}+2\sigma_{io}}} \end{pmatrix}, \end{aligned}$$

其中,$\sigma_i$ 和 $\sigma_o$ 分别表示神经网络层输入和输出的标准差,$\sigma_{io}$ 是它们之间的协方差。

具体矩阵形式。对于Pre-Norm,其超连接矩阵是一个2×2矩阵,其右下三角部分填充为1,其余部分为占位符0。对于Post-Norm,权重取决于输入和输出的方差和协方差,形成一个2×2矩阵。因此,它们的超连接矩阵是不可训练的。在这项工作中,我们提出的超连接可以是(n+1)×(n+1)的矩阵,其权重是可训练的,甚至可以根据输入进行预测。完整的推导在附录G中提供。

3.2 顺序-并行二元性

可学习的层排列方式。给定一系列神经网络模块,我们可以选择将它们按顺序或并行排列。然而,超连接提供了一种方法,可以学习将这些层重新排列成一种融合了顺序和并行排列的配置。


图 4: 扩展率为 n = 2 时超连接的顺序和并行排列。

顺序排列的矩阵形式。不失一般性,我们将扩展率设为n=2。如果超连接学习为以下矩阵,神经网络将按顺序排列:

$$\begin{aligned} \mathcal{HC} = \begin{pmatrix} 0 & 1 & 1 \\ 1 & 1 & 0 \\ 0 & 0 & 1 \end{pmatrix}. \end{aligned}$$

在这种情况下,深度连接退化为残差连接,如图4(a)所示。

并行排列的矩阵形式。当奇数层和偶数层(层编号从1开始)的超连接由以下矩阵定义时,神经网络将每两个连续层并行排列,类似于Transformer中并行Transformer块的排列(【33, Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX, Wang, 2021, Ben Wang. Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX. https://github.com/kingoflolz/mesh-transformer-jax, May 2021.】),如图4(b)所示。一般和完整的推导在附录H中提供。

$$\begin{aligned} \mathcal{HC}_{odd} = \begin{pmatrix} 0 & 1 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{pmatrix}, \end{aligned}$$ $$\begin{aligned} \mathcal{HC}_{even} = \begin{pmatrix} 0 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 0 & 1 \end{pmatrix}. \end{aligned}$$

动态排列的优势。因此,学习各种形式的超连接矩阵可以创建超越传统顺序和并行配置的层排列,从而产生一种软混合甚至动态的排列。对于静态超连接,网络内的层排列在训练后保持固定。相比之下,动态超连接允许排列为每个token动态适应。

A4 实验环境

  • 数据集

    • 语言模型

      • 密集模型:使用 dolmap-v1.5-sample(【30, Dolma: An open corpus of three trillion tokens for language model pretraining research, Soldaini et al., 2024, Luca Soldaini, Rodney Kinney, Akshita Bhagia, Dustin Schwenk, David Atkinson, Russell Authur, Ben Bogin, Khyathi Chandu, Jennifer Dumas, Yanai Elazar, et al. Dolma: An open corpus of three trillion tokens for language model pretraining research. arXiv preprint arXiv:2402.00159, 2024.】)作为训练数据集。
      • MoE模型:在 OLMOE-MIX 数据集上训练。
      • 训练规模:所有实验均在5000亿个tokens上进行训练。
    • 视觉任务

      • 图像生成与分类:使用 ILSVRC-2012 ImageNet 数据集(【9, Imagenet: A large-scale hierarchical image database, Deng et al., 2009, Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition. Ieee, 2009.】)。
  • 模型架构

    • 密集语言模型:基于 OLMo(【13, Olmo: Accelerating the science of language models, Groeneveld et al., 2024, Dirk Groeneveld, Iz Beltagy, Pete Walsh, Akshita Bhagia, Rodney Kinney, Oyvind Tafjord, Ananya Harsh Jha, Hamish Ivison, Ian Magnusson, Yizhong Wang, et al. Olmo: Accelerating the science of language models. arXiv preprint arXiv:2402.00838, 2024.】)框架,在1B和7B规模上进行实验。
    • MoE语言模型:基于 OLMoE(【23, OLMoE: Open Mixture-of-Experts Language Models, Muennighoff et al., 2024, Niklas Muennighoff, Luca Soldaini, Dirk Groeneveld, Kyle Lo, Jacob Morrison, Sewon Min, Weijia Shi, Pete Walsh, Oyvind Tafjord, Nathan Lambert, Yuling Gu, Shane Arora, Akshita Bhagia, Dustin Schwenk, David Wadden, Alexander Wettig, Binyuan Hui, Tim Dettmers, Douwe Kiela, Ali Farhadi, Noah A. Smith, Pang Wei Koh, Amanpreet Singh, and Hannaneh Hajishirzi. Olmoe: Open mixture-of-experts language models, 2024. URL https://arxiv.org/abs/2409.02060.】)框架,训 练 OLMoE-1B-7B 模型(总参数7B,激活参数1.3B)。
    • 图像生成模型:使用 DiT 框架(【24, Scalable diffusion models with transformers, Peebles & Xie, 2022, William Peebles and Saining Xie. Scalable diffusion models with transformers. arXiv preprint arXiv:2212.09748, 2022.】)。
    • 图像分类模型:使用 ViT/16-Base 和 ViT/16-Large 模型(【11, An image is worth 16x16 words: Transformers for image recognition at scale, Dosovitskiy et al., 2020, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.】)。
  • 硬件配置

    • 论文中未明确提及具体的硬件配置(如GPU型号、数量等)。
  • 软件与实现细节

    • 实现框架:实验设置遵循 OLMo 和 OLMoE 的框架。
    • 核心改动:将基线模型中的残差连接替换为超连接。
    • 训练配置

      • 静态超连接部分不使用权重衰减(weight decay),动态部分使用。
      • 为保持输出的标准差(std)与原始模型一致,在初始化时,将所有层的输出模块权重(包括FFN的第二个线性层和注意力模块的输出投影器)的std乘以因子 $1/\sqrt{n}$,其中n是扩展率。
      • 引入的额外参数和计算开销可忽略不计(见附录表7和表8)。
    • 评估指标

      • 语言模型:在V2和V3验证集上报告平均困惑度(PPL)和损失;在下游基准上报告零样本评估的平均指标。对于MoE模型,还报告了下游基准的准确率。

A4 实验结果

我们主要对大型语言模型的预训练进行实验,包括密集模型和专家混合(MoE)模型,并扩展到视觉生成和分类任务。

4.1 消融研究

实验默认使用扩展率为n=4的动态超连接(DHC),并包含tanh函数。

  • 扩展率n的影响:如表1和图5所示,当扩展率n=1时,DHC的性能劣于基线。但当n>1时,DHC显著优于基线,在n=4时达到最佳效果,增加到n=8带来的额外收益很小。值得注意的是,不带tanh的OLMo-1B-DHC×8在V2和V3验证集上表现出色。此外,DHC(n≥2)的训练损失下降速度比基线更快,且训练过程更稳定,未出现尖峰。


图 5: 不同扩展率的训练损失曲线比较。左子图包括不同扩展率下的动态超连接(DHC)模型,右子图显示了省略tanh函数的效果。两个子图都说明了增加扩展率如何在5000亿tokens上带来更好的训练损失性能。结果使用系数为0.99的指数移动平均进行平滑。

Table 1: 扩展率n在5000亿tokens训练上的消融研究。

  • 静态与动态超连接对比:表2显示,所有超连接(HC)变体都显著优于基线。在扩展率为2时,DHC和SHC的改进相似。但在扩展率为4时,DHC的性能明显优于SHC。

Table 2: 静态和动态超连接在5000亿tokens训练上的消融研究。

  • B和WC的重要性:如表3所示,不训练宽度连接(WC)会导致显著的性能下降。相比之下,不训练B(输出权重)的影响较小。因此,确保WC和B都是可训练的至关重要。

Table 3: OLMo-1B-DHC×4的消融研究。在B或WC列中,符号“✗”表示参数从初始化开始就不可训练。

4.2 与相关工作的比较

我们将Altup(【2, Alternating updates for efficient transformers, Baykal et al., 2024, Cenk Baykal, Dylan Cutler, Nishanth Dikkala, Nikhil Ghosh, Rina Panigrahy, and Xin Wang. Alternating updates for efficient transformers. Advances in Neural Information Processing Systems, 36, 2024.】)和ResiDual(【36, Residual: Transformer with dual residual connections, Xie et al., 2023, Shufang Xie, Huishuai Zhang, Junliang Guo, Xu Tan, Jiang Bian, Hany Hassan Awadalla, Arul Menezes, Tao Qin, and Rui Yan. Residual: Transformer with dual residual connections. arXiv preprint arXiv:2304.14802, 2023.】)方法在OLMo中实现。尽管这些方法在训练初期显示出优势,但最终被基线超越。如表4和图15所示,本文提出的方法性能更优。

Table 4: 相关方法在OLMo-1B模型上的性能。

4.3 7B模型

在7B模型规模上,我们训练了OLMo-7B-DHC×4模型。根据表5和图6,该模型在所有平均指标上均显著优于基线OLMo-7B模型。V2评估中,损失改进0.022,PPL改进0.293。下游基准的平均得分从0.701提升到0.710。此外,DHC模型训练过程更加稳定,没有出现基线模型中频繁的训练尖峰,并且即使在处理了4000亿个token后,性能优势依然持续。


图 6: (1)和(2) OLMo-7B和OLMo-7B-DHC×4模型的训练损失(0.99 EMA平滑)和C4-en验证损失。(3)和(4) hellaswag和sciq上的准确率曲线,展示了OLMo-7B-DHC×4模型的卓越性能。

Table 5: 7B模型的性能。FLOPs指前向传递中每个token的计算量。

4.4 MoE模型

在MoE模型上,我们训练了应用DHC(n=4)的OLMoE-1B-7B模型。结果(图1,图9,表6)显示,超连接在几乎所有指标上都优于残差连接。在许多指标上,我们的方法仅用一半的训练tokens就能达到基线的性能。具体来说,训练损失降低约0.027,C4-en验证集损失降低0.028,ARC-Challenge得分提高6分,MMLU Var得分提高1.2分。

Table 6: 在OLMoE评估设置下,使用5000亿tokens训练的MoE模型的下游评估。ARC-C代表ARC-Challenge,ARC-E代表ARC-Easy。MMLU Var是MMLU的修改版本,包含变化的少样本示例,在早期训练中提供稳定反馈,如OLMoE设置中所述(【23, OLMoE: Open Mixture-of-Experts Language Models, Muennighoff et al., 2024, Niklas Muennighoff, Luca Soldaini, Dirk Groeneveld, Kyle Lo, Jacob Morrison, Sewon Min, Weijia Shi, Pete Walsh, Oyvind Tafjord, Nathan Lambert, Yuling Gu, Shane Arora, Akshita Bhagia, Dustin Schwenk, David Wadden, Alexander Wettig, Binyuan Hui, Tim Dettmers, Douwe Kiela, Ali Farhadi, Noah A. Smith, Pang Wei Koh, Amanpreet Singh, and Hannaneh Hajishirzi. Olmoe: Open mixture-of-experts language models, 2024. URL https://arxiv.org/abs/2409.02060.】) 。

4.5 可视化分析

本节通过将超连接展开为跨层的密集连接来研究其学习到的权重。


图 7: 超连接及各种相关基线方法的连接矩阵可视化。奇数id的注意力层用绿色勾号标记。

  • 基线方法的连接模式:如图7所示,Pre-Norm的连接矩阵是除了对角线外的下三角矩阵;Post-Norm的连接仅限于相邻层;Pre-Norm PTB(并行Transformer块)呈现锯齿状;两跳残差连接呈现垂直条纹状。

  • Λ形连接模式:超连接的连接矩阵呈现出一种Λ形模式。它结合了两种模式:一种是类似Post-Norm的长期衰减模式,即层倾向于依赖少数相邻层的输出;另一种是类似Pre-Norm的模式,即底层(如第0、2层)的输出被大多数后续层频繁使用。这表明超连接实现了Pre-Norm和Post-Norm架构的自由混合。

  • 输入词嵌入被从模型输出中消除:连接矩阵的第一列显示,输入词嵌入对大多数层都有贡献,但对最后一层(用于预测下一个token)没有贡献。这有利于下一个token的预测。

  • 观察到并行的Transformer块:超连接在实践中学习到了类似PTB的模式(可通过局部锯齿状模式识别)。例如,第11层对第12层输入的贡献极小,表明这两层可以并行操作。

  • 注意力层倾向于有更少的长期连接:底部的注意力层几乎没有长期贡献,这种趋势一直持续到第17层。这类似于两跳残差连接的设计,即注意力输出仅贡献给下一个FFN层的输入,而不加入主残差路径。

A5 结论

我们引入了超连接作为Transformer中残差连接的有效替代方案。我们的分析表明,超连接不仅克服了残差连接的局限性,还能够实现网络架构的动态调整。实验结果证实了其在大型语言模型预训练、图像生成和图像分类等多种任务中的巨大优势。

A6 附录

B 参数、计算和内存占用分析

静态超连接(SHC)参数。一个SHC模块的参数数量为 $|\theta_{SHC}| = n \cdot (n+2)$。对于有L层的网络,每层有两个HC模块(一个用于自注意力,一个用于前馈网络),因此总额外参数为 $P_{extra} = |\theta_{SHC}| \times 2 \times L$。例如,在OLMo-1B-SHC×4中,额外参数仅为768。

动态超连接(DHC)参数。一个DHC模块的参数数量为 $|\theta_{DHC}| = |\theta_{norm}| + d_{model} \times (n+2) + n \times (n+2) + 2$。总额外参数为 $P_{extra} = |\theta_{DHC}| \times 2 \times L$。例如,对于OLMo-1B-DHC×4,额外参数约为39.4万。

参数与计算开销总结。如表7和表8所示,无论使用SHC还是DHC,引入的额外参数和计算开销都非常小,可以忽略不计。

Table 7: 参数数量对比。

计算分析。HC的主要计算成本在于宽度连接,其复杂度为 $O(d_{model} \times n \times (n+1))$,远小于FFN的 $O(2 \times d_{model} \times d_{ffn})$ 和注意力投影的 $O(4 \times d_{model} \times d_{model})$。因此,HC的计算成本可以忽略不计。

Table 8: 前向传递中每个token的FLOPs。

内存占用。HC会少量增加训练期间的激活内存。例如,对于n=2,它对标准Transformer的总内存使用量贡献不到15%。通过在反向传播时重新计算HC生成的隐藏状态,可以进一步优化内存。表9显示了在8个GPU上测量的实际内存占用情况。

Table 9: 在8个GPU上测量的内存占用。

E 视觉实验

图像生成。实验使用DiT框架(【24, Scalable diffusion models with transformers, Peebles & Xie, 2022, William Peebles and Saining Xie. Scalable diffusion models with transformers. arXiv preprint arXiv:2212.09748, 2022.】)在ImageNet 256×256上进行。如表10所示,集成了超连接的DiT模型(DiT-XL/2-SHC×2)在参数量减少的情况下,性能指标与参数量多50%的DiT模型(DiT-1B/2)相当,甚至更好。

Table 10: ImageNet 256×256上的类条件图像生成基准测试,cfg=1.50。NP、P和R分别代表数值精度、精确率和召回率。

图像分类。实验在ViT/16-Base和ViT/16-Large模型上进行,使用扩展率n=2的SHC和DHC。如表11所示,SHC和DHC均显著提高了准确率。对于Base模型,DHC将准确率提升了0.88%。对于Large模型,DHC表现最佳,将准确率提升了2.69%。图11显示了训练损失曲线。

Table 11: ImageNet上的准确率。ViT*/16指(【11, An image is worth 16x16 words: Transformers for image recognition at scale, Dosovitskiy et al., 2020, Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.】)报告的结果,而ViT/16表示我们重新实现的基线。SHC和DHC表示残差连接分别被静态和动态超连接替换。


图 11: ViT/16-Large和ViT/16-Large-DHC×2的训练损失曲线,使用衰减率为0.999的指数移动平均(EMA)进行平滑。超连接带来的增益随着训练的进行而减小,这可能是因为在多个epoch中重复遍历同一数据集,导致超连接提供的额外容量的收益递减。

DHC可视化。如图12所示,我们可视化了ViT-Base/16-DHC×2模型最后一层DHC的动态连接权重。观察到,同一类别样本的beta值分布高度集中,而alpha值的分布则不那么集中,但不同类别之间的分布差异更明显。


图 12: ViT-Base/16-DHC×2模型中最后一个DHC的权重分布。

F 更多可视化与分析

展开超连接。我们首先介绍如何确定超连接的连接矩阵$C^{(0)}$。为了简化书写,层输出$T^k(h_0^k)$简写为$T^k$。公式2中的超连接递归形式展开如下:

$$\begin{aligned} \begin{aligned} \mathbf{h_{0}}^{k} & ={\mathbf{H}^{k}}^{\mathsf{T}}\mathbf{A_{m}}^{k}=(\mathcal{T}^{k-1}\mathbf{B}^{k-1}+{\mathbf{H}^{k-1}}^{\mathsf{T}}\mathbf{A_{r}}^{k-1})\mathbf{A_{m}}^{k} \\ & =\sum_{j=0}^{k-1}\mathcal{T}^{j}\mathbf{B}^{j}(\mathbf{A_{r}}^{j+1}\mathbf{A_{r}}^{j+2}...\mathbf{A_{r}}^{k-1})\mathbf{A_{m}}^{k} \\ & =\sum_{j=0}^{k-1}\mathcal{T}^{j}\mathbf{B}^{j}(\prod_{t=j+1}^{k-1}\mathbf{A_{r}}^{t})\mathbf{A_{m}}^{k}. \end{aligned} \end{aligned}$$

因此,我们得到连接矩阵 $c_{kj}^{(0)} = B^j(\prod_{t=j+1}^{k-1} A_r^t)A_m^k$。类似地,第k层第i个超隐藏向量的连接矩阵$C^{(i)}$可以通过将公式27中的最后一个$A_m^k$替换为$A_r^k$来计算。

超隐藏向量的可视化。我们在图13中可视化了超隐藏向量的连接矩阵,以揭示超连接如何维持中间层的输出。首先,四个超隐藏向量彼此不同,显示出完全不同的连接模式。其次,我们可以看到FFN层的输出在超隐藏向量中被长期保留,而注意力层的输出则保留较少。还观察到,长期连接通常成对存储在超隐藏向量中,其中一个超隐藏向量中的连接是正的,而另一个是负的,例如$C^{(1)}$和$C^{(3)}$中的第0和第2列。通过这种策略,这些连接可以在unembedding层之前的求和池化操作中被轻松消除。


图 13: 展开连接矩阵的可视化。从左到右的矩阵分别是 $C^{(0)}$ (对于 $\{h_j^0\}_{j=0}^{L+1}$ 的连接),$C^{(i)}$ (对于 $\{h_{ji}'\}_{j=0}^{L+1}$ 的连接),其中 $i \in \{1, 2, 3, 4\}$。

SHC与DHC共享相似的连接模式。我们在图13b中展示了OLMo-1B-SHC×4模型的连接矩阵。与图13a中的DHC相比,SHC共享完全相同的连接模式。此外,我们在SHC中观察到更多类似PTB的块,例如从第13层到第18层。注意,SHC的连接关系与token无关,因此这种类似PTB的块可以被物理上重组以进行并行计算。

HC×1为何失败。在我们的实验中,OLMo-1B-DHC×1模型的表现比基线差。其连接矩阵如图14所示,揭示了其失败的原因。首先,我们观察到第17层被浪费了,它与后续层完全没有连接。其次,与HC×2和HC×4模型相比,Λ形模式没有出现。需要注意的是,HC×1在其数学公式中不支持Λ模式,其中与前几层的连接必须同时被减弱或加强。因此,从早期层到最终层缺乏连接可能会遭受梯度消失的影响,类似于post-norm风格的transformer,从而导致性能下降。


图 14: OLMo-1B-DHC×1, OLMo-1B-DHC×2 和 OLMo-1B-DHC×4 模型展开连接矩阵的比较。

G 残差连接的非可训练超连接矩阵推导

Pre-Norm残差连接。Pre-Norm可以表示为 $\hat{h} = T(Norm(h)) + h$。通过将归一化操作符合并到层T中,即 $T := T \circ Norm$,我们可以将整个过程表示为 $\hat{h} = T(h) + h$。要使用超连接来表示这一点,Pre-Norm的矩阵可以构造成如下形式:

$$\begin{aligned} \mathcal{HC}_{PreNorm} = \begin{pmatrix} 0 & 1 \\ 1 & 1 \end{pmatrix} \end{aligned}$$

Post-Norm残差连接。Post-Norm可以表示为 $\hat{h} = Norm(h+h')$,其中 $h' = T(h)$。要用扩展率为n=1的超连接来表示此操作,需要一个封装此操作的超连接矩阵HC:

$$\begin{aligned} \mathcal{HC}_{PostNorm} = \begin{pmatrix} 0 & \frac{1}{\sqrt{\sigma_{\mathbf{h}}^{2}+\sigma_{\mathbf{h}'}^{2},+2\sigma_{\mathbf{hh}'}}} \\ 1 & \frac{1}{\sqrt{\sigma_{\mathbf{h}}^{2}+\sigma_{\mathbf{h}'}^{2},+2\sigma_{\mathbf{hh}'}}} \end{pmatrix} = \begin{pmatrix} 0 & \mathbf{B} \\ \mathbf{A}_{m} & \mathbf{A}_{r} \end{pmatrix}. \end{aligned}$$

其中 $\sigma_h^2$, $\sigma_{h'}^2$ 和 $\sigma_{hh'}$ 分别是输入h的方差、层输出h'的方差以及它们之间的协方差。

H 顺序-并行二元性

顺序排列的超连接矩阵。本节证明,以下超连接矩阵将产生n个相同的网络,这些网络按顺序排列,并且它们之间有残差连接:

$$\begin{aligned} \mathcal{HC} = \begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{1}_{1 \times n} \\ \mathbf{e}_1 & \mathbf{e}_{n \times n} \end{pmatrix}, \end{aligned}$$

其中 $e_{n \times n}$ 是一个n×n的单位矩阵,$e_i \in R^{n \times 1}$ 表示 $e_{n \times n}$ 的第i列,而 $1_{1 \times n}$ 表示一个1×n的全1矩阵。

并行排列的超连接矩阵。本节展示,以下超连接矩阵将产生一个网络,其中每n个相邻层并行排列,每层都包含残差连接。我们定义一个并行排列的网络,使得n个相邻层形成一个组,组内的层是并行的,而组与组之间是顺序排列的。可以证明,这种排列可以通过以下超连接矩阵来描述。首先,对于k满足 $k-1 \equiv 0 \pmod n$ 的层:

$$\begin{aligned} \mathcal{H}\mathcal{C}^{\{k \mid k-1 \equiv 0 \quad(\bmod n)\}}=\begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{e}_{1}^{\intercal} \\ \mathbf{1}_{n \times 1} & \mathbf{1}_{n \times n,} \end{pmatrix} \end{aligned}$$

接下来,对于k满足 $k-1 \equiv i \pmod n$ 且 $i \neq 0$ 的层:

$$\begin{aligned} \mathcal{HC}^{\{k | k-1 \equiv i \pmod n), i \neq 0\}} = \begin{pmatrix} \mathbf{0}_{1 \times 1} & \mathbf{e}_i^{\intercal} \\ \mathbf{e}_i & \mathbf{e}_{n \times n} \end{pmatrix}. \end{aligned}$$

I 超连接的伪代码

J 超连接的PyTorch实现

算法2:类似PyTorch风格的超连接伪代码。

# h: 超隐藏矩阵 (BxLxNxD)
class HyperConnection(nn.Module):
    def __init__(self, dim, rate, layer_id, dynamic, device=None):
        super(HyperConnection, self).__init__()
        self.rate = rate
        self.layer_id = layer_id
        self.dynamic = dynamic
        self.static_beta = nn.Parameter(torch.ones((rate,), device=device))
        init_alpha0 = torch.zeros((rate, 1), device=device)
        init_alpha0[layer_id % rate, 0] = 1.
        self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye((rate), device=device)], dim=1))
        if self.dynamic:
            self.dynamic_alpha_fn = nn.Parameter(torch.zeros((dim, rate + 1), device=device))
            self.dynamic_alpha_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.dynamic_beta_fn = nn.Parameter(torch.zeros((dim,), device=device))
            self.dynamic_beta_scale = nn.Parameter(torch.ones(1, device=device) * 0.01)
            self.layer_norm = LayerNorm(dim)
    
    def width_connection(self, h):
        # 获取 alpha 和 beta
        if self.dynamic:
            norm_h = self.layer_norm(h)
            if self.dynamic:
                wc_weight = norm_h @ self.dynamic_alpha_fn
                wc_weight = F.tanh(wc_weight)
                dynamic_alpha = wc_weight * self.dynamic_alpha_scale
                alpha = dynamic_alpha + self.static_alpha[None, None, ...]
            else:
                alpha = self.static_alpha[None, None, ...]
            if self.dynamic:
                dc_weight = norm_h @ self.dynamic_beta_fn
                dc_weight = F.tanh(dc_weight)
                dynamic_beta = dc_weight * self.dynamic_beta_scale
                beta = dynamic_beta + self.static_beta[None, None, ...]
            else:
                beta = self.static_beta[None, None, ...]
        
        # 宽度连接
        mix_h = alpha.transpose(-1, -2) @ h
        return mix_h, beta

    def depth_connection(self, mix_h, h_o, beta):
        h = torch.einsum("blh,bln->blnh", h_o, beta) + mix_h[..., 1:, :]
        return h

算法3:类似PyTorch风格的带超连接的Transformer伪代码。

# h: 超隐藏矩阵 (BxLxNxD)
# atten_hyper_connection, ffn_hyper_connection: 超连接模块
# attn_norm, ffn_norm: 归一化模块

# Attention Block
mix_h, beta = atten_hyper_connection.width_connection(h)
h = attn_norm(mix_h[..., 0, :])
h = self_attention(h)
h = atten_hyper_connection.depth_connection(mix_h, dropout(h), beta)

# FFN Block
mix_h, beta = ffn_hyper_connection.width_connection(h)
h = ffn_norm(mix_h[..., 0, :])
h = ffn(h)
h = ffn_hyper_connection.depth_connection(mix_h, dropout(h), beta)

K 验证集和下游任务

Table 13: OLMo的默认配置使用多种指标进行评估。V2和V3验证集使用困惑度(PPL)和损失,而下游基准测试则应用零样本测试。然而,由于性能指标不稳定,灰色基准测试被排除在我们的分析之外。

Table 14: OLMoE的下游基准。

L 1B模型实验补充图表


图 15: 相关工作的训练损失曲线,使用衰减率为0.99的指数移动平均(EMA)进行平滑。


图 16: 带有tanh的DHC在5000亿tokens上的训练损失曲线,使用衰减率为0.99的指数移动平均(EMA)进行平滑。


图 17: 不带tanh的DHC在5000亿tokens上的训练损失曲线,使用衰减率为0.99的指数移动平均(EMA)进行平滑。


图 18: 与并行Transformer块(PTB)比较的训练损失曲线,使用衰减率为0.99的指数移动平均(EMA)进行平滑。

Table 15: 1B模型在下游基准上的结果。