Probe Pruning: Accelerating LLMs Through Dynamic Pruning via Model-Probing
Probe Pruning: Accelerating LLMs Through Dynamic Pruning via Model-Probing
作者/机构: Qi Le1, Enmao Diao, Ziyan Wang2, Xinran Wang1, Jie Ding1, Li Yang2, Ali Anwar1; 1University of Minnesota 2University of North Carolina at Charlotte
主要贡献
核心问题与研究目标:大型语言模型(LLMs)的推理过程计算量巨大且延迟高,对部署构成了挑战。结构化剪枝是一种有前景的硬件友好型加速方法,但应用于LLM时,与密集模型相比常存在性能差距。这主要是因为LLM内部表征存在显著的离群点现象,且这些离群点具有批次依赖性,即不同输入批次的离群点特性不同。现有的先进结构化剪枝方法通常使用校准数据集来评估权重重要性,这种静态方法无法应对推理时动态变化的离群点,导致性能下降。此外,使用微调来恢复模型性能不仅资源消耗大,还可能损害模型的泛化能力。因此,本文的目标是设计一种无需大量微调即可有效处理批次依赖性离群点、缩小剪枝模型与密集模型性能差距的方法。
创新点(Probe Pruning, PP):为解决上述问题,本文提出了Probe Pruning (PP),一个在线的动态结构化剪枝框架,它在推理期间根据每个批次的隐藏状态进行剪枝,且无需额外的神经网络模块或微调。其核心贡献和创新点如下:
-
探测(Probing)机制:PP提出了一种探测方法,利用一小部分关键的隐藏状态(基于残差重要性选择)提前运行模型几个层,以捕获指导当前批次剪枝决策所必需的中间隐藏状态信息。这种方法解决了在不引入大量计算开销的情况下,在线获取剪枝所需中间信息的难题。例如,该探测过程仅使用5%的样本和50%的词元,只消耗了密集模型推理1.5%的浮点运算(FLOPs),却能有效指导剪枝。
-
历史信息指导的剪枝与重要性缩放融合:为了克服仅依赖校准数据集可能引入的偏见,并增强小探针(probe)提供信息的鲁棒性,PP提出了一种历史信息指导的剪枝策略。该策略通过“重要性缩放融合”方法,将当前批次的探测状态与历史状态(来自校准数据和先前推理批次)进行策略性整合。这既利用了校准数据集的全局信息,又最小化了其带来的偏见,从而能更准确地捕捉每个批次的独有特性。
-
新型剪枝度量(PPsp):本文提出了一种新的结构化剪枝度量PPsp,它通过保留单个权重的内在重要性(Wanda度量的平方)并计算输入通道重要性分数的L2范数,比现有度量更有效地选择待剪枝的通道。
-
无需微调的高性能:PP的一个主要优势是它完全依赖于原始模型结构和隐藏状态,无需任何微调即可在各种模型(LLaMA-2/3, OPT)和剪枝率下,显著优于现有的、甚至包括经过微调的先进剪枝方法。例如,在LLaMA-2-7B上以40%的剪枝率进行评估时,PP的性能下降与运行时间减少的比率比当前最先进方法低2.56倍。
图 1: Probe Pruning (PP) 分四个阶段执行:(1) PP 基于残差重要性,从层归一化的隐藏状态中选择关键样本和词元,创建一个小而关键的探针。(2) PP 部署此探针提前运行模型几个层,并获取探针的中间隐藏状态。(3) PP 将探测状态与历史状态整合,并使用整合后的状态计算剪枝度量并剪枝权重通道。(4) PP 在剩余的权重上执行完整推理。
符号和准备知识
一个大型语言模型(LLM) M 由 L 个块组成,每个块可以是注意力块或多层感知器(MLP)块。每个注意力块由四个线性投影层表征:查询(Q)、键(K)、值(V)和输出(O)。类似地,每个MLP块包括两个线性层:全连接层1(FC1)和全连接层2(FC2)。
每个块 l 将输入隐藏状态 $X^l \in R^{N \times S \times D}$ 转换为输出隐藏状态 $X^{l+1} \in R^{N \times S \times D}$。这里,N、S 和 D 分别表示批次大小、序列长度和特征维度。每个块 l 中的转换可以表示为:
$$\mathbf{X}^{l+1}=\mathbf{X}^{l}+\mathcal{F}^{l}(\mathbf{X}^{l}),$$其中,$F^l$ 包含了块 l 内的所有转换。该函数可以进一步分解为:
$$\mathcal{F}^l(\mathbf{X}^l) = \mathbf{X}^{l,\text{int}}(\boldsymbol{W}^{l,\text{final}})^T, \quad \mathbf{X}^{l,\text{int}} = \mathcal{T}^l(\text{LN}^l(\mathbf{X}^l)),$$其中,$T^l$ 表示应用于输入隐藏状态 $X^l$ 的所有中间转换,不包括层归一化 $LN^l$ 和最终的权重矩阵 $W^{l,final} \in R^{C_{out} \times C_{in}}$。最终的权重矩阵在注意力块中是输出投影(O),在MLP块中是FC2。中间隐藏状态 $X^{l,int} \in R^{N \times S \times C_{in}}$ 是将这些中间转换应用于 $X^l$ 的结果。此外,我们将残差重要性定义为输入隐藏状态 $X^l$ 在特定维度上的L2范数,这一概念在4.2节中有更详细的说明。
在LLM的结构化剪枝中,整个耦合结构被剪枝【索引36, Xinyin Ma, Gongfan Fang, and Xinchao Wang. Llm-pruner: On the structural pruning of large language models. Advances in neural information processing systems, 2023.】、【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】。具体来说,在块 l 中,前面的权重矩阵通过剪枝其输出通道进行调整,这些输出通道与最终权重矩阵剪枝的输入通道一一对应。例如,在一个MLP块中,权重矩阵根据未剪枝通道索引集合 $C^l \subseteq \{1, 2, . . . , C_{in}\}$ 进行如下调整:
$$\tilde{\boldsymbol{W}}^{l, \mathrm{FC} 1}=\boldsymbol{W}^{l, \mathrm{FC} 1}\left[\mathbb{C}^{l},:\right], \quad \tilde{\boldsymbol{W}}^{l, \mathrm{FC} 2}=\boldsymbol{W}^{l, \mathrm{FC} 2}\left[:, \mathbb{C}^{l}\right]$$其中 $\tilde{W}^{l,FC1} \in R^{|C^l| \times C_{in}}$ 和 $\tilde{W}^{l,FC2} \in R^{C_{out} \times |C^l|}$。符号 $|C^l|$ 表示 $C^l$ 的基数。类似地,在注意力块中,注意力头可以被视为耦合结构【索引36, Xinyin Ma, Gongfan Fang, and Xinchao Wang. Llm-pruner: On the structural pruning of large language models. Advances in neural information processing systems, 2023.】、【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】,并且整个注意力头被剪枝。
方法细节
Probe Pruning (PP) 的目标是以批次为单位实现在线动态结构化剪枝。我们工作的核心思想如图1所示。我们的核心策略包括:(1) 探测(Probing)(4.1和4.2节),包含两个步骤:首先,基于残差重要性生成一个探针;其次,使用该探针运行未剪枝的模型以收集有价值的中间隐藏状态信息。(2) 历史信息指导的剪枝(History-informed pruning)(4.3节),它使用重要性缩放融合(importance-scaled fusion)将探测状态与历史状态谨慎地合并,以捕捉每个批次的基本特征。之后,我们使用一种新颖的剪枝度量(4.4节)来剪枝模型,该度量比现有度量更有效地选择要剪枝的通道。
4.1 探测(PROBING)
引入探测概念。我们引入一个名为“探测”(probing)的新概念,它利用现有的模型结构和隐藏状态来形成一个预测机制。具体来说,当输入隐藏状态到达块 $l$ 时,探测首先利用残差重要性选择关键的样本和词元,从$LN^l(X^l)$中形成探针$P^l$。$LN^l$表示块$l$处的层归一化。探针生成过程在下一节详述。然后,它运行块$l$中的中间转换,记为$T^l(P^l)$。值得注意的是,有效的探测消耗很少的计算资源,并且可以获得重要的中间状态信息来指导剪枝决策。
探测的上限。作为一种替代方案,我们可以通过使用当前批次中的所有输入隐藏状态来生成探针,即$P^l = LN^l(X^l)$,我们称这种方法为全批量探测(Full-Batch Probing)。通过利用整个批次而不减少维度N或S,全批量探测能捕获完整的中间隐藏状态信息,这可能导致最优的剪枝性能。然而,这种方法显著增加了计算资源需求和延迟。因此,全批量探测可作为我们方法的理论上限。我们对PP的目标是选择与全批量探测所选通道相似的剪枝通道。我们认为,PP和全批量探测之间共同剪枝通道的比例越高,表明模型性能越好,探针质量也越高。
探测为何有效。探测之所以有效,是因为并非所有样本和词元对模型输出的贡献都相等,并且LLM中的大幅度离群点对模型的行为有显著影响。在自然语言序列中,某些词元比其他词元携带更多的语义或句法重要性【索引62, Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453, 2023.】、【索引53, Mingjie Sun, Xinlei Chen, J Zico Kolter, and Zhuang Liu. Massive activations in large language models. arXiv preprint arXiv:2402.17762, 2024.】、【索引32, Ruikang Liu, Haoli Bai, Haokun Lin, Yuening Li, Han Gao, Zhengzhuo Xu, Lu Hou, Jun Yao, and Chun Yuan. Intactkv: Improving large language model quantization by keeping pivot tokens intact. arXiv preprint arXiv:2403.01241, 2024.】。通过基于残差重要性选择关键样本和词元,探针能够聚焦于批次内信息最丰富的部分。这种有针对性的方法使得探针能够捕获对决定哪些通道可以被剪枝最有影响力的基本中间隐藏状态信息。因此,尽管探针处理的是批次的缩减子集,它仍能提供足够的洞察力来指导剪枝决策,并可能以显著降低的计算成本实现与全批量探测相当的结果。
计算复杂性。探测仅需极小的计算复杂性。具体来说,对于一个每个注意力块和MLP块包含六个线性变换(Q/K/V/O和FC1/FC2)并结合了权重变换和注意力机制的LLM,其密集矩阵计算复杂度总计为$O(6NSC_{in}C_{out} + 2NS^2C_{in})$。对于探测,通过将批次大小减少到其原始大小的x%,序列长度减少到y%,复杂度降低到$O(4x\% \cdot y\% \cdot NSC_{in}C_{out} + 2x\% \cdot (y\%)^2 \cdot NS^2C_{in})$。
4.2 探针生成
基于残差重要性生成探针。PP通过测量残差重要性,即$X^l$在特定维度上的L2范数,来识别关键的样本和词元。一旦识别出来,这些关键的样本和词元将从$LN^l(X^l)$中被选中,以生成块l的探针,其中$LN^l$表示块l处的层归一化。我们不利用源自$LN^l(X^l)$的重要性来识别关键样本和词元,因为层归一化会显著改变输入隐藏状态。
计算重要性分数。为了测量目标维度上的残差重要性,我们计算$X^l$在非目标维度上的L2范数。目标维度可以是批次维度或序列维度。
$$\mathbf{U}_{i}^{l, \text{batch}} = \|\mathbf{X}_{i, :, :, :}^{l}\|_{2}, \quad \text{for } i = 1, \dots, N,$$ $$\mathbf{U}_{j}^{l, \text { seq }}=\left\|\mathbf{X}_{:, j,:}^{l}\right\|_{2}, \quad \text { for } j=1, \ldots, S .$$排序与索引。计算完重要性分数后,我们按降序对其进行排序,并将索引存储在I中:
$$\begin{aligned} \begin{aligned} \mathbb{I}^{l, \text { batch }} & =\operatorname{argsort}\left(-\mathbf{U}^{l, \text { batch }}\right), \\ \mathbb{I}^{l, \text { seq }} & =\operatorname{argsort}\left(-\mathbf{U}^{l, \text { seq }}\right) . \end{aligned} \end{aligned}$$生成探针。使用排序后的索引,我们接着通过从层归一化后的$X^l$中选择前x%的样本或y%的词元来生成探针:
$$\begin{aligned} \mathbf{P}^{l}=\left\{\begin{array}{ll}\mathrm{LN}^{l}\left(\mathbf{X}^{l}\right)_{\mathbb{I}_{: x \%}^{l, \text { batch }},:,:} & \text { if selecting top } x \% \text { of samples, } \\ \mathrm{LN}^{l}\left(\mathbf{X}^{l}\right)_{:, \mathbb{I}_{: y \%}^{l, \text { seq }},:} & \text { if selecting top } y \% \text { of tokens. }\end{array}\right. \end{aligned}$$这种方法确保探针由按其重要性分数排名的最重要的样本和词元组成。
序列化剪枝方法。PP采用一种序列化方法来有效剪枝序列和批次维度。首先,根据公式(5)和(7),利用当前批次内的序列分布,从残差$X^l$中选择前y%的词元:$X^l_{:,I^{l,seq},:}$。接着,我们应用这个缩减后的序列集,使用公式(4)和(6)来确定前x%的样本,得到索引$I^{l,batch|seq}$。最后,我们选择关键的样本和词元用于探针生成,即$LN^l(X^l)_{I^{l,batch|seq},I^{l,seq},:}$。
算法1:Probe Pruning
输入:一个LLM模型M,包含L个块,每个块含有变换$F^l$、中间变换$T^l$和层归一化$LN^l$;校准数据集D;推理批次B。
系统执行:
1. 使用模型M运行校准数据集D,以获得历史状态V。
2. 对于第t个批次$B_t$:
1. 初始化批次$B_t$的隐藏状态$X_0$。
2. 对于每个块 $l = 0, ..., L-1$:
a. 利用残差重要性从$LN^l(X^l)$生成一个探针$P^l$(4.2节)。
b. 使用$P^l$执行块$l$的中间变换,并收集得到的中间隐藏状态,记为$X^{l,int,probe} = T^l(P^l)$。
c. 使用重要性缩放融合将探测状态$X'^{l,int,probe}$与历史状态整合(4.3节)。
d. 从整合后的状态计算PPsp剪枝度量(4.4节),并相应地剪枝权重通道。
e. 使用剪枝后的权重$\tilde{W}^l$对$X^l$执行完整推理,记为$\tilde{F}^l(X^l)$。
4.3 历史信息指导的剪枝与重要性缩放融合
融合的动机。探针的中间隐藏状态,由下式给出:
$$\mathbf{X}^{l, \text { int, probe }}=\mathcal{T}^{l}\left(\mathbf{P}^{l}\right)$$包含了指导剪枝决策的关键信息。然而,当探针非常小——例如,当N和S减少到5%时——由于上下文有限,它们可能会导致不恰当的剪枝决策。为了解决这个问题并提升性能,我们引入了历史信息指导的剪枝与重要性缩放融合。
状态压缩。为简化符号,本节省略表示块编号的上标l。对于形状为(N, S, $C_{in}$)的中间隐藏状态$X_{int}$,以下关系成立:
$$\sum_{j=1}^{S} \sum_{i=1}^{N}(\mathbf{x}_{i,j,k}^{\text{int}})^2 = \sum_{j=1}^{S} \|\mathbf{x}_{:,j,k}^{\text{int}}\|_2^2 = \|\mathbf{x}_{:,:,k}^{\text{int}}\|_2^2$$我们在公式10的第一步中压缩批次维度以存储历史状态,因为内存限制不允许存储所有样本的中间隐藏状态。我们在公式10的第二步中对序列维度求和,以获得形状为$R^{C_{in}}$的张量,该张量用于计算剪枝度量(见4.4节)。
历史状态的初始化与计算。如同以往的研究【索引52, Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. arXiv preprint arXiv:2306.11695, 2023.】、【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】,我们使用模型M处理校准数据集D以获得初始历史状态。对于每个块,初始历史状态表示为$V|_{0} \in R^{S \times C_{in}}$,通过公式10的第一步计算以减少批次维度:$V|_{0} = ||X_{int:,j,k}||_2^2 = \sum_{i=1}^{N}(X_{int_{i,j,k}})^2$。类似地,为了减少探针的中间隐藏状态$X_{int,probe} \in R^{x\% \cdot N \times y\% \cdot S \times C_{in}}$的批次维度,我们计算探测状态为$||X_{int,probe:,j,k} ||_2^2 = \sum_{i=1}^{x\% \cdot N} (X_{int,probe_{i,j,k}})^2$。
重要性缩放融合。由于探测可以使用选定的词元执行,因此有必要对齐序列维度。我们定义$V_{probe} = V_{I_{seq}}$,其中$V_{probe} \in R^{y\% \cdot S \times C_{in}}$且$I_{seq:y\%}$(从公式7获得)表示前y%词元的索引。然后我们应用重要性缩放融合来获得整合状态:
$$\hat{\mathbf{X}}^{\text{int,probe}} = \frac{\|\mathbf{X}_{:,j,k}^{\text{int,probe}}\|_2^2}{\|\mathbf{X}_{:,j,k}^{\text{int,probe}}\|_2^2 + \mathbf{V}^{\text{probe}}} \cdot \|\mathbf{X}_{:,j,k}^{\text{int,probe}}\|_2^2 + \frac{\mathbf{V}^{\text{probe}}}{\|\mathbf{X}_{:,j,k}^{\text{int,probe}}\|_2^2 + \mathbf{V}^{\text{probe}}} \cdot \mathbf{V}^{\text{probe}}$$其中 $\hat{X}_{int,probe} \in R^{y\% \cdot S \times C_{in}}$。遵循公式10的第二步,我们将$\hat{X}_{int,probe}$在序列维度上求和,得到 $\sum_{j=1}^{y\% \cdot S} \hat{X}_{int,probe_{j,k}}$。请注意,若没有重要性缩放融合,$\sum_{j=1}^{y\% \cdot S} \hat{X}_{int,probe_{j,k}}$ 可以简化为 $||X_{int:,:,k}||_2^2$。然后,我们使用$W_{final}$和$\sum_{j=1}^{y\% \cdot S} \hat{X}_{int,probe_{j,k}}$作为$||X_{int:,:,k}||_2^2$的代理,根据公式(15)计算剪枝度量,并相应地剪枝权重通道。最后,我们在剩余的权重上运行完整推理。
通过完整推理更新历史状态。为了增强对中间隐藏状态属性的跟踪,我们在对所选权重通道C进行完整推理时,采用指数移动平均。更新公式表示为:
$$\mathbf{V}_{:,\mathbb{C}}|^t = \lambda \mathbf{V}_{:,\mathbb{C}}|^{t-1} + (1-\lambda) \|\tilde{\mathbf{X}}_{:,j,\mathbb{C}}^{\text{int}}\|_2^2|^t,$$V的值在第t个推理批次时更新,$\tilde{X}_{int}$表示完整推理期间的中间隐藏状态。我们在所有实现中始终将λ设置为0.99。
4.4 剪枝度量
提出PPsp度量。我们提出了一种新的结构化剪枝度量,名为PPsp,其中“sp”代表结构化剪枝。与现有度量相比,该度量能更有效地选择要剪枝的通道。我们将非结构化剪枝度量Wanda【索引52, Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. arXiv preprint arXiv:2306.11695, 2023.】调整到结构化剪枝场景中。PPsp引入了两项增强:(1) 我们保留了单个权重的内在重要性,表现为Wanda度量的平方值;(2) 我们计算MLP输入通道和注意力头的重要性分数的L2范数,以确定剪枝结构的整体重要性,而不是在剪枝结构之间求和这些分数。
Wanda度量回顾。我们为一般情况引入剪枝度量。为清晰起见,本节省略了表示块编号的上标l。在每个块中,给定形状为(N, S, $C_{in}$)的中间隐藏状态$X_{int}$(其中N和S分别代表批次和序列维度)和形状为($C_{out}$, $C_{in}$)的权重矩阵$W_{final}$,Wanda【索引52, Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. arXiv preprint arXiv:2306.11695, 2023.】将单个权重$W_{final_{i,k}}$的重要性定义为:
$$\mathbf{I}_{i,k} = |\boldsymbol{W}_{i,k}^{\text{final}}| \cdot ||\mathbf{X}_{:,:,k}^{\text{int}}||_2,$$其中$|\cdot|$表示绝对值操作,$||X_{int:,:,k}||_2$评估第k个特征在(N, S)维度上的L2范数。然后将这两个标量值相乘得到最终的重要性。
内在重要性与PPsp公式。然而,正如在Wanda【索引52, Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. arXiv preprint arXiv:2306.11695, 2023.】中推导的,单个权重的内在重要性定义为:
$$\mathbf{l}_{i,k} = (| \boldsymbol{W}_{i,k}^{\text{final}} | \cdot \| \mathbf{X}_{:,:,k}^{\text{int}} \|_2)^2 = | \boldsymbol{W}_{i,k}^{\text{final}} |^2 \cdot \| \mathbf{X}_{:,:,k}^{\text{int}} \|_2^2.$$Wanda在局部权重重要性排序中舍弃了公式(14)中的平方,因为$|W_{final_{i,k}}|$和$||X_{int:,:,j}||_2$的非负性不影响重要性的相对排序。然而,当涉及到结构化剪枝时,保持单个权重的内在重要性是至关重要的。因此,我们将Wanda度量进行平方,并计算输入通道在$C_{out}$维度上的欧几里得距离。公式如下:
$$\mathbf{I}_{k}=\left\|\left\{\left|\boldsymbol{W}_{i, k}^{\text {final }}\right|^{2} \cdot\left\|\mathbf{X}_{:,:, k}^{\text {int }}\right\|_{2}^{2}\right\}_{i=0}^{C_{\text {out }}}\right\|_{2},$$其中$\{\cdot\}$表示元素集合,且$I \in R^{C_{in}}$。
实验环境
表 1: LLM结构化剪枝方法的比较。我们的实现加载完整模型进行动态剪枝,而其他方法仅加载剪枝后的版本。
硬件配置:所有实验均在NVIDIA A100 GPU上进行。
模型与评估:
- 模型架构:在三个流行的模型家族上评估PP:LLaMA-2 7B/13B【索引54, Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.】,LLaMA-3 8B【索引38, Meta AI. LLaMA-3. https://llama.meta.com/llama3/, 2024.】,以及OPT-13B【索引66, Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.】。
-
评估任务与数据集:
- 文本生成(困惑度):使用WikiText2【索引37, Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. arXiv preprint arXiv:1609.07843, 2016.】数据集评估困惑度。
- 常识推理(准确率):在多个零样本任务上评估准确率,包括BoolQ, PIQA, HellaSwag, WinoGrande, ARC-Easy, ARC-Challenge, 和OpenbookQA。
-
实验参数:所有任务的批次大小设置为20。对于常识推理任务,序列长度设置为批次中最长样本的长度。对于文本生成任务,序列长度设置为1024。PP的默认探针大小设置为批次大小的5%和序列长度的50%,其FLOPs成本约占密集模型推理的1.5%。
软件配置与基线:
* 基线方法:将PP与四种先前的方法进行比较:Wanda-sp, FLAP【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】,LoRAPrune【索引65, Mingyang Zhang, Chunhua Shen, Zhen Yang, Linlin Ou, Xinyi Yu, Bohan Zhuang, et al. Pruning meets low-rank parameter-efficient fine-tuning. arXiv preprint arXiv:2305.18403, 2023.】,以及LLM-Pruner【索引36, Xinyin Ma, Gongfan Fang, and Xinchao Wang. Llm-pruner: On the structural pruning of large language models. Advances in neural information processing systems, 2023.】。同时,还将PP与其理论上限“全批量探测”进行比较。
* 校准数据集:所有方法均使用C4【索引44, Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 2020.】数据集作为校准数据集。PP, Wanda-sp, 和FLAP使用2,000个校准样本,而LoRAPrune和LLM-Pruner使用20,000个样本进行微调。
* 剪枝率:评估了20%和40%两种剪枝率。
实验结果
主要结果:
* 性能对比:如表2和表3所示,在文本生成和常识推理任务上,Probe Pruning (PP) 在各种模型和剪枝率下,其零样本性能(未经微调)始终优于所有基线方法。例如,在WikiText2数据集上,当剪枝率为40%时,PP在LLaMA-2-7B上实现了16.8的困惑度,LLaMA-2-13B上为11.3,OPT-13B上为26.7,均低于竞争方法。在LLaMA-3-8B上,PP在几乎所有任务中都超过了Wanda-sp和FLAP。
* 与理论上限对比:实验结果证实,全批量探测(Full-Batch Probing)作为PP的理论上限,在所有测试场景中表现最佳,支持了最初的假设。相比之下,PP仅使用约1.5%的密集模型推理FLOPs,便取得了与全批量探测相当的结果,证明了PP的有效性,并表明探针的中间隐藏状态能有效帮助识别处理不同批次的重要权重。
表 2: LLaMA-2-7B/13B 和 OPT-13B 在剪枝注意力和MLP块后(无微调)的零样本性能:PP 在几乎所有场景中都表现出优越的性能。箭头表示指标方向(↓:越低越好;↑:越高越好)。
表 3: LLaMA-3-8B 在剪枝MLP块后(无微调)的零样本性能:PP 在几乎所有场景中都表现出优越的性能。
共同剪枝通道的Jaccard指数:
* 实验内容:为了验证“PP与全批量探测的剪枝通道重叠度越高,模型性能越好”的假设,计算了PP与全批量探测之间,以及固定剪枝模型与全批量探测之间共同剪枝通道的Jaccard指数。
* 实验结果与结论:如图2所示,与固定剪枝模型相比,PP选择的剪枝通道在几乎所有注意力和MLP块中都与全批量探测选择的通道更为相似。这种更高的一致性有助于提升整体性能,并表明探针的中间隐藏状态能够有效指导剪枝决策。
图 2: 共同剪枝通道的Jaccard指数:比较PP与全批量探测,以及比较固定剪枝模型(无PP)与全批量探测在每个批次上的情况。
探针组合对性能的影响:
* 实验内容:通过改变探针的批次大小(1, 5, 20)和序列比例(0.05到1.0),研究不同探针大小对PP性能的影响。
* 实验结果与结论:如图3所示,即使是小探针(批次大小为1,序列比例为0.05)也能带来显著的性能提升。随着探针批次大小和序列比例的增加,性能进一步提升。特别地,序列比例从0.05增加到0.3时,性能改善最为迅速,这表明当序列信息超过某个相对于当前批次序列长度的阈值时,对剪枝的有效性会显著增强。
图 3: 40%剪枝率下不同探针组合的性能。
计算成本与推理速度:
* 计算开销:如表4所示,使用DeepSpeed包测量的结果显示,探测的计算开销约占密集模型推理FLOPs的1.5%,与理论分析一致。
* 推理加速:如表5所示,在40%剪枝率下,PP的推理速度与其他结构化剪枝基线相当,但性能更优。在注意力块和MLP块上,PP分别实现了1.46倍和1.30倍的加速。MLP块的轻微延迟可归因于权重提取等系统开销,这种差距在更大批次或更长序列下会缩小。
表 4: 密集模型推理与探测的FLOPs比较。
表 5: 在40%剪枝率下,WikiText2所有批次的推理运行时间分解。加速比通过将密集模型的推理运行时间除以各方法的推理运行时间计算得出。
性能运行时比率 (PRR):
* 实验内容:引入PRR指标,用于量化性能下降与运行时间减少之间的权衡,其值越小越好。公式如下:
* 实验结果与结论:如表5所示,PP的PRR值为37.37,远低于FLAP的95.65和Wanda-sp的106.48。这表明PP比FLAP和Wanda-sp分别高效2.56倍和2.85倍,即在减少运行时间的同时,性能下降率显著更低。
与微调基线的比较:
* 实验内容:将无需微调的PP与经过微调的基线LoRAPrune和LLM-Pruner进行比较。
* 实验结果与结论:如表6所示,PP在无需微调的情况下,性能始终优于或接近微调后的模型。在20%剪枝率下,PP在两类任务上均表现出色。在40%剪枝率下,PP的困惑度与微调基线相当,但在推理任务上的准确率显著更高。
表 6: PP与在LLaMA-2-7B/13B模型上进行微调的基线比较,剪枝了注意力和MLP层:PP在无微调的情况下,在各种场景中始终表现更优。
重要性缩放融合消融实验:
* 实验内容:将重要性缩放融合与三种固定融合比例(0.1, 0.5, 0.9)进行比较。
* 实验结果与结论:如图4所示,重要性缩放融合能够有效利用校准数据集的优势,同时最小化其带来的偏见,性能优于所有固定比例融合。
图 4: 重要性缩放融合研究。
剪枝度量消融实验:
* 实验内容:在固定剪枝模型上,比较本文提出的PPsp度量与Wanda-sp和FLAP的性能。
* 实验结果与结论:如表7所示,PPsp在各种剪枝场景下始终优于Wanda-sp和FLAP。在LLaMA-2-7B和OPT-13B模型上,PPsp显著降低了困惑度,取得了最低分。
表 7: 在由校准数据集剪枝的模型上,不同度量在WikiText2上的困惑度,显示PPsp在三种度量中表现最佳。
结论
本文提出了Probe Pruning (PP),一个新颖的在线动态剪枝框架,它使用一小部分关键的隐藏状态来运行模型,从而获取关键的剪枝信息以指导完整推理。值得注意的是,PP仅依赖于原始的模型结构和隐藏状态,不需要额外的神经网络模块或微调。此外,PP在几乎所有实验设置中都持续优于所有基线,包括那些经过微调的方法。未来的研究方向包括:优化探针生成和探测过程,将PP与先进的解码和对齐技术相结合,以及探索其对被投毒模型或对抗性提示的鲁棒性。
附录
A 实现细节
通用设置。对于所有方法,我们都保持前三层不变,这与Ma等人【索引36, Xinyin Ma, Gongfan Fang, and Xinchao Wang. Llm-pruner: On the structural pruning of large language models. Advances in neural information processing systems, 2023.】和Zhang等人【索引65, Mingyang Zhang, Chunhua Shen, Zhen Yang, Linlin Ou, Xinyi Yu, Bohan Zhuang, et al. Pruning meets low-rank parameter-efficient fine-tuning. arXiv preprint arXiv:2305.18403, 2023.】的做法类似,因为剪枝这些层的参数对模型有重大影响。剪枝率表示模型中所有注意力和MLP块的平均剪枝率。例如,当LLaMA-2-7B的目标剪枝率为20%和40%时,我们分别从第4到32个注意力和MLP块中剪枝22%和44%。
校准数据集。为了公平比较,我们使用了完全相同的C4【索引44, Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 2020.】数据集子集作为校准数据集。
PP, FLAP, 和 Wanda-sp。对于PP, FLAP【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】, 和 Wanda-sp【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】,我们使用2000个序列长度为1024的样本作为文本生成任务的校准数据集,以及2000个序列长度为512的样本作为常识推理任务的校准数据集。
LLM-Pruner。对于LLM-Pruner【索引36, Xinyin Ma, Gongfan Fang, and Xinchao Wang. Llm-pruner: On the structural pruning of large language models. Advances in neural information processing systems, 2023.】,我们遵循了其原始实现细节。我们使用10个随机选择的样本(每个截断到128个token)来构建重要性度量,并使用20000个序列长度为256的样本进行恢复性重训练。具体地,在恢复阶段,我们使用AdamW【索引21, Yang He, Yuhang Ding, Ping Liu, Linchao Zhu, Hanwang Zhang, and Yi Yang. Learning filter pruning criteria for deep convolutional neural networks acceleration. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 2020.】优化器,设置100个预热步骤,LoRA【索引23, Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.】秩r为8,学习率为$1 \times 10^{-4}$,批次大小为64,并进行2个周期的恢复性重训练。
LoRAPrune。对于LoRAPrune【索引65, Mingyang Zhang, Chunhua Shen, Zhen Yang, Linlin Ou, Xinyi Yu, Bohan Zhuang, et al. Pruning meets low-rank parameter-efficient fine-tuning. arXiv preprint arXiv:2305.18403, 2023.】,我们遵循了其原始实现细节。我们从C4数据集中随机抽取20000个句子,每个长度为512个token。训练超参数包括设置LoRA秩为8,学习率为$1 \times 10^{-4}$,批次大小为128,共进行2个训练周期。在融合剪枝与微调时,我们采用立方稀疏度调度器【索引47, Victor Sanh, Thomas Wolf, and Alexander Rush. Movement pruning: Adaptive sparsity by fine-tuning. Advances in neural information processing systems, 2020.】迭代剪枝模型,直到达到目标稀疏度。如果只进行剪枝而不进行微调(以匹配其他一次性剪枝方法),我们使用10个选定的序列长度为512的样本来估计重要性,并进行一次性剪枝,不更新权重。所有训练过程都使用带有线性学习率衰减的AdamW优化器进行优化。
B 消融研究
B.1 校准数据集
校准数据集引入偏见。我们展示了FLAP【索引1, Yongqi An, Xu Zhao, Tao Yu, Ming Tang, and Jinqiao Wang. Fluctuation-based adaptive structured pruning for large language models. Proceedings of the AAAI Conference on Artificial Intelligence, 2024.】使用不同校准数据集在WikiText2困惑度测试上的性能,如表8所示。结果表明,仅依赖校准数据集的结构化剪枝方法可能会引入偏见。例如,当使用WikiText2验证集作为校准数据集时,FLAP在WikiText2上以40%的剪枝率达到了18.5的困惑度。然而,当使用C4数据集作为校准数据集时,困惑度恶化到38.9。
B.2 手动平方注意力度量
平方注意力度量的影响。在FLAP的实现中(https://github.com/CASIA-IVA-Lab/FLAP),注意力度量被手动平方。表9展示了在FLAP中手动平方注意力度量与不平方的影响。研究发现,平方该度量会导致对注意力块的剪枝更为保守。例如,对于LLaMA-2-7B,在20%的总体剪枝率下,未平方的FLAP方法剪枝了17.8%的注意力权重,而实施平方后仅剪枝了0.6%。这意味着平方显著减轻了对注意力的剪枝 。
保守剪枝提升性能。此外,对注意力块进行较不激进的剪枝与更好的模型性能相关。具体来说,在LLaMA-2-7B上,在40%的总体剪枝率下,未平方的FLAP剪枝了35.4%的注意力权重,导致WikiText2的困惑度为38.9。相反,平方后的FLAP以17.6%的较低速率剪枝,实现了29.1的较低困惑度。这些结果表明,对注意力块进行更保守的剪枝可以提升模型性能。
表 8: 在 LLaMA-2-7B 和 LLaMA-2-13B 模型上,FLAP 在不同剪枝率和校准数据集下的性能比较。
表 9: 在LLaMA-2-7B和LLaMA-2-13B模型上,比较FLAP在平方和不平方注意力度量(同时保持MLP度量始终不平方)下的表现。
B.3 词元位置上的批次依赖离群点
批次依赖离群点的存在。在正文第1节中,我们指出大型语言模型(LLM)表现出批次依赖的离群点,这需要在线动态剪枝来处理这些动态离群点。图5展示了在第10层和第20层,输入隐藏状态在批次和特征维度上每个词元位置计算出的L2范数。结果显示,每个词元位置都存在批次依赖的离群点,这与现有工作【索引32, Ruikang Liu, Haoli Bai, Haokun Lin, Yuening Li, Han Gao, Zhengzhuo Xu, Lu Hou, Jun Yao, and Chun Yuan. Intactkv: Improving large language model quantization by keeping pivot tokens intact. arXiv preprint arXiv:2403.01241, 2024.】、【索引53, Mingjie Sun, Xinlei Chen, J Zico Kolter, and Zhuang Liu. Massive activations in large language models. arXiv preprint arXiv:2402.17762, 2024.】的观察结果一致。
图 5: 在 LLaMA-2-7B 模型上使用 WikiText2 数据集进行的实验中,第10层和第20层输入隐藏状态在批次和特征维度上每个词元位置的L2范数可视化。
B.4 残差重要性
残差重要性的有效性。在正文4.2节中,我们提到层归一化会显著改变输入隐藏状态,从而使其重要性无法准确识别关键样本和词元。为了验证这一观察,表10比较了基于残差重要性识别关键样本和词元与基于层归一化输入隐藏状态重要性进行识别(PP不使用残差重要性)的效果。实验结果证明了残差重要性的有效性。
表 10: 残差重要性对LLaMA-2-7B探针生成的影响。应用残差重要性可以带来更好的探针性能。
B.5 历史状态整合
历史状态整合提升性能。表11中的结果说明了将历史状态纳入剪枝决策过程如何增强了PP的有效性。具体来说,当PP利用历史状态时,与仅利用探测状态的场景(PP w/o historical states)相比,所有模型和剪枝率下的性能指标都有一致的提升。例如,在40%的剪枝率下,使用从5%批次和50%序列生成的探针,带有历史状态的PP将WikiText2上的困惑度从20.1降低到16.9,并将平均准确率从51.2%提高到56.6%,相比于仅使用当前探测状态而无历史数据。
表 11: 在LLaMA-2-7B上,不同探针组合下整合历史状态的性能。历史状态可以增强PP的性能。
B.6 并行探测
并行探测的可行性。我们验证了在较早剪枝块的实际计算过程中并行运行探针的可能性。我们在下方的表12中展示了结果。其中,PP-Parallel代表一种方法,即当实际计算在较早的剪枝块上进行时,我们从这些块的残差中生成探针并执行探测。PP代表论文正文中使用的默认设置。结果表明,我们仍然可以获得性能提升,并取得与PP相当的结果。例如,在40%的剪枝率下,PP-Parallel在WikiText2上实现了17.9的困惑度,接近PP的水平,远低于FLAP的38.9。此外,PP-Parallel在ARC-e上实现了61.4%的准确率,接近PP的水平,远高于FLAP的52.5%。然而,我们在此只是展示了进一步提升PP推理速度的可行性;实际的并行化依赖于硬件和具体实现。
表 12: LLaMA-2-7B在剪枝注意力和MLP块后(无微调)的零样本性能。
B.7 剪枝注意力与MLP之间的差异
注意力与MLP的剪枝敏感性差异。我们发现注意力和MLP层的剪枝率应独立考虑,因为它们可能在不同的点达到饱和。表13显示了剪枝注意力和MLP之间性能的明显差异,尤其是在剪枝率增加时。虽然较低的剪枝率(20%)对两个组件的性能影响相似,但较高的比率(40%,60%)表明注意力头达到了饱和,特别是在像WikiText2和HellaSwag这样要求较高的任务中。例如,当注意力剪枝率为60%时,WikiText2上的性能急剧下降到33.5,而当MLP以相同水平剪枝时为21.1。同样,剪枝注意力时HellaSwag上的性能显著下降到42.2,而剪枝MLP时为55.3。此外,考虑到每个模块的实际FLOPs,性能差距更大,这强调了需要一种策略性的方法来剪枝神经网络组件。
表 13: LLaMA-2-7B在不同比率下剪枝注意力头与MLP的性能比较,比较了仅剪枝注意力头或仅剪枝MLP的效果。
C 附加实验结果
本节展示了每个任务的详细实验结果。未经微调的性能如表14、15、16和17所示。PP与经过微调的基线的比较见表18和19。PP在几乎所有实验设置中都持续优于所有基线,包括那些经过微调的基线。
表 14: LLaMA-2-7B在剪枝注意力和MLP块后(无微调)的零样本性能:PP在几乎所有场景中都表现出优越的性能。
表 15: LLaMA-2-13B在剪枝注意力和MLP块后(无微调)的零样本性能:PP在几乎所有场景中都表现出优越的性能。
表 16: OPT-13B在剪枝注意力和MLP块后(无微调)的零样本性能:PP在几乎所有场景中都表现出优越的性能。
表 17: 对LLaMA-3-8B进行MLP剪枝的零样本性能。PP在几乎所有测试场景中都持续表现出优越的性能。
表 18: 在LLaMA-2-7B模型上,PP与经过微调的基线在剪枝注意力和MLP层后的比较:PP在无微调的情况下,在各种场景中始终表现更优。
表 19: 在LLaMA-2-13B模型上,PP与经过微调的基线在剪枝注意力和MLP层后的比较:PP在无微调的情况下,在各种场景中始终表现更优。
💬 评论讨论
欢迎在这里分享您的想法和见解!