SageAttention: Accurate 8-Bit Attention for Plug-and-Play Inference Acceleration

作者/机构: Jintao Zhang, Jia Wei, Pengle Zhang, Jun Zhu, Jianfei Chen∗, Dept. of Comp. Sci. & Tech., Institute for AI, BNRist Center, Tsinghua-Bosch Joint ML Center, THBI Lab, Tsinghua University

A1 主要贡献

本文旨在解决Transformer模型中注意力机制的计算效率问题。注意力机制是Transformer的核心,但其计算复杂度为$O(N^2)$,在处理长序列(如8K~128K)时,其计算成本远超其他操作,成为性能瓶颈(如图2所示)。尽管量化是提升神经网络效率的有效方法,但现有工作主要集中在线性层的量化,而注意力机制通常仍以FP16等高精度格式计算,缺乏系统性的注意力机制量化研究。此外,许多量化方法需要额外的训练,对于大规模模型来说成本高昂。虽然最近发布的FlashAttention3提供了FP8版本,但它仅限于Nvidia Hopper架构,且直接使用可能导致性能下降(如表1所示)。

量化注意力机制面临两大挑战:
1. (C1) K矩阵存在显著的通道级离群值,这在量化过程中会导致严重的精度损失。
2. (C2) 简单地将P矩阵和V矩阵量化为INT8,无法在各种场景下稳定保证PV矩阵乘法结果的准确性。

为应对这些挑战,本文提出了SageAttention,一种高效且准确的注意力机制量化方法,旨在实现推理时即插即用,无需重新训练。
主要贡献和创新点如下:
1. 选择INT8作为量化格式:研究发现,在RTX4090和3090等常用GPU上,INT8矩阵乘法的速度是FP16的四倍,是FP8的两倍。同时,对于注意力的Q和K矩阵,INT8量化比FP8更精确(见表2)。
2. K矩阵平滑化技术:针对挑战(C1),提出了一种通过减去所有token的均值来平滑K矩阵的方法。该方法基于K矩阵的通道离群值主要源于一个共享的大偏差(bias)这一观察,能显著提升精度,而时间开销可忽略不计(<0.2%)。
3. 针对PV矩阵乘法的高效精确方案:为解决挑战(C2),本文不将P和V量化为8位,而是将它们保持在FP16格式,并使用低精度的FP16累加器进行矩阵乘法。这一策略在不损失任何精度的情况下,将矩阵乘法的速度提高了一倍。
4. 自适应量化策略:实现了多种具有不同速度-精度权衡的注意力核,并提出一种自适应方法,为模型的每一层选择最快的实现版本,同时保证整体精度。
5. 高性能实现与显著加速:使用Triton在RTX4090和3090 GPU上实现了SageAttention。其内核融合了ROPE与量化过程,并采用了FlashAttention风格的分块(tiling)技术,充分利用了硬件的mma(u8.u8.s32)mma(f16.f16.f16)指令。SageAttention的性能(OPS)比FlashAttention2和xformers分别高出约2.1倍和2.7倍,在RTX4090上达到了340 TOPS,接近理论INT8吞吐量的52%,远超FlashAttention2的165 TOPS。

全面的实验验证了SageAttention在大型语言模型、图像生成和视频生成等多种模型上几乎没有端到端指标损失,可以直接替代原有注意力实现,达到显著的加速效果。

图1:SageAttention在视频生成(CogvideoX on RTX4090)上的一个例子。
图1:SageAttention在视频生成(CogvideoX on RTX4090)上的一个例子。
图2:注意力的延迟。
图2:注意力的延迟。
图3:一个比较示例。
图3:一个比较示例。

A3 背景知识

3.1 FLASHATTENTION

自注意力计算公式。自注意力的计算可以表示为:$S = QK^T/\sqrt{d}$,$P = \sigma(S)$,$O = PV$,其中$\sigma(S){ij} = \exp(S$。然后,为避免(S、P)的内存I/O,它使用在线softmax【【40】Milakov & Gimelshein, 2018, Online normalizer calculation for softmax】来逐步计算O的每个块,即$O_i$。}) / \sum_k \exp(S_{ik})$是softmax操作。矩阵Q、K和V的维度均为$N \times d$,而矩阵S、P的维度为$N \times N$。虽然$d$通常很小(例如64或128),但$N$可能达到数千甚至数百万。因此,$N \times N$的矩阵(S、P)比(Q、K、V)大得多,朴素的实现会因读写(S、P)而产生大量的全局内存I/O。FlashAttention【【6】Dao, 2023, Flashattention-2: Faster attention with better parallelism and work partitioning】提出将Q、K和V从token维度分块为${Q_i}$、${K_i}$、${V_i}$,块大小分别为$b_q$、$b_{kv}$、$b_{kv

在线Softmax计算过程。首先,对于每个块${K_j}$、${V_j}$,它迭代计算以下方程:

$$S_i^j=Q_i K_j^{\top} / \sqrt{d}, \quad(m_i^j, \widetilde{P}_i^j)=\tilde{\sigma}(m_i^{j-1}, S_i^j),$$ $$l_{i}^{j}=\exp(m_{i}^{j-1}-m_{i}^{j})l_{i}^{j-1}+\operatorname{rowsum}(\widetilde{P}_{i}^{j}),\ O_{i}^{j}=\operatorname{diag}\left(\exp(m_{i}^{j-1}-m_{i}^{j})\right)O_{i}^{j-1}+\widetilde{P}_{i}^{j}V_{j}$$

其中$m_j$和$l_j$是$b_q \times 1$的向量,分别初始化为$-\infty$和0。$\tilde{\sigma}()$是一个在线softmax算子:$m_{ji} = \max{m_{j-1,i}, \text{rowmax}(S_{ji})}$,$P^e_{ij} = \exp(S_{ji} - m_{ji})$。

最终输出计算。最后,在所有迭代(即$j = b_{kv}$)之后,可以通过$O_i = \text{diag}(l_{ji})^{-1} O_{ji}$来计算输出$O_i$。

3.2 动态量化

矩阵乘法的量化加速。一个矩阵乘法$C=AB$可以通过量化加速为:

$$(\delta_A, \hat{A}) = \psi(A), \quad (\delta_B, \hat{B}) = \psi(B), \quad \hat{C} = \hat{A}\hat{B}, \quad C = \psi_{\delta_A \delta_B}^{-1}(\hat{C})$$

在这里,$\psi$是一个量化器,它将一个高精度(例如FP32)矩阵A转换为一个低精度格式$\hat{A}$(例如INT8或FP8)和一个缩放因子$\delta_A$。$\psi^{-1}$是一个反量化器,用于转换回高精度。我们应有$\psi^{-1}(\hat{A}) \approx A$。实际的矩阵乘法$\hat{A}\hat{B}$是在低精度下进行的。在现代GPU中,低精度矩阵乘法通常比高精度乘法快几倍。

不同粒度的量化器。许多量化器依赖于数值格式和粒度,例如,多少元素共享一个共同的缩放因子。
* per-tensor动态量化器:首先计算整个张量的最大绝对值作为缩放因子,然后将元素缩放到INT8的可表示范围[-127, +127],最后四舍五入转换为INT8:$\hat{A} = \lceil A/\delta_A \rfloor$,$\delta_A = \max(|A|)/127$。
* per-token量化器:为张量的每个token分配一个缩放因子:$\hat{A}[i, :] = \lceil A[i, :]/\delta_A \rfloor$,$\delta_A[i, :] = \max(|A[i, :]|)/127$。
* per-channel量化器:为张量的每个通道(即沿通道维度)分配一个缩放因子:$A[:, i] = \lceil A[:, i]/\delta_A \rfloor$,$\delta_A = \max(|A[:, i]|)/127$。
* per-block量化器:基于FlashAttention的分块方法,我们可以相应地应用逐块量化。它为每$b=m-n$个token分配一个缩放因子:$\hat{A}[m:n, :] = \lceil A[m:n, :]/\delta_A \rfloor$,$\delta_A = \max(|A[m:n, :]|)/127$。
反量化仅涉及逐元素的缩放:$\psi^{-1}_{\delta}(\hat{A}) = \delta_A \hat{A}$。

A2 方法细节

4. SAGE ATTENTION

本节提出了SageAttention,一种利用8位量化来加速注意力计算同时保持高精度的方法。考虑到大多数网络并非原生使用量化注意力进行训练,SageAttention被设计为即插即用型。与易于量化的线性层不同,量化注意力更为复杂,需要额外处理以确保良好的准确性和速度。

4.1 公式化

量化注意力公式。基于第3.1节和3.2节中对FlashAttention和动态量化的描述,我们将量化注意力公式化如下:
* 量化: $(\delta_Q, \hat{Q}) = \psi_Q(Q/\sqrt{d})$, $(\delta_K, \hat{K}) = \phi_K(K)$, $(\delta_P, \hat{P}) = \psi_P(P)$, $(\delta_V, \hat{V}) = \psi_V(V)$ (4)
* 注意力计算: $S = \psi^{-1}{\delta_Q \delta_K}(\hat{Q}\hat{K}^\top)$, $(m', P) = \tilde{\sigma}(m, S)$, $O = \text{diag}(\exp(m' - m))O + \psi^{-1})$ (5)}(\hat{P}\hat{V
其中$\phi_K$是一个用于获得量化K的变换,我们将在后续章节中讨论。为简化起见,我们省略了所有的上标和下标,但注意力中使用的矩阵仍然是分块的,计算仍然按照第3.1节描述的FlashAttention方式组织。与原始全精度版本相比,如公式4和5所示,SageAttention为Q、K、P、V添加了量化器,并为乘积添加了反量化器,以加速$QK^\top$和$PV$这两个矩阵乘法。在线softmax部分则保留在全精度下进行。

图4:(Q, K, V)数据分布的典型示例。
图4:(Q, K, V)数据分布的典型示例。

表1:不同量化方法的端到端指标比较。
表1:不同量化方法的端到端指标比较。

4.2 K矩阵平滑化

K矩阵的离群值问题与平滑化方法。直接量化Q和K通常会导致很大的误差。特别地,将Q和K量化为INT8会在文本到图像/视频任务中产生完全模糊的图像/视频。如图4.1所示,我们可视化了来自文本到图像模型Unidiffuser【【1】Bao et al., 2023, All are worth words: A vit backbone for diffusion models】和文本到视频模型CogvideoX【【74】Yang et al., 2024, Cogvideox: Text-to-video diffusion models with an expert transformer】的两组典型的Q、K、V。值得注意的是,K表现出明显的通道级离群值。然而,对K不能应用逐通道量化,因为量化只能在矩阵乘法$QK^\top$的外轴(token维度)上进行。此外,先前为线性层提出的平滑技术【【70】Xiao et al., 2023a, Smoothquant: Accurate and efficient post-training quantization for large language models】也不适用,因为Q也受到离群值的严重影响。幸运的是,K的通道离群值具有一种模式:每个token的键实际上是一个由所有token共享的大偏差,加上一个小的token级信号。因此,离群值并非来自token间的巨大变化,而仅仅是这个大的偏差。基于这一观察,我们提出通过一个变换$\gamma$来平滑矩阵K,该变换减去所有token的平均K:

$$\gamma(K)=K-\operatorname{mean}(K)$$

其中$\text{mean}(K) = \frac{1}{N} \sum_{t=1}^N K[t, :]$是平均键,形状为$1 \times d$。需要注意的是,$\sigma(q(K - \text{mean}(K))^\top) = \sigma(qK^\top - q \cdot \text{mean}(K)) = \sigma(qK^\top)$。最终,从全精度K到量化$\hat{K}$的变换可以写为$\phi_K(K) = \psi_K \circ \gamma$,其中$\psi_K$是一个量化器。换句话说,一个全精度的K在最终被量化之前,会减去其均值。

平滑化效果。表1展示了在各种模型上使用和不使用平滑K的不同量化方法的端到端指标。结果表明,平滑K能显著提升准确性。此外,平滑K对注意力的速度开销小于0.2%(见表10)。

表2:在真实模型的所有层中使用不同数据类型的平均准确率。
表2:在真实模型的所有层中使用不同数据类型的平均准确率。

表3:在真实模型的所有层中使用不同数据类型的最差准确率。
表3:在真实模型的所有层中使用不同数据类型的最差准确率。

4.3 Q, K, P, V的量化

Q, K的量化粒度。$\psi_Q(Q)$和$\psi_K(K)$可以设置为per-token、per-block或per-tensor的粒度。这是因为per-channel量化是不可行的,因为$QK^\top$内轴的缩放因子无法用于进行反量化【【70】Xiao et al., 2023a, Smoothquant: Accurate and efficient post-training quantization for large language models】。

Q, K的数据类型选择。我们选择INT8用于$\psi_Q(Q)$和$\psi_K(K)$,原因有二。首先,表2显示了在Llama2 (7B)【【53】Touvron et al., 2023, Llama 2: Open foundation and fine-tuned chat models】和Unidiffuser的所有层中使用不同数据类型(INT8, E4M3, E5M2)对Q, K, P, V进行量化的平均准确性。结果表明,将Q, K量化为INT8比使用E4M3和E5M2具有更高的准确性。其次,在许多常用GPU(如RTX4090和3090)上,使用INT8的矩阵乘法比使用FP8快两倍。

P, V的量化粒度选择。我们建议对$\psi_P(P^e)$使用per-block量化,对$\psi_V(V)$使用per-channel量化,原因有三。(1) 对$P^e$进行per-channel量化和对V进行per-token量化是不可行的,因为反量化需要外轴的缩放因子。(2) $P^e = \exp(S_i - \text{rowmax}(S_i))$,其中$S_i$是Q的一个块与$K^\top$的矩阵乘法结果,$P^e$每行的最大值为1。因此,我们可以为$P^e$的一个块分配一个单一的静态缩放因子$s = 1/127$,其准确性等同于per-token量化。(3) Per-channel量化可以解决V的通道级离群值问题。

P, V的数据类型选择。我们选择INT8用于$\psi_P(P^e)$和$\psi_V(V)$,因为在一些常用GPU上,使用INT8的矩阵乘法比使用FP8快两倍,并且尽管使用INT8的$\psi_P(P^e)$和$\psi_V(V)$的准确性比E4M3和E5M2差,但平均准确性是相似的(见表2)。

准确性评估指标。我们使用三个指标来评估量化注意力输出$O'$相对于全精度注意力输出$O$的准确性:首先,我们将$O'$和$O$展平成形状为$1 \times n$的向量。然后,计算余弦相似度(Cosine Sim)= $\sum OO' / \sqrt{\sum O^2 \sum O'^2}$,相对L1范数(Relative L1)= $\sum |O - O'| / \sum |O|$,以及均方根误差(RMSE)= $\sqrt{(1/n) \sum(O - O')^2}$。

表4:在真实模型的所有层中使用不同累加器的平均准确率。
表5:在真实模型的所有层中使用不同累加器的最差准确率。
表4和表5

4.4 FP16累加器:一个更准确且高效的解决方案

INT8量化P,V的精度问题。上述针对$\psi_P(P^e)$和$\psi_V(V)$的解决方案存在一个问题,即在某些模型层中,使用INT8的准确性非常差。表3显示了在Llama2和Unidiffuser的所有层中,使用不同数据类型对Q, K, P, V进行量化的最差准确性。结果表明,INT8的$\psi_P(P^e)$和$\psi_V(V)$带来了不可接受的误差。

FP16累加器方案。为此,我们提出了一个非常准确且高效的解决方案。具体来说,我们建议在矩阵乘法$PV$中使用FP16作为数据类型,并配备一个FP16累加器。

FP16累加器的优势。这种解决方案的好处是显而易见的。首先,在一些常用GPU(如RTX4090和3090)上,使用FP16累加器的FP16矩阵乘法速度比使用FP32累加器的快2倍。此外,使用FP16累加器比使用FP32累加器可以节省更多的寄存器资源,从而加速计算。其次,表3显示,对P, V使用FP16比使用所有其他8位数据类型要准确得多。更重要的是,使用FP16累加器与使用FP32累加器相比没有精度损失。具体来说,表4和表5展示了在Llama2和Unidiffuser的所有层上使用FP16或FP32累加器的平均和最差准确性,表明使用FP16累加器没有精度损失。

表6:SageAttention的四种核实现。
表6:SageAttention的四种核实现。

4.5 自适应量化

四种注意力核实现。基于第4.3节和4.4节的讨论,我们实现了四种注意力核(见表6),这些核基于两组选择:(1) 对$\psi_Q(Q)$和$\psi_K(K)$使用per-token或per-block量化。(2) 对$\psi_P(P^e)$和$\psi_V(V)$使用INT8量化,或将P, V保留在FP16并使用FP16累加器。

自适应选择策略。SAGEAttn-B对于所有模型都足够准确,并且可以实现2倍的加速(见图6和图7)。然而,SAGEAttn-vB在模型的某些层中也足够准确,并且比SAGEAttn-B稍快(约4%)。因此,我们使用各种输入来测试模型每一层SAGEAttn-vB的余弦相似度。然后,我们将为那些SAGEAttn-vB余弦相似度大于99.8%(SAGEAttn-B的最差相似度)的层选择SAGEAttn-vB,而其他层则使用SAGEAttn-B。

图5:SageAttention的工作流程。
图5:SageAttention的工作流程。

算法 1: SAGEAttn-B的实现

输入: 矩阵 Q(FP16), K(FP16), V (FP16) $\in R^{N \times d}$, 块大小 $b_q, b_{kv}$。
1. 预处理: $K = K - \text{mean}(K)$; // 减去token间的均值
2. 量化: $(\delta_Q, \hat{Q}) = \psi_Q(Q/\sqrt{d}), (\delta_K, \hat{K}) = \psi_K(K)$; // INT8 per-block量化
3. 将$\hat{Q}$划分为$T_m = N/b_q$个块${\hat{Q}i}$,将$\hat{K}, V$划分为$T_n = N/b}$个块${\hat{Ki}$和${V_i}$;
4. for i in [1, $T_m$] do ; // 外循环在SM(流处理器)中并行
5. 将$\hat{Q}_i$和$\delta_Q[i]$加载到SM中;
6. for j in [1, $T_n$] do
7. 将$\hat{K}_j, V_j$和$\delta_K[j]$加载到SM中;
8. $S
} = \text{Matmul}(\hat{Qi, \hat{K}_j^T) \times \delta_Q[i] \times \delta_K[j]$;
9. $m
)$;} = \max(m_{j-1,i}, \text{rowmax}(S_{ji})), P^e_{ji} = \exp(S_{ji} - m_{ji}), l_{ji} = e^{m_{j-1,i}-m_{ji}} + \text{rowsum}(P^e_{ji
10. $O_{ji} = \text{diag}(e^{m_{j-1,i}-m_{ji}})^{-1}O_{j-1,i} + \text{Matmul}(P^e_{ji}.\text{to}(\text{FP16}), V_j, \text{Accum type} = \text{FP16})$;
11. end for
12. $O_i = \text{diag}(l_{T_n})O_{T_n}$;
13. 写回 $O_i$;
14. end for
15. return $O = {O_i}$;

4.6 融合技巧和性能分析

融合技巧。为了减少量化的开销,我们将量化过程与注意力层之前的算子进行融合。例如,我们将量化融合到ROPE(旋转位置嵌入)【【48】Su et al., 2021, Roformer: Enhanced transformer with rotary position embedding】层中。具体来说,在ROPE的结果(A)从共享内存写入全局内存之前,我们执行$\delta_A, \hat{A} = \psi(A)$。随后,将$\delta_A, \hat{A}$写入全局内存。此外,我们还将$QK^\top$的系数$(1/\sqrt{d})$融合到量化过程中,而不是将其留在注意力层。具体来说,我们在量化Q之前,在芯片上将Q乘以$(1/\sqrt{d})$。

性能分析。我们以SAGEAttn-B为例来讨论在实际硬件上的加速效果:
1. 矩阵乘法加速。利用当前主流硬件上的INT8矩阵乘法单元可以实现2-4倍的吞吐量。虽然FP16累加器在大多数计算卡上不提供吞吐量提升,但在边缘加速器(如RTX4090)上,仍可实现比FP32累加器2倍的提升。
2. 量化开销。量化和反量化被认为是当前量化方法的主要开销【【34】Lin et al., 2024, Qserve: W4a8kv4 quantization and system co-design for efficient llm serving】。计算开销无法避免,但通过将Q、K的量化与ROPE融合,我们避免了量化的IO开销。
3. 缓存和寄存器。目前,主流加速器在计算过程中需要将数据存储在缓存(如SharedMemory)中。使用8位数据进行计算可以减少通用缓存的使用,而使用fp16累加器也可以减少累加寄存器的使用。
4. DRAM访问。使用8位数据可以将张量从DRAM传输到计算单元的开销减半。尽管量化引入了额外的FP32缩放因子,但与张量相比,这些缩放因子可以忽略不计。

A4 实验环境

  • 模型:实验涵盖了语言、图像和视频生成领域的多种代表性模型,包括:Llama2 (7B)(文本到文本)、CogvideoX(文本到视频)、Unidiffuser 和 UltraPixel(文本到图像)、TIMM(图像分类)以及Llava1.6(视觉问答)。
  • 数据集

    • Llama2: 在三个零样本任务上评估:WikiText(评估预测置信度)、LAMBADA(评估上下文理解)和MMLU(衡量跨学科知识)。
    • CogvideoX: 使用open-sora提示集进行评估。
    • UltraPixel & Unidiffuser: 在COCO annotations数据集上评估,使用(提示,图像)对。
    • TIMM: 在ImageNet、ImageNet-Sketch和ImageNet-Rendition三个图像数据集上评估。
    • Llava1.6: 在TextVQA、POPE和VQAv2三个数据集上评估。
  • 硬件配置

    • RTX 4090平台: PCIE 5.0,16核 Xeon(R) 6430 CPU,120GB DDR4 RAM。
    • RTX 3090平台: 16核 Xeon(R) 8358P CPU,80GB DDR4 RAM。
  • 软件配置

    • 操作系统: Ubuntu 22.04。
    • 核心库: PyTorch 2.4.0+cu121,Triton-nightly (20240816版),Python 3.11,(gcc, g++) 9。
    • 实现: 注意力核使用OpenAI Triton实现。
  • 注意力核超参数:Q的块大小为128,K和V的块大小为64。Num Warps和Num Stages参数详见附录表12。

A4 实验结果

5.2 注意力核的速度与准确性

  • 速度:在RTX4090和RTX3090上,针对headdim=64和headdim=128,以及有无Causal Mask的配置,对SageAttention与基线方法(Torch, xformers, FlashAttention2)的速度进行了比较。

    • 在RTX4090上(图6和图7),SageAttention的峰值性能达到341 TOPS,平均比FlashAttention2快2倍,比xformers快2.9倍。
    • 在RTX3090上(图8和图9)也表现出类似的加速性能。
  • 准确性:使用服从正态分布的(Q, K, V)数据,对四种SageAttention实现与全精度注意力的数值误差进行了评估(表9)。结果显示,所有实现的误差都非常小。特别是SAGEAttn-T和SAGEAttn-B,它们达到了100%的余弦相似度,且RMSE在$10^{-4}$级别。

图6:SageAttention与基线的速度比较(RTX4090, headdim=64)。
图6:SageAttention与基线的速度比较(RTX4090, headdim=64)。
图7:SageAttention与基线的速度比较(RTX4090, headdim=128)。
图7:SageAttention与基线的速度比较(RTX4090, headdim=128)。
图8:SageAttention与基线的速度比较(RTX3090, headdim=64)。
图8:SageAttention与基线的速度比较(RTX3090, headdim=64)。
图9:SageAttention与基线的速度比较(RTX3090, headdim=128)。
图9:SageAttention与基线的速度比较(RTX3090, headdim=128)。

表9:SageAttention核的准确性。
表9:SageAttention核的准确性。

5.3 端到端性能

  • 实际加速比:在RTX4090上测量了SageAttention与原始注意力在Unidiffuser、UltraPixel、CogvideoX、Llama2和TIMM模型上的实际运行速度。如表7所示,SageAttention在所有模型上均优于原始注意力,平均实现了2.83倍的加速。
  • 指标损失:评估了使用SageAttention与全精度注意力在各种模型上的端到端指标。如表8所示,SageAttention在所有模型上的性能与全精度注意力相当。具体来说,在Llama2、CogvideoX、UltraPixel和Unidiffuser上,平均性能下降仅为0.2%。在TIMM模型上,SageAttention的性能甚至超过了全精度注意力。

表7:SageAttention在RTX4090上的实际加速比。
表7:SageAttention在RTX4090上的实际加速比。

表8:文本、图像和视频生成模型的端到端指标损失。
表8:文本、图像和视频生成模型的端到端指标损失。

5.4 消融实验

  • 平滑K的开销:在真实模型中,平滑K操作对注意力速度的影响极小,性能下降不到0.2%(表10)。
  • 自适应量化的好处:对比了仅使用SAGEAttn-T和采用自适应策略(SageAttention)的性能差异。在CogvideoX和Llama2上,自适应策略在不损失任何指标的情况下,将注意力的速度提升了11.7%(表11)。

表10:平滑K的开销。
表10:平滑K的开销。

表11:自适应量化的好处。
表11:自适应量化的好处。

A5 结论

本文介绍了SageAttention,一种高效且精确的注意力机制INT8量化方法。通过三大核心技术,SageAttention实现了性能与精度的双重提升:
1. K矩阵平滑化:提出了一种平滑矩阵K的方法,以不到0.2%的速度开销显著提升了量化精度。
2. FP16累加器:在(P, V)的矩阵乘法中采用FP16累加器,同时提升了计算速度和准确性。
3. 自适应量化:通过自适应地选择最优的注意力核,在不牺牲准确性的前提下,进一步将OPS提升了12%。

实验结果表明,SageAttention的性能(OPS)分别比FlashAttention2和xformers高出约2.1倍和2.7倍。在包括语言、图像和视频生成在内的多种模型上的广泛测试证实,本方法在保持端到端指标的同时,实现了显著的加速,为Transformer模型的即插即用式推理加速提供了一个有效的解决方案。

A6 附录

A.3 数据集和模型细节

数据集选择细节。对于UltraPixel和Unidiffuser的图像生成任务,我们选择了COCO 2014val数据集的前256个标注作为提示集。同时,我们使用这256个提示对应的256张图像作为基准图像,来计算FID和sFID。对于CogvideoX,由于模型是在长文本上训练的,我们使用了一个open-sora提示集,其中每个提示都包含超过120个单词。

模型选择细节。在TIMM的实验中,我们使用的具体模型是vit_base_patch16_224.augreg2_in21k_ft_in1k

A.5 与其他方法的比较

与任务特定量化方法的正交性。存在一些针对特定任务的量化方法,例如用于LLM的AWQ、用于文本到图像的Q-diffusion以及用于文本到视频的ViDiT-Q。SageAttention与这些方法是正交的,因为它们主要用于量化线性层。其次,AWQ仅用于压缩LLM的参数,在计算上没有加速效果。Q-diffusion在其论文中未报告加速效果,也未在其官方代码库中提供带加速效果的代码。ViDiT-Q也未在其官方代码库中提供带加速效果的代码。

与AWQ的比较。尽管如此,我们还是将SageAttention与这些工作进行了比较。我们比较了Llama2-7B在WikiText上的困惑度以及在预填充(prefilling)阶段的加速比。结果如表13所示。

表13:SageAttention与AWQ (W4A16) 在Llama2上的比较。
表13:SageAttention与AWQ (W4A16) 在Llama2上的比较。

与Q-diffusion的比较。我们在Unidiffuser上将SageAttention与Q-diffusion (W8A8)进行了比较。结果如表14所示。

表14:SageAttention与Q-diffusion (W8A8) 在Unidiffuser上的比较。
表14:SageAttention与Q-diffusion (W8A8) 在Unidiffuser上的比较。

与VIDIT-Q的比较。我们在CogvideoX上将SageAttention与VIDIT-Q进行了比较,结果如表15所示。由于官方代码库未提供加速代码,我们估算了一个理论最大值:线性层占Cogvideo延迟的24%,W8A8为线性层提供最多4倍加速,因此理论最大端到端加速比为 $100 / (100 - 24 \times 3/4)\% = 22\%$。

表15:SageAttention与VIDIT-Q在CogvideoX上的比较。
表15:SageAttention与VIDIT-Q在CogvideoX上的比较。

A.6 一些见解

Llama2量化稳定性分析及启发。表1显示,Llama2的指标在量化后保持稳定。原因在于Llama2-7B的注意力机制中Q、K和V的分布相对均匀。因此,将Q、K和V量化为INT8或FP8并不会显著影响注意力的准确性。这一见解启发我们,更好地控制模型中的离群值激活可以带来更精确的量化结果。我们相信,SageAttention也可以有效地应用于与Transformer相关的各种应用,如MOE系统、线性层量化、RAG系统、训练优化、异构GPU系统和扩散模型等,这些工作在参考文献中均有提及。

B 基于TORCH ATTENTION的实现

基于Torch Attention的实现速度。FlashAttention是当前最先进且最常用的标准注意力实现;另一种常用的注意力实现是Torch attention【【43】PyTorch Contributors, Torch backend documentation】。我们在表16中报告了基于Torch实现的版本速度。

表16:基于Torch Attention的SageAttention。
表16:基于Torch Attention的SageAttention。

B.1 额外的精度比较

Q·K乘积的精度比较。表17展示了使用不同数据类型进行per-token量化的$Q \cdot K$与全精度$Q \cdot K$相比的精度。该实验使用来自Unidiffuser第24层的Q、K进行。结果显示,将Q、K量化为INT8比使用E4M3和E5M2具有更高的精度。

表17:使用不同类型量化的$Q \cdot K$的数值误差。
表17:使用不同类型量化的$Q \cdot K$的数值误差。

K平滑化对精度的影响。表18显示了在各种模型上,使用和不使用平滑K的不同量化方法的精度。结果表明,平滑K能显著提升精度。

表18:使用或不使用平滑K的量化注意力误差。
表18:使用或不使用平滑K的量化注意力误差。

B.2 可视化结果

可视化结果对比。图10展示了使用全精度注意力和SageAttention的UltraPixel生成的高分辨率图像(2560x1536)。可以看出,SageAttention在高质量和高细节图像上与全精度版本相匹配。图11展示了使用不同精度由Open-Sora【【86】Zheng et al., 2024c, Open-sora: Democratizing efficient video production for all】生成的视频(720x1280)。SageAttention生成的视频与全精度版本完全相同。图12、图13和图14展示了在UltraPixel、Unidiffuser和CogvideoX上更多的可视化比较结果。

图10:UltraPixel的一个图像生成示例。
图10:UltraPixel的一个图像生成示例。
图11:Open-Sora的一个视频生成示例。
图11:Open-Sora的一个视频生成示例。
图12:更多UltraPixel的图像生成示例,其中prompt1=“两只狗在浴室的马桶旁站着抬头看”,prompt2=“一辆灰色自行车锁在一些金属门上”,prompt3=“一辆汽车在高速公路上行驶的图像”,prompt4=“一只在马桶盖上看起来很不安的猫”。
图12:更多UltraPixel的图像生成示例,其中prompt1=“两只狗在浴室的马桶旁站着抬头看”,prompt2=“一辆灰色自行车锁在一些金属门上”,prompt3=“一辆汽车在高速公路上行驶的图像”,prompt4=“一只在马桶盖上看起来很不安的猫”。
图13:更多Unidiffuser的图像生成示例,其中prompt1=“喜马拉雅山的美丽景色”,prompt2=“海底的一头大象”,prompt3=“英式乡村花园设计”,prompt4=“科罗拉多州杜兰戈的一列古老的红色电气火车”。
图13:更多Unidiffuser的图像生成示例,其中prompt1=“喜马拉雅山的美丽景色”,prompt2=“海底的一头大象”,prompt3=“英式乡村花园设计”,prompt4=“科罗拉多州杜兰戈的一列古老的红色电气火车”。
图14:更多CogvideoX的图像生成示例。有关提示和完整视频的更多详细信息,请参阅 https://anonymous.4open.science/r/image_video_examples-3E44/README.md 。
图14:更多CogvideoX的图像生成示例。有关提示和完整视频的更多详细信息,请参阅 https://anonymous.4open.science/r/image_video_examples-3E44/README.md 。

B.3 RTX3090上的实际加速比

RTX3090上的实际加速比。我们进一步在RTX3090上测量了SageAttention和原始Attention在Unidiffuser、UltraPixel、CogvideoX、Llama2和TIMM上的实际速度。表19(原文中为Table 7,但根据内容应为新表)显示SageAttention在所有模型上均优于原始注意力。具体来说,与原始Attentions相比,SageAttention平均实现了2.7倍的加速。

表19:RTX3090上的实际加速比比较。
表19:RTX3090上的实际加速比比较。