DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads
DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads
作者/机构: Guangxuan Xiao (MIT), Jiaming Tang (MIT), Jingwei Zuo (Tsinghua University), Junxian Guo (MIT, SJTU), Shang Yang (MIT), Haotian Tang (MIT), Yao Fu (University of Edinburgh), Song Han (MIT, NVIDIA)
A1 主要贡献
部署长上下文大型语言模型(LLM)至关重要,但也带来了巨大的计算和内存挑战。在所有注意力头中缓存所有键(Key)和值(Value)状态(KV缓存)会消耗大量内存。现有的KV缓存剪枝方法要么损害了LLM的长上下文能力,要么效率提升有限。
核心问题与研究目标:
长上下文推理面临的主要问题是,随着序列长度的增加,全注意力机制导致解码延迟呈线性增长,预填充延迟呈二次方增长,而KV缓存导致的内存使用也呈线性增长。例如,对于Llama-3-8B模型,使用FP16 KV缓存处理100万个词元需要至少137GB内存,超出了单个80GB GPU的容量。现有方法,如架构修改(GQA)、线性注意力、近似注意力(H2O, StreamingLLM)和系统级优化(FlashAttention),都存在各自的局限性,无法同时解决内存、计算和长上下文准确性问题。
核心观察与创新点:
本文提出了一个关键观察:LLM中的注意力头可以分为两种不同类型,即检索头(Retrieval Heads)和流式头(Streaming Heads)。
1. 检索头:只占一小部分,对于处理长上下文至关重要,需要在所有词元上进行完全注意力计算。
2. 流式头:占大多数,主要关注最近的词元和“注意力池”(attention sinks),不需要完全注意力,可以使用一个精简的KV缓存。
基于这一观察,本文提出了DuoAttention框架,该框架的核心创新点如下:
1. 差异化KV缓存策略:仅对检索头使用完整的KV缓存,而对流式头使用一个轻量级的、固定长度的KV缓存(仅包含注意力池和最近的词元)。这显著减少了LLM在解码和预填充阶段的内存占用和延迟,同时不损害其长上下文处理能力。
2. 基于优化的检索头识别:提出了一种轻量级的、基于优化的算法,使用合成数据来准确识别检索头。与依赖注意力模式分析的现有方法不同,DuoAttention直接测量因词元丢弃引起的输出偏差,从而实现了更高的压缩率和部署效率。
3. 高效率与兼容性:DuoAttention的设计简洁高效,易于集成。每个Transformer层有两个KV缓存:一个用于检索头的完整缓存和一个用于流式头的恒定大小缓存。该方法与GQA和量化等重要优化技术完全兼容。
主要成果:
- 内存减少:对于多头注意力(MHA)模型最高减少2.55倍,对于分组查询注意力(GQA)模型最高减少1.67倍。
- 解码加速:MHA模型最高加速2.18倍,GQA模型最高加速1.50倍。
- 预填充加速:MHA模型最高加速1.73倍,GQA模型最高加速1.63倍。
- 超长上下文能力:与8位权重和4位KV缓存量化结合后,DuoAttention能使Llama-3-8B模型在单张A100 GPU上处理高达330万的上下文长度,容量比标准全注意力FP16部署增加了6.4倍。
图 1: Llama-2-7B模型中句子“The best fruit is orange. What is the best fruit? Orange.”的注意力图可视化,展示了检索头(如第15层,头12)和流式头(如第10层,头1)的鲜明角色。左侧,检索头捕获上下文相关词元,如“best”、“fruit”和“orange”,这对于处理长上下文信息至关重要,因此需要完整的KV缓存。中间,流式头主要关注初始和最近的词元,不强调过去的上下文相关性。右侧,展示了将注意力限制在注意力池和最近词元对长上下文密钥检索准确率的影响:修改检索头会严重损害性能,而限制流式头的影响则微乎其微。
图 2: DuoAttention概览:(1)在检索头识别阶段,我们为每个注意力头分配一个可训练的门控值α,它混合了全注意力和流式注意力的输出。训练目标是优化这些值,以最小化与全注意力模型输出的偏差,同时施加正则化损失以鼓励较低的门控值。这个训练阶段非常高效,只需要门控值是可训练的——所有其他模型参数都保持冻结——因此可以在一个8 GPU节点上几小时内完成。(2)在部署期间,这些门控值被二值化,以根据阈值τ将头分类为检索头或流式头。门控值高于阈值的检索头使用全注意力,缓存所有词元的KV对。相比之下,流式头只缓存最近词元和注意力池的KV对。
A3 背景知识/关键Observation/设计原则
2.1 检索头与流式头
检索头。在基于Transformer的LLM中,注意力头表现出独特且一致的模式,反映了它们的专门功能【索引10, What does BERT look at? an analysis of BERT’s attention, 2019, ACL】、【索引76, Efficient streaming language models with attention sinks, 2023, arXiv】、【索引75, Retrieval head mechanistically explains long-context factuality, 2024】。图1使用句子“The best fruit is orange. What is the best fruit? Orange.”可视化了Llama-2-7B-32K-Instruct模型中的两种注意力头。左图展示了一个在解码时强调相关词元的注意力头;例如,在解码第二个“best fruit”时会突出第一个“best fruit”,在推断第二个“orange”时会突出第一个“orange”。我们将这些注意力头称为检索头(Retrieval Heads),它们对于上下文处理至关重要,因为它们能捕获上下文相关的词元。压缩检索头的KV缓存会导致关键上下文信息的丢失,因此它们需要在所有词元上进行全注意力计算。
流式头。与检索头相反,图1中间面板所描绘的注意力头主要关注最近的词元和注意力池(attention sinks)【索引76, Efficient streaming language models with attention sinks, 2023, arXiv】,而不会突出显示上下文中早期的相关词元。我们将这些头称为流式头(Streaming Heads)。对流式头的KV缓存进行压缩是可行的,因为丢弃未被关注的中间词元不会显著改变注意力输出。因此,可以通过仅保留注意力池和最近词元的KV状态来优化流式头,而不会影响模型处理长上下文的能力。
词元剪枝对检索头和流式头的影响。图1的右侧面板展示了一个初步的密钥检索实验,结果表明,当检索头的KV缓存中的中间词元被剪枝(即替换为流式注意力)时,模型的性能会显著下降。相比之下,为流式头移除中间词元对密钥检索的准确性没有显著影响。这一观察表明,我们可以提高计算效率而不牺牲模型的长上下文能力:通过为流式头丢弃中间词元,同时为检索头保留全注意力,我们将流式头的内存需求降低到$O(1)$,从而提高了处理长上下文的效率。
A2 方法细节
2.2 基于优化的检索头识别
检索头的定义。第2.1节定性地定义了检索头和流式头,但为了精确识别,我们需要一个具体和量化的定义。在本文中,我们将“检索头”定义为那些:当被限制只关注最近词元和注意力池时,会显著改变模型输出的注意力头。我们使用此标准来区分检索头和流式头。请注意,这个定义与现有工作【索引15, Model tells you what to discard: Adaptive KV cache compression for LLMs, 2024, ICLR】、【索引75, Retrieval head mechanistically explains long-context factuality, 2024】、【索引69, Razorattention: Efficient kv cache compression through retrieval heads, 2024a, https: //http://arxiv.org/abs/2407.15891。
基于优化的识别。我们采用一种基于优化的方法来识别检索头,灵感来自先前在CNN滤波器剪枝方面的工作【索引53, Learning efficient convolutional networks through network slimming, 2017, ICCV】,如图2所示。首先,我们为LLM中的每个键值(KV)头分配一个门控值$α_{i,j}$。这个值直观地表示第i层中第j个KV头在处理长上下文信息方面的重要性。请注意,在使用GQA的模型中,一个KV头可以与多个注意力头相关联,我们的方法考虑了对整组注意力头进行KV缓存压缩。
优化过程。我们基于优化的识别方法直接评估了对每个KV头使用仅包含注意力池和最近词元的KV缓存进行压缩的影响。我们首先将每个头的门控值$α_{i,j} \in [0, 1]$初始化为1,假设所有头最初都作为检索头。然后优化这些门控值,而LLM的参数保持固定,将可训练参数的数量限制在$N \times H$,从而防止对模型原有能力的影响。
前向传播计算。在前向传播过程中,我们结合了每个KV头的全注意力和流式注意力(仅关注注意力池和最近词元)的输出,使用门控值作为混合权重:
$$ \text{attn}_{i,j} = \alpha_{i,j} \cdot \text{full\_attn} + (1 - \alpha_{i,j}) \cdot \text{streaming\_attn} $$其中,注意力的计算定义如下:
$$\begin{aligned} \begin{aligned} \text{full\_attn} &= \text{softmax}(\boldsymbol{QK}^{T} \odot \boldsymbol{M}_{\text{causal}})\boldsymbol{V}, \\ \text{streaming\_attn} &= \text{softmax}(\boldsymbol{QK}^{T} \odot \boldsymbol{M}_{\text{streaming}})\boldsymbol{V}. \end{aligned} \end{aligned}$$其中,$M_{causal}$是因果注意力掩码(一个下三角矩阵),而$M_{streaming}$代表一个Λ形掩码【索引22, LM-Infinite: Simple on-the-fly length generalization for large language models, 2023】、【索引76, Efficient streaming language models with attention sinks, 2023, arXiv】,它只关注最近和初始的词元。
用于识别检索头的合成数据集。然而,仅仅依赖自然语言建模目标不足以识别检索头,因为自然文本中需要跨越长距离进行推断的监督信号是稀疏的,而且大多数词元可以使用局部上下文来推断。为了解决这个问题,我们设计了一个专门旨在增强模型长上下文检索能力的合成数据集,使我们能够有效地识别哪些KV头可以在不损害模型性能的情况下被压缩。如图3所示,我们通过在一个非常长的上下文中十个随机位置嵌入十个随机生成的s个词元的密钥序列(实验中s=32)来创建一个密钥检索数据集。然后,模型被要求在上下文末尾回忆这十个序列。
图 3: 用于识别检索头的合成数据集示例。我们在长文本中嵌入十个32词的密钥,并要求模型回忆这些密钥。蒸馏损失仅在这些密钥上计算。
图 4: 四个LLM的优化后门控值。Llama-2-7B使用MHA,每层32个头,而Mistral和Llama-3模型使用GQA,每层8个头。检索头得分较高。与GQA模型相比,MHA模型中检索头的比例较低。
训练与损失函数。我们优化蒸馏损失,即全注意力模型的最后隐藏状态($H_{full}$)与使用DuoAttention模型的隐藏状态($H_{mixed}$)之间的L2差异,仅关注整个输入(共T个词元)中最后的l个密钥词元:
$$\mathcal{L}_{\text {distill }}=\frac{1}{N} \sum_{i=1}^{N} \sum_{j=T-l+1}^{T}\left(\boldsymbol{H}_{\text {full }}^{(i)}[j]-\boldsymbol{H}_{\text {mixed }}^{(i)}[j]\right)^{2}$$正则化。我们的合成数据集确保了每个监督信号都与最终的压缩策略相关,使得该过程在信息检索准确性方面是无损的。事实证明,这比单独使用自然语言建模更有效(见第13节的消融研究)。我们使用L1正则化项(即Lasso【索引71, Regression shrinkage and selection via the lasso, 1996, Journal of the Royal Statistical Society (Series B)】)来鼓励门控值的稀疏性:
$$\mathcal{L}_{\mathrm{reg}}=\sum_{i=1}^{L} \sum_{j=1}^{H}\left|\alpha_{i, j}\right|.$$最终损失与训练。最终的训练损失是蒸馏损失和正则化损失的组合,由一个超参数λ加权,我们在实验中将其设置为0.05:
$$\mathcal{L}=\mathcal{L}_{\text {distill }}+\lambda \mathcal{L}_{\text {reg. }}$$由于可训练参数的总数只有数千个浮点数,这个优化过程相当快,只需要2000步。我们论文中的所有训练实验都可以在8台NVIDIA A100 GPU服务器上进行。
2.3 使用DUOATTENTION部署LLM
二值化注意力实现。在推理时,我们仅对指定的检索头应用全注意力,这些检索头是使用训练阶段优化的门控值(如图4所示)识别的。我们根据一个阈值τ对每个头的注意力策略进行二值化,该阈值由指定的稀疏性分位数确定,以区分检索头和流式头:
$$\begin{aligned} \operatorname{attn}_{i,j}=\begin{cases}\texttt{full\_attn} & \text{if } \alpha_{i,j}>\tau \\ \texttt{streaming\_attn} & \text{otherwise}\end{cases} \end{aligned}$$重排注意力头。在部署之前,我们通过根据注意力头的分配重新排序查询(Query)、键(Key)和值(Value)投影权重的输出通道来预处理模型。这种重排将检索头和流式头分组到两个不同的、连续的簇中,从而在管理层内这两种类型的头的KV缓存时,能够进行高效的切片和连接操作,而不是依赖于分散和收集操作。
解码。如图5所示,在解码期间,我们为LLM中的每一层分配两个KV缓存:一个用于检索头,存储所有过去的键和值;另一个用于流式头,仅存储注意力池和最近的词元,保持恒定大小。当处理一个新词元时,其查询、键和值向量会沿着头维度被分割,以计算检索头的全注意力和流式头的流式注意力。然后,结果会沿着头维度被拼接起来,用于输出投影。
分块预填充。我们使用FlashAttention-2【索引11, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2023】来预填充检索头和流式头的KV缓存。在长上下文LLM中,分块预填充是一种常见做法【索引2, Sarathi: Efficient llm inference by piggybacking decodes with chunked prefills, 2023, https://arxiv.org/abs/2308.16369】、【索 引45, Efficient memory management for large language model serving with pagedattention, 2023】,它将提示(prompt)划分为固定长度的块来预填充KV缓存。该技术通过将线性层中的峰值中间激活大小从序列长度降低到块大小,显著减少了峰值内存使用(见表10)。DuoAttention与分块预填充完全兼容,并且DuoAttention中流式头的预填充可以实现线性的时间和恒定的内存复杂度,无需专门的内核。如图5所示,一旦计算完一层的KV,流式头的KV缓存会立即被修剪,只保留注意力池和最近的词元。下一块输入的词元在预填充期间将只关注恒定数量的上下文词元。设L为序列长度,K为块大小。流式头的预填充时间复杂度从$O(L^2)$优化到$O(LK)$,内存复杂度从$O(L)$降低到$O(K)$。
图 5: DuoAttention中的解码(左)和分块预填充(右)过程:(1)检索头的KV缓存存储所有词元,而流式头的KV缓存仅保留最近的词元和注意力池,确保了恒定的内存使用。(2)DuoAttention的流式头在一个16词元序列上的分块预填充过程,其中有1个注意力池,2个最近词元,块大小为4。DuoAttention的流式头在长序列预填充期间具有线性的时间和恒定的内存复杂度。
批处理操作的适用性。值得注意的是,DuoAttention的设计非常适合批处理操作,这可以在具有大批量大小的服务场景中进一步提高LLM的效率。
A4 实验环境
-
模型:
- Llama-2-7B-chat 【索引74, Llama 2: Open foundation and fine-tuned chat models, 2023b, arXiv】
- Llama-2-7B-32K-Instruct 【索引72, Llama-2-7b-32k-instruct — and fine-tuning for llama-2 models with together api, 2023, https://together.ai/blog/llama-2-7b-32k-instruct 】
- Llama-3-8B-Instruct & Llama-3-70B-Instruct 【索引13, The llama 3 herd of models, 2024, https://arxiv.org/abs/2407.21783 】
- Llama-3-8B-Instruct-Gradient-1048k
- Mistral-7B-v0.2-Instruct 【索引30, Mistral 7b, 2023】
-
数据集:
- 长上下文评估: Needle-in-a-Haystack (NIAH) 【索引32, Llmtest_needleinahaystack: Doing simple retrieval from llm models at various context lengths to measure accuracy, 2024, https://github.com/gkamradt/LLMTest_ NeedleInAHaystack】 和 LongBench 【索引5, Longbench: A bilingual, multitask benchmark for long context understanding, 2023, arXiv】。
- 短上下文评估: MMLU 【索引23, Measuring massive multitask language understanding, 2021, ICLR】、MBPP 【索引4, Program synthesis with large language models, 2021, arXiv】 和 MT-Bench 【索引80, Judging llm-as-a-judge with mt-bench and chatbot arena, 2023】。
- 检索头识别: 使用BookSum 【索引44, Booksum: A collection of datasets for long-form narrative summarization, 2021】数据集构建的合成密钥检索数据集。
-
硬件配置:
- NVIDIA A100 GPU。效率测试在单卡上进行,训练和长上下文基准测试在8卡服务器上进行。
-
软件配置:
- 实现: PyTorch 【索引59, PyTorch: An imperative style, high-performance deep learning library, 2019, NeurIPS】。
- 依赖库: 使用FlashInfer 【索引77, Cascade inference: Memory bandwidth efficient shared prefix batch decoding, 2024, https: //http://flashinfer.ai/2024/01/08/cascade-inference.html。
- 优化器: AdamW 【索引33, Adam: A method for stochastic optimization, 2015, ICLR】。
-
基线方法:
- H2O 【索引79, H2o: Heavyhitter oracle for efficient generative inference of large language models, 2023b】
- TOVA 【索引58, Transformers are multi-state rnns, 2024】
- FastGen 【索引15, Model tells you what to discard: Adaptive KV cache compression for LLMs, 2024, ICLR】
- StreamingLLM 【索引76, Efficient streaming language models with attention sinks, 2023b, arXiv】
A4 实验结果
长上下文基准测试 (Long-Context Benchmarks)
- 实验内容: 在Needle-in-a-Haystack (NIAH) 和 LongBench 基准上评估DuoAttention及基线方法的性能。DuoAttention为Llama-2-7B-32K设置25%的检索头比例,为Llama-3-8B-1048k设置50%的比例。
- 实验结果:
- NIAH (图6): DuoAttention在所有序列深度上均表现出色,能够有效处理长达1048K的词元。相比之下,H2O、TOVA、StreamingLLM和FastGen等基线方法由于丢弃了包含关键信息的KV缓存,无法在长序列中检索到正确答案。
- LongBench (图7, 附录表3, 4, 5, 6): DuoAttention在KV缓存预算和准确性之间展现了更优的权衡。在多数任务上,DuoAttention使用25% (MHA) 或 50% (GQA) 的KV缓存预算即可达到与全注意力相当的性能,并持续优于其他基线方法。
图 6: DuoAttention在Needle-in-a-Haystack基准测试上提供了与全注意力相当的准确性,MHA模型使用25%的全注意力比例,GQA模型使用50%的全注意力比例。
图 7: DuoAttention在LongBench基准测试上提供了更好的KV预算与准确性权衡。
短上下文基准测试 (Short-Context Benchmarks)
- 实验内容: 在MMLU、MBPP和MT-Bench上评估DuoAttention,以确保其不损害模型在短上下文任务上的性能。
- 实验结果 (图8, 表1): 在相同的KV缓存预算下,DuoAttention在Llama-2-7B, Llama-3-8B, 和 Llama-3-70B等多个模型上始终优于所有基线方法。使用50%的KV缓存预算时,DuoAttention在大多数基准测试上实现了接近无损的性能,表明它保留了模型的原始能力。
图 8: 短上下文基准测试结果。
表 1: Llama-3-70B在短上下文基准上的结果。
效率测试 (Efficiency Results)
- 实验内容: 在单张NVIDIA A100 GPU上评估DuoAttention在Llama-2-7B(25%检索头)和Llama-3-8B(50%检索头)上的解码和预填充延迟与内存使用。
- 实验结果:
- 解码效率 (图9, 11): DuoAttention的解码速度和内存占用随上下文长度增长的斜率远低于全注意力。随着上下文长度增加,内存减少和速度提升接近检索头比例的倒数。在A100上,MHA模型实现了最高2.55倍的内存减少和2.18倍的延迟降低,GQA模型则为1.67倍和1.50倍。
- 预填充效率 (图10): DuoAttention显著降低了预填充的延迟和内存,且预填充块越小,节省效果越明显。MHA模型最高实现1.73倍的延迟降低和2.38倍的内存减少,GQA模型则为1.63倍和1.53倍。
- 与量化结合 (图12): 结合8位权重和4位KV缓存量化,DuoAttention能够让Llama-3-8B模型在单张A100-80G GPU上容纳高达330万个词元,容量相比朴素的BF16全注意力部署提升了6.4倍。
图 9: DuoAttention与全注意力在不同上下文大小下的每词元解码延迟和内存使用对比。DuoAttention对Llama-2-7B (MHA) 使用25%的检索头比例,对Llama-3-8B (GQA) 使用50%。DuoAttention在MHA模型上实现了高达2.45倍的内存减少,在GQA模型上为1.65倍,同时MHA模型的延迟降低高达2.13倍,GQA模型为1.5倍。这些降低随着上下文长度的增加而趋近于检索头比例的倒数。内存溢出(OOM)的结果是通过测量数据线性外推的。
图 10: DuoAttention与全注意力在不同预填充块大小下的预填充延迟和内存使用对比。DuoAttention对Llama-2-7B (MHA) 使用25%的检索头比例,预填充100K词元的上下文;对Llama-3-8B (GQA) 使用50%的比例,预填充320K词元的上下文。随着预填充块大小的减小,DuoAttention在MHA模型上实现了高达1.73倍的延迟降低,在GQA模型上为1.63倍,内存减少在MHA模型上高达2.38倍,在GQA模型上高达1.53倍。
图 11: DuoAttention在固定上下文长度下,解码内存和延迟与KV预算的关系。当检索头的比例降低时,内存和延迟呈线性减少。DuoAttention在MHA模型上实现了高达2.55倍的内存减少,在GQA模型上为1.67倍,同时MHA模型的延迟降低高达2.18倍,GQA模型为1.50倍。
图 12: 结合8位权重和4位KV缓存量化,DuoAttention可以使Llama-3-8B模型在单张A100-80G GPU上容纳330万个词元。
消融研究 (Ablation Studies)
- 实验内容: 在Mistral-7B-Instruct-v0.2上进行消融研究,使用密钥检索和MMLU数据集。
- 实验结果 (图13):
- 识别方法对比: 本文提出的基于优化的方法(使用合成数据)在识别检索头方面显著优于基于注意力分析的方法和基于语言建模的方法。
- 优化配置: 在优化阶段,结合注意力池和最近词元注意力对于有效识别检索头至关重要,单独使用任一者均效果不佳。
- 部署配置: 在部署阶段,流式头的性能在16个注意力池词元和64个最近词元时达到饱和,继续增加带来的收益微乎其微。
图 13: 消融研究:(1)检索头识别方法的比较,显示了我们基于优化的合成数据方法相对于注意力分析和语言建模的优越性。(2)对起始和最近词元大小的分析表明,结合注意力池和最近词元注意力能最优化地识别检索头。(3)部署性能表明,16个注意力池和64个最近词元是最佳配置,超过这些值后收益甚微。
A7 补充细节
4 相关工作
引言。已有多种方法被开发出来以扩展LLM并提高其处理长上下文的效率。这些方法可分为四类:优化模型架构、使用近似注意力机制、应用KV缓存量化和系统级优化。
模型架构。多查询注意力(MQA)【索引64, Fast transformer decoding: One write-head is all you need, 2019】和分组查询注意力(GQA)【索引3, Gqa: Training generalized multi-query transformer models from multi-head checkpoints, 2023】通过在查询头之间共享KV头来减小键值(KV)缓存的大小。然而,这些方法需要使用特定架构进行预训练,并且不会降低计算成本。线性注意力Transformer【索引17, Mamba: Linear-time sequence modeling with selective state spaces, 2023】减少了内存使用,但在需要长上下文处理的任务上往往表现不佳。
近似注意力。诸如稀疏Transformer【索引9, Generating long sequences with sparse transformers, 2019】和LongFormer【索引6, Longformer: The long-document transformer, 2020, arXiv:2004.05150】之类的方法使用局部或块状注意力模式来降低计算复杂性。BigBird【索引78, Big Bird: Transformers for longer sequences, 2020, NeurIPS】通过结合局部和全局注意力实现了线性复杂度,但许多这些方法需要自定义GPU内核或重新训练,限制了它们的实用性。H2O【索引79, H2o: Heavyhitter oracle for efficient generative inference of large language models, 2023b】和TOVA【索引58, Transformers are multi-state rnns, 2024】通过根据查询模式丢弃词元来简化注意力。StreamingLLM【索引76, Efficient streaming language models with attention sinks, 2023b, arXiv】识别出“注意力池”,并提出始终保留初始和最近的词元,以维持恒定的解码延迟和内存使用,使模型能够处理比预训练序列长度多得多的输入词元。FastGen【索引15, Model tells you what to discard: Adaptive KV cache compression for LLMs, 2024, ICLR】通过分析注意力头来在解码过程中丢弃词元。然而,我们的实验表明,这些方法会降低LLM的长上下文能力。此外,这些方法无法降低长上下文LLM的预填充成本。
KV缓存量化。诸如8位和4位量化【索引54, Kivi: A tuning-free asymmetric 2bit quantization for kv cache, 2024, arXiv】、【索引26, Kvquant: Towards 10 million context length llm inference with kv cache quantization, 2024】、【索引50, Qserve: W4a8kv4 quantization and system co-design for efficient llm serving, 2024, arXiv】等技术减小了KV缓存的大小,但它们没有解决注意力内核的计算开销问题。这些方法与DuoAttention是互补的,可以一起使用以进一步减少内存使用。
系统优化。vLLM【索引45, Efficient memory management for large language model serving with pagedattention, 2023】和FlashAttention【索引12, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022, arXiv:2205.14135】、【索引11, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2023】通过优化批处理和利用GPU内存层次结构来提高注意力计算效率。FlashDecoding【索引25, Flashdecoding++: Faster large language model inference on gpus, 2024】和RingAttention【索引51, Ring attention with blockwise transformers for nearinfinite context, 2023a】在解码速度和序列级并行方面引入了进一步的改进。虽然这些方法增强了计算性能,但它们没有解决KV缓存大小的缩减问题,使其成为DuoAttention的补充,以实现额外的速度和内存优化。
近期工作。最近的一些工作与DuoAttention有相似的想法。Wu等人【索引75, Retrieval head mechanistically explains long-context factuality, 2024】引入了检索头的概念来解释LLM的长上下文能力。然而,他们的方法并未压缩非检索头的KV缓存,仅关注准确性。MInference【索引31, Minference 1.0: Accelerating pre-filling for long-context llms via dynamic sparse attention, 2024, arXiv】通过使用稀疏注意力模式加速了长上下文LLM的预填充,但没有优化解码过程中的KV缓存存储或延迟。RazorAttention【索引69, Razorattention: Efficient kv cache compression through retrieval heads, 2024a, https: //http://arxiv.org/abs/2407.15891。
A5 结论
我们引入了DuoAttention,这是一个通过区分检索头和流式头来优化LLM内存和计算资源的框架。通过仅对检索头应用完整的KV缓存,DuoAttention在长上下文应用中显著减少了解码和预填充的内存使用和延迟。它在MHA模型上实现了高达2.55倍的内存减少,GQA模型上为1.67倍;解码速度在MHA模型上提升高达2.18倍,GQA模型上为1.50倍;预填充在MHA模型上加速高达1.73倍,GQA模型上为1.63倍,而与全注意力相比,准确性损失极小。当与量化结合时,DuoAttention进一步提升了KV缓存容量,在单张A100 GPU上支持高达330万的上下文词元。DuoAttention为LLM处理百万级词元上下文铺平了道路。
A6 附录
A.1 实验细节
训练配置。我们使用PyTorch中的FSDP2进行模型训练,并利用DeepSpeed Ulysses【索引27, Deepspeed ulysses: System optimizations for enabling training of extreme long sequence transformer models, 2023, https://arxiv.org/abs/2309. 14509】的序列并行来支持长序列。在训练过程中,我们使用了一种高效的Λ形注意力的块稀疏近似来进行流式注意力计算,该实现来自Guo等人【索引18, Block Sparse Attention, 2024, https://github.com/mit-han-lab/Block-Sparse-Attention】,如图14所示。不同模型的最大序列长度详见表2 。
表 2: 训练超参数。
图 14: Λ形注意力的块稀疏近似。
A.3 H2O和TOVA在长上下文基准上的实现
实现细节。H2O和TOVA算法的原始设计在预填充阶段与FlashAttention不兼容,因为它们依赖于注意力分数来执行词元驱逐。由于FlashAttention中的注意力分数从未被物化,这些算法无法用于预填充,这是它们的主要缺陷之一。因此,无法在像“大海捞针”和LongBench这样的长上下文设置中评估这些算法,因为它们会在上下文预填充期间导致内存溢出(OOM)。
修改后的评估方案。为了与这些策略进行比较,我们修改了算法:在预填充阶段,我们使用FlashAttention进行精确计算。在解码阶段,我们根据生成词元对上下文词元的注意力分数来执行词元驱逐。这种修改比原始设计提高了性能,因为预填充是精确的,词元驱逐仅在解码期间发生。在极端情况下,如果答案中只有一个生成的词元(例如多项选择任务),我们实现的H2O和TOVA将与全注意力完全相同,这与它们的真实准确性不同。为了接近它们的真实性能,我们在长输入基准测试(“大海捞针”和LongBench)中模拟最后50个词元作为生成词元,以便足够长时间地执行它们的词元驱逐策略,我们的算法也同样处理。这种实验设置也被Tang等人【索引70, Quest: Query-aware sparsity for efficient long-context llm inference, 2024b】使用。实验结果表明我们的方法可以通过这个压力测试,而H2O和TOVA不能。
表 3: Llama-3-8B-Instruct-1048K的完整LongBench结果。DuoAttention在大多数数据集上以50%的KV缓存预算取得了最佳性能。
A.4 MISTRAL模型上的NIAH结果
图 15: Mistral-7B-Instructv0.2模型上的NIAH结果。
图 16: Mistral-7B-Instructv0.3模型上的NIAH结果。
A.5 FASTGEN在长上下文基准上的实现
实现细节。由于缺乏FastGen【索引15, Model tells you what to discard: Adaptive KV cache compression for LLMs, 2024, ICLR】算法的官方实现,我们使用一个社区代码库【索引1, Cold compress: A toolkit for benchmarking kv cache compression approaches, 2024, https://www.answer.ai/ posts/2024-08-01-cold-compress.html】来复现它,该代码库被FastGen的官方仓库引用。在FastGen算法中,剪枝率不能直接配置;而是使用恢复率T来控制稀疏性,如FastGen论文所述。为了量化稀疏性,我们计算了所有测试用例的平均KV缓存使用率作为稀疏性的总体度量。对于Llama-2-7B模型,我们将恢复率设置为0.7,确保平均KV缓存预算超过完整KV缓存的25%。同样,对于Llama-3-8B模型,我们将恢复率设置为0.87,确保平均KV缓存预算超过完整KV缓存的50%。
长上下文限制。此外,由于FastGen使用用户提供提示的完整注意力图来分析不同头的类型,这导致了$O(n^2)$的注意力图复杂度。因此,我们无法在长上下文中测试其性能。对于长上下文基准测试,我们使用了8个A100-80G GPU,对于Llama-2-7B模型,序列长度最高达到24k词元,对于Llama-3-8B模型,最高达到32k词元。
性能对比。除了图6中显示的“大海捞针”基准测试外,我们还在LongBench上对两种模型评估了FastGen。然而,由于FastGen的二次内存消耗,我们只报告了在8个A100-80G GPU上使用FastGen能够运行的数据集的结果。如表5和表6所示,DuoAttention在LongBench数据集上能够持续优于FastGen。
表 4: Llama-2-7B-Instruct-32K的完整LongBench结果。DuoAttention在大多数数据集上以25%的KV缓存预算取得了最佳性能。
表 5: 在Llama-3-8B-Instruct-1048K模型上,FastGen与DuoAttention在LongBench子集上的比较。
表 6: 在Llama-2-7B-32K-Instruct模型上,FastGen与DuoAttention在LongBench子集上的比较。
💬 评论讨论
欢迎在这里分享您的想法和见解!