SLA: BEYOND SPARSITY IN DIFFUSION TRANSFORMERS VIA FINE-TUNABLE SPARSE–LINEAR ATTENTION
SLA: BEYOND SPARSITY IN DIFFUSION TRANSFORMERS VIA FINE-TUNABLE SPARSE–LINEAR ATTENTION
SLA:通过可微调的稀疏-线性注意力超越扩散 Transformer 中的稀疏性
Jintao Zhang, Haoxu Wang, Kai Jiang, Shuo Yang, Kaiwen Zheng, Haocheng Xi, Ziteng Wang, Hongzhou Zhu, Min Zhao, Ion Stoica, Joseph E. Gonzalez, Jun Zhu, Jianfei Chen
清华大学,加州大学伯克利分校
{zhang-jt24@mails., jianfeic@, dcszj@}http://tsinghua.edu.cn
A1 主要贡献
核心问题: 在扩散 Transformer(DiT)模型中,尤其是在视频生成领域,由于序列长度很长,注意力机制的二次方复杂度成为主要的延迟瓶颈。
研究目标: 降低 DiT 模型中注意力机制的计算成本,从而提高其效率。
现有方法的局限性:
1. 线性注意力 (L1): 在实践中,尤其是在视频扩散模型上,线性注意力方法通常会严重降低视频质量。相关研究很少,且主要局限于图像生成。
2. 稀疏注意力 (L2): 现有的稀疏注意力方法难以实现非常高的稀疏度。在序列长度低于 50K 时,通常只能达到 40-60% 的稀疏度。即使在更长的序列上(100K-300K),稀疏度也仅能达到 80-85%。
关键观察: 本文发现扩散 Transformer 中的注意力权重可以被分解为两个矩阵:一小部分具有高秩的大权重,和大部分具有极低秩的剩余权重。这一观察解释了为什么单独使用稀疏注意力或线性注意力都无法取得满意效果,并自然地引出了将稀疏加速应用于前者、低秩加速应用于后者的策略。
本文方法 (SLA): 基于上述观察,本文提出了SLA(稀疏-线性注意力),一种可训练的混合稀疏和线性注意力方法,用于加速DiT模型。
* 权重分类: SLA 将注意力权重块动态地分为三类:关键(critical)、边缘(marginal)和可忽略(negligible)。
* 混合计算:
* 对关键权重块,使用标准的 FlashAttention 进行精确计算。
* 对可忽略的权重块,直接跳过计算。
* 对边缘权重块,采用线性注意力进行处理。
* 设计优势: 这种设计使得稀疏度能够大幅提升(例如,从70%提升到95%),同时保持准确性。由于线性注意力的计算成本(在视频生成模型中不到全注意力的0.5%)几乎可以忽略不计,SLA 比单独的稀疏注意力快数倍。
* 实现: 本文为SLA实现了高效的前向和后向传播过程。
核心成果:
* SLA 在不降低视频生成质量的前提下,将注意力计算量减少了95%,即使在30K的中等序列长度下也能实现。
* 本文实现的SLA在注意力核计算上实现了13.7倍的加速,在Wan2.1-1.3B模型的视频生成任务中实现了2.2倍的端到端加速,使得注意力时间几乎可以忽略不计。
* 在生成质量和效率方面,SLA 均优于基线方法。
A3 背景知识/关键观察/设计原则
2.1 块稀疏注意力
标准注意力计算: 对于给定的查询 $Q$、键 $K$ 和值 $V \in \mathbb{R}^{N \times d}$,标准注意力机制计算得分矩阵 $S = QK^\top / \sqrt{d}$ 和注意力权重 $P = \text{Softmax}(S)$,最终得到输出 $O = PV$。当 $N$ 很大时,这种计算方式效率低下,需要 $O(N^2d)$ 的操作。稀疏注意力的思想是通过对注意力权重应用一个掩码 $M \in \{0, 1\}^{N \times N}$ 来减少计算量:$P \leftarrow P \odot M$,其中 $\odot$ 是逐元素乘积。一种常见的策略是选择一个阈值 $\tau$,如果 $P_{ij} > \tau$,则设置 $M_{ij} = 1$。对于 $M_{ij}=0$ 的条目,可以跳过 $Q_i K_j^\top$ 和 $P_{ij}V_j$ 的乘法计算,其中 $Q_i = Q[i, :]$, $K_j = K[j, :]$, $V_j = V[j, :]$。
块级稀疏注意力的实际实现: 逐元素的稀疏注意力在现代 GPU 上效率不高。像 FlashAttention 【【5】Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).】这样的实用实现是在块级别上操作的。具体来说,稀疏 FlashAttention 首先将 $Q, K, V, S, P, M$ 划分为块 $\{Q_i\}, \{K_j\}, \{V_j\}, \{S_{ij}\}, \{P_{ij}\}, \{M_{ij}\}$,其中 $Q_i \in \mathbb{R}^{b_q \times d}, K_j, V_j \in \mathbb{R}^{b_{kv} \times d}$,以及 $S_{ij}, P_{ij}, M_{ij} \in \mathbb{R}^{b_q \times b_{kv}}$。每个块掩码 $M_{ij}$ 的所有元素要么是0要么是1,如果 $M_{ij}[:, :] = 0$,我们就跳过 $Q_i K_j^\top$ 和 $P_{ij}V_j$ 的计算。
2.2 线性注意力
线性注意力的核心思想: 线性注意力方法将标准注意力的复杂度从 $O(N^2d)$ 降低到 $O(Nd^2)$。一个关键思想是通过引入一个应用于 $Q$ 和 $K$ 的特征映射 $\phi(\cdot)$ 来解耦 softmax 操作。具体来说,它用 $\text{rowsum}(\phi(Q)\phi(K)^\top)$ 替换标准注意力中的注意力权重。$\phi(Q)\phi(K)^\top$ 这种重新表述使得矩阵乘法的顺序可以重排:它不是显式地计算注意力权重,而是首先计算 $\phi(K)^\top V$,然后将这个中间结果应用于 $\phi(Q)$:
$H = \phi(K)^\top V, Z = \text{rowsum}(\phi(K)^\top) \in \mathbb{R}^{d \times 1}, O = \frac{\phi(Q)H}{\partial(Q)Z}.$
映射 $\phi(\cdot)$ 通常是一个激活函数(例如,ELU + 1 或 ReLU 【【3】Clevert, Djork-Arne, Thomas Unterthiner, and Sepp Hochreiter. "Fast and accurate deep network learning by exponential linear units (elus)." Proceedings of the International Conference on Learning Representations (ICLR), 2016.】【【41】Xavier, Glorot, Antoine Bordes, and Yoshua Bengio. "Deep sparse rectifier neural networks." Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics, pp. 315-323, 2011.】)。这种形式避免了显式构造 $N \times N$ 的矩阵 $S$ 和 $P$,从而实现了线性的计算复杂度。
3.1 SLA的动机
注意力权重分布的观察: 由于 softmax 算子,注意力权重 $P$ 的值在 $[0, 1]$ 区间内,并且每行之和为1。此外,由于 softmax 中的指数缩放, $P$ 中只有一小部分条目相对较大,而绝大多数都接近于零。图1(左)展示了从 Wan2.1 模型【【38】Wan, Team, et al. "Wan: Open and advanced large-scale video generative models." arXiv preprint arXiv:2503.20314 (2025).】中采样的典型注意力权重 $P$ 的分布。我们强调两个关键观察:(1) 只有大约8.1%的权重值大于平均值 $1/N$。(2) 相当大比例的权重值非常小。在我们的案例中,大约45%的权重值低于 $1/(100N)$。如图1(右)所示,在稀疏注意力中跳过这最小的45%的权重(即,将 $M$ 中相应条目设为0),与全注意力输出相比,引入的相对 L1 误差小于3%。相比之下,仅保留最大的8.1%的权重(稀疏度=92%)会导致误差急剧增加,达到约33%。这解释了为什么现有的稀疏注意力方法难以实现超过90%的稀疏度。
图 1: 左图显示了从 Wan2.1 模型中采样的典型注意力权重分布。右图显示了不同稀疏度下稀疏注意力的准确性。
中间值的处理困境与SLA的启发: 介于 $1/(100N)$ 和 $1/N$ 之间的中间值(图1中的黄色柱)带来了一个两难选择:忽略它们会引入显著的精度损失,而用全注意力计算它们则会导致稀疏度大幅下降。幸运的是,这些值的重要性远低于最大的那些值。这一发现启发我们将注意力权重分为三类:关键(critical)、边缘(marginal)和可忽略(negligible)。对于关键权重,我们使用稀疏 FlashAttention 计算输出,因为它们主导了注意力分布;对于可忽略的权重,我们跳过计算;对于边缘权重,我们采用线性注意力方法将计算复杂度降低到 $O(Nd^2)$,并增强稀疏注意力的性能。
图 2: 在 Wan2.1 上使用全注意力、线性注意力、稀疏注意力和 SLA 进行微调的视频生成示例。SLA 可以实现 95% 的高稀疏度和无损的视频质量。
实验结果验证: 在图2中,我们展示了由 Wan2.1 模型通过不同注意力方法微调后生成的一些视频:仅使用线性注意力、90%稀疏度的稀疏注意力,以及95%稀疏度的SLA。值得注意的是,SLA在95%稀疏度下的计算复杂度几乎是90%稀疏注意力的一半,因为线性注意力的成本几乎可以忽略不计。例如,在Wan2.1模型中,线性注意力占全注意力成本的不到0.5%。这些实验结果表明,SLA在视频质量上显著优于其他两种方法。
3.2 分离注意力权重:稀疏的少数,低秩的多数
关键观察: 如图3所示,全注意力权重可以被分解为两个部分:(1)一个小的子集(< 10%),其秩与全注意力相当;以及(2)一个大的子集(> 90%),其秩非常低。由于加速注意力的方法主要关注稀疏性或低秩结构,这自然地提出了一种优雅的策略:对第一部分应用稀疏注意力,对第二部分应用低秩近似。
图 3: 注意力权重的分解。我们从 Wan2.1 模型中采样注意力权重:左图显示了全部权重,中图是前 8% 的权重,右图是后 92% 的权重。
线性注意力失败的原因与新思路: 之前线性注意力的失败很大程度上是由于全注意力权重的高秩特性【【8】Fan, Qihang, Huaibo Huang, and Ran He. "Breaking the low-rank dilemma of linear attention." CVPR (2025).】,而线性注意力的秩最多被限制为 $d$。图3(左)使用稳定秩【【23】Rudelson, Mark, and Roman Vershynin. "Sampling from large matrices: an approach through geometric functional analysis." (2006).】的概念说明了这一点。我们观察到,在移除注意力权重 $P$ 中的最大值后,剩余的矩阵变得秩极低。这启发我们使用稀疏掩码 $M$ 对 $P$ 进行分解:
$$P = \underbrace{P \odot M}_{\text{sparse component}} + \underbrace{P \odot (1-M)}_{\text{low-rank component}}.$$
低秩部分的处理: 由于线性注意力本质上是注意力的一种低秩版本,这为我们提供了一种可能性,即用线性注意力来替代低秩分量 $P \odot (1 - M)$。
A2 方法细节
4. SLA
SLA在一个统一的框架内有效地集成了稀疏注意力和线性注意力,使它们能够相互补充。特别地,我们将这两种注意力融合到一个高效的GPU核中。在本节中,我们介绍SLA的稀疏和线性注意力组件。
压缩注意力权重的预测: SLA首先预测一个压缩的注意力权重矩阵 $P_c \in \mathbb{R}^{N/b_q \times N/b_{kv}}$:
$P_c = \text{Softmax}(\text{pool}(Q)\text{pool}(K)^T/\sqrt{d}).$
其中 pool(·) 是沿 token 维度的平均池化操作。对于 $P_c$ 的每个元素,我们将其分为三类,并将结果记录在一个压缩掩码 $M_c \in \mathbb{R}^{N/b_q \times N/b_{kv}}$ 中。
权重分类: 具体来说,前 $k_h\%$ 的位置被标记为关键(critical,标记为1),后 $k_l\%$ 的位置被标记为可忽略(negligible,标记为-1),其余位置被标记为边缘(marginal,标记为0)。形式上:
$M[i, j]= [1 \text{ (top } k\text{\%), } -1 \text{ (bottom } k\text{\%), } 0 \text{ (otherwise)}]$
我们根据 $M_c$ 应用不同的方法。
4.1 SLA中的稀疏注意力
稀疏注意力的计算: 在掩码 $M_c$ 的指导下,使用稀疏 FlashAttention 来计算稀疏注意力的输出。对于每个 $Q$ 块 $Q_i$,我们遍历所有的 $K, V$ 块 $K_j, V_j$,其中 $j=0, \dots, N/b_{kv}$。当 $M_c[i, j] = 1$ 时,我们执行以下操作:
$S_i = Q K^T / \sqrt{d_k}, P_j = \text{Online-Softmax}(S_i), O_i = O'_i + P_j V_j$
这里的 OnlineSoftmax(·) 算子【【17】Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).】以分块方式计算矩阵的 softmax(实现见算法1的10-11行)。每个 $O_i^s$ 的初始值设为零。算法1描述了稀疏注意力组件的前向计算,我们将其最终输出表示为 $O^s$。
4.2 SLA中的线性注意力
线性注意力的应用: 受低秩近似思想的启发,我们将方程1中的低秩分量 $P \odot (1 - M)$ 替换为第2.2节中介绍的线性注意力,如下所示:
$$\frac{\phi(Q)\phi(K)^{\top}}{\text{rowsum}(\phi(Q)\phi(K)^{\top})} \odot (1-M).$$
线性注意力的计算: 具体来说,$M_c$ 中值为0的条目决定了由线性注意力处理的块。对于每个查询块 $Q_i$,我们计算相应的线性注意力输出:
$$
\mathbf{u}_{i}=\sum_{j-M_{i, l}=0} \mathbf{c}\left(\mathbf{K}_{j}\right)^{T} \mathbf{v}_{j}, \mathbf{z}_{i}=\sum_{j-M_{i, l}=0} \operatorname{rowsum}\left(\phi\left(\mathbf{K}_{j}\right)^{T}\right) \lambda_{i}, \mathbf{o}_{i}=\frac{\phi\left(\mathbf{Q}_{i}, \mathbf{u}_{i}\right)}{\phi\left(\mathbf{Q}_{i}, \mathbf{z}_{i}\right)}
$$
这里,如第2.2节所述,$\phi(\cdot)$ 表示激活函数,而 $H_i \in \mathbb{R}^{d \times d}, Z_i \in \mathbb{R}^{d \times 1}$ 是类似于 $H$ 和 $Z$ 的中间结果。算法1描述了线性注意力组件的前向传播过程,该组件的最终输出表示为 $O^l$。
最终输出的组合: 最后,SLA的整体注意力输出定义为:
$O = O^s + \operatorname{Proj}(O^l).$
其中 Proj 是一个可学习的线性变换 $\mathbb{R}^d \to \mathbb{R}^d$。将此投影应用于 $O^l$ 有助于减少 softmax 注意力和线性注意力之间的分布不匹配。其计算成本为 $O(Nd^2)$,与计算 $O^l$ 相同,并且与全注意力的 $O(N^2d)$ 成本相比可以忽略不计。
设计洞察: SLA中的线性注意力并非旨在近似边缘注意力权重的输出,而是作为一种可学习的补偿,以增强稀疏注意力的效果。这是因为仅靠线性注意力本身难以近似全注意力的输出【【2】Choromanski, Krzysztof Marcin, et al. "Rethinking attention with performers." International Conference on Learning Representations, 2020.】【【20】Qin, Zhen, et al. "cosformer: Rethinking softmax in attention." International Conference on Learning Representations, 2022.】。因此,我们需要微调目标模型的参数,使其能够适应线性注意力的使用。
图 4: SLA 概览。左图阐述了高层思想:注意力权重被分为三类,并分配给不同复杂度的计算。右图展示了使用预测的压缩注意力权重的 SLA 详细前向算法。
5. 使用SLA进行微调
要将SLA应用于扩散模型,我们可以简单地将原始注意力替换为SLA,并在与预训练数据一致的数据集上对模型进行少量步骤的微调。本节我们描述SLA的前向和后向传播过程。此外,附录A.3中详述了一些针对SLA的额外效率优化。
5.1 前向传播
前向计算: 前向计算的公式已在第4节中介绍。SLA前向传播的完整算法在算法1中呈现。值得注意的是,我们为每对 $(K_j, V_j)$ 预先计算了 $h_j = \phi(K_j)^\top V_j$ 和 $z_j = \text{rowsum}(\phi(K_j)^\top)$(算法1第4行)。这样的设计确保了当 $M_c[i, j] = 0$ 时,相应的操作仅涉及一次矩阵加法(算法1第13行),从而提高了效率。为简化表示,下文中我们记 $Q_\phi = \phi(Q)$ 和 $K_\phi = \phi(K)$。
算法1:SLA的前向传播
1: 输入:矩阵 Q, K, V, Qϕ, Kϕ ∈ R^(N×d),块大小 bq, bkv,超参数 kh, kl。
2: 将 Q, Qϕ 划分为 Tm = N/bq 个块 {Qi} 和 {Qϕi};
3: 将 K, V, Kϕ 划分为 Tn = N/bkv 个块 {Ki}, {Vi} 和 {Kϕi};
4: h = {hj} = {(Kϕj)⊤Vj};z = {zj} = {rowsum((Kϕj)⊤)}; // 为线性注意力预计算
5: Pc = Softmax(pool(Q)pool(K)⊤/sqrt(d));初始化 Mc = 0;
6: 如果 Pc[i, j] ∈ TopK(Pc[i, :], kh) 则 Mc[i, j] = 1;如果 Pc[i, j] ∈ BottomK(Pc[i, :], kl) 则 Mc[i, j] = 0;
7: for i = 1 to Tm do
8: for j = 1 to Tn do
9: if Mc[i, j] = 1 then
10: Sij = QiK⊤j / sqrt(d);mij = max(mi,j-1, rowmax(Sij));Pij = exp(Sij - mij);
11: lij = e^(mi,j-1 - mij) * li,j-1 + rowsum(Pij);Osij = diag(e^(mi,j-1 - mij)) * Osi,j-1 + Pij * Vj;
12: else if Mc[i, j] = 0 then
13: Hi ← Hi + hj;Zi ← Zi + zj;
14: end if
15: end for
16: Osi = diag(l_Tni)^(-1) * Osi,Tn;Oli = Qϕi * Hi / (Qϕi * Zi);Li = mi,Tn + log(li,Tn);
17: end for
18: return Os = {Osi}, Ol = {Oli};
算法2:SLA的后向传播
1: 输入:来自前向传播的 Q, K, V, Qϕ, Kϕ, Mc, {Li}, {Hi}, {Zi}, Os, Ol,以及 dOs, dOl ∈ R^(N×d)。
2: Ds = rowsum(dOs ⊙ Os), Dl = rowsum(dOl ⊙ Ol),将 Ds, Dl 划分为 Tm 个块 {Dsi}, {Dli};
3: for i = 1 to Tm do
4: dHi = (Qϕ/(QϕZi))⊤ * dOli;dZi = -(Qϕ/(QϕZi))⊤ * Dli;
5: dQϕi = (dOli * (Hi)⊤ - Dli * (Zi)⊤) / (Qϕi * Zi);
6: end for
7: for j = 1 to Tn do
8: 初始化 dH = 0, dZ = 0;
9: for i = 1 to Tm do
10: if Mc[i, j] = 1 then
11: Sij = QiK⊤j / d;Pij = exp(Sij - Li);dVj ← dVj + P⊤ij * dOsi;dPij = dOsij * V⊤j;
12: dSij = Pij ⊙ (dPij - Dsi);dQi ← dQi + dSij * Kj;dKj ← dKj + dS⊤ij * Qi;
13: else if Mc[i, j] = 0 then
14: dH ← dH + dHi;dZ ← dZ + dZi;
15: end if
16: end for
17: dKϕj = Vj * (dH)⊤ + (dZ)⊤;dVj = Kϕj * dH;
18: end for
19: return dQ = {dQi}, dK = {dKi}, dV = {dVi}, dQϕ = {dQϕi}, dKϕ = {dKϕi};
5.2 后向传播
后向传播的融合: 后向传播计算稀疏和线性两个组件的梯度,这两个过程也被融合到一个单一的GPU核中以提高效率。
梯度表示法: 前缀 d 或 d 用于表示梯度,例如,dOs, dOl 分别是 Os, Ol 相对于某个损失函数 ℓ 的梯度。
稀疏注意力的梯度: 输出梯度 dOs 被反向传播以计算 dQ, dK, 和 dV,其推导过程与 FlashAttention 【【5】Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).】相同。给定 dOs,后向传播过程如下:
$dP_{ij} = dO_i^{\text{T}}V, \quad D_j^{\prime} = \text{rowsum}(dO_i^{\text{T}} \odot O_j), \quad dS_{ij} = \mathbf{P}_{ij} \odot (dP_{ij} - D_j^{\prime}),$
$dQ_i = dS_{ij}K_j, \quad dK_j = dS_{ij}Q_j, \quad dV_j = P_{ij}^{\text{T}} dO_j^{\prime}.$
这里,我们将 $D_i^s \in \mathbb{R}^{b_q \times 1}$ 视为列向量。
线性注意力的梯度: 梯度 dOl 通过链式法则产生 dQϕ, dKϕ, dV:
$$
\begin{array}{l}
d H_{i}=\left(\frac{Q_{i}^{g}}{Q_{i}^{g} Z_{i}}\right)^{\top} d O_{i}^{\prime}, \quad D_{i}^{\prime}=\operatorname{rowsum}\left(d O_{i}^{\prime} \odot O_{i}^{g}\right), \quad d Z_{i}=-\left(\frac{Q_{i}^{g}}{Q_{i}^{g} Z_{i}}\right)^{\top} D_{i}^{\prime} \\
d Q_{i}^{g}=\frac{\left(d O_{i}^{\prime}\left(H_{i}\right)^{\top}-D_{i}^{\prime} Z_{i}^{\top}\right)}{Q_{i}^{g} Z_{i}}, \quad d K_{j}^{g}=V_{j}\left(d H_{i}\right)^{\top}+\left(d Z_{i}\right)^{\top}, \quad d V_{j}=K_{j}^{g} d H_{i}
\end{array}
$$
这里,dKϕj 和 dVj 是通过聚合 dHi 和 dZi 得到的。与前向传播类似,每个 dHi 和 dZi 都被预先计算,使得剩余的计算简化为简单的矩阵加法。详细算法见算法2。
A4 实验
6.1 实验设置
-
模型与数据集:
- 视频生成: 使用 Wan2.1-1.3B 模型【【38】Wan, Team, et al. "Wan: Open and advanced large-scale video generative models." arXiv preprint arXiv:2503.20314 (2025).】。微调数据集为一个私有数据集,包含从Pexels【【19】Pexels. "Pexels: Free stock photos and videos."】和Common Crawl【【4】Common Crawl. "Common crawl."】等网站收集的20,000个5秒长、480p分辨率的视频。
- 图像生成: 使用 LightningDiT【【50】Yao, Jingfeng, Bin Yang, and Xinggang Wang. "Reconstruction vs. generation: Taming optimization dilemma in latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2025).】模型,在ImageNet【【7】Deng, Jia, et al. "Imagenet: A large-scale hierarchical image database." 2009 IEEE conference on computer vision and pattern recognition. Ieee, 2009.】数据集上进行,分辨率为 512×512。
-
基线方法:
- VSA 【【54】Zhang, Peiyuan, et al. "Vsa: Faster video diffusion with trainable sparse attention." arXiv preprint arXiv:2505.13389 (2025).】
- VMoBa 【【40】Wu, Jianzong, et al. "Vmoba: Mixture-of-block attention for video diffusion models." arXiv preprint arXiv:2506.23858 (2025).】
- 训练无关的 SparseAttn (Sparge-F) 【【53】Zhang, Jintao, et al. "Spargeattn: Accurate sparse attention accelerating any model inference." International Conference on Machine Learning (ICML) (2025).】
- 可训练的 SparseAttn (Sparge-T)
- 消融研究基线: Linear Only (仅线性注意力), Sparse Only (仅SLA的稀疏部分), L+S (直接将Linear Only和Sparse Only的输出相加)。
-
评估指标:
- 视频质量: VBench【【51】Zhang, Fan, et al. "Evaluation agent: Efficient and promptable evaluation framework for visual generative models." arXiv preprint arXiv:2412.09645 (2024).】的四个维度:成像质量(IQ)、整体一致性(OC)、美学质量(AQ)、主体一致性(SC)。以及 Vision Reward (VR)【【47】Xu, Jiazheng, et al. "Visionreward: Fine-grained multi-dimensional human preference learning for image and video generation." arXiv preprint arXiv:2412.21059 (2024).】、美学视频质量(VA)和技术视频质量(VT)【【16】Liu, Yaofang, et al. "Evalcrafter: Benchmarking and evaluating large video generation models." arXiv preprint arXiv:2310.11440 (2023).】。
- 图像质量: FID。
- 计算复杂度与效率: FLOPs (浮点运算次数),FLOPS (每秒浮点运算次数,用于注意力核效率,定义为 O(全注意力)/t),以及端到端生成延迟(秒)。
-
超参数:
- 训练批量大小为64,对 Wan2.1 模型微调2000步。
- 激活函数 $\phi$ 使用 softmax。
- $k_h\%$ 为 5%,$k_l\%$ 为 10%。
- 块大小 $b_q = b_{kv} = 64$。
6.2 有效性
视频生成质量与效率: 如表1所示,SLA 在保持与全注意力相当的视频质量的同时,实现了约19.3倍的效率提升。与基线方法相比,SLA在更高的稀疏度下仍然能实现更高的质量。例如,SLA在95%的稀疏度下,效率比85%稀疏度的基线方法高出约3倍,同时视频质量更优。
表 1: SLA 与其他基线方法的质量和效率比较。
6.3 效率
核速度与端到端延迟: 图6比较了SLA在RTX5090上与基线方法的核速度和端到端延迟。需要注意的是,即使VSA在89%稀疏度和VMoBa在85%稀疏度下,其生成质量已不如SLA,因此在更高稀疏度(如95%)下进行比较对它们并不公平。
* 前向传播: SLA 比 FlashAttention2 快13.7倍,比95%稀疏度的VSA快1.93倍,比95%稀疏度的VMoBa快3.36倍。
* 后向传播: SLA 比 FlashAttention2 快6.8倍,同样优于VSA和VMoBa。
* 端到端视频生成: SLA 将注意力延迟从97秒减少到11秒(减少了8.8倍),实现了2.2倍的端到端加速。
* 微调开销: 微调开销极小,仅为预训练成本的不到0.1%。
(a) 在 RTX5090 上的注意力 GPU 核速度比较。
(b) 端到端视频生成延迟比较。
图 6: 在 Wan2.1-1.3B 和 RTX5090 上,SLA 及基线方法的注意力核速度和端到端生成延迟。FlashAttn 指的是 FlashAttn2,是 RTX5090 上最快的可用版本。
6.4 消融研究
稀疏与线性注意力的融合: 如表2所示,为了评估SLA融合稀疏和线性注意力的有效性,我们将其与Sparse Only、Linear Only和S+L进行了比较。结果显示SLA在生成质量上达到最佳,并且比Sparse Only和S+L更高效,证实了我们融合策略的有效性。
线性注意力中的激活函数: 我们评估了softmax、elu+1和hedgehog作为激活函数 $\phi$ 的效果。表2的结果显示,softmax通常能提供更好的生成质量和效率。
参数 $k_h$ 的影响: 我们将 $k_h$ 从5%变化到20%。结果(表2)发现,$k_h=5\%$ 时生成的质量已接近全注意力。由于 $k_h=5\%$ 的计算量分别比 $k_h=10\%$ 和 $k_h=20\%$ 节省约一半和四分之三,因此它在效率和质量之间提供了最佳的权衡。
表 2: SLA 的消融实验结果。
6.5 可视化示例
视频生成效果: 图5和图7展示了使用SLA和基线方法微调的Wan2.1-1.3B生成的视频示例。即使在95%的稀疏度下,SLA生成的视频质量也与全注意力相当,而其他方法即使在低于90%的稀疏度下也表现出明显的失真。
图 5: 使用经SLA和基线方法微调的Wan2.1生成的视频示例。对于Linear Only, Sparse Only, Sparge-F, VSA, 和VMoBa,每个提示只显示一帧,因为它们的视频质量不足。完整的可见比较在附录A.1的图7中。
A5 结论
本文提出了SLA,一种统一了稀疏和线性注意力的可训练注意力机制,用于加速扩散Transformer。SLA根据重要性分配计算资源:对关键权重执行 $O(N^2)$ 的注意力计算,对边缘权重执行 $O(N)$ 的注意力计算,并跳过可忽略的计算。这种设计能够在保持有效性的同时大幅降低注意力成本。实验表明,仅需少量微调步骤,SLA就能有效地加速模型。具体来说,SLA在Wan2.1-1.3B上实现了注意力计算量减少20倍,GPU核速度提升13.7倍,端到端速度提升2.2倍,且视频生成质量没有下降。
A6 附录
A.1 更多可视化示例
图 7: 使用经SLA和基线方法微调的Wan2.1生成的完整视频示例。第一个提示是“一只北极熊在弹吉他”。第二个提示是“太平洋海岸,卡梅尔海边的海洋和波浪”。第三个提示是“一只鸟用树枝和树叶筑巢”。
视频示例分析: 图7展示了由Wan2.1模型经过SLA和其他注意力方法微调后生成的更多视频示例。我们可以发现,SLA即使在比基线方法稀疏度更高的情况下,也始终能实现更高的质量。
A.2 图像生成实验
实验设置: 如第6.1节所述,我们在LightningDiT【【50】Yao, Jingfeng, Bin Yang, and Xinggang Wang. "Reconstruction vs. generation: Taming optimization dilemma in latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2025).】的预训练任务上评估了SLA和基线方法。具体来说,我们使用了LightningDiT-1p0B/1模型,该模型包含10.3亿参数,在ImageNet【【7】Deng, Jia, et al. "Imagenet: A large-scale hierarchical image database." 2009 IEEE conference on computer vision and pattern recognition. Ieee, 2009.】数据集上以512×512的分辨率进行训练。
超参数: 所有超参数遵循【【50】Yao, Jingfeng, Bin Yang, and Xinggang Wang. "Reconstruction vs. generation: Taming optimization dilemma in latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2025).】的设置,除了我们将训练步数设为100,000,批量大小为128。对于SLA,我们将$\phi$设为softmax,并使用块大小$b_q = b_{kv} = 64$。
评估指标: 遵循【【50】Yao, Jingfeng, Bin Yang, and Xinggang Wang. "Reconstruction vs. generation: Taming optimization dilemma in latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2025).】,我们采用FID来评估图像质量,用FLOPs来衡量计算复杂度。
结果: 结果总结在表3中。在最高的稀疏度水平上,SLA在FID指标上超过了所有其他基线方法,甚至超过了全注意力,证实了SLA在保持图像质量方面的优势。这一发现与第6.2节中报道的在Wan2.1上的视频实验结果一致。
表 3: SLA 与其他基线在图像生成任务上的质量和效率比较。
A.3 额外的效率优化
优化策略: 由于SLA的效率在很大程度上取决于稀疏模式,我们引入了几种针对不同稀疏度水平的补充优化。这些优化带来了显著的计算效率提升:
-
查找表 (Lookup table): 当 $M_c$ 非常稀疏时(例如,稀疏度 > 90%),扫描整行或整列来读取掩码值会导致显著的内存开销。为了缓解这个问题,我们预处理每行和每列的非零位置,并将它们存储在查找表中。在计算过程中,只访问查找表,从而大幅减少了内存流量。
-
线性注意力的预聚合 (Pre-aggregation for linear attention): 尽管算法1中的第13行和算法2中的第14行只需要一次矩阵加法,但当 $M_c$ 中许多条目为0时(例如,> 90%),重复执行此类加法会产生高昂的开销。为了解决这个问题,我们预先计算行/列的和 $\sum_j h_j$ 和 $\sum_j z_j$,然后减去对应于 $M_c[i, j] \ne 0$ 的贡献。通过这种方式,90%的加法可以被替换为仅10%的减法。
-
四俄国人方法 (Method of Four Russians): 当 $M_c[i, j] = 0$ 的块数既不是非常小也不是非常大时(例如,约50%),我们为算法1的第13行和算法2的第14行提供了一种高效的实现。具体来说,我们采用了四俄国人方法【【1】Arlazarov, V. L., et al. "On economical construction of the transitive closure of an oriented graph." Soviet Mathematics Doklady 11 (1970): 1209-1210.】。其核心思想是将 $h_j$ 和 $z_j$ 分组成连续 $g$ 个块的段,并预计算每个段内所有 $2^g$ 种可能的子集和。在前向传播过程中,任何 $g$ 个块的子集和都可以通过单次查找获得,而不是即时求和。该方案理论上可以将计算量减少 $1/g$。
💬 评论讨论
欢迎在这里分享您的想法和见解!