Nested Learning: The Illusion of Deep Learning Architectures
Nested Learning: The Illusion of Deep Learning Architectures
文章标题:嵌套学习:深度学习架构的幻象
作者/机构:Ali Behrouz, Meisam Razaviyayn, Peiling Zhong, Vahab Mirrokni / Google Research, USA
A1 主要贡献
核心问题与研究目标:近几十年来,人工智能研究的核心在于开发更强大的神经架构和设计有效的优化算法来训练它们。尽管取得了巨大进展,尤其是在语言模型(LMs)方面,但关于这些模型如何持续学习/记忆、自我改进以及找到“有效解”等基本挑战和问题仍未得到解答。现有的大型语言模型(LLMs)在部署后基本上是静态的,它们能够执行在预训练或后训练阶段学到的任务,但无法在其直接上下文之外持续获得新的能力。LLMs唯一的适应性组件是其上下文学习能力。这种静态特性类似于顺行性遗忘症,即患者无法形成新的长期记忆。当前LLM的记忆处理系统也存在类似模式,其知识仅限于上下文窗口内的即时上下文或存储在MLP层中的“预训练结束”前的长期知识。本文旨在提出一种新的学习范式,以解决现有模型的这些局限性。
创新点与主要贡献:本文提出了一种名为嵌套学习(Nested Learning, NL)的新学习范式。NL将一个模型连贯地表示为一组嵌套的、多层次的、和/或并行的优化问题,每个问题都有其自己的“上下文流”。NL揭示了现有的深度学习方法通过压缩其自身的上下文流来从数据中学习,并解释了上下文学习(in-context learning)如何在大模型中涌现。NL为设计更具表达能力的学习算法提供了一条新路径(深度学习的一个新维度),通过增加“层次”来获得更高阶的上下文学习能力。
本文的主要贡献有三点:
1. 深度优化器(Deep Optimizers):基于NL,本文证明了著名的基于梯度的优化器(如Adam、带冲量的SGD等)实际上是旨在通过梯度下降来压缩梯度的联想记忆模块。基于这一见解,本文提出了一系列具有深度记忆和/或更强大升级规则的更具表达力的优化器。
2. 自修改泰坦(Self-Modifying Titans):利用NL对学习算法的见解,本文提出了一种新颖的序列模型,该模型通过学习自己的更新算法来学习如何修改自身。
3. 连续谱记忆系统(Continuum Memory System):本文提出了一种新的记忆系统公式,推广了传统的“长期/短期记忆”观点。
通过将自修改序列模型与连续谱记忆系统相结合,本文提出了一种名为HOPE的学习模块,在语言建模、持续学习和长上下文推理任务中显示出有希望的结果。
A3 背景知识与动机
1.1 人类大脑视角与神经生理学动机
人类大脑的持续学习与记忆巩固。人类大脑在持续学习(即有效上下文管理)方面非常高效和有效,这通常归因于神经可塑性——大脑响应新体验、记忆、学习甚至损伤而改变自身的神奇能力。最近的研究支持,长期记忆的形成至少涉及两个不同但互补的巩固过程:
1. 快速“在线”巩固(也称为突触巩固):在学习后立即或很快发生,甚至在清醒时。此时,新的、最初脆弱的记忆痕迹被稳定下来,并开始从短期存储转移到长期存储。
2. “离线”巩固(也称为系统巩固):重复回放最近编码的模式——在海马体的尖波涟漪(SWRs)期间,与皮层的睡眠纺锤波和慢振荡相协调——从而加强和重组记忆,并支持其向皮层部位的转移。
LLM的局限性与大脑损伤的类比。回到顺行性遗忘症的类比,有证据表明该病症可以影响记忆巩固的两个阶段,尤其是在线巩固阶段,主要是因为海马体是编码新陈述性记忆的门户,其损伤意味着新信息将永远无法存入长期记忆。LLM的设计,特别是基于Transformer的骨干网络,在预训练阶段后也存在类似情况。即,在上下文中提供的信息永远不会影响长期记忆参数(例如前馈层),因此模型无法获得新知识或技能,除非信息仍存储在短期记忆中(例如注意力机制)。因此,尽管第二阶段(离线巩固)对于记忆巩固同样甚至更为关键,但本文的工作重点是第一阶段:将记忆巩固作为一个在线过程。
符号说明。本文使用 $x \in \mathbb{R}^{N \times d_{in}}$ 表示输入, $M_t$ 表示记忆/模型 $M$ 在时间 $t$ 的状态, $K$ 表示键(keys), $V$ 表示值(values), $Q$ 表示查询(query)矩阵。使用带下标 $t$ 的粗体小写字母表示对应于输入 $t$ 的向量(即 $k_t, v_t, q_t$ )。用 $p(f)$ 表示任意实体 $f$ 的分布。本文使用带 $L_M \ge 1$ 层和残差连接的简单MLP作为记忆模块 $M(\cdot)$ 的架构。必要时,用 $\theta_M \supseteq \{W_1, W_2, \dots, W_{L_M}\}$ 对记忆模块进行参数化。使用带括号的上标来指代嵌套学习中不同层次(不同更新频率)的参数:即 $W^{(\ell)}$。
A2 方法细节
2. 嵌套学习 (Nested Learning)
本节讨论嵌套学习(NL)的动机、形式化定义和一般性高层含义。首先从联想记忆的公式化入手,然后通过逐步的例子,建立架构分解及其与将神经网络建模为优化问题集成系统之间联系的直觉。目标是首先展示深度学习中现有方法和概念如何归入NL范式,然后提出超越传统方法的新公式,和/或为如何改进现有算法和设计提供见解。
2.1 联想记忆
联想记忆的定义。联想记忆是形成和检索事件之间联系的能力,是人类学习中不可或缺的基本心智过程。在文献中,记忆和学习的概念常被互换使用;然而,在神经心理学文献中,这两者有明确区分。本文基于神经心理学文献【57,Learning and memory. Proceedings of the National Academy of Sciences, 2000】的定义来构建术语体系:记忆是由输入引起的神经更新,而学习是获取有效和有用记忆的过程。本文旨在证明计算序列模型的所有元素,包括优化器和神经网络,都是压缩自身上下文流的联想记忆系统。广义上讲,联想记忆是一个将一组键映射到一组值的操作符。本文遵循Behrouz等人【58,It’s all connected: A journey through test-time memorization, attentional bias, retention, and online optimization. arXiv preprint arXiv:2504.13173, 2025】对联想记忆的通用定义。
定义1(联想记忆)。给定一组键 $K \subseteq \mathbb{R}^{d_k}$ 和值 $V \subseteq \mathbb{R}^{d_v}$,联想记忆是一个操作符 $M: K \to V$,它映射两组键 $K$ 和值 $V$ 。为了从数据中学习这种映射,一个目标函数 $\tilde{\mathcal{L}}(\cdot; \cdot)$ 用来衡量映射的质量, $M$ 可以被定义为:
联想记忆与学习过程的联系。操作符本身是一个记忆,映射过程扮演着记忆化的角色(即记忆上下文中的事件连接),而基于数据获取这样一个有效的操作符则是一个学习过程。值得注意的是,这里的键和值可以是记忆旨在映射的任何任意事件,不限于词元(tokens)。在序列建模中,当键和值是输入词元时,选择不同的目标函数和优化过程来求解方程1可以得到不同的序列建模架构,如全局/局部softmax注意力【27,Attention is all you need. In NIPS, 2017】或其他现代循环模型【28,Titans: Learning to memorize at test time. arXiv preprint arXiv:2501.00663, 2024; 60; 61】。这个简单的序列模型公式化有助于我们更好地理解其内部过程,并提供一个工具来简单地根据其目标和优化过程比较它们的建模能力。
MLP训练的简单示例。我们从一个简单的例子开始:用梯度下降法在数据集 $\mathcal{D}_{\text{train}} = \{x_1, \dots, x_{|\mathcal{D}_{\text{train}}|}\}$ 上为任务 $\mathcal{T}$ 训练一个单层MLP(由 $W$ 参数化),优化目标函数 $\mathcal{L}(\cdot; \cdot)$ 。训练过程等价于以下优化问题:
用梯度下降法优化该问题,得到的权重更新规则为:
将反向传播视为联想记忆优化。基于上述公式,可以令 $u_{t+1} = \nabla_{y_{t+1}} \mathcal{L}(W_t; x_{t+1})$,并将反向传播过程重新表述为寻找一个最优联想记忆的优化问题,该记忆将输入数据点 $\mathcal{D}_{\text{train}} = \{x_t\}_{t=1}^{|\mathcal{D}_{\text{train}}|}$ 映射到它们对应的 $u_{t+1} = \nabla_{y_{t+1}} \mathcal{L}(W_t; x_{t+1})$ 。具体来说,我们让 $M(\cdot) = W_t \cdot$ 参数化该记忆,并使用点积相似度来衡量 $W_t$ 在 $x_{t+1}$ 和 $\nabla_{y_{t+1}} \mathcal{L}(W_t; x_{t+1})$ 之间映射的质量:
引入动量(Momentum)的梯度下降。接下来,在上述例子中,我们将梯度下降算法替换为其增强的基于动量的变体,得到更新规则:
在方程8中,给定方程7在时间 $t$ 的状态, $\nabla_W \mathcal{L}(W_t; x_{t+1})$ 或类似地 $\nabla_{y_{t+1}}\mathcal{L}(W_t; x_{t+1})$ 的值与方程8中的递归无关,因此可以预先计算。为此,我们令 $u_{t+1} = \nabla_{W_t}\mathcal{L}(W_t; x_{t+1})$,于是方程8可以重新表述为:
$W_{t+1} = W_t - \mathbf{m}_{t+1},$
$\mathbf{m}_{t+1}=\underset{\mathbf{m}}{\arg \min }-\left\langle\mathbf{m}, \nabla_{W_{t}} \mathcal{L}\left(W_{t} ; x_{t+1}\right)\right\rangle+\eta_{t+1}\left\|\mathbf{m}-\mathbf{m}_{t}\right\|_{2}^{2}$
动量项的解释。基于这些公式,动量项可以被解释为:(1) 一个将梯度压缩到其参数中的无键(key-less)联想记忆,或者 (2) 一个学习如何将数据点映射到其对应LSS值的联想记忆。有趣的是,这个公式揭示了带冲量的梯度下降实际上是一个两级优化过程,其中记忆由简单的梯度下降算法优化。这个过程与快速权重编程(Fast Weight Programs, FWPs)【62,Learning to control fast-weight memories: An alternative to recurrent nets. Neural Computation, 1992】密切相关,其中权重更新过程(方程9)是慢网络,其动量权重由快网络(方程10)生成。总结以上例子,我们观察到训练一个单层MLP:(1) 使用梯度下降是一个学习如何将数据点映射到其LSS值的单级联想记忆;(2) 使用带冲量的梯度下降是一个两级联想记忆(或优化过程),其内层学习将梯度值存储到其参数中,然后外层用内层记忆的值更新慢权重(即 $W_t$)。
架构分解示例:线性注意力。在下一个例子中,我们将MLP模块替换为线性注意力【60,Transformers are rnns: Fast autoregressive transformers with linear attention. In ICML, 2020】。我们旨在用梯度下降法在一个序列 $\mathcal{D}_{\text{train}} = \{x_1, \dots, x_{|\mathcal{D}_{\text{train}}|}\}$ 上为任务 $\mathcal{T}$ 训练一个单层线性注意力,优化目标 $\mathcal{L}$。回顾未归一化的线性注意力公式:
$\mathcal{M}_t = \mathcal{M}_{t-1} + \mathbf{v}_t \mathbf{k}_t^{\mathsf{T}},$
$y_t = \mathcal{M}_t \mathbf{q}_t.$
如早期研究【58; 59】所述,方程13中的递归可以被重新表述为一个矩阵值联想记忆 $M_t(\cdot)$ 的优化过程,其中它旨在将键和值的映射压缩到其参数中。具体来说,在定义1中,如果我们令 $\tilde{\mathcal{L}}(M_{t-1}; k_t, v_t) := -\langle M_{t-1}k_t, v_t \rangle$ 并旨在用梯度下降法优化记忆,记忆更新规则为:(注意 $\nabla \tilde{\mathcal{L}}(M_{t-1}; k_t, v_t) = v_t k_t^\top$ 并且我们令学习率 $\eta_t = 1$)
线性注意力的两级优化视角。训练线性注意力的过程可以被看作一个两级优化过程。外循环(即训练过程)用梯度下降法优化投影层 $W_k, W_v, W_q$(其本身是单级联想记忆优化,见方程3),而内循环用梯度下降法优化内部记忆 $M_t$。这里存在两个联想记忆,每个都有自己的优化过程和梯度流。在外层参数 $W_k, W_v, W_q$ 的优化中,没有关于参数 $M(\cdot)$ 的梯度,因此没有通过它的反向传播。类似地,在内层中,没有通过投影层的反向传播,它们被认为是冻结的。这个公式也与线性注意力的FWPs视角【63,Linear transformers are secretly fast weight programmers. In ICML, 2021】密切相关,其中投影被认为是慢权重,而方程13中的记忆更新是快权重更新规则。
更多层次的架构分解。在实践中,我们需要使用更强大的优化算法来训练模型,和/或使用更强大的循环更新规则来更新记忆。例如,假设我们使用带冲量的梯度下降来训练线性注意力模型。这个模型可以表示为一个2级优化问题,其中(1) 内层用梯度下降优化记忆以压缩上下文(见方程15),(2) 外层用带冲量的梯度下降优化投影层。而我们已知“带冲量的梯度下降”本身就是一个2级优化问题,其中动量项本身就是一个联想记忆,将过去的梯度压缩到其参数中。因此,整个系统构成了一个更深层次的嵌套优化结构。
2.2 嵌套优化问题
形式化定义:更新频率。前述例子展示了如何将一个机器学习模型分解为一组嵌套或多层次的优化问题。为了在这些问题之间建立一个层次结构,本文受大脑脑电波层次结构的启发,使用每个优化问题的更新速率来对组件进行排序。我们将一个数据点上的一次更新步骤定义为时间单位,并定义每个组件的更新频率率。
定义2(更新频率)。对于任何组件 $A$(可以是参数化组件如可学习权重,或非参数化组件如注意力块),我们将其频率定义为其每单位时间的更新次数,记为 $f_A$。
组件排序与层次划分。基于更新频率,我们可以用操作符 $(\succ)$ 对机器学习算法的组件进行排序。如果 (1) $f_A > f_B$,或者 (2) $f_A = f_B$ 但计算 $B$ 在时间 $t$ 的状态需要计算 $A$ 在时间 $t$ 的状态,我们称 $A$ 比 $B$ 快,记为 $A \succ B$。当 $A \not\succ B$ 且 $B \not\succ A$ 时,我们记为 $A \stackrel{f}{=} B$,表示 $A$ 和 $B$ 具有相同的更新频率且它们的计算相互独立(例如AdamW优化器中的情况)。基于此操作符,我们将组件排序为一组有序的“层次”,其中 (1) 同一层次的组件具有相同的更新频率,(2) 层次越高,其频率越低。每个组件都有自己的优化问题和上下文,并且拥有专属的梯度流。
神经学习模块。基于嵌套优化问题的定义,我们定义神经学习模块(Neural Learning Module)作为一种新的机器学习模型表示方式,它将模型展示为一个由多个组件互联而成的系统,每个组件都有自己的梯度流。与深度学习正交,嵌套学习允许我们定义具有更多层次的神经学习模型,从而得到更具表达能力的架构。嵌套学习使得由多个(多层)层次组成的计算模型能够以不同的抽象层次和时间尺度从数据中学习和处理数据。
2.3 优化器即学习模块
将现有优化器视为嵌套学习特例。本节首先探讨著名优化器及其变体如何成为嵌套学习的特例。回顾带冲量的梯度下降法:
扩展1:更具表达力的关联。动量是一个无值(value-less)的联想记忆,表达能力有限。为了解决这个问题,遵循联想记忆的原始定义(将键映射到值),我们让值参数 $v_i = P_i$,此时动量旨在最小化:
$\min_{\mathbf{m}} \left\langle \mathbf{m} \nabla \mathcal{L}(W_i; x_i)^{\top}, \mathbf{P}_i \right\rangle,$
使用梯度下降,得到更新规则:
扩展2:更具表达力的目标函数。优化点积相似度的内部目标会导致类赫布(Hebbian-like)更新规则,可能使记忆效率降低。一个自然的扩展是使用 $\ell_2(\cdot)$ 回归损失来衡量键值映射的拟合度,最小化损失函数 $\|\mathbf{m}\nabla\mathcal{L}(W_i; x_i)^\top - \mathbf{P}_i\|_2^2$,得到基于delta规则【64,Neural network capacity using delta rule. Electronics Letters, 1989】的更新:
扩展3:更具表达力的记忆模块。动量可以被看作是使用一个线性层来压缩过去梯度值的元记忆模型。为了增加该模块的学习能力,一个替代方案是使用更强大的持久学习模块,例如用一个MLP替换线性的矩阵值记忆。因此,动量作为过去梯度的记忆,有更大的容量来捕捉梯度的潜在动态。我们将方程17的公式扩展为:
$W_{i+1} = W_i + \mathbf{m}_{i+1}(\mathbf{u}_i), \quad \text{and} \quad \mathbf{m}_{i+1} = \alpha_{i+1}\mathbf{m}_i - \eta_t \nabla \mathcal{L}^{(2)}(\mathbf{m}_i; \mathbf{u}_i, \mathbf{I})$
其中 $u_i = \nabla\mathcal{L}(W_i; x_i)$ 且 $\nabla\mathcal{L}^{(2)}(\cdot)$ 是动量的内部目标(例如点积相似度 $\langle m(u_i^\top), 1\rangle$)。我们称此变体为深度动量梯度下降(Deep Momentum Gradient Descent, DMGD)。
扩展4:非线性输出。将动量视为一个神经架构,一个常见的增强其表示能力的技术是在其输出之上使用非线性函数【28; 65】。我们将方程23重新表述为:
$W_{i+1}=W_i+\sigma(\mathbf{m}_{i+1}(\mathbf{u}_i)), \quad \text{and} \quad \mathbf{m}_{i+1}=\alpha_{i+1}\mathbf{m}_i-\eta_t\nabla\mathcal{L}^{(2)}(\mathbf{m}_i;\mathbf{u}_i,\mathbf{I}),$
其中 $\sigma(\cdot)$ 是任意的非线性函数。例如,如果令 $\sigma(\cdot)$ 为迭代的Newton-Schulz方法【66,Functions of matrices: theory and computation. SIAM, 2008】,且 $m(\cdot)$ 是一个线性层,那么得到的优化器等价于Muon优化器【34,Muon: An optimizer for hidden layers in neural networks, 2024b】。
超越简单反向传播。预训练过程和反向传播是一种联想记忆形式,其中输入数据被映射到其预测输出引起的意外信号 $\nabla_y \mathcal{L}(W_t; x_t)$:
$W_{t+1} = W_t - \eta_{t+1} \nabla_{W_t} \mathcal{L}(W_t; x_t) = W_t - \eta_{t+1} \nabla_{x_t} \mathcal{L}(W_t; x_t) \otimes x_t, \quad \text{where } x_t \sim \mathcal{D}_{\text{train}}$
从联想记忆的角度看,这等价于优化以下目标的一步梯度下降:
上述公式忽略了数据样本(如 $x_t$)之间的依赖关系。为了扩展到一个更强大的公式,使其也考虑数据点的依赖性(这在使用词元空间优化器时尤其重要),我们使用L2回归目标并进行一步梯度下降:
3. HOPE: 一个具有连续谱记忆的自引用学习模块
连续谱记忆系统(Continuum Memory System, CMS)。现有的架构骨干包含一个工作记忆模块(如注意力)和一个前馈层(如MLP),后者作为预训练阶段的持久记忆或知识存储。从NL的角度来看,预训练是学习模块最外层在其有限上下文流上更新的阶段。因此,我们扩展了传统的长短期记忆观点,并为每个层次(频率域)建议一个知识存储前馈层。CMS被形式化为一系列MLP块 $MLP^{(f_1)}(\cdot), \dots, MLP^{(f_k)}(\cdot)$,每个块关联一个块大小 $C^{(\ell)} := \max_{\ell} C^{(\ell)} f$,使得给定输入 $x = \{x_1, \dots, x_T\}$,链的输出计算如下(为清晰起见,忽略了归一化):
$y_t = \text{MLP}^{(f_k)}(\text{MLP}^{(f_{k-1})}(\dots \text{MLP}^{(f_1)}(x_t))))$
其中第 $\ell$ 个MLP块的参数 $\theta^{(f_\ell)}$ 每 $C^{(\ell)}$ 步更新一次:
这里 $\eta_t^{(\ell)}$ 是对应于 $\theta^{(f_\ell)}$ 的学习率,$f(\cdot)$ 是任意优化器的误差分量。常规的Transformer块【27】是此公式的一个特例,其中 $k=1$。方程31提供了一个重要解释:参数 $\theta^{(f_\ell)}_t$ 负责将它们自己的上下文压缩到其参数中,因此它们是其上下文抽象知识的代表。
HOPE架构。我们进一步提出了一个基于Titans【28】和我们在B.1节中提出的梯度下降变体的自引用学习模块。将这个自引用序列模型与连续谱记忆系统相结合,得到了HOPE架构。
A4 实验
实验环境
- 模型:HOPE架构与多种基线模型进行比较,包括Transformer++、RetNet、DeltaNet、TTT、Samba*、Titans (LMM)。
- 模型规模:实验涵盖了三种不同的参数规模:340M、760M和1.3B。
-
数据集与任务:
- 语言建模:在WikiText和LAMBADA数据集上评估,指标为困惑度(ppl),越低越好。
- 常识推理:在一系列下游任务上进行评估,包括PIQA、HellaSwag (Hella.)、Winogrande (Wino.)、ARC-easy (ARC-e)、ARC-challenge (ARC-c)、Social IQA (SIQA)、BoolQ,指标为准确率(acc),越高越好。
-
硬件与软件:论文中未提供详细的硬件和软件配置,仅提及更多细节见附录G。
实验结果
本文在语言建模和常识推理任务上对HOPE进行了评估。详细的实验结果、包括优化器实验、上下文学习的涌现、HOPE的持续学习能力、消融研究、长上下文任务等,均报告在附录中。
语言建模与常识推理结果:
遵循近期的序列建模研究【28, 67, 68】,本文报告了340M、760M和1.3B参数规模的HOPE及基线模型在语言建模和常识推理下游任务上的结果。如表1所示,HOPE在所有规模和基准测试任务中都表现出非常好的性能,优于Transformer和包括Gated DeltaNet和Titans在内的近期现代循环神经网络。与Titans和Gated DeltaNet相比,HOPE根据上下文动态改变键、值和查询的投影以及使用深度记忆模块,使其在基准测试中获得了更低的困惑度和更高的准确率。
表1:HOPE与基线模型在语言建模和常识推理任务上的性能。混合模型用*标记。
A5 结论
本文提出了一种名为嵌套学习(Nested Learning, NL)的新学习范式,它将机器学习模型及其训练过程重新概念化为一组嵌套的、多层次的优化问题,每个问题都有其独立的上下文流和更新频率。这种视角揭示了现有深度学习架构和优化算法的内在机制,例如,证明了常用的优化器(如带冲量的SGD)本质上是用于压缩梯度的联想记忆模块。
基于NL的见解,本文做出了三大贡献:
1. 开发了更具表达力的“深度优化器”,通过引入更强大的记忆模块(如MLP)、更有效的内部目标(如L2损失)和非线性输出来增强传统优化器的能力。
2. 设计了名为HOPE的自修改序列模型,它结合了自引用学习模块(受Titans启发)和新颖的连续谱记忆系统(CMS)。CMS通过在不同频率上更新不同层次的MLP,推广了传统的长短期记忆概念,允许模型在不同时间尺度上整合和巩固知识。
3. 实验验证了HOPE的有效性,在多种规模的语言建模和常识推理任务上,HOPE的性能显著优于现有的Transformer和现代RNN基线模型。
总的来说,嵌套学习为超越传统的“堆叠层数”方法提供了一个新的维度,为设计能够持续学习、自我改进并具备更强上下文推理能力的模型开辟了新的道路。
A6 附录
本文多次提及附录中包含更多的方法细节、实验结果和讨论,但所提供的文档中未包含附录内容。
💬 评论讨论
欢迎在这里分享您的想法和见解!