SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization
SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization
文章标题:SageAttention2: 通过彻底的异常值平滑和每线程INT4量化实现高效注意力机制
作者/机构:Jintao Zhang, Haofeng Huang, Pengle Zhang, Jia Wei, Jun Zhu, Jianfei Chen
A1 主要贡献
本文针对现有注意力计算加速方法的不足,提出了 SageAttention2,旨在通过使用速度更快的4位矩阵乘法(INT4 Matmul)并结合多种精度增强技术,进一步提升注意力计算的效率和精度。
核心问题与动机:
先前的工作 SageAttention 通过将 QK⊤ 量化为 INT8、对 K 进行平滑处理以及在特定 GPU 上使用 FP16 累加器来加速 PV 计算,实现了相较于 FlashAttention2 约2倍的加速。然而,SageAttention 存在两个主要弱点:
1. (W1) INT8 矩阵乘法的速度只有 INT4 的一半,限制了性能的进一步提升。
2. (W2) 使用 FP16 累加器的 PV 计算加速仅在 RTX 4090 和 RTX 3090 等少数 GPU 上有效,通用性不强。
为了克服这些弱点并利用更快的硬件特性,本文提出将 Q、K 矩阵量化到 INT4,并将 P、V 矩阵量化到 FP8。
研究挑战:
将 Q、K 量化到 INT4 以及 P、V 量化到 FP8 带来了严峻的挑战:
1. (C1) INT4 的数值范围([-7, +7])极其有限,当 Q 和 K 矩阵中存在异常值(outliers)时,会导致巨大的量化误差。例如,仅使用张量级(per-tensor)INT4 量化会导致文生视频模型 CogvideoX 生成完全模糊的视频,Llama3 模型在 MMLU 任务上的准确率降至随机猜测水平。
2. (C2) 经研究发现,硬件中为 FP8 矩阵乘法设计的 FP32 累加器(mma.f32.f8.f8.f32)实际上是 FP22 精度(1个符号位,8个指数位,13个尾数位),这会导致 PV 计算的精度损失。
核心创新点/方法:
为应对上述挑战,本文提出了以下技术:
1. Q 矩阵平滑(Smoothing Q): 针对挑战 (C1),本文在 SageAttention 对 K 矩阵平滑的基础上,进一步提出对 Q 矩阵进行平滑。具体方法是减去 Q 矩阵在通道维度上的平均值 $\overrightarrow{Q_m}$,然后在 QK⊤ 矩阵乘法后加上补偿项 $\overrightarrow{Q_m}K$ 以确保计算的正确性。此举能有效消除 Q 和 K 中的通道维异常值,大幅提升 INT4 量化的精度。
2. 每线程量化(Per-thread Quantization): 针对挑战 (C1),本文提出一种硬件友好的每线程量化方法。该方法比 SageAttention 的每块(per-block)量化更精细,同时避免了每词元(per-token)量化所需的高昂反量化开销。它根据 PTX mma 指令的 GPU 线程与内存布局的映射关系,将对应同一线程的词元分组进行量化,确保每个线程在反量化时只关联一个量化尺度,从而在不增加额外延迟的情况下实现高精度。
3. 两级累加策略(Two-level Accumulation): 针对挑战 (C2) 中 FP22 累加器导致的精度损失,本文提出一种两级累加策略。该策略在每次 PV 的块级矩阵乘法后,使用一个 FP32 缓冲区来累加来自 FP22 累加器的值,从而将误差限制在块范围内,提升了 FP8 PV 计算的整体精度。
4. 可选的 V 矩阵平滑: 作为补充,本文还设计了一种可选的 V 矩阵平滑技术,通过减去 V 在通道维度上的平均值,并在最终输出中加回补偿项,来进一步提高 FP22 累加器的精度。
主要性能:
SageAttention2 在 RTX4090 和 L20 GPU 上实现了高性能内核,其峰值性能在 RTX4090 上达到 481 TOPS,速度分别超过 FlashAttention2 约 3 倍、xformers 约 4.5 倍。为了支持缺乏原生 INT4 张量核的 NVIDIA Hopper 架构 GPU,本文还提供了 SageAttention2-8b(将 Q, K 量化为 INT8)版本。该版本在 Hopper GPU 上的速度与 FlashAttention3(fp8) 相当,但精度远高于后者。大量实验表明,SageAttention2 能够以即插即用的方式在语言、图像和视频生成等多种先进模型上实现加速,且对端到端指标的损失微乎其微。
A3 背景知识
2.1. FlashAttention
Attention 计算公式。Attention 的计算可以表示为:$S = QK^\top/\sqrt{d}$,$P = \sigma(S)$,$O = PV$,其中 $\sigma(S)_{ij} = \exp(S_{ij})/\sum_k \exp(S_{ik})$。矩阵 $Q$、$K$ 和 $V$ 的维度均为 $N \times d$,$S$ 和 $P$ 的维度为 $N \times N$。$d$ 通常较小(如64或128),而 $N$ 可能达到数千甚至数百万。Attention 的时间复杂度为 $O(N^2)$,主要源于两个矩阵乘法($QK^\top$ 和 $PV$),它们的复杂度均为 $O(N^2d)$。FlashAttention (【索引8,Flashattention-2: Faster attention with better parallelism and work partitioning,2024,ICLR】) 是一种对 GPU 友好的 Attention 实现,它将 $Q$、$K$ 和 $V$ 从词元(token)维度切分为块 $\{Q_i\}_{i=1}^{n_q}$、$\{K_i\}_{i=1}^{n_k}$、$\{V_i\}_{i=1}^{n_v}$,块大小分别为 $b_q$、$b_k$、$b_v$ 个词元,其中 $n_q, n_k, n_v$ 是分块数量,且 $b_k = b_v$。FlashAttention 并行地以块为单位计算输出矩阵 $O$。每个流式多处理器(SM)通过迭代加载 $K_j, V_j$ 来计算一个块 $O_i$(对应一个 $Q_i$),并使用在线 softmax (【索引37,Online normalizer calculation for softmax,2018,arXiv】) 更新输出:
$$ S_{ij}=Q_i K_j^{\mathrm{T}}/\sqrt{d}, (m_{ij}, \tilde{P}_{ij}) = \tilde{\sigma}(m_{i,j-1}, S_{ij}), \\ l_{ij} = \exp(m_{i,j-1} - m_{ij})l_{i,j-1} + \text{rowsum}(\tilde{P}_{ij}), \\ O_{ij}=\text{diag}(\exp(m_{i,j-1} - m_{ij})) O_{i,j-1} + \tilde{P}_{ij}V_j. $$
其中 $m_{ij}$ 和 $l_{ij}$ 是 $b_q$ 维的向量,分别初始化为 $-\infty$ 和 0。$\tilde{\sigma}(\cdot)$ 是一个在线 softmax 算子:$m_{ij} = \max\{m_{i,j-1}, \text{rowmax}(S_{ij})\}$,$P_{ij} = \exp(S_{ij} - m_{ij})$。最后,输出计算为 $O_i = \text{diag}(l_{i,n_q})^{-1}O_{i,n_q}$。
2.2. 量化
量化加速矩阵乘法。一个矩阵乘法 $C = AB$ 可以通过量化来加速,过程如下:
$(\delta_A, \hat{A}) = \psi(A)$, $(\delta_B, \hat{B}) = \psi(B)$, $C = \psi_{\delta_A \delta_B}^{-1}(\hat{A}\hat{B})$
$\psi$ 是一个量化器,它将高精度(如 FP32 或 FP16)矩阵 $A$ 转换为一个低精度格式 $\hat{A}$(如 INT4 或 FP8)和一个缩放因子 $\delta_A$;$\psi^{-1}$ 是一个反量化器,用于将结果转换回高精度。我们期望 $\psi^{-1}(\hat{A}) \approx A$。实际的矩阵乘法 $\hat{A}\hat{B}$ 是以低精度执行的。在现代 GPU 中,低精度矩阵乘法通常比高精度快数倍。量化器在数值格式和粒度上有所不同,例如,多少个元素(“量化组”)共享一个共同的缩放因子。例如,一个 INT4、张量级(per-tensor)量化器首先计算缩放因子为整个张量的最大绝对值,然后将元素缩放到 INT4 的最大可表示范围 [-7, +7] 内,最后通过四舍五入转换为 INT4:
$\hat{A} = \lceil A / \delta_A \rfloor, \delta_A = \max(|A|) / 7$。
反量化过程是逐元素的缩放。例如,对于张量级反量化,$\psi^{-1}_{\delta_A \delta_B}(\hat{A}\hat{B}) = \hat{A}\hat{B} \times \delta_A \delta_B$。
表1. 相较于使用FP32累加器的FP16矩阵乘法的加速比。
2.3. SageAttention
SageAttention 的量化策略。基于 FlashAttention (【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】) 的分块方法,SageAttention (【索引61,Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration,2025,ICLR】) 以每块(per-block)的粒度将 $Q$、$K$ 量化为 INT8,即每个 $Q_i$、$K_i$ 都有一个单独的标量缩放因子:$\delta_{Q_i} = \max(|Q_i|)/127$, $\delta_{K_j} = \max(|K_j|)/127$。这样,公式 (1) 中的乘积 $S_{ij}$ 可以近似为 $S_{ij} \approx \hat{Q}_i \hat{K}_j^\top \times (\delta_{Q_i} \delta_{K_j} / \sqrt{d})$。为了保持精度,SageAttention 提出了一种预处理技术,即从 $K$ 中减去词元级(token-wise)的均值。此外,SageAttention 保持 $P_{ij}$ 和 $V_j$ 为 FP16,但在计算乘积 $P_{ij}V_j$ 时使用了 FP16 累加器(而非 FP32)。在 RTX4090 GPU 上,降低累加器精度可以加速矩阵乘法(MM)。然而,如表1所示,其他 GPU(如 L20、L40 或 H100)则没有这种特性。
A2 方法细节
本节将介绍 SageAttention2,一个高效且精确的量化注意力机制。SageAttention2 的工作流程如图3所示。我们为最大化效率,将 Q、K 量化为 INT4,将 P、$\tilde{V}$ 量化为 FP8,并提出了几种技术来保持精度,包括 QK-平滑、每线程量化和两级累加,我们将在后续小节中讨论这些技术。
3.1. Smooth Q
INT4 量化面临的挑战。首先,我们讨论如何用 INT4 精确计算 $QK^\top$。INT4 的数值范围非常有限,这给量化带来了挑战,因为存在异常值(outliers)【索引33,QServe:w4a8KV4 quantization and system co-design for efficient LLM serving,2025,MLSys】。考虑到 INT4 的范围是 [-7, +7],如果一个元素比组内最大元素小14倍以上(0.5 vs 7),它将被量化为零。由于异常值远大于其他元素,很可能导致许多非异常值元素被量化为零,从而造成严重的精度下降。因此,为了保持量化精度,我们需要使最大元素尽可能小,即使元素的量级尽可能均匀。这种技术被称为平滑(smoothing)。
提出 QK 平滑方法。在此,我们提出一种受 SageAttention【索引61,Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration,2025,ICLR】启发的平滑技术。SageAttention 观察到,所有词元的 Q、K 实际上高度相似,不同词元之间只有微小变化(图2展示了从 Llama3.1【索引13,The llama 3 herd of models,2024,arXiv】和 CogvideoX【索引57,Cogvideox: Text-to-video diffusion models with an expert transformer,2025,ICLR】中随机采样的 Q、K 和 V 的热图分布)。我们建议像 SageAttention 一样平滑 K,并进一步通过减去每个块的共同均值来平滑 Q:
$\gamma(Q_i) = Q_i - \bar{q}_i, \gamma(K_j) = K_j - \bar{k},$
其中 $\bar{q}_i = \text{mean}(Q_i)$ 和 $\bar{k} = \text{mean}(K)$ 是 $1 \times D$ 维的向量,均值是沿着词元轴计算的,并且 $\bar{q}_i, \bar{k}$ 会被广播到块内和张量内的所有词元以进行减法操作。
QK 平滑的计算分解与执行流程。通过这种分解,我们得到 $S_{ij} = Q_i K^\top = (\bar{q}_i + \gamma(Q_i))(\bar{k} + \gamma(K_j))^\top = \bar{q}_i \bar{k}^\top + \bar{q}_i \gamma(K_j)^\top + \gamma(Q_i) \bar{k}^\top + \gamma(Q_i) \gamma(K_j)^\top = \gamma(Q_i) \gamma(K_j)^\top + \Delta S_{ij} + b$。这里,$\Delta S_{ij} = \bar{q}_i \gamma(K_j)^\top$ 是一个 $1 \times N$ 的向量,而 $b = \bar{q}_i \bar{k}^\top + \gamma(Q_i) \bar{k}^\top$ 是一个 $N \times 1$ 的向量。我们不需要计算 $b$,因为给 $S$ 的整行加上一个共同的偏置不会影响 softmax 后的结果。因此,我们可以通过以下两个阶段用 INT4 加速 $Q_i K_j^\top$ 的计算:
1. 预处理阶段: 根据公式 (2) 平滑 Q、K,然后进行量化 $(\delta_{Q_i}, \hat{Q}_i) = \psi_Q(\gamma(Q_i)), (\delta_{K_j}, \hat{K}_j) = \psi_K(\gamma(K_j))$,并计算 $\Delta S_{ij} = \bar{q}_i \gamma(K_j)^\top$。平滑、量化和用于计算 $\Delta S$ 的通用矩阵向量乘法(GEMV)可以融合成一个单独的核函数,该核函数只需从片外读取一次 Q 和 K。
2. Attention 计算阶段: 执行低精度的通用矩阵乘法(GEMM),进行反量化,并加回向量 $\Delta S$:$S_{ij} = \psi^{-1}_{\delta_Q \delta_K}(\hat{Q}_i \hat{K}_j^\top) + \Delta S_{ij}$。这些操作都在芯片上完成,反量化和向量加法与昂贵的矩阵乘法 mma 操作相比,只增加了微不足道的开销。重要的是,我们量化的是 $\gamma(Q_i), \gamma(K_j)$ 而不是 $Q_i, K_j$。由于平滑后的矩阵在数值上小得多且包含更少的异常值,量化精度可以得到显著提升。关于平滑好处的理论分析见附录 A.5。
与现有技术的比较。传统的改善激活-权重矩阵乘法(MM)的技术,如逐通道量化或 SmoothQuant【索引66,Smoothquant: Accurate and efficient post-training quantization for large language models,2023,ICML】,在这里不适用于 Attention 中的 query-key MM。逐通道量化不能应用于 Q、K,因为量化必须沿着 $QK^\top$ 的外轴(词元维度)进行。另一方面,Q 和 K 都存在显著的异常值,因此使用 SmoothQuant 在它们之间权衡量化精度无法有效工作,如第4节所示。我们利用了 Attention 中独特的词元相似性模式,为 Q 和 K 推导了一种专门的量化方法。之前的工作 SageAttention 只平滑 K,因此其精度不如我们的方法。
实验结果验证。附录 A.9 中的图20展示了一个来自 CogvideoX 的例子,比较了平滑 Q 前后 $\hat{Q}$ 的分布。我们可以发现,通过平滑 Q,INT4 的数值范围得到了更均匀和充分的利用。表5展示了在 Llama3.1 和 CogvideoX (2b) 上,有无平滑 Q+K 的不同量化方法的端到端指标。结果表明,平滑 Q+K 带来了显著的精度提升。此外,表4和表17显示,方法的有效性排序为:平滑 Q+K > 平滑 Q > 平滑 K > 其他基线方法。
3.2. INT4 每线程量化
量化粒度的权衡。与平滑方法正交,我们可以通过细化量化粒度来减轻异常值问题,使得受异常值影响的元素数量变少。尽管每词元(per-token)量化提供了非常精细的粒度,但它在反量化过程中会产生巨大的开销。具体来说,在每词元量化中,每个 GPU 线程必须处理多个量化尺度,导致量化尺度向量 $\delta_Q$ 和 $\delta_K$ 的点积延迟很高。SageAttention 使用每块(per-block)量化,其中每个块 $Q_i$($b_q$ 个词元)和 $K_i$($b_k$ 个词元)共享一个量化尺度。这种量化策略能够达到接近每词元量化的精度性能,并避免了高昂的反量化开销。然而,将 Q 和 K 量化到 INT4 需要更精细的量化粒度。为了解决这个问题,我们提出了每线程(per-thread)量化,这是一种比每块量化器更精确、更细粒度的方法,并且同样没有 $\delta_Q$ 和 $\delta_K$ 之间向量点积的额外开销。
每线程量化的具体实现。具体来说,SageAttention 中的每个 Q 块,即 $Q_i$,会被分割成 $c_w$ 个段,并由一个 GPU 流处理器(SM)中的 $c_w$ 个 GPU warp 来处理。我们将 $Q_i$ 的每个段称为 $Q_w$,而 $K_w = K_j$,因为 $K_j$ 在 warp 之间是共享的。然后,每个包含32个线程的 warp 使用 mma.m16n8k64 PTX 指令(NVIDIA)来计算 $Q_w K^\top$。根据该指令的布局要求,我们发现 $Q_w[8k + i]$ 可以共享一个量化尺度,而 $K_j[8k + 2i]$ 与 $K_j[8k + 2i + 1]$ 可以共享一个量化尺度。这种量化方法更加细粒度,且没有额外开销。这是因为它根据 MMA 指令布局将不同的 GPU 线程分配到不同的量化组,每个线程只使用单个量化尺度值进行反量化。我们在图4中展示了一个每线程量化的例子。详细的公式见公式8和图18(更多细节请参见附录 A.6)。
实验结果验证。如表6和表15所示,我们使用 CogvideoX 所有层的真实 Q、K、V 矩阵,比较了每词元、每线程、每块和每张量粒度下 INT4 量化的平均和最差精度。结果表明,每线程量化的精度非常接近每词元量化,并且显著优于其他粒度。此外,表19显示,每线程量化几乎没有引入速度下降,而每词元量化由于硬件效率降低,会带来明显的开销。
3.3. P$\tilde{V}$ 的 FP8 量化
P̃ 矩阵的分布特性与量化挑战。现在我们转向矩阵乘法 $P\tilde{V}$,其中根据公式(1),$\tilde{P}_{ij} = \exp(S_{ij} - m_{ij})$ 是未归一化的量。$\tilde{P}$ 的分布是独特的,不同于其他激活值。首先,我们注意到 $S_{ij} - m_{ij} \le 0$,因此 $P_{ij} \in [0, 1]$($\le$ 和 $\in$ 逐元素适用)。我们发现 $\tilde{P}$ 通常由许多小元素组成,但它们的总和不可忽略(例如,5000个值在 $10^{-4}$ 左右)。在这种情况下,我们必须精确地表示小元素。INT 量化不适用于此场景,因为它在数值范围内均匀分布量化点。SageAttention【索引61,Sageattention: Accurate 8-bit attention for plug-and-play inference acceleration,2025,ICLR】选择将 $\tilde{P}$ 和 $V$ 保留为 FP16,并通过降低累加器精度来加速矩阵乘法。然而,这种策略只在极少数 GPU 上有效。
提出 FP8(E4M3) 量化方案。我们提出将 $P$ 和 $V$ 量化为具有4个指数位和3个尾数位(E4M3)的 FP8 格式。E4M3 的数值范围是 [-448, +448]。我们使用静态缩放因子对 $P$ 进行量化:$\delta_P = 1/448$,因为原始 $P$ 的元素已经在 [0, 1] 范围内。我们对 $V$ 进行逐通道量化,以处理图2中显示的通道级异常值。表7和表16中的实验结果显示了在 CogvideoX 所有层中,用于 $P, V$ 的不同数据类型的平均和最差精度。
Algorithm 1 SageAttention2的实现。
FP8(E4M3) 的优势。累加器始终是32位的。我们可以看到 E4M3 的精度非常接近 FP16,并且优于 E5M2 和 INT8。大多数现代 GPU 的张量核都支持 FP8 矩阵乘法运算,其速度是使用 FP16 的两倍。
3.4. 针对 FP22 累加器的 FP32 MMA 缓冲区
发现 FP22 累加器问题。虽然上述对 $P\tilde{V}$ 的 FP8 量化在理论模拟中是准确的,但我们观察到实际的 CUDA 实现会遭受持续的精度下降。经过问题排查,我们发现 Ada 和 Hopper 架构上 mma(f32f8f8f32) 指令的累加器实际上是 FP22,具体为1个符号位、8个指数位和13个尾数位。具体来说,对于指令 $C = AB + D$,其中 A、B 是 FP8 矩阵,C、D 是 FP32 矩阵,我们将 A、B 初始化为零并改变 D 来测试累加器的数据类型。当 D 用1个符号位、8个指数位和13个尾数位初始化时,C 的值与 D 完全匹配。然而,当 D 用超过13个尾数位初始化时,C 的值等于 D 的最低有效10个尾数位被清零(即截断)后的值。因此,与使用 FP32 累加器相比,将 $PV$ 量化为 FP8 的矩阵乘法会产生一定程度的精度损失。
提出两级累加策略。为了减轻这种精度损失,我们采用了一种两级累加策略,即使用一个 FP32 缓冲区来累加 FP22 格式的 $P_{ij}V_j$ 的值。具体来说,我们将公式 (1) 重写为 $R_{ij} = P_{ij}V_j, O_{ij} = \text{diag}(\exp(m_{i,j-1} - m_{ij}))O_{i,j-1} + R_{ij}$。这里,两组累加器 $R_{ij}$ 和 $O_{ij}$ 维护在寄存器中。$R_{ij}$ 使用 mma(f32f8f8f32) 指令计算,提供22个有效位,这已经足够了,因为我们只对少量 $b_k$ 个词元(例如,$b_k = 64$)进行累加。然后,$R_{ij}$ 以高精度的 FP32 累加到 $O_{ij}$ 中。
相关工作与贡献。两级累加策略也已在 CUTLASS【索引39,CUTLASS: CUDA Templates for Linear Algebra Subroutines and Solvers,2023,GitHub】和 DeepGemm【索引10,Deepseek-v3 technical report,2024,arXiv】中实现,用于计算线性层中的权重-激活乘积。据我们所知,我们是第一个发现并研究 FP22 累加器效应,并为 Attention 实现两级累加的工作。
可选的平滑 V 技术。我们还找到了另一种方法来减轻由于 FP22 累加器导致的精度损失,特别是在 V 具有通道级偏差时:$\overrightarrow{V_m} = \text{mean}(V, \text{axis}=0)$, $V = V - \overrightarrow{V_m}$。此外,为保持 Attention 计算的正确性,只需将 $\overrightarrow{V_m}$ 添加到 O 的最终计算中:$O = O + \overrightarrow{V_m}$。这是因为 P 矩阵的每行之和等于1,所以 $P\overrightarrow{V_m} = \overrightarrow{V_m}$。
备注。关于平滑 V 的详细信息,请参见附录 A.3。该技术是可选的,并且未在我们主要实验中使用,因为它仅在 V 表现出通道级偏差时提供显著好处,而某些模型(如 Llama3.1)中不存在这种情况(见图2)。
A4 实验环境
- 模型: 实验覆盖了文本、图像和视频生成领域的十个代表性模型:
- 文本生成: Llama2 (7B), Llama3.1 (8B), GLM4 (9B)
- 文生视频: CogvideoX (2B), CogvideoX (1.5-5B), HunyuanVideo, Mochi
- 文生图像: Flux (schnell), Stable-Diffusion3.5 (turbo)
- 图像分类: TIMM
- 数据集与评估指标:
- 文本: WikiText (Perplexity), LAMBADA (Accuracy), MMLU (Accuracy), Longbench (Longbench score)
- 视频: open-sora prompt sets (CLIPSIM, CLIP-T, VQA-a, VQA-t, FScore)
- 图像: MJHQ-30K (FID, sFID, Clipscore, ImageReward)
- 分类: ImageNet, ImageNet-Sketch, ImageNet-Rendition (Accuracy)
- 硬件配置:
- GPU: RTX4090, L20, L40, H100, H20
- 软件配置:
- 实现: 使用 CUDA 实现
- SageAttention2 变体:
- SageAttn2-4b: Q, K 使用 INT4 每线程量化;P, V 使用 FP8 每块和每通道量化。
- SageAttn2-8b: Q, K 使用 INT8 每线程量化;P, V 使用 FP8 每块和每通道量化。此版本为不支持原生INT4张量核的NVIDIA Hopper GPU适配。
- 对比基线:
- SmoothAttn: 采用 SmoothQuant 方法对 Q, K 进行平滑处理。
- HadmdAttn: 在 INT4 量化前对 Q, K 应用随机哈达玛变换。
- SageAttention: 使用 K 平滑、INT8 每块量化 Q, K,以及 FP16 计算 P, V。
- FlashAttn3(fp8): FlashAttention3 的 FP8 版本,仅在 Hopper GPU 上运行。
A4 实验结果
核心结论: SageAttention2 的速度比 FlashAttention2 和 xformers 快约 3 倍和 4.5 倍。此外,SageAttention2 在 Hopper GPU 上的速度与 FlashAttn3(fp8) 相当,但精度远高于后者。SageAttention2 能够在语言、图像和视频生成模型上保持端到端指标。
4.2. 内核速度与精度
- 内核速度:
- 实验内容: 在 headdim=64 和 headdim=128,以及有无 Causal Mask 的条件下,比较 SageAttention2 与基线的速度。
- 实验结果: 在 RTX4090 上,SageAttn2-4b 和 SageAttn2-8b 的速度分别比 FlashAttention2 快约 3 倍和 2.7 倍,比 xformers 快约 4.5 倍和 4 倍 (图 5)。在 RTX4090, L20, H20, H100 GPU 上的更多速度测试结果见附录 A.2 中的图 10-16。
- 内核精度:
- 实验内容: 在 CogvideoX 的所有层上,比较不同方法使用 INT4 (Q, K) 和 FP8 (P, V) 时的平均精度。
- 实验结果: SageAttn2-4b 的精度优于其他所有基线方法 (表 4, 17)。
4.3. 端到端性能
- 指标损失:
- 实验内容: 在多种模型和任务上评估 SageAttention2 和基线的端到端指标。
- 实验结果: SageAttn2-4b 在所有基线中表现最佳,并在所有模型上基本保持了端到端精度。SageAttn2-8b 在各种模型上几乎没有指标损失 (表 2)。
- 可视化效果:
- 实验内容: 对比 HunyuanVideo, Mochi 和 CogvideoX 等模型的生成结果。
- 实验结果: SageAttn2-8b 与全精度 attention 相比没有引入任何可见差异。SageAttn2-4b 存在微小差异,但远优于基线方法 (图 6, 7, 8, 9)。
- 端到端加速:
- 实验内容: 对比长序列模型使用 SageAttention2 前后的生成延迟。
- 实验结果: 观察到显著的加速效果。例如,在 CogvideoX (1.5-5B) 模型上,SageAttn2-8b 实现了 1.8 倍的加速,且无任何指标损失 (表 8)。
4.4. 消融研究
- 实验内容: 评估所提技术的开销。
- 实验结果: 相较于基础的 Attention 内核,每线程量化、平滑 Q 和两级累加的开销分别为 0.35%、3.7% 和 0%,开销极低 (表 18)。平滑 V 的增益实验见附录 A.4。
A5 结论
本文介绍了 SageAttention2,一个高效且精确的量化注意力机制。
1. 量化策略: 我们提出将 Q, K 矩阵以线程级粒度量化,并将 P, V 矩阵量化为 FP8。
2. 精度增强技术:
- 提出了一种平滑 Q 矩阵的方法,以提升 $QK^\top$ 的精度。
- 提出了一种两级累加策略,以提升 FP8 $PV$ 计算的精度。
3. 性能表现: SageAttention2 的速度比 FlashAttention2 和 xformers 快约 3 倍和 4.5 倍。在 Hopper GPU 上,其速度与 FlashAttention3(fp8) 相当,但精度显著更高。
4. 实验验证: 大量实验证实,我们的方法在语言、图像和视频生成模型上均能保持端到端的性能指标。
A6 附录
A.1. 可视化对比示例
本节提供了更多来自 CogvideoX (2B) 模型的可视化对比结果。
A.2. 额外的内核速度比较
更多速度对比。图10、11、12、13、14、15和16比较了SageAttention2与基线在RTX4090、L20、H100和H20 GPU上,使用headdim=64和headdim=128,以及有无因果掩码(Causal Mask)【索引50,Attention is all you need,2017,NeurIPS】时的速度。
表9总结了不同注意力方法在各种现代GPU上相对于基线的性能增益。
A.3. Smoothing V
平滑V的原理。如图17所示,该策略能提高FP22对$PV$中值的表示精度的原因如下:$P_e$的每一行值范围在0到1之间,而在某些模型中,$V$的每一列始终具有通道级偏差,这些偏差要么是纯正的,要么是纯负的,例如在CogvideoX中范围在8到9之间。因此,$P_V e$的值可能相当大。然而,浮点数的表示范围不是均匀的——它在零附近更密集。因此,通过从$V$中减去沿通道维度的均值$\overrightarrow{V_m}$,$PV$的值会更接近零,从而获得更高的表示精度(见图17的可视化演示)。此外,为了保持注意力计算的正确性,只需将$\overrightarrow{V_m}$加到$O$的最终计算中:$O = O + \overrightarrow{V_m}$。这是因为$P_e$矩阵的每行之和等于1,所以$P_e\overrightarrow{V_m} = \overrightarrow{V_m}$。换句话说,该方法将$V$分解为两部分:$\overrightarrow{V_m}$和$V$。对于$V$,它将每列的值中心化到零附近,这导致量化后的$P$矩阵的一行与量化后的$V$矩阵的一列之间的点积结果更接近零,使得FP22的表示更精确。同时,$\overrightarrow{V_m}$保留为FP16,并在最后加到$O$上,不会对$\overrightarrow{V_m}$部分造成精度损失。
A.4. Smoothing V 的实验
实验结果。表10显示了在从CogvideoX采样的真实张量上,有无平滑V时的注意力精度。结果表明,当将Q, K量化为INT4并将P, V量化为FP8时,平滑V可以提高SageAttention2的精度。我们发现平滑V对扩散模型通常是有效的【索引80,Dpm-solver-v3: Improved diffusion ode solver with empirical model statistics,2023,NeurIPS】等。
A.5. 平滑的理论分析
平滑的理论基础。本节我们从理论角度分析平滑的好处。令$X \in \mathbb{R}^{N \times d}$为$N$个维度为$d$的激活词元。遵循【索引12,QLoRA: Efficient finetuning of quantized LLMs,2023,NeurIPS】,我们假设一个激活词元服从高斯分布$N(\mu, \Sigma^2)$,其中$\mu = (\mu_1, \mu_2, \dots, \mu_d)$,$\Sigma^2$是一个对角矩阵,$\Sigma^2 = \text{diag}(\sigma_1^2, \sigma_2^2, \dots, \sigma_d^2)$。此外,我们假设不同的词元$X_i$是从同一分布中独立同分布采样的。假设一个量化组内的绝对最大值为$M$,位宽为$b$,则有$2^b$个量化级别。在四舍五入到最近值的策略下,期望量化误差为$\frac{1}{2}\frac{2M}{2^b}$,这与量化组中的最大绝对值成正比。因此,较小的绝对最大值会导致较小的量化误差。
平滑后的分布变化。经过平滑后,我们有:
$$Y_{ij} = X_{ij} - \frac{1}{N} \sum_{k=1}^{N} X_{kj}$$
$Y_{ij}$也服从高斯分布。$Y_{ij}$的均值和方差可以计算如下:
$$ E[Y_i] = E[X_i] - \frac{1}{N} \sum_{k=1}^N E[X_k] = \mu_i - \frac{1}{N} \sum_{k=1}^N \mu_k = 0 $$
$$ Var[Y_i] = Var\left[\frac{N-1}{N} X_i\right] + \sum_{k=i, k \neq i}^N Var\left[-\frac{1}{N} X_k\right] = \frac{(N-1)^2}{N^2} \sigma_i^2 + (N-1) \frac{1}{N^2} \sigma_k^2 = \frac{(N-1)}{N} \sigma_k^2 $$
所以与$X_{ij}$相比,$Y_{ij}$的均值和方差更小。根据高斯分布的性质,我们有:
$P(|Y_{ij}| < \epsilon) > P(|X_{ij}| > \epsilon), \forall \epsilon > 0$
因此,平滑后$X_{ij}$的分布更集中于0。然后我们知道:
$P(\mathrm{obmax}(Y) < c) = \prod_{i=1}^d P(Y_i < c) \ge \prod_{i=1}^d P(X_i < c) = P(\mathrm{obmax}(X) < c)$
这使得词元中绝对最大值的分布更集中于0,从而导致更小的量化误差。
A.6. 每线程量化公式
Tensor Core MMA 指令介绍。为了进一步阐明每线程量化,我们首先介绍 Tensor Core 的 INT4 MMA 指令,然后给出每线程量化的公式。Tensor Core 是 NVIDIA Volta 架构中首次引入的专用单元,专为高效的矩阵乘法累加(MMA)操作而设计。Tensor Core 以 warp(32个连续线程)为单位计算小的 MMA 操作块,具体为 $D = A \cdot B + C$。warp 中的每个线程持有输入矩阵的片段,并将获得输出矩阵的片段作为计算结果。INT4 mma.m16n8k64 Tensor Core 操作计算一个 16x64 的 INT4 矩阵 A 和一个 64x8 的 INT4 矩阵 B 的乘积,两者都存储在寄存器中。它将结果累加到一个 16x8 的 INT32 矩阵 C 中,该矩阵也存储在寄存器中,并返回最终的乘积矩阵 D,其形状(16x8)、数据类型(INT32)和存储位置(寄存器)均相同。每个线程只持有输入和输出数据的1/32。图18摘自 PTX 文档(NVIDIA),显示了 warp 中32个线程之间矩阵 C 和 D 的内存布局。每个线程只持有128个结果元素中的4个。
每线程量化的设计原理与优势。通过确保每个线程持有的结果共享一个共同的反量化尺度(属于同一个量化组),我们可以避免与每词元量化相关的开销。利用这一观察,我们设计了如公式8所示的每线程量化,其中$c_w$是GPU Warp的数量,$b_q$和$b_k$是Q、K的块大小,$n$是Q、K的词元索引。对于典型的块大小$b_q=128, b_k=64$和warp数$c_w=4$(如FlashAttention2中所用),每个warp处理一个32个查询词元和64个键词元的瓦片。查询词元$i, 8+i, 16+i, 24+i$ ($i=0, 1, \dots, 7$)可以组成一个量化组,键词元$j, 1+j, 8+j, 9+j, \dots, 56+j, 57+j$ ($j=0, 1, 2, 3$)可以组成一个量化组,如图4所示。这种设计与图18所示的张量核输出矩阵D的内存布局对齐,确保每个线程在反量化时只需要一个Q尺度和一个K尺度。因此,该方法在一个128x64的块中为Q创建了32个量化组(4个warp,每个8个),为K创建了4个量化组,分别比每块量化为查询词元和键词元提供了32倍和4倍的更细粒度。表6和表15显示了使用每线程量化带来的精度提升。每线程量化实现了与每词元量化非常接近的精度,而没有引入任何内核速度下降(见表18和19)。
$$i_{\delta q} = \lfloor (n*8*c_w/b_q) \rfloor$$
$$q_l[i_{\delta q}] = \left\{ 8 \times (n\%8) + \lfloor (n*\frac{b_q}{c_w}) \rfloor \frac{b_q}{c_w} \right\}, n \in [0,N]$$
$$\delta_Q[i_{\delta q}] = \frac{\max(|Q[q_l[i_{\delta q}]]|)}{7}$$
$$\hat{Q}[q_l[i_{\delta q}]] = \left\lceil \frac{Q[q_l[i_{\delta q}]]}{\delta_Q[i_{\delta q}]} \right\rceil$$
$$i_{\delta k} = \lfloor (n*4/b_k) \rfloor$$
$$k_n[i_{\delta k}] = \left\{ 8 \times (n\%8) + \lfloor n/b_k * b_k \rfloor \cup 8 \times (n\%8) + 1 + \lfloor n/b_k * b_k \rfloor \right\}, n \in [0,N]$$
$$\delta_K[i_{\delta k}] = \frac{\max(|K[k_n[i_{\delta k}]]|)}{7}$$
$$\hat{K}[k_n[i_{\delta k}]] = \left\lceil \frac{K[k_n[i_{\delta k}]]}{\delta_K[i_{\delta k}]} \right\rceil$$
A.7. 实验中的数据集和指标
数据集。文生文模型在四个零样本任务上进行评估:WikiText【索引36,Pointer sentinel mixture models,2022,ICLR】用于评估模型的预测置信度,LAMBADA【索引42,The lambada dataset: Word prediction requiring a broad discourse context,2016,ACL】评估上下文理解能力,MMLU【索引18,Measuring massive multitask language understanding,2021,ICLR】用于衡量跨多个学科的知识,以及Longbench【索引2,LongBench: A bilingual, multitask benchmark for long context understanding,2024,ACL】用于全面评估长上下文理解能力。文生视频模型使用open-sora【索引82,Open-sora: Democratizing efficient video production for all,2024,arXiv】提示词集进行评估。文生图模型在MJHQ-30K【索引29,Playground v2.5: Three insights towards enhancing aesthetic quality in text-to-image generation,2024,arXiv】上进行评估。TIMM在三个图像数据集上进行评估:ImageNet【索引11,Imagenet: A large-scale hierarchical image database,2009,CVPR】、ImageNet-Sketch (Sketch)【索引52,Learning robust global representations by penalizing local predictive power,2019,NeurIPS】和ImageNet-Rendition (ImageNetr)【索引17,The many faces of robustness: A critical analysis of out-of-distribution generalization,2021,ICCV】。
端到端指标。对于文生文模型,我们使用WikiText的困惑度(ppl.)【索引22,Perplexity—a measure of the difficulty of speech recognition tasks,1977,JASA】、LAMBADA和MMLU的准确率(Acc.),以及Longbench得分【索引2,LongBench: A bilingual, multitask benchmark for long context understanding,2024,ACL】。对于文生视频模型,我们评估生成视频的质量,使用五个指标:CLIPSIM和CLIP-Temp (CLIP-T)【索引34,Evalcrafter: Benchmarking and evaluating large video generation models,2024,CVPR】衡量文本-视频对齐度;(VQA-a)和(VQA-t)分别评估视频的美学和技术质量;以及Flow-score (FScore)用于时间一致性【索引55,Exploring video quality assessment on user generated contents from aesthetic and technical perspectives,2023,ICCV】。对于文生图模型,生成的图像与MJHQ-30K数据集中的图像在三个方面进行比较:FID【索引20,Gans trained by a two time-scale update rule converge to a local nash equilibrium,2017,NeurIPS】和sFID【索引43,Improved techniques for training gans,2016,NeurIPS】用于保真度评估,Clipscore (CLIP)【索引19,Clipscore: A reference-free evaluation metric for image captioning,2021,EMNLP】用于文本-图像对齐,以及ImageReward (IR)【索引67,Imagereward: Learning and evaluating human preferences for text-to-image generation,2023,NeurIPS】用于人类偏好。对于TIMM,我们使用分类准确率。
精度指标。我们使用三个指标来评估量化注意力输出$O'$与全精度注意力输出$O$的准确性:首先,我们将$O'$和$O$展平成形状为$1 \times n$的向量。然后,余弦相似度:$\text{CosSim} = \frac{\sum OO'}{\sqrt{\sum O^2}\sqrt{\sum O'^2}}$,相对L1距离:$L1 = \frac{\sum |O - O'|}{\sum |O|}$,均方根误差。
A.8. 内核基准测试设置
基准测试配置。我们以批大小为4和32个注意力头,在多种序列长度下对内核速度进行基准测试。基准测试使用64和128的头维度,并分别在有和没有因果掩码的情况下进行【索引50,Attention is all you need,2017,NeurIPS】。为了生成用于基准测试的输入张量,我们遵循了先前工作(如FlashAttention【索引9,Flashattention: Fast and memory-efficient exact attention with io-awareness,2022,NeurIPS】)中采用的标准做法。对于浮点数据类型,输入从均值为0、标准差为1的高斯分布中抽取;对于整数数据类型,输入在其表示范围内均匀采样:INT8为[-128, 127],INT4为[-8, 7]。
A.9. 额外的实验和分析
补充结果。表11、12和13显示了SageAttention2和其他基线在Llama2 (7B)、CogvideoX (2B)和TIMM上的结果。
超长上下文结果。我们进一步使用Llama-3-262k (8B)在InfiniBench【索引75,Bench: Extending long context evaluation beyond 100K tokens,2024,ACL】和Needle-in-a-Haystack (NIAH)【索引23,Llmtest needle in a haystack - pressure testing llms,2023,GitHub】上进行了超长上下文实验,序列长度高达262k词元,在H100 GPU上进行。由于Hopper GPU缺乏原生INT4张量核支持,我们使用SageAttention2-8b进行此评估。我们将其与FlashAttention3(fp8)进行比较,确保两种方法在相同的位宽下运行。结果如表14和图19所示。SageAttention2即使在超长上下文下也能保持模型性能,而FlashAttention3(fp8)则遭受端到端精度下降。
音频任务结果。我们评估了Qwen2-Audio (7b)【索引7,Qwen2-audio technical report,2024,arXiv】,一个语音转文本模型,在Librispeech【索引40,Librispeech: An asr corpus based on public domain audio books,2015,ICASSP】测试集上进行ASR任务,并使用WER(词错误率)指标衡量其性能。如表20所示,SageAttention2持续优于基线,突显了其在音频相关模型和基准测试中的有效性。
补充图表分析。
- Q平滑效果: 图20展示了Q在平滑前后的量化值分布,平滑后分布更均匀,有效利用了INT4的表示范围。
- 量化粒度精度对比: 表15展示了不同量化粒度在CogvideoX上的最差精度,每线程量化与每词元量化相当,远优于其他方法。
- PV数据类型精度对比: 表16展示了不同数据类型用于(P, V)时的最差精度,E4M3表现最佳。
- 平滑方法精度对比: 表17展示了不同平滑方法的最差精度,Smooth Q+K(即SageAttn2-4b)效果最好。
- 技术开销: 表18量化了各项技术对性能的影响,开销很小。
- 量化粒度速度对比: 表19显示每线程量化几乎不影响速度,而每词元量化有明显开销。
💬 评论讨论
欢迎在这里分享您的想法和见解!