Weight-sparse Transformers Have Interpretable Circuits

作者/机构:Leo Gao, Achyuta Rajaram, Jacob Coxon, Soham V. Govande, Bowen Baker, Dan Mossing

A1 主要贡献

在语言模型中寻找人类可理解的回路是机制可解释性领域的核心目标。本文通过约束模型的大部分权重为零来训练模型,使其拥有更易于理解的回路,这样每个神经元只有少数几个连接。为了恢复支持多个手工任务的细粒度回路,我们对模型进行剪枝,以分离出负责该任务的部分。这些回路通常包含与自然概念相对应的神经元和残差通道,它们之间存在少量直接可解释的连接。我们研究了这些模型的扩展性,发现增加权重稀疏性是在能力和可解释性之间进行权衡,而扩大模型规模可以改善能力-可解释性的前沿。然而,在保持可解释性的同时将稀疏模型扩展到数千万个非零参数以上仍然是一个挑战。除了从头开始训练权重稀疏模型,我们还展示了初步结果,表明我们的方法也可以用于解释现有的密集模型。我们的工作产生了达到了前所未有的人类可理解性水平的回路,并通过相当严格的验证证实了它们。


图 1. 我们整体设置的示意图。我们首先训练权重稀疏模型。然后,对于一系列精心策划的简单行为,我们将模型剪枝到执行该任务所需的节点子集。我们通过将节点剪枝到其在预训练分布上的平均激活值来对其进行消融。

核心问题与研究目标

  • 核心问题:大型语言模型等神经网络虽然能力强大,但其内部工作机制(即它们实现的算法)仍然很不清晰。一个主要的困难是激活和权重不易直接理解,例如神经元的激活模式与人类可理解的概念不对应。一个假设的原因是“叠加”(superposition)【20, Toy models of superposition, 2022, arXiv】,即密集模型是一个更大、更清晰的稀疏网络计算的近似。
  • 研究目标:本文旨在通过引入一种新的范式来产生更简单、更通用的回路,这些回路即使在最低的抽象层次上也能够被完全理解。

创新点

  1. 权重稀疏训练 (Weight-sparse training):本文提出训练权重绝大多数为零的Transformer模型(即权重的L0范数很小)。这种约束极大地简化了模型的计算。由于每个神经元只能从少数几个残差通道读取或写入,模型被阻止将概念表示分布在多个残差通道上,或使用超过表示单个概念所需的最少神经元。

  2. 可分离且紧凑的任务回路:通过隔离执行每个任务的最小回路,我们表明模型为不同任务学习了解耦的回路。在这些回路中,我们发现神经元激活通常对应于简单的概念,例如“跟在单引号后面的词元”或“列表嵌套的深度”,并且权重编码了概念之间直观的联系。

  3. 严格的回路验证:我们证明了这些解耦的回路对于模型在这些任务上的行为是必要且充分的。保留回路中的少数节点而对其他所有节点进行平均值消融,可以保持任务损失不变;而删除回路中的少数节点则会严重损害任务损失。


    图 2. 我们的权重稀疏模型学习到比密集模型更简单的特定任务回路。我们检查了一个稀疏模型和一个具有相同预训练损失的密集模型。我们扫描目标损失,并找出每个模型中能够达到该损失的最小回路大小,然后在各个任务上取平均值。在任何给定的损失下,稀疏模型的回路大约小16倍。

  4. 可扩展性分析:我们发现,增加权重稀疏模型的总参数数量可以改善其在“能力-可解释性”上的帕累托前沿。在固定总参数数量的情况下,使模型更稀疏(即减小权重的L0范数)会在这两者之间进行权衡,损害能力但提高可解释性。


    图 3. 扩展权重稀疏模型的总参数数量改善了能力-可解释性的帕累托前沿。在总参数数量固定的情况下,使模型更稀疏(即减小权重的L0范数)会在这两者之间进行权衡,损害能力但提高可解释性。我们将能力定义为预训练损失;可解释性的定义见第2.2节。左下方表示更好。

  5. 解释现有密集模型:本文展示了初步结果,表明权重稀疏训练可以通过在每一层使用“桥接”(bridges)将其应用于理解现有的密集模型。这些桥接将权重稀疏模型的表示与目标密集模型的表示耦合起来,使得我们的模型可以作为原始密集模型的可解释替代品。

A2 方法细节

我们首先训练权重稀疏模型——即大部分参数为零的Transformer模型。我们所有的模型都在一个Python代码数据集上进行预训练。然后,我们在一个精心策划的简单、明确的任务套件上检查模型的行为,在这些任务中,模型被迫在两种可能的补全之间做出选择。

为了评估我们模型的可解释性,我们使用一种新颖的剪枝方法来分离出模型用于执行每项任务的小型稀疏回路。由于可解释的模型应该易于解耦,单个行为应该由紧凑的独立回路来实现。

稀疏回路被定义为一组由边连接的节点。我们对节点的定义是最大程度细粒度的,对应于权重矩阵的行和列:我们将节点定义为单个神经元、注意力通道、残差通道读取或残差通道写入。因此,边是权重矩阵中的一个非零项,连接两个节点。


图 4. 字符串闭合回路。我们省略了任何细节,展示了近乎完美完成任务所需的全部12个节点和9条边。首先,0.mlp将词元嵌入转换为“引号检测器”和“引号类型分类器”残差通道,它们分别被10.attn中的键和值通道读取。随后的词元关注该键并复制该值以预测相应的闭合引号。在图中,每个输入词元下方的垂直线束是其残差流。左侧显示了在精心挑选的任务示例上重要节点的激活。水平虚线标记了层的边界。⊗表示标量乘法;直接合并的线表示标量加法。黑色数字表示通道或神经元索引。红色和蓝色数字标记正负权重(或偏置)。此图仅显示相关的注意力路径。回路的非活动部分被灰色化,不相关的层被省略。

我们将手工策划任务中回路的边的几何平均数作为我们主要的定量可解释性指标。


图 5. 计数嵌套深度的电路简化示意图,使用了图4的约定。单个注意力值通道作为从词元 [ 的嵌入派生出的“开括号检测器”。然后,注意力头对上下文中该检测器的值进行平均,并将其写入每个词元的残差流(即“嵌套深度”)。随后的注意力头使用一个查询通道读出嵌套深度,并对其进行阈值处理,使其仅在嵌套列表中激活。该电路使用7个节点和4条边。理解这个算法使我们能够用“干扰项”对模型进行对抗性攻击。


图 6. 跟踪变量类型的电路粗略示意图。该模型使用一个包含2个注意力头的两跳算法,总共使用4个查询/键通道和3个值通道。首先,它将变量名 current 复制到 set() 词元中。然后,它将此作为键,允许模型将 set() 词元的值复制到最终的词元位置,在那里它读出正确的答案。

2.1. 稀疏训练

模型架构。我们使用一个类似于Radford等人【Language models are unsupervised multitask learners, 2019, OpenAI blog】的GPT-2风格的仅解码器Transformer,并做了一些微小修改。我们对所有权重和偏置(包括词元嵌入)强制施加稀疏性,这样我们可以在保持非零参数数量(L0范数)完全不变的情况下增加模型的宽度。我们最稀疏的模型大约每1000个权重中有一个非零。我们还在所有节点位置强制实施温和的激活稀疏性,每4个激活中有一个非零。更多细节和消融研究请参见A.1节。

优化过程。我们使用AdamW优化器【56, Decoupled weight decay regularization, 2019, ICLR】来最小化交叉熵损失。为了强制执行L0权重稀疏性约束,我们在每个训练步骤中应用AdamW后,将每个权重矩阵中除了幅度最大的条目之外的所有条目都置为零,以使每个矩阵具有相同的非零元素比例。我们在整个训练过程中将L0范数从完全密集退火到目标L0范数。有关用于确保优化稳定性的技术的更多细节和消融研究,请参见A.2节。

2.2. 测量可解释性

任务分布。我们手动构建了一组20个简单的Python下一词元二元预测任务。例如,一个任务(单双引号)是预测用单引号还是双引号来闭合一个字符串,其中上下文中唯一的区别是开头的引号词元是单引号还是双引号。另一个任务(集合或字符串)通过询问一个变量名后面应该是.add还是+=来衡量模型跟踪变量类型的能力,其中示例仅在变量的初始化上有所不同。所有任务的描述见表1。

剪枝方法。对于每个任务,我们对模型进行剪枝,以获得在任务分布上达到目标损失的最小回路。除非另有说明,目标损失在所有地方都为0.15。我们通过删除所有词元位置上的部分节点子集来进行剪枝,这类似于Cao等人【11, Low-complexity probing via finding subnetworks, 2021, arXiv】的方法。被删除的节点被“均值消融”——即它们的激活被冻结在预训练分布上的平均激活值。关于各种消融方法的影响的讨论,请参见E节。我们提出了一种新颖的结构化剪枝算法。我们学习一组掩码$τ_i$(由节点索引),用于门控相应的节点位置$x_i \to x_i \odot \sigma(\tau_i)$,其中$\sigma$是亥维赛德阶跃函数。我们通过使用sigmoid导数代理梯度来反向传播通过亥维赛德阶跃函数(类似于直通估计器(Straight-Through Estimator)【2, Estimating or propagating gradients through stochastic neurons for conditional computation, 2013, arXiv】),从而训练掩码参数$τ_i$,最小化任务损失和回路大小的联合目标。更多细节请参见A.5节。

2.3. 桥接 (Bridges)

引入桥接以理解密集模型。在第3.3节中,我们将我们的方法扩展到理解已经训练好的密集模型。我们训练一个权重稀疏模型,同时训练一系列桥接,这些桥接在每个子层(即每个注意力块和每个MLP块之前)将密集模型和稀疏模型的激活进行映射。每个桥接由一个将密集模型激活映射到稀疏模型激活的编码器和一个反向映射的解码器组成。

损失函数。我们希望权重稀疏模型能够匹配密集模型的计算,同时桥接能够准确地在稀疏和密集激活之间进行转换。为此,除了常规的预训练损失外,我们还使用了多个桥接损失项(图7)。我们使用一个归一化均方误差(MSE)项,训练桥接编码器从密集激活中准确预测稀疏激活(桥接解码器则反之)。我们还运行稀疏模型和密集模型的混合前向传播,在不同的单个位置使用桥接将一种激活转换为另一种。我们训练稀疏模型的权重,使得这些混合前向传播与原始密集模型相比具有较低的KL散度。关于我们设置的更精确陈述,请参见A.3节。

3. 实验结果

3.1. 权重稀疏性提升可解释性

首先,我们测量稀疏性是否允许模型学习更小的回路。我们针对每个任务,在一个密集模型和一个预训练损失相当的稀疏模型中计算最小回路,并对所有任务的回路大小取平均值(图2)。结果显示,对我们的权重稀疏模型进行剪枝,得到的回路比对预训练损失相当的密集模型剪枝得到的回路小约16倍。我们还能够以更多边的代价构建任意精确的回路。这表明,在权重稀疏模型中,简单行为的回路比在密集模型中更加解耦和局部化。


图 7. 从一个现有的密集模型开始,我们可以将权重稀疏模型与桥接(一系列允许我们在稀疏和密集模型表示之间转换的线性映射)联合训练,从而使得所有通过稀疏和密集层混合的路径在预训练中仍然表现良好。

为了进一步验证我们回路的忠实性,我们还表明它们不仅是充分的,而且是任务所必需的。当我们消融(ablate)构成我们回路的极小部分节点,而保持网络其余部分不变时,模型性能会受到严重损害(图32)。

我们还发现我们的方法随着模型规模的扩大而改进。当我们增加模型的隐藏维度,同时保持层数不变时,我们改善了可解释性-能力的前沿(图3)。改变L0范数会沿着这个前沿移动,在能力和可解释性之间进行权衡。最后,如果我们固定L0范数并使用一个更大的模型,我们发现能力和可解释性都得到了提升。具有相同L0范数的更大模型表达能力更强,并且每个神经元或残差通道的非零权重更少。

可解释性的另一个方面是特征质量,根据经验,这似乎与激活稀疏性密切相关。我们发现增加权重稀疏性会自然地增加残差流激活的稀疏性(图10)。

3.2. 定性回路研究

我们的最终目标是机械地完全理解模型的计算过程。虽然我们在任务上找到了小回路,并通过均值消融验证了它们忠实地实现了模型的底层计算,但回路大小仅仅是可解释性的一个代理指标。为了验证模型是否实现了人类可理解的算法,我们手动解释了两个模型中三个不同任务的剪枝回路(选择它们是因为它们看起来很简单)。对于每个回路,我们大约花费了一个研究员日的工作量。这包括手动移除无关的节点,以及通过手动激活修补实验来验证我们对节点的自然语言描述。

稀疏回路的一个独特前景是,我们可以在没有任何特定任务数据或假设的情况下,期望能解释它们。乐观地看,在将节点与自然概念联系起来之后,我们可以通过直接检查它们的边来提取相关回路。为此,对于以下回路,我们报告了每个回路中组件的总边数相对于被解释子集的比例。通常,如果组件的总边数较少,我们预计追踪它们的回路会更容易,它们的边也可能更具可解释性。


图 10. 我们发现权重稀疏性会诱导残差流中的激活稀疏性。由于残差流的条目很少完全为零,我们改为测量峰度。随着权重L0范数变小,或总参数数量变大,最终残差流激活的峰度增加。


图 8. 使用我们的回路发现的一个令人惊讶的针对括号计数回路的对抗性样本。因为我们的模型使用同一特征的更强和更弱的激活来表示不同的括号嵌套深度,并且还对上下文中该特征取平均值,所以应该可以通过在上下文中放入更多词元来欺骗模型。我们发现我们的模型(即使未剪枝)在较长的列表上会经历显著的“上下文稀释”,无法正确预测 ]] 的补全。这是3.2.2节中所理解的回路的自然结果;随着平均值在更多词元上计算,激活的幅度减小(与上下文长度的倒数成正比)。引人注目的是,我们发现这种攻击可以推广到能力相似的密集模型。

3.2.1. 闭合字符串

我们从一个简单的任务(单双引号)开始,输入中包含一个以单引号或双引号开头的字符串,该字符串必须以相应类型的引号闭合。我们发现这个任务的回路使用两个步骤,涉及一个MLP层中的两个神经元和一个注意力头(使用一个QK通道和一个V通道)。我们相信以下机制反映了模型在我们任务分布上闭合字符串的方式。

在第一步中,最早的MLP层(0.mlp)将("和('的嵌入组合成一个“引号检测器”神经元(通道985,在("和('上均为正值),以及一个“引号类型分类器”神经元(通道460,在("上为正值,在('上为负值)。在第二步中,第10层的注意力头(10.attn.head82)使用“引号检测器”作为键(通道1),使用“引号类型分类器”作为值(通道0)。由于最后一个词元有一个恒定的正值查询,注意力头的输出成功地闭合了字符串。该回路的详细示意图见图4。

我们还检查了我们找到的节点在整个预训练分布上是否具有相同的解释。我们发现,即使在预训练分布上,也只有部分而非全部节点是完全单义的——一个例子见图39。

总的来说,所描述的四个组件(两个MLP神经元、一个QK通道和一个V通道)总共有41条边连接到网络的其余部分,其中9条被该回路使用——基于此,我们乐观地认为,在没有特定任务数据的情况下,有可能理解这个回路。

3.2.2. 计算嵌套深度

接下来,我们研究一个更复杂的任务,探究模型是否能分别用 ] 或 ]] 适当地闭合一个扁平列表或嵌套列表(括号计数)。我们手动为这个任务提取了一个最小回路。我们相当有信心地认为,模型在我们任务分布上通过以下三个步骤来计算列表的嵌套深度:

  1. 嵌入。最后一个词元的词元嵌入被忽略。[ 词元的词元嵌入写入到残差通道759、826和1711。这些通道成为 [ 词元的“括号检测器”。
  2. 计数。有了这些括号检测器激活,模型接着使用第2层中的一个值通道来计算括号数量。模型将嵌入通道加总到一个值通道(头125通道12),该通道作为一个“开括号检测器”,在每个开括号词元处激活。注意力头125在整个序列中具有接近零的查询和恒定的键,这意味着softmax注意力操作相当于对上下文进行平均。然后,头125将得到的平均开括号检测器值写入残差通道1249。残差通道1249通过其幅度编码“列表深度”。
  3. 阈值化。然而,为了确定我们是否要输出 ]] ,我们需要将我们的列表深度阈值化为一个二元输出。第4层中的第二个注意力头(头80)通过使用一个强注意力汇(attention sink)来实现这一点,将列表深度用作查询通道激活(通道4)。注意力softmax起到阈值的作用;扁平列表和列表外的词元的$q \cdot k \ll \text{sink}$,而嵌套列表的$q \cdot k \gg \text{sink}$。因此,头80仅在嵌套列表上(查询足够大时)向残差流通道1079(“嵌套列表闭合”)输出一个正值,并输出 ]] 。

我们对括号计数的机理理解可以用来准确预测模型在未见过的相关输入上的表现。该回路依赖于对上下文中先前见过的词元的简单平均,这在列表前出现“干扰项”未匹配的开括号时会失败。使用一个对抗性的代码注释,我们可以成功地欺骗模型在一个扁平列表上预测双括号补全。这种攻击,连同一个回路示意图,呈现在图5中。此外,我们知道注意力回路使用残差通道1249的激活幅度来表示不同的括号嵌套深度,并且还对上下文长度取该特征的平均值。因此,可以通过让列表变得非常长来使(未剪枝的)模型在嵌套列表上错误地预测 ] ,从而“稀释”上下文。这种效应与残差通道1249的激活幅度完全相关(更多细节见图8)。我们发现这种攻击甚至可以推广到能力相当的密集模型,这表明这种攻击并非权重稀疏模型的特有现象。

这个回路总共使用了6个通道,有283条边连接到网络的其余部分。上面省略的另外11个第3层注意力通道和48个第7层MLP神经元还贡献了1217条边。如果没有特定任务的数据,要追踪这个回路可能会很困难。

3.2.3. 跟踪变量类型

即使描述模型行为的回路没有完美地解耦成少数几个可单独解释的激活,我们的模型也倾向于学习部分可解释的计算图。在“set or string fixedvarname”任务中,通过检查回路,我们认为模型使用以下两步算法来跟踪一个名为current的变量是集合还是字符串。首先,给定输入current = set()current = "",第4层中一个关注最近词元的头(头73)通过值通道8将current的嵌入复制到set()""词元中。然后,当模型需要在序列后面回忆变量“current”的值时,第6层中的一个头(头26)使用“current”的嵌入作为查询和键激活。因此,该头将完成任务所需的信息从set()""词元复制到最终词元。这个算法的示意图在图6中概述。

所描述的回路在连续两层的两个注意力头中使用了4个QK通道和3个V通道。这些通道总共有100条边连接到网络的其余部分。如果没有特定任务的数据,可能很难理解这个回路。

3.3. 使用桥接从现有模型中提取回路

到目前为止,我们所有的结果都来自从头开始训练的权重稀疏模型。然而,如果我们也能用我们的方法来理解已经训练好的模型,那将非常有价值。密集模型在计算上比稀疏模型高效得多,并且确认稀疏模型中的回路与密集模型中的回路在机理上是类似的,这将很有价值。

我们在这个方向上做了一个初步的探索,通过训练一个权重稀疏模型,其计算与同一层的密集模型计算相对应(图7)。然后,对权重稀疏模型激活的“可解释”扰动可以映射到密集模型激活的相应扰动,以实现期望的行为改变。


图 9. 使用稀疏模型和桥接来编辑现有密集模型的表示。我们在两个任务(单双引号和while return true)中对密集模型的激活进行“可解释的扰动”,以测试桥接的稀疏模型是否忠实于密集模型。在左图中,我们尝试编辑密集模型的“引号类型分类器”表示,以诱导模型表现得像提示补全单引号字符串而非双引号字符串。在右图中,我们尝试编辑密集模型关于当前行是以 if、while 还是 except 开头的表示,以诱导模型表现得像提示补全 return True 而非 while True。在这两种情况下,密集模型的行为都与至少部分成功的编辑相一致。

使用剪枝,我们识别出能够在特定任务上表现良好的最小稀疏回路,如2.2节所述。我们从这个回路中手动选择一个节点,这个节点似乎(1)基于消融对任务很重要,并且(2)编码了某个感兴趣的特征。我们在稀疏模型中扰动这些节点,并使用桥接将扰动线性映射到原始的密集模型(见A.4节)。

如图9所示,这个过程允许我们构建对密集模型激活的扰动,这些扰动与改变感兴趣的特征是一致的。

在第一个案例中,我们在一个4层密集模型和一个桥接的4层稀疏模型中研究了“单双引号”任务。这个稀疏模型的回路在性质上类似于图4中描述的那个。我们扰动了稀疏模型最后一个注意力层输入处的一个残差通道,该通道充当“引号类型分类器”(激活见图40)。我们用一个双引号字符串提示模型,并引导这个通道,使其激活类似于单引号字符串的激活,以构建一个可解释的扰动。将这个扰动通过桥接应用后,密集模型输出单引号的概率急剧增加。这与编辑模型存储在引号词元中的引号类型表示是一致的。

在第二个案例中,我们研究了另一个任务,“while return true”,其中模型被期望在while True后输出一个:词元,但在return True后输出一个换行符。对于这个任务,我们检查了与同一密集模型耦合的第二个桥接的4层稀疏模型,并操纵了最后一个MLP层输入处的一个通道。这个通道的激活在以ifwhileexcept开头的行中(这些行都应该以冒号结尾)始终是高度负值的(激活见图41)。我们用以return True结尾的代码提示模型,并将这个通道引导向其反事实的while True激活。将这个扰动通过桥接应用后,密集模型输出冒号而非换行符的概率增加(尽管不如前一个任务那么陡峭)。这与部分编辑密集模型关于当前行是否应根据第一个词元以冒号结尾的表示是一致的。

A5 结论

由于基本限制,非结构化的权重稀疏神经网络不太可能达到密集网络的效率(见B节)。因此,使用我们的方法来完全解释前沿模型,或者从头开始训练可解释的前沿模型将是不可行的。我们需要克服这一障碍,以帮助我们提高对前沿模型的理解。我们对两个主要途径感到兴奋。

首先,我们可以扩展我们的方法,创建一系列可解释的模型生物(model organisms),能力水平达到GPT-3。Transformers似乎学习了在稀疏和密集模型中跨尺度出现的通用回路基序(universal circuit motifs)。如果是这样,那么研究这些模型生物的回路基序将让我们了解在前沿模型中寻找什么样的基序,并帮助我们更好地定位我们的研究。特别地,如果我们创建的模型生物的计算通过桥接与密集模型耦合,比较它们的计算对于研究密集模型中的叠加和干扰权重等现象将是有价值的【Olah et al., 2025】。

其次,尽管在整个预训练分布上理解一个前沿模型成本高得令人望而却步,但我们可以通过寻求理解更少的内容来节省计算资源。特别是,我们可以在一个狭窄但重要的任务分布上(例如,欺骗、拒绝、目标寻求)训练一个稀疏的桥接模型。尽管这可能无法实现对前沿模型行为的雄心勃勃的逆向工程,但这可能成为安全案例的一个有价值的工具。

我们也对使用稀疏回路支持自动化可解释性感到兴奋。稀疏回路,像字典学习方法一样,为理解模型计算提供了新的原语——一种计算表达更简单的语言。我们怀疑自动化可解释性的瓶颈在于这类原语,这使得稀疏回路成为自动化的自然补充。

局限性与未来工作

我们相信我们的方法还有许多改进空间。
* 计算(低)效率:与同等能力的密集模型相比,稀疏模型需要多出100-1000倍的训练和推理计算。在优化和系统方面都有很大的改进空间。我们对更好的重初始化技术、更高效地探索不同稀疏性掩码的方法以及系统改进(如使用稀疏核函数)感到兴奋。我们也对权重稀疏的专家混合模型(MoE)感兴趣。
* 多义性特征 (Polysemantic features):我们的回路,特别是对于更复杂的任务,并非完全由单义性节点或边组成。一个可能是模型利用少量叠加是有利的,另一个可能是将稀疏模型的宽度扩展到与典型SAE相近的规模可能会解决这个问题。
* 非二元值特征 (Non-binary-valued features):我们的特征并非都是可二元化的,它们有时在其幅度中携带超出开关状态的信息。如果一个特征不能被离散化,我们需要解释其幅度。
* 定义忠实性 (Defining faithfulness):均值消融不是衡量忠实性的完美标准。最终,需要某种形式的因果擦洗(causal scrubbing)【12, Causal scrubbing, a method for rigorously testing interpretability hypotheses, 2022, AI Alignment Forum】才能对我们回路的忠实性获得完全的信心。
* 定义可解释性 (Defining interpretability):我们使用的“可解释性”概念(拥有紧凑的特定任务回路)并未完全捕捉直观的可解释性概念。我们的定性研究指向了一个更强的人类可理解性概念,我们可以尝试将其编入一个改进的可解释性度量中。
* 其他可解释的归纳偏置:权重稀疏性可能不是完全解开叠加所需的唯一归纳偏置。例如,专家稀疏性不仅可以提高效率,还可以提供重要的归纳偏置。
* 更好的剪枝:我们的剪枝方法专注于剪枝节点,但理想情况下应该直接剪枝边。此外,我们的算法通常不能完全消除所有可剪枝的节点,需要额外的手动剪枝。
* 超越小模型和简单任务的扩展:我们的方法是否能扩展尚不确定。即使在乐观的假设下,解释更强大模型中的复杂行为也会产生极其庞大和复杂的回路,可能需要依赖自动化可解释性。悲观地看,更强大的语言模型可能以根本上无法简单描述的方式执行复杂任务,这将限制整个雄心勃勃的机制可解释性的前景。

A6 附录

A. 方法细节

我们的许多实验在不同时间点使用了略有不同的设置。因此,不同图表中的数字难以比较,这些方法细节是一般性规则,但并不完全适用于我们训练的每一个模型。

A.1. 架构

我们的架构非常接近标准的GPT-2 Transformer架构【Radford et al., 2019】。大多数实验使用nlayer = 8, dmodel = 2048, nctx = 256,除非另有说明。为了确保零值在残差流中具有特殊意义,我们使用RMSNorm【80, Root mean square layer normalization, 2019, arXiv】而不是LayerNorm。使用RMSNorm还使我们能够将所有归一化权重折叠到MLP/注意力权重中,而不改变权重的L0范数。我们的嵌入和解嵌矩阵是不共享的。对于某些模型,我们使用注意力汇(attention sinks)【79, Efficient streaming language models with attention sinks, 2023, arXiv】(一种每个头可学习的注意力分母偏置),我们发现这能产生更清晰的注意力回路,而对损失影响不大(图18)。


图 11. 带有1%预热和前50%训练中L0衰减的“鲨鱼鳍”学习率调度。

为了强制激活的小L0范数,我们在模型的多个位置应用一个AbsTopK激活函数,将除了幅度最大的k个值之外的所有值都置为零;我们通常将k设置为该位置维度的1/4,除非另有说明。图12展示了AbsTopK激活函数在注意力和MLP层中的放置位置。我们发现,在权重稀疏性的基础上,一定量的激活稀疏性有所帮助,但过多的激活稀疏性会再次损害能力-可解释性前沿(图37)。

我们还有一些证据表明激活稀疏性是权重稀疏性的自然结果:我们观察到激活的峰度随着权重稀疏性的增加而增加(图10)。

我们还添加了一个二元组表(bigram table)——一个单一的密集$d_{vocab} \times d_{vocab}$矩阵,将最近一个词元的条目添加到最终的logits中——以避免稀疏参数需要记忆二元组频率。直观地说,向模型添加具有非常简单解释的密集组件是可取的(二元组的先验对数概率);它们有助于改善损失(图24),并避免宝贵的MLP和注意力空间被这些信息占用。

在大多数实验中,我们完全不使用位置嵌入【35, Transformer language models without positional encodings still learn positional information, 2022, arXiv】;我们发现这对损失大致是中性的(图26)。在一些早期的实验中,我们使用了学习的绝对位置编码,作为只读通道连接到残差流:我们发现这提高了可解释性,导致注意力头内部的激活模式更稀疏。

我们使用较小的d_head(本文大多数实验为16),我们发现这在经验上有助于提高注意力头的单义性,但代价是系统效率。

A.2. 权重稀疏模型优化

我们使用AdamW【56, Decoupled weight decay regularization, 2019, ICLR】【42, Adam: A method for stochastic optimization, 2014, arXiv】,其中$\beta_1 = 0.9, \beta_2 = 0.95, \lambda = 0.1, \epsilon = 0.1$,并对每个实验都扫描了学习率lr。这些超参数的消融研究见图27、图28、图29和图30。

我们还避免将那些会导致神经元或注意力通道非零值少于$j=4$个的值清零,以减少死神经元的几率(图22)。我们在训练的前50%线性地退火L0范数(图17),使得模型在训练过程中变得越来越稀疏【81, To prune, or not to prune: Exploring the efficacy of pruning for model compression, 2017, arXiv】。我们的学习率调度由一个正常的预热-衰减调度和一个因子$1/\sqrt{L0}$的乘积定义,因为我们发现较小的L0范数需要较大的学习率。我们发现这一改变对于从L0退火中获得任何好处是必要的(图25)。

我们对梯度的均方根进行裁剪,上限为1。我们发现梯度裁剪对于确保训练稳定性至关重要(图16)。

在某个时候,我们发现我们的L0调度存在一个错误,即虽然我们的权重具有预期的L0调度,但嵌入、解嵌和偏置矩阵会在训练中期的一个相对较小的步数内从密集变为稀疏。然而,修复这个错误似乎在能力和可解释性方面都略微损害了模型质量,所以我们保留了它(图21)。

我们发现学习率预热(在大多数实验中为训练的前1%)对于在较高学习率下的稳定性至关重要,显著改善了最优损失(图13)。我们还发现进行更长的预热稍微有益(图14)。

我们通常没有发现通过增加词元预算在当前边际上改善能力-可解释性帕累托前沿;虽然增加词元预算改善了预训练损失,但通常会损害剪枝后的回路大小。

表 1. 我们创建的所有手工任务列表。

A.3. 桥接损失项

我们希望权重稀疏模型能够匹配密集模型的计算,同时桥接能够准确地在稀疏和密集激活之间进行转换。为此,除了正常的预训练损失外,我们还使用了三个桥接损失项。设$h_i^d$和$h_i^s$分别是密集模型和稀疏模型在第i层的残差激活,$M^d, M^s$是模型的子层(MLP或注意力块),因此$h_0^d, h_0^s$是密集模型和稀疏模型的词元嵌入,且$M_i^d(h_i^d) = h_{i+1}^d, M_i^s(h_i^s) = h_{i+1}^s$,$f_i$和$g_i$是桥接的编码器和解码器,$y^d = M_{unemb}^d(h_L^d)$是最终的logits。首先,我们有一个归一化均方误差项:

$$\mathcal{L}_{\text{NMSE}} = \sum_i^L \text{NMSE}(f_i(h_i^{\text{d}}), h_i^{\text{s}}) + \text{NMSE}(g_i(h_i^{\text{s}}), h_i^{\text{d}})$$

其次,我们有一个KL项来训练稀疏模型接受来自密集模型的激活:

$$\mathcal{L}_{\text{KL}, d \to s} = \sum_{i} \text{KL}\left(y^d, (M^s_{\text{unemb}} \circ M^s_L \circ \dots \circ M^s_i \circ f_i)(h_i^d)\right)$$

还有一个反向的项——密集模型接受来自稀疏模型的激活:

$$\mathcal{L}_{\text{KL},s \to d} = \sum_i \text{KL} \left(y^{\text{d}}, \left(M_{\text{unemb}}^{\text{d}} \circ M_L^{\text{d}} \circ \dots \circ M_i^{\text{d}} \circ g_i\right) (h_i^s)\right)$$

(图7)。

理想情况下,我们希望在所有$2^L$种稀疏和密集层的组合上进行训练。我们可以将我们的KL项视为对此的一阶近似,其中我们只考虑了稀疏和密集之间只有一次转换的组合子集。

A.4. 桥接干预

桥接被训练用于在密集模型和稀疏模型的残差流位置之间进行转换。我们希望重用这些桥接,将稀疏模型中的单节点干预转换为密集模型中“可解释的”密集干预。

桥接作用于残差流位置(RMSNorm之前),因此它们学会了补偿稀疏模型和密集模型之间残差流尺度的差异。因为这些位置在我们的框架中不是“节点”,所以在进行干预时,我们转而使用下一个存在的节点,即RMSNorm之后的残差流读取。残差流尺度考虑不适用于RMSNorm之后,此时两个模型的激活都具有$d_{model}$尺度的范数。为了补偿这一点,在进行干预之前,我们将桥接权重乘以在参考数据集上平均的RMS残差流激活的比率。


图 12. 分别是我们注意力和MLP块的示意图。AbsTopK激活函数被插入到每个节点位置,在每个操作之间。AbsTopK分别应用于Q、K和V。

为了构建稀疏模型干预,我们首先将在给定条件下(例如双引号)感兴趣通道的激活从反事实条件下(例如单引号)的激活中减去,这在所有词元上进行。然后,我们将其与桥接的相应行(按上述方式缩放)进行外积,并构建一个形状为(词元数,密集模型隐藏维度)的张量。我们将其乘以一个介于0.0(无干预)和1.0(完全修补)之间的“引导强度”。


图 13. 学习率预热消融


图 14. 学习率预热比例消融(1%)


图 15. 学习率衰减消融


图 16. 梯度裁剪消融


图 17. 用于退火的训练比例消融(基线为50%)


图 18. 注意力汇消融


图 19. 稀疏偏置(基线)与密集偏置消融


图 20. d_head扫描(基线=16)


图 21. 偏置和嵌入矩阵的L0衰减调度


图 23. L0退火消融


图 22. 每个神经元强制激活的权重数量消融(基线为 j = 4)


图 24. Bigram表消融


图 25. 鲨鱼鳍(基线)与正常预热-衰减学习率调度


图 26. 位置嵌入消融


图 27. Adam β1消融(基线=0.9)


图 28. Adam β2消融(基线=0.95)


图 29. Adam λ消融(基线=0.1)


图 30. Adam ϵ消融(基线=0.1)

A.5. 剪枝算法

设置。首先,我们在模型中的几个位置插入布尔掩码:
* 在注意力和MLP块中每个RMSNorm之后立即插入
* 在每个注意力和MLP块的末尾,将其结果加回到残差流之前
* 在每个MLP激活之后
* 在注意力q, k, v激活之后
我们将这些位置中掩码的每个元素视为一个“节点”,掩码决定了该节点是包含在回路中还是被排除。我们在本文的许多其他部分重用了这个节点的定义。请注意,我们在节点级别应用掩码,而不是像一些以前的剪枝方法那样在边(即节点之间的交互)级别应用【4, Finding transformer circuits with edge pruning, 2025, arXiv】。我们在所有提示和所有词元上统一应用掩码。


图 31. 学习剪枝的性能显著优于基于归因的剪枝基线,在所有损失目标下都能找到更小的回路。

我们为每个节点学习一个被限制在[-1, 1]范围内的参数,该参数用于通过亥维赛德阶跃函数在每一步计算一个布尔掩码,从而决定哪些节点被包含在最终的回路中。通过优化掩码,我们可以学到极小的回路(与基于梯度归因于任务损失选择顶部节点的基线相比,如图31所示)。

初始化。我们使用高斯噪声进行初始化,噪声由一个init_noise_scale因子缩放,并以init_noise_bias为中心,然后将掩码裁剪到[-1, 1]。我们在每个训练步骤后都强制执行这个裁剪约束。

优化。我们使用AdamW和梯度裁剪来优化掩码。我们在整个训练过程中线性衰减学习率,没有任何预热。为了通过亥维赛德阶跃函数进行反向传播,我们使用一个sigmoid估计器来计算反向传播中的有偏梯度近似,温度为heaviside_temp

损失函数。我们在任务交叉熵和k(掩码中非零元素的数量)的线性组合上优化我们的掩码。我们通过一个k_coef项来加权这些损失。

掩码离散化。训练后,我们通过二分查找来找到恰好达到阈值的k;我们观察到有时这与训练过程输出的最终k有很大差异。由于我们发现离散化后的模型通常校准得不好,我们使用16步LBFGS【57, On the limited memory BFGS method for large scale optimization, 1989, Mathematical Programming】来优化最终logits的缩放+平移变换。目前尚不清楚这在一般情况下是否是原则性的做法。

超参数优化。由于超参数数量众多且设置困难,我们使用CARBS【24, Tune as you scale: Hyperparameter optimization for compute efficient training, 2023, arXiv】为每个模型和任务的组合重新调整超参数。我们运行32次CARBS迭代,每次迭代有8个并行的剪枝任务,从表2中给出的初始超参数开始。我们通常发现,在最初的几次迭代中结果很差,而在CARBS调整的后期结果会显著改善。

我们使用的批次大小为64个任务数据点(每个数据点由一个正序列和一个负序列组成),总共128个序列,长度最多为256个词元。

表 2. CARBS的初始搜索中心

A.6. 数据集

我们的数据由GPT-4-base和GPT-4o-base混合生成的Python代码组成。数据包括两个部分:一个100亿词元的Python代码部分,旨在包含特别简单和重复的习语;以及一个250亿词元的部分,旨在更接近Python代码的完整分布。

我们使用更简单数据部分的决定受到了Eldan & Li【18, Tinystories: How small can language models be and still speak coherent english?, 2023, arXiv】的启发,我们希望我们的混合数据能让我们在较小规模上更容易观察到有趣的回路,同时又不会使数据过于玩具化,尽管我们没有对此进行消融研究。本文中一些较小规模的实验仅在较简单的部分上进行训练。

我们使用我们的数据集训练了一个2048个词元的BPE分词器。

A.7. 任务分词考虑

对于“单双引号”任务,我们考虑词元("和(',因为它们总是作为独立的词元出现,不像"和'可能会以多种方式被分词。对于闭合引号词元也是如此。

对于“括号计数”任务,],是一个词元,所以]不太可能表示非最后一行行的末尾。


图 32. 反向剪枝:消融通过剪枝找到的回路会严重损害整个模型的损失。随着找到的回路变大(以及任务性能提高),模型在没有它的情况下会变得(稍微)更差。

B. 扩展的系统考虑

我们的权重稀疏训练和剪枝过程是为了可解释性而优化的,而不是为了硬件效率。特别是,用于强制执行L0约束的(每个张量的)top-k幅度掩码会引入非结构化稀疏性。从系统角度来看,非结构化稀疏性对现代GPU不友好:gather/scatter操作破坏了高性能GEMM核所依赖的平铺数据流,减少了片上SRAM的数据重用,并且通常无法使用Tensor Cores。由于Tensor Cores提供了绝大多数可用的数学吞吐量,默认的密集路径几乎总是比尝试在CUDA核上实现矩阵乘法更可取。

Hopper和Blackwell架构上的GEMM数据流。现代高性能GEMM核通常使用围绕密集子矩阵的批量移动和Tensor Cores计算而设计的平铺流水线:
1. 通过张量内存加速器(TMA)进行分块移动。TMA启动从全局内存(HBM)到共享内存(SMEM)的连续张量分块的异步传输。
2. 在分块上进行MMA。Tensor Core指令对SMEM中的大数据分块进行操作,并计算最大尺寸为128x256的矩阵乘法。
3. Epilogue和全局存储。累加器可选择性地与偏置/激活的epilogue融合,并写回HBM。
这个数据流假设了规整、连续的分块。非结构化掩码(我们的top-k选择)创建了不规则的稀疏模式,这破坏了批量的TMA复制和稳定的Tensor Core指令发行。

即使在非GPU硬件上,由于实现稀疏GEMM所需的根本复杂性,稀疏模型也预计会比密集模型效率低得多。使用大于一的分块大小的平铺数据流会导致内存带宽浪费在移动大部分零上,因此需要直接单独路由权重。从根本上说,GPU芯片上需要“额外空间”来布线任意到任意的连接(将权重矩阵中的每个条目从内存移动到其对应的硬件算术电路),因此即使是为稀疏模型高度优化的硬件实现也无法达到使用类脉动阵列架构的计算密度。这一限制意味着本文所理解的权重稀疏模型在核函数效率方面不可能赶上密集模型。尽管如此,我们提出了一些在现代GPU上提高效率的想法,超越了基线的朴素实现。

选项1:CUDA核GEMM。根据加速器上CUDA核FLOPS与Tensor Core FLOPS的比率,我们可以决定调度到为在CUDA核上运行而编写的稀疏GEMM核,而不是Tensor Core。CUDA核没有分块要求,因此可以高效地操作稀疏数据,但缺点是它们要慢得多。我们发现在低于1%的稀疏度水平上,CUDA核GEMM比它们的Tensor Core对应物更高效。

选项2:半结构化(2:4)稀疏性。当一个操作数满足2:4模式时(在归约维度上,每四个连续元素中至少有两个零),稀疏Tensor Cores可以加速GEMM,理论上最大可获得2倍的数学吞吐量加速和减少的内存流量。我们尝试了2:4半结构化稀疏性,发现它在吞吐量上有所改善,尤其是在稀疏度较低的区域(即1%到50%)。然而,改善通常小于理论上的2倍。将剪枝和压缩融合到矩阵乘法中是具有挑战性的,因此为了最小化内存带宽使用,我们理想情况下会将前向和后向(转置)布局的2:4剪枝和压缩融合到top-k核中。

我们通过二分搜索一个阈值来找到k,这需要非常少的内存写入,从而实现了巨大的加速。这种加速在激活稀疏性期间最为明显,因为小行可以独立地适应于不同的warp中,从而允许在批处理维度上实现大规模并行。这种类型的实现可以进一步加速;我们对近似top-k的方法感到兴奋,因为模型在训练期间似乎对此不特别敏感,只要最后几步是精确完成的。

尽管这些核函数理论上可以提供巨大的加速,但我们当前的优化栈不适合在训练期间使用它们。需要进行大量的进一步优化研究,才能在所有地方使用稀疏核函数,以及所有模型参数的稀疏表示(以节省内存)。在项目早期,我们尝试了各种优化更改以启用核函数的更改,但许多这些系统更改需要以不利于快速迭代的方式来架构我们的代码库,因此我们最终决定在这项工作中专注于研究速度。如果没有对系统方法的额外研究,我们认为我们不太可能大幅扩展我们的模型,但我们相信下面概述的所有问题都是可以解决的。

TopK核函数。我们发现大量时间花费在AbsTopK操作上,包括前向传播(强制激活稀疏性)和训练步骤(强制权重稀疏性)。PyTorch的朴素TopK实现使用基数排序操作,速度很慢。

dW计算。虽然前向传播的矩阵乘法和dx的矩阵乘法都是稀疏权重乘以密集激活矩阵,但dW的计算需要两个密集矩阵的乘积。因此,dW计算将渐近地成为限制因素。重要的是,只有dW的L0个元素会实际用于计算,但这些元素是通过W + dW的top-k选择的(假设我们使用SGD;我们将在下一节讨论Adam)。虽然内存使用问题比较容易解决(我们可以将TopK融合到矩阵乘法核中),但计算使用是一个更大的挑战。我们认为可能可以近似W + dW的top-k操作,或者将其摊销到许多步骤中,类似于Evci等人【22, Rigging the lottery: Making all tickets winners, 2021, arXiv】的方法。

Adam动量。密集内存和计算使用的另一个来源是Adam动量。在早期的实验中,我们探索了将Adam动量剪枝到$m \cdot L0$个最重要的条目,针对不同的m值和动量重要性度量,但总发现这在合理的m值下是一个不可忽略的优化损失。我们可以采取几种方法:我们可以使用像SGD或signSGD【3, signsgd: Compressed optimisation for nonconvex problems, 2018, arXiv】这样的优化器;如果系统整体上的收益是值得的,我们可以接受Adam动量剪枝的成本;我们可以使用稀疏动量Adam与内存高效的类Adam优化器Adafactor【71, Adafactor: Adaptive learning rates with sublinear memory cost, 2018, arXiv】的某种混合。

退火。我们目前发现在大量训练中将我们的模型从密集退火到稀疏对优化有很大帮助。如果我们要大规模扩展,这将是一个巨大的成本。我们也尝试过在最初几个训练步骤中快速衰减的退火计划,并且仍然达到了相似的最终测试损失。


图 33. 稀疏模型和密集模型在单个词元上的损失是强相关的。这种相关性与使用不同种子训练的密集模型之间的相关性相当。

C. 特征幅度的二值化

当一个SAE特征激活时,它可以取非常广泛的不同可能强度。因此,要完全理解一个特征在做什么,仅仅展示它在某个概念存在时才激活是不够的;我们还必须解释为什么它激活到那个程度,或者声称其幅度没有意义【12, Causal scrubbing, a method for rigorously testing interpretability hypotheses, 2022, AI Alignment Forum】。

因此,我们通过在每个节点位置插入一个阶跃函数来对模型进行二值化,然后测量平均任务损失。

二值化算法。我们构建一个激活函数$\psi_{t, l, r}$,使得$\psi_1$是恒等函数,而$\psi_0$是一个阶跃函数

$$\begin{aligned} \psi_{0}(x)=\left\{\begin{array}{ll} \ell & \text { if } x<(\ell+r) / 2 \\ r & \text { otherwise, } \end{array}\right. \end{aligned}$$


这样我们就可以在这两者之间进行插值,并且使得$\psi_t(l) = l$和$\psi_t(r) = r$,通过将恒等函数与一个适当平移的、温度为t的sigmoid函数混合。在整个训练过程中,我们以$(1 - \text{progress})^5$的方式对t进行退火。

为了初始化每个节点的参数$l$和$r$,我们遍历每个节点,并搜索可能的选择,以找到导致性能下降最小的选择。具体来说,我们尝试在观察到的最大和最小激活之间取1/4、1/2和3/4位置的阈值,并取阈值之上和之下激活子集的平均值作为$l$和$r$。


图 34. 在回路中所有特征二值化后的任务损失,在所有任务上取平均值。由于显而易见的原因,L0 > 总参数的数据不可用。

结果。我们发现我们通常可以对任务进行二值化,虽然结果非常嘈杂,但总体趋势是,随着总参数增加而L0保持不变,可二值化性会增强(图34)。然而,与回路大小度量(其中减小L0会提高可解释性)不同,我们通常发现在总参数数量不变的情况下增加L0会提高可二值化性。所有基线任务损失开始时都相同,因为我们剪枝到一个固定的目标任务损失。

从定性探索的经验来看,我们的一些特征似乎不可二值化,但仍然看起来是可理解的——例如,一些节点,比如括号计数中的那个,取3个不同的语义值(分别表示在列表外、在单层嵌套列表内、或在双层嵌套列表内);其他节点,特别是注意力键通道,在整个上下文中持续增长,以便更多地关注最近的词元。

D. L0范数的平滑近似

先前的工作已经发现了几种在语言模型中强制稀疏性的技术。特别是,Louizos等人【57, Learning sparse neural networks through L0 regularization, 2018, arXiv】发现使用基于HardConcrete分布的L0范数的可微估计器取得了成功。在这项工作中,我们在一个玩具语言建模任务上将这种方法与我们的TopK变体进行了基准比较。

为了提高性能,我们对Louizos等人【57, Learning sparse neural networks through L0 regularization, 2018, arXiv】中概述的技术进行了两项主要修改。即,
* 初始化:Louizos等人【57, Learning sparse neural networks through L0 regularization, 2018, arXiv】选择将所有参数初始化为大致相同的值。相反,我们从一个缩放的伯努利分布中采样原始参数值,并具有固定的初始稀疏度。
* 稀疏性下限:通常,特别是在高稀疏性正则化水平下,稀疏模型会学会停用其所有权重,从而阻止模型进一步学习。为了缓解这个问题,并更好地控制最终的稀疏度,我们将稀疏性惩罚裁剪到一个固定的最小值。

我们发现该技术的性能始终比TopK差,在所有稀疏度水平上损失都略高(图35)。这与之前的工作,如Gale等人【31, The state of sparsity in deep neural networks, 2019, arXiv】的研究结果一致,后者在神经机器翻译上对这些技术进行了基准测试。

我们还尝试使用这种技术进行剪枝,类似于Cao等人【11, Low-complexity probing via finding subnetworks, 2021, arXiv】,但未能超越我们的基线。


图 35. 比较我们的方法(基线)与Louizos等人(2018)的方法。

E. 用均值消融验证回路假设

一旦我们相信在我们的稀疏模型中找到了一个回路,我们需要确保该回路确实反映了模型的内部行为。没有忠实性,解释可能看起来合理,但不一定准确反映模型的底层计算【6, An interpretability illusion for BERT, 2021, arXiv】【58, Is this the subspace you are looking for? an interpretability illusion for subspace activation patching, 2023, arXiv】【29, Interpretability illusions in the generalization of simplified models, 2024, arXiv】。因此,为了从可解释性中获得安全保证,我们必须有某种程序来验证我们的解释确实是忠实的。

我们的工作在非常细粒度的节点上,大大提高了电路忠实性的现有技术水平,但它绝不是最大程度忠实的。

我们深受Causal Scrubbing(因果擦洗)【12, Causal scrubbing, a method for rigorously testing interpretability hypotheses, 2022, AI Alignment Forum】的启发。不幸的是,按照因果擦洗的标准,我们的回路并非完全忠实。因果擦洗有两个组成部分:
1. 我们声称在语义上相同的任何两个节点值应该是可以互换的。例如,如果我们声称一个神经元的值大于x当且仅当该词元是一个包含字符串的变量,那么必须是这样:将该节点更改为任何大于x的分布内值都不会影响性能。
2. 我们声称不相关的任何节点必须可以用从预训练期间观察到的值的分布中抽取的任何节点值来替代。

正如我们在C节中看到的,虽然我们的模型在某些任务上满足条件1,但它们远非普遍能够做到。在这方面,我们的方法显示出一些希望的迹象,但表现不佳。

至于条件2,仅使用均值是条件2的一个严格弱化版本。在早期的结果中,我们发现即使仅仅是因果擦洗的条件2也比均值消融难满足得多。然而,我们声称,使用均值消融代替完整的预训练分布所带来的忠实性损失并没有那么糟糕。

Chan等人【12, Causal scrubbing, a method for rigorously testing interpretability hypotheses, 2022, AI Alignment Forum】声称均值消融是次优的,因为它(a)可能以一种无原则的方式将模型带出分布,(b)可能对测量的性能产生不可预测的影响,以及(c)可能移除你的模型可能依赖于其性能的变异。然而,(a)只有在你的回路确实依赖于不相关的节点时才重要,即使如此,偏离合理激活的流形应该比帮助性能更频繁地损害性能;同样,(b)似乎会增加方差并平均损害可解释性分数,但似乎不太可能高估可解释性;为了减少(但不是消除)(c)成为主要问题的可能性,我们还验证了消融我们声称相关的网络部分确实会破坏性能(图32)。

重要的是,均值消融比激活修补(activation patching)【36, How to use and interpret activation patching, 2024, arXiv】能产生更完整的回路。如果回路的某些部分对于执行任务至关重要,但其激活在提示对之间而不是在每个提示对内部有所不同,那么激活修补将无法注意到它。具体来说,例如,由于setstring任务中的每对提示使用不同的变量名,但变量名在每对内部是相同的,激活修补将完全忽略将变量名复制到最终词元的回路部分。在早期的实验中,我们发现激活修补更容易获得好分数,但产生的回路在质量上不令人满意。

我们期待未来的工作能进一步推动电路忠实性的前沿,达到完全的因果擦洗甚至更高水平。

F. 定性结果的细节

F.1. 恒定查询

在“单双引号”任务中,查询通道似乎是一个常数,而不是数据依赖的。为了验证这是一个有效的简化,我们展示了在整个预训练分布上将其设置为一个恒定值,损失仅增加了2.6e-5 nats/token,这与将Q置零时1.47e-4 nats/token的预训练损失相比非常小。

F.2. 重缩放消融和冗余性

对于第3.2.2节中概述的回路,为了获得完整的描述,我们发现重缩放少量节点有助于移除模型的冗余组件。然而,当广泛应用时,这种干预非常强大,并且可以在模型内部“隐藏”叠加。

我们小心确保只在模型计算确实是冗余的,并且没有隐藏模型中存在的叠加时才执行重缩放干预。对于括号计数回路,我们重缩放了两个激活标量:4.attn.resid_delta_idx_10792.attn.resid_delta_idx_1249

对于残差流通道1079,回路在其最终的MLP层中使用它来计算正确的输出logits。如果这个最终的MLP被消融,损失会受到影响。然而,如果直接重缩放4.attn.resid_delta_idx_1079,这个最后一层可以被消融。这等同于剪枝中用来重新校准剪枝后回路输出logits的logit缩放和logit偏置变换。我们还发现,与零消融的基线相比(5.8e-5 nats/token vs 2.69e-4),线性替换MLP的未剪枝输出能够恢复大部分损失。

对于2.attn.resid_delta_idx_1249,似乎在括号计数回路内,残差通道1249同时被注意力层3和注意力层2写入。2.attn.resid_delta_idx_12493.attn.resid_delta_idx_1249的激活模式在预训练分布和括号计数任务分布上都呈现出高度相关性(图36)。因此,我们怀疑模型正在使用3.attn来放大来自2.attn的激活。为了验证这一点,我们干预3.attn.resid_delta_idx_1249,将其激活替换为2.attn.resid_delta_idx_1249的线性函数,并计算预训练损失。与对该节点进行零消融的基线(损失增加4e-3 nat/token)相比,线性替换的损失增加相对可以忽略不计,为7e-4 nats/token。因此,这个通道不太可能处于跨层叠加状态,并且3.attn在索引1249处是冗余的。为了简化回路描述,我们消融了3.attn,并相应地重缩放了2.attn.resid_delta_idx_1249

我们还发现3.attn正在实现高度可解释的计算,在一个注意力头(头85)中基本上是与2.attn“相同的回路”,在另一个头中则复制2.attn.resid_delta_idx_1249。我们主要是为了简化而将其消融。


图 36. 2.attn.resid_delta_idx_1249 与 3.attn.resid_delta_idx_1249 在预训练期间的激活对比。


图 37. 变化{L0, 总参数, 激活稀疏度}之一并保持另外两个不变的等高线图。L0沿着前沿移动。第一张图变化L0,第二张图变化总参数,第三张图变化激活稀疏度。总参数推动了前沿。增加激活稀疏度最初有帮助,但最终会被帕累托支配。


图 38. 我们模型中一些随机(RMSNorm后)残差流特征的激活模式截图。值得注意的是,在选择随机节点之前,我们进行了一个预训练剪枝步骤以移除死节点。对于每个节点,显示了激活值最高和最低5%的随机文档。随机节点在某种程度上是可解释的,特别是考虑到我们的模型不太可能表示极其复杂的概念。


图 39. 我们回路可视化工具的截图,显示了来自整个预训练分布(而不仅仅是单双引号任务)的图4中10.attn.resid_delta.83节点的各种激活百分位数的文档。即使在预训练分布上,激活也惊人地单义——它在双引号字符串内为正激活,在单引号字符串内为负激活,在字符串外接近于零。这个节点是精心挑选的;并非所有节点都如此单义。


图 40. 用于构建可解释扰动的桥接稀疏模型中,一个引号类型分类器第3层注意力输入通道的可视化激活截图。顶部是大多数正负预训练激活,下方是成对的任务激活。


图 41. 用于构建可解释扰动的桥接稀疏模型中,一个“当前行以if/while/except开头”的第3层MLP输入通道的可视化激活截图。顶部是大多数负预训练激活,下方是成对的任务激活。

G. 贡献

  • Leo Gao 确定了研究方向并领导了项目。Leo设计并实现了稀疏模型训练代码库和剪枝代码库。Leo研究了扩展性、优化、架构、剪枝、桥接和特征二值化。Leo负责预训练系统和核函数。Leo创建了几个剪枝任务。Leo为可视化工具做出了贡献。Leo为论文文本的撰写做出了重大贡献。Leo在整个项目期间为Achyuta提供了技术指导。
  • Achyuta Rajaram 研究了预训练架构、优化和回路剪枝。Achyuta改进了剪枝算法。Achyuta为论文中的示例进行了定性分析。Achyuta创建了许多剪枝任务。Achyuta尝试了top-k以外的权重稀疏性替代方案。Achyuta为论文文本的撰写做出了重大贡献。
  • Jacob Coxon 研究了优化和回路剪枝。Jacob对L0退火和基于归因的剪枝进行了初步探索。Jacob设计了数据集的初始版本。Jacob创建了几个剪枝任务。Jacob实现了回路可视化工具。
  • Soham V. Govande 实现了一个优化的CARBS实现,并为核函数和系统分析做出了贡献。
  • Bowen Baker 在项目的第一部分管理了Leo。
  • Dan Mossing 研究了桥接并提供了日常技术反馈。Dan为论文文本的撰写做出了重大贡献。Dan在整个项目期间管理了Jacob和Achyuta,并在项目的后半部分管理了Leo。