PROMPT CACHE: MODULAR ATTENTION REUSE FOR LOW-LATENCY INFERENCE
作者/机构: In Gim, Guojun Chen, Seung-seob Lee, Nikhil Sarda, Anurag Khandelwal, Lin Zhong
A1 主要贡献
核心问题: 大量大型语言模型(LLM)的提示(prompt)包含频繁重用的文本片段,例如系统消息、提示模板以及用于上下文的文档。这种重叠导致在推理过程中存在大量的冗余计算。
研究目标: 本文旨在通过重用不同LLM提示之间的注意力状态来加速LLM的推理过程,特别是降低生成第一个令牌的时间(Time-to-First-Token, TTFT)延迟。
核心思想与创新点:
1. 提出Prompt Cache: 核心思想是预计算并存储频繁出现的文本片段的注意力状态(键值对),并在这些片段出现在后续用户提示中时直接重用,从而避免重复计算。
2. 模块化注意力状态重用: Prompt Cache将注意力状态的重用从单个提示内部扩展到多个提示之间,实现了模块化的重用。
3. 引入Prompt Markup Language (PML): 为了解决注意力状态的位置依赖性(即文本片段在不同提示中位置不同)和如何有效识别可重用片段的挑战,本文设计了PML。PML通过一个模式(schema)来显式地定义可重用的文本片段(称为“提示模块”),并为它们分配唯一的位置ID。
4. 利用不连续位置ID: 本文的实证研究发现,LLM可以在具有不连续位置ID的注意力状态上正常工作。这一发现使得可以将不同提示模块的注意力状态进行拼接,从而允许用户根据需求灵活选择、组合甚至在运行时更新提示模块。
5. 显著的性能提升: 原型实现表明,Prompt Cache在GPU推理上可将TTFT延迟降低8倍,在CPU推理上可降低60倍,同时保持了输出的准确性,且无需修改模型参数。
A2 背景知识
2.1 自回归令牌生成
自回归生成过程。大型语言模型(LLM)以自回归的方式生成输出令牌。它从一个初始输入(通常称为提示)开始,并基于该提示生成下一个令牌。然后,模型将生成的令牌附加到提示后,并使用这个新的序列来生成再下一个令牌。这个生成过程会持续进行,直到满足某个停止条件,例如生成了预定数量的令牌、生成了一个特殊的序列结束令牌,或者生成的序列达到了满意的连贯性或完整性水平。重要的是,在每一步中,模型都将整个提示和到目前为止已生成的所有令牌作为输入,并重复此过程。
2.2 键值缓存(Key-Value Cache)
KV缓存机制。上述的自回归令牌生成过程由于自注意力机制需要在每一步都对整个输入进行操作,因此会产生巨大的计算开销。为了缓解这个问题,键值(KV)缓存机制【28,Efficiently scaling transformer inference,2022】被广泛使用。该技术在整个自回归令牌生成过程中,只为每个令牌计算一次键(key)和值(value)嵌入。具体来说,假设用户提示是一个由n个令牌组成的序列:$s_1, \dots, s_n$,随后生成的k个令牌为$s_{n+1}, \dots, s_{n+k}$。在朴素的自回归生成中,注意力状态 $\{(k_1, v_1), \dots, (k_{n+k}, v_{n+k})\}$ 在每一步都会被完全重新计算。相比之下,KV缓存首先为输入计算注意力状态,表示为$S_0 = \{(k_i, v_i)|i \le n\}$,并将它们缓存在内存中。这个步骤通常被称为预填充(prefill)阶段。对于随后的每一步$j \le k$,模型会重用缓存的值$S_j = \{(k_i, v_i)|i < n + j\}$来计算新令牌$s_{n+j}$的注意力状态$(k_{n+j}, v_{n+j})$。这种方法显著减少了自注意力所需的计算量。具体而言,每一步的计算量(以矩阵运算的FLOPs衡量)减少了$1/n$倍。操作数从大约$6nd^2 + 4n^2d$减少到$6d^2 + 4nd$,其中d是隐藏维度的大小。每一步之后,新计算的$(k_{n+j}, v_{n+j})$注意力状态会被附加到缓存中以供后续使用。在因果语言模型(占大多数LLM)中,使用KV缓存不会影响模型的准确性,因为位置i的注意力仅基于位于第i个令牌之前的位置上的令牌来计算。
基于KV缓存的后续研究。KV缓存催化了对LLM加速的进一步探索。后续研究要么专注于为KV缓存优化内存管理,如分页注意力(paged attention)【20,Efficient memory management for large language model serving with pagedattention,2023,arXiv preprint】所示,要么致力于修剪多余的KV缓存数据【44,H2o: Heavy-hitter oracle for efficient generative inference of large language models,2023】,或者对其进行压缩【23,Cachegen: Fast context loading for language model applications,2023】。也有一些初步工作探索了跨不同请求重用KV缓存。例如,【10,Attmemo : Accelerating transformers with memoization on big memory systems,2023】根据嵌入相似性度量来重用记忆化的注意力状态。分页注意力也展示了简单的前缀共享,即具有相同前缀的不同提示共享KV缓存。然而,现有方法仅适用于特定场景,而本文研究的是通用LLM提示的注意力重用问题。
2.3 其他低延迟LLM推理方法
正交的优化策略。Prompt Cache引入了一种正交的优化策略,可以增强现有的致力于高效LLM推理的系统。这包括使用多个GPU进行推理的系统【1,Deepspeed- inference: Enabling efficient inference of transformer models at unprecedented scale,2022,SC22】和那些拥有用于softmax注意力分数计算的高性能GPU内核的系统【6,Flashat-´ tention: Fast and memory-efficient exact attention with io-awareness,2022,Advances in Neural Information Processing Systems】。尽管我们目前的重点是在LLM中实现低延迟推理,但Prompt Cache也可以通过减少计算量来使旨在实现高吞吐量的系统受益【36,Flexgen: ´ High-throughput generative inference of large language models with a single GPU,2023,ICML】。
A3 方法细节
3 设计
KV缓存的有效性引出了下一个问题:注意力状态能否在多个推理请求之间重用?我们观察到不同的提示常常有重叠的文本片段。例如,相同的“系统消息”或元提示(metaprompts)经常被插入到提示的开头,以引导LLM产生期望的响应。再比如,在许多法律和医疗领域的LLM应用中【5,Chatlaw: Open-source legal large language model with integrated external knowledge bases,2023; 37,Language models are an effective representation learning technique for electronic health record data,2021,J. Biomed. Informatics; 32,Medbert: pretrained contextualized embeddings on large-scale structured electronic health records for disease prediction,2021,npj Digit. Medicine】,同一组文档常常作为上下文提供给不同的提示。最后,可重用的提示格式,即提示模板,在机器人和工具学习的LLM应用中被普遍使用【7,Palm-e: An embodied multimodal language model,2023; 30,Tool learning with foundation models,2023】,因为大多数任务只是一些常见任务的变体。在本节中,我们描述了我们称为Prompt Cache的方法,它肯定地回答了上述问题。Prompt Cache通过以结构化的方式利用共享片段,通过请求间的注意力状态重用提高计算效率。
3.1 概述
位置依赖性挑战。一个文本片段的注意力状态只有在该片段出现在LLM输入的相同位置时才能被重用。这是因为Transformer架构将唯一的位置嵌入集成到(k, v)注意力状态中。对于使用KV缓存服务单个提示来说,这不是问题,因为在所有步骤中,相同的提示文本都位于相同的位置,即输入的开头。
跨提示重用的两个问题。另一方面,共享的文本片段可能出现在不同提示的不同位置。为了跨提示重用它们的注意力状态,缓存系统必须解决两个问题。首先,它必须允许在文本片段出现在不同提示的不同位置时进行重用。其次,当系统接收到一个新提示时,系统必须能够有效地识别其注意力状态可能已被缓存的文本片段,以便进行重用。
Prompt Cache的解决方案。为了解决这两个问题,我们结合了两种思想。第一种是使用提示标记语言(Prompt Markup Language, PML)来明确提示的结构。如图2所示,PML将可重用的文本片段明确为模块,即提示模块。这不仅解决了上述第二个问题,还为解决第一个问题打开了大门,因为可以为每个提示模块分配唯一的位置ID。我们的第二个思想是我们的实证发现,即LLM可以在具有不连续位置ID的注意力状态上运行。只要令牌的相对位置得以保留,输出质量就不受影响。这意味着我们可以提取不同片段的注意力状态并将它们拼接起来,以构成新的含义。我们利用这一点,使用户能够根据自己的需求选择提示模块,甚至在运行时替换某些含义。
工作流程。Prompt Cache将这两种思想结合起来,工作流程如下。LLM用户使用PML编写他们的提示,意图基于提示模块重用注意力状态。重要的是,他们必须从一个模式(schema)派生出提示,该模式也用PML编写。图2展示了一个基于示例模式的示例提示。当Prompt Cache收到一个提示时,它首先处理其模式并为其提示模块计算注意力状态。它会为该提示中的提示模块以及从同一模式派生的其他提示重用这些状态。
方法细节章节安排。我们在§3.2中详细介绍了PML的设计,重点关注最大化重用机会的技术。我们在§3.3中解释了Prompt Cache如何计算模式中提示模块的注意力状态,以及这可能如何影响输出质量。我们在§3.4中解释了Prompt Cache如何为一个提示的服务重用来自模式的注意力状态。
与局部掩码注意力的关系。Prompt Cache中的模块化KV缓存构建与局部掩码注意力【3,Longformer: The long-document transformer,2020,CoRR; 39,Efficient transformers: A survey,2023,ACM Comput. Surv.】中观察到的近似有相似之处,后者通过设置有限的窗口进行注意力分数计算来优化计算,而不是将其注意力扩展到输入序列中的每个令牌。考虑Prompt Cache中的一个场景,其中每个提示模块都是独立编码的。鉴于注意力状态严格在提示模块的范围内计算,这与一个注意力掩码的设置非常相似,该掩码会屏蔽掉提示模块外部的序列。因此,Prompt Cache所做的近似是将注意力窗口限制在每个提示模块内。我们注意到,使用这样的注意力掩码不一定会降低输出质量,我们将在§5中讨论。在某些情况下,这些掩码甚至可能通过有效过滤掉不相关信息而引入有益的归纳偏置。
3.2 提示标记语言 (PML)
PML的核心特性。接下来我们描述PML的关键特性,它用于定义模式和从模式派生的提示。
3.2.1 模式与提示
定义与关系。模式(schema)是一个定义提示模块并描绘其相对位置和层次结构的文档。每个模式都有一个唯一的标识符(通过name
属性),并使用<module>
标签指定提示模块。未被<module>
标签包围或未指定标识符的文本被视为匿名提示模块,并且总是包含在从该模式构建的提示中。对于LLM用户来说,模式是创建和重用提示模块注意力状态的接口。用户可以从一个模式构建一个提示,使用<prompt>
标签。该标签通过schema
属性指定要使用的模式,列出要导入的提示模块,并添加任何额外的(非缓存的)指令。例如,要从图2中的模式导入模块miami
,可以表示为<miami/>
。Prompt Cache将只计算模式中未指定的文本的注意力状态,例如图2中的“Highlights the surf spots”,并重用导入模块的注意力状态,例如trip-plan
和miami
,从而减少延迟。
3.2.2 通过参数最大化重用
参数化模块。PML允许对提示模块进行参数化,以最大化重用机会。参数是一个带有指定长度的命名占位符,可以出现在模式中提示模块的任何位置。它使用<param>
标签定义,name
和len
属性分别指示其名称和参数的最大令牌数。当一个提示导入该提示模块时,它可以为该参数提供一个值。图2展示了一个参数化提示模块(trip-plan
)的示例,以及一个提示如何包含该提示模块并为其参数(duration
)提供一个值(3 days
)。参数值是不被缓存的。参数化提示模块有两个重要用途。首先,一个提示模块与另一个模块通常只在一些明确定义的地方有所不同。参数允许用户在运行时提供特定的参数来自定义模块,同时仍能从重用中受益。图2通过trip-plan
说明了这种用例。这对于模板化的提示特别有用。其次,参数可用于在模式中提示模块的开头或结尾创建一个“缓冲区”。这个缓冲区允许用户在提示中添加任意文本片段,只要该片段不长于它所替换的参数令牌长度。
3.2.3 其他特性
联合模块(Union modules)。某些提示模块表现出互斥关系。也就是说,在一组模块中,只应选择一个。例如,考虑一个提示,要求LLM根据读者的个人资料(由一个提示模块描述)推荐一本书。可能有多个提示模块,每个模块描述一个读者资料,但提示只能包含其中一个。为了适应这些互斥关系,我们引入了联合(union)的概念。模块的联合使用<union>
标签表示。嵌套在同一个联合中的提示模块共享相同的起始位置ID。联合不仅简化了布局的组织,还节省了用于编码提示模块的位置ID。此外,系统可以利用这种结构进行优化,例如预取。虽然参数化模块和联合看起来相似,但它们在两个方面有所不同。首先,正如我们将在§3.3中展示的,参数和联合模块的编码方式不同。其次,它们服务于不同的目的:参数用于内联修改以最大化模块的重用,而联合模块旨在实现更好的提示结构和更有效的位置ID利用。
嵌套模块(Nested modules)。PML还支持嵌套模块来表达层次化的提示模块。也就是说,一个提示模块可以包含提示模块或联合作为其组件。在提示中,嵌套模块作为模块内的模块被导入,如图8所示。
与LLM特定模板的兼容性。经过指令调优的LLM通常遵循特定的模板来格式化对话。例如,在Llama2中,用户和助手之间的一次交互遵循模板:<s>[INST] user message [/INST] assistant message </s>
。为了减少为不同LLM手动格式化提示模式以匹配此类模板所需的工作量,我们引入了三个专用标签:<system>
用于系统级提示,<user>
用于用户生成的提示,以及<assistant>
用于LLM生成的示例响应。Prompt Cache会动态翻译和编译这些专用标签,以与所使用的LLM的指定提示模板对齐。
3.2.4 从提示程序派生PML
自动化转换。为了简化PML的编写,Prompt Cache可以自动将来自Python等语言的提示程序(prompt programs)【4,Prompting is programming: A query language for large language models,2023,Proceedings of the ACM on Programming Languages; 12,A guidance language for controlling large language models,2023】转换为PML,从而无需手动编写模式。这主要是通过一个Python API实现的,该API将Python函数转换为相应的PML模式。转换过程很直接:if
语句在PML中变成<module>
结构,将条件提示封装在内。当条件评估为真时,相应的模块被激活。if-else
或switch
等选择一的语句被映射到<union>
标签。函数调用被转换为嵌套的提示模块。此外,我们实现了一个装饰器来管理参数,特别是限制最大参数长度。这对应于<param>
中的len
属性。这种Python到PML的编译向用户隐藏了PML的复杂性,并提供了更好的提示可维护性。
3.3 编码模式
编码过程。当第一次需要一个提示模块的注意力状态时,必须计算它们并存储在设备内存中,我们称之为提示模块编码。首先,Prompt Cache从模式中提取提示模块的令牌序列。然后它为每个令牌分配位置ID。起始位置ID由提示模块在模式中的绝对位置决定。例如,如果前面两个提示模块的令牌序列大小分别为50和60,那么该提示模块将被分配一个起始位置ID 110。联合模块是一个例外。由于联合内的提示模块从相同的位置开始,它们的令牌序列大小被视为最大子模块的大小。从提示模块的令牌序列和相应的位置ID,这些信息被传递给LLM来计算(k, v)注意力状态。我们注意到,分配的位置ID不从零开始。这在语义上是可接受的,因为空格不会改变预计算文本的含义。然而,许多现有的transformer位置编码实现,如RoPE,通常需要进行调整以适应不连续的位置ID,我们将在(§ 4.2)中讨论。
参数化模块的编码。对于编码参数化提示模块,我们利用了提示中有空格不影响其语义的思想。参数被替换为预定数量的<unk>
令牌,数量等同于其len
属性值。与这些<unk>
令牌对应的位置ID被记录下来以备将来替换。当这个模块被集成到用户提示中并与相关参数配对时,所提供参数的令牌序列将采用先前与<unk>
令牌关联的位置ID。由此产生的(k, v)注意力状态随后会替换最初为<unk>
令牌分配的注意力状态。我们注意到,新提供的令牌长度可以小于指定的参数长度,因为尾部的空格不会改变语义。
注意力掩码效应。Prompt Cache将注意力分数的计算限制在每个提示模块的范围内,从而掩盖了模块间的注意力状态。这种掩码效应可能会根据模块的语义独立性增强或降低输出质量。对于语义独立的模块,掩码可以减少噪声并提高质量。然而,对于语义相关的模块,它可能会产生相反的效果。因此,每个提示模块应该是自包含的并且在语义上独立于其他模块。一种消除掩码效应的方法是使用我们称之为脚手架(scaffolding)的方法。以额外的内存为代价,我们允许用户指定“脚手架”,这是一组一起编码以共享注意力范围的提示模块,此外还有它们各自的注意力状态。当一个脚手架中的所有提示模块都在一个提示中被导入时,脚手架的注意力状态将覆盖单个的注意力状态。脚手架用额外的内存换取输出的一致性,这对于需要确定性结果的应用可能很有用。
3.4 缓存推理
推理过程。当一个提示被提供给Prompt Cache时,Prompt Cache会解析它以确保与声明的模式对齐。它会验证导入模块的有效性。然后,如图2所示,Prompt Cache从缓存中检索导入的提示模块的(k, v)注意力状态 ( 2 ),为新的文本片段计算这些状态 ( 3 和 4 ),并将它们拼接起来生成整个提示的注意力状态 ( 5 ),从而取代了预填充操作。具体来说,Prompt Cache首先拼接与提示中每个导入的提示模块相对应的KV状态张量。例如,当用户提示使用模块A和B时,拼接后的KV张量被构造为:$(k_C, v_C) = (\text{concat}(k_A, k_B), \text{concat}(v_A, v_B))$。值得注意的是,由于transformer的置换不变性【8,Position infor-¨ mation in transformers: An overview,2022,Computational Linguistics】,拼接的顺序无关紧要。这一步仅需要内存拷贝。然后,Prompt Cache为提示中未缓存的部分计算注意力状态,具体来说,是模式中未定义的令牌序列和参数化提示模块的参数。Prompt Cache首先根据未缓存文本相对于其他已使用提示模块的位置来确定其位置ID。例如,如果文本位于模块A和B之间,它将被分配从A的结束位置开始的位置ID,假设A和B的位置之间存在间隙。参数化提示模块的参数被分配给<unk>
令牌的位置ID。随后,令牌序列和位置ID被聚合,并使用$(k_C, v_C)$作为KV缓存传递给LLM,以计算整个提示的注意力状态。重要的是要注意,生成后续令牌的计算复杂度与KV缓存的计算复杂度保持一致,因为提示模块在初始令牌之后不再使用。实质上,Prompt Cache减少了生成第一个令牌(即time-to-first-token, TTFT)所涉及的延迟。
批量推理中的内存优化。为了更好地利用GPU,提示通常以批处理方式进行服务。源自同一模式的不同提示可能包含相同的提示模块,例如系统提示。这通过减少批处理中KV缓存的冗余,开辟了额外的优化机会。分页注意力【20,Efficient memory management for large language model serving with pagedattention,2023,arXiv preprint】可以通过在不同提示间共享指向同一提示模块的指针来解决这个问题,而不是复制注意力状态。在这里,使用Prompt Cache可以通过允许更多提示并行处理来间接提高系统吞吐量。
4 实现
我们使用HuggingFace transformers库【42,Huggingface’s transformers: State-of-the-art natural language processing,2020】在PyTorch中构建了一个Prompt Cache原型,包含3000行Python代码。我们的目标是与现有的LLM代码库无缝集成并重用其权重。我们实现了Prompt Cache以使用CPU和GPU内存来容纳提示模块,并在这两个平台上对其进行了评估。
4.1 在内存中存储提示模块
存储方案。我们将编码后的提示模块存储在两种类型的内存中:CPU内存(主机DRAM)和GPU内存(HBM)。为了管理跨两种内存类型的张量,我们使用了PyTorch【26,Pytorch: An imperative style, high-performance deep learning library,2019,Advances in Neural Information Processing Systems】的内存分配器。除了简单地将CPU与CPU内存中的提示模块配对,GPU与GPU内存配对之外,我们还使GPU能够访问存储在CPU内存中的提示模块。这是通过在需要时将提示模块从主机复制到设备来实现的。这个过程会产生主机到设备的内存复制开销。尽管如此,它允许GPU利用丰富的CPU内存,这些内存可以扩展到TB级别。正如我们将在§5中展示的,Prompt Cache带来的计算节省远远超过了由内存复制操作引起的延迟。使用GPU暴露了内存容量和延迟之间的权衡:GPU内存速度更快但容量有限,而CPU内存可以轻松扩展但会产生额外的内存复制开销。考虑一个利用CPU和GPU内存的缓存机制似乎是可行的。我们将开发一个包含缓存替换和预取策略的系统留给未来的研究。
4.2 适配Transformer架构
实现调整。实现Prompt Cache需要支持不连续的位置ID(§3.2)。虽然Transformers库目前不提供这些功能,但它们可以通过少量修改来集成。例如,每个LLM需要大约20行额外的代码。我们概述了所需的调整:
* 嵌入表:像BERT【40,Attention is all you need,2023】和GPT-2【31,Improving language understanding by generative pre-training,2018】这样的早期模型使用查找表将位置ID映射到学习到的嵌入或固定的偏置,无需改动。
* RoPE:像Llama2【39,Llama 2: Open foundation and fine-tuned chat models,2023】和Falcon【27,The refinedweb dataset for falcon llm: outperforming curated corpora with web data, and web data only,2023,arXiv preprint】这样的LLM采用RoPE【38,Roformer: Enhanced transformer with rotary position embedding,2021,CoRR】,它在注意力计算中使用旋转矩阵进行位置编码。我们为每个旋转矩阵创建一个查找表,从而能够根据位置ID进行检索。
* ALiBi:在MPT【24,Introducing mpt-7b: A new standard for opensource, commercially usable llms,2023】和Bloom【34,BLOOM: A 176b-parameter open-access multilingual language model,2022,CoRR】等模型中使用,ALiBi【29,Train short, test long: Attention with linear biases enables input length extrapolation,2022,ICLR】在softmax分数计算期间集成一个静态偏置。与RoPE类似,我们设计一个查找表来根据提供的位置ID调整偏置矩阵。
内存分配优化。我们还重写了PyTorch的拼接操作符以实现更高效的内存分配。PyTorch只支持连续的张量,因此,两个张量的拼接总是会导致新的内存分配。Prompt Cache需要拼接提示模块的注意力状态,默认行为会导致冗余的内存分配。我们实现了一个带缓冲的拼接操作符,在拼接张量时重用内存。这一优化改善了Prompt Cache的内存占用并减少了内存分配的开销。
A4 实验环境与结果
实验环境
- 硬件配置:
- CPU平台:
- Intel i9-13900K, 128 GB DDR5 RAM @ 5600 MT/s
- AMD Ryzen 9 7950X, 128 GB DDR4 RAM @ 3600 MT/s
- GPU平台:
- NVIDIA RTX 4090 (与Intel i9-13900K配对)
- NVIDIA A40 和 A100 (NCSA Delta上的虚拟节点,每个节点配备16核AMD EPYC 7763和224 GB RAM)
- CPU平台:
- 模型架构:
- 使用了多种开源LLM,包括Llama2, CodeLlama, MPT, 和 Falcon。
- 选择的模型都适合单个GPU的内存容量(40 GB)。
- 数据集:
- LongBench套件 【2,Longbench: A bilingual, multitask benchmark for long context understanding,2023,CoRR】: 用于评估TTFT改进和输出质量变化。
- 数据集包含4K到10K上下文长度的长数据样本,涵盖6大类21个数据集。
- 任务类型包括多文档问答、摘要和代码补全等。
- 在实验中,LongBench数据集中的文档(如维基页面、新闻文章)被定义为提示模块,而特定于任务的指令则作为非缓存的用户文本。
- 软件配置:
- 代码实现基于Python和PyTorch。
- 依赖于HuggingFace
transformers
库【42,Huggingface’s transformers: State-of-the-art natural language processing,2020】。
实验结果
GPU和CPU上的延迟改进
- 实验内容: 使用Llama 7B模型,在LongBench的8个基准数据集上测量了GPU和CPU上的TTFT(首个令牌生成时间)延迟。
- 实验结果:
- GPU推理延迟 (图3): 在所有数据集和三种NVIDIA GPU(RTX 4090, A40, A100)上,TTFT延迟均显著降低。当提示模块存储在CPU内存中时,延迟降低1.5倍至3倍;当存储在GPU内存中时,延迟降低5倍至10倍。
- CPU推理延迟 (图4): 在Intel CPU上延迟降低高达70倍,在AMD CPU上降低20倍。作者推测这种差异是由于系统内存带宽不同造成的(DDR5 vs DDR4)。CPU推理比GPU推理从Prompt Cache中获益更多,因为CPU的注意力计算延迟要高得多。
- 分析结论: Prompt Cache能显著降低GPU和CPU上的推理延迟,尤其是在资源受限的环境(如边缘设备或GPU资源有限的云服务器)中效果更佳。
Prompt Cache的准确性
- 实验内容: 使用LongBench套件,在三种不同Transformer架构的LLM(Llama2, MPT, Falcon)上评估Prompt Cache对输出质量的影响(未使用scaffolding)。
- 实验结果 (表1): 在所有数据集上,使用Prompt Cache的输出准确性与基线(常规KV缓存)相当。实验采用确定性采样以确保结果的可比性。
- 分析结论: Prompt Cache在显著提升性能的同时,能够保持模型的输出精度。
延迟改进的深入理解
- 实验内容: 使用合成数据集,在不同序列长度下测试Prompt Cache的TTFT延迟,并与常规KV缓存进行比较,以验证其理论上的二次方延迟降低。
- 实验结果 (图5):
- 二次方改进: 实验证实,KV缓存的延迟随序列长度二次方增长,而Prompt Cache的内存拷贝开销仅线性增长。因此,Prompt Cache的延迟优势(两条曲线之间的差距)随序列长度呈二次方扩大。
- 模型大小的影响: 从7B模型切换到13B模型时,KV缓存的延迟增加了220毫秒,而Prompt Cache仅增加了30毫秒。这表明Prompt Cache的优势也随模型大小(即隐藏维度)呈二次方增长。
- 端到端延迟: Prompt Cache主要减少TTFT,后续令牌生成时间(TTST)不变。然而,更快的响应时间改善了用户体验和整体端到端延迟。例如,在RTX 4090上,节省的TTFT时间相当于额外生成25个令牌。此外,通过与paged attention等技术结合,Prompt Cache可以减少内存占用,从而允许更大的批量大小,提高整体吞吐量。
- 分析结论: Prompt Cache的性能优势随着输入序列长度和模型规模的增加而显著增强,符合理论预期。
内存开销
- 实验内容: 分析Prompt Cache的内存开销,并根据LLM模型和16位浮点精度计算了缓存单个令牌的内存需求。
- 实验结果 (表2): 内存开销与缓存的总令牌数成正比。例如,对于Falcon 1B模型,缓存一个1K令牌的文档需要约180MB内存。对于Llama 70B这样的大模型,同样长度的模块需要2.5GB内存,这使得CPU内存成为存储提示模块的唯一选择。
- 分析结论: 内存开销是Prompt Cache的一个重要考虑因素,特别是对于大型模型。注意力状态的压缩技术是未来值得研究的方向。
Prompt Cache的应用案例
- 实验内容: 展示PML在三个复杂用例中的表达能力和性能优势:代码生成、个性化和参数化提示。
- 实验结果:
- 代码生成 (图6): 将每个源代码文件视为一个提示模块。在CodeLlama 7B上,GPU的TTFT延迟提高了4倍,同时输出完全相同。
- 个性化 (图7): 将用户的特征(如年级、熟练度)表示为提示模块,并使用
<union>
标签处理互斥特征。实验展示了延迟优势和高质量的输出。 - 参数化提示 (图8): 在旅行计划用例中,使用参数指定旅行时长,并用联合模块选择目的地。这展示了Prompt Cache在提供灵活性和保持缓存效率方面的能力。
- 分析结论: Prompt Cache及其PML设计足够灵活和强大,能够支持复杂的、结构化的提示应用,并在这些场景中提供显著的性能提升和高质量的输出。
A5 结论
本文介绍了Prompt Cache,这是一种基于“注意力状态可以在LLM提示之间重用”这一洞见而设计的加速技术。Prompt Cache利用一个提示模式(prompt schema)来描绘这些可重用的文本片段,将它们组织成一种称为“提示模块”的模块化且位置连贯的结构。这使得LLM用户能够将这些模块无缝地集成到他们的提示中,从而以可忽略的延迟代价利用它们作为上下文。我们的基准数据集评估显示,在GPU上TTFT延迟降低高达8倍,在CPU上高达60倍。
对于未来的工作,我们计划将Prompt Cache作为未来LLM服务系统的基础构建块。这样的系统可以配备优化的GPU缓存替换策略,以实现Prompt Cache所能达到的延迟下限。减少主机到设备内存开销的不同策略也可能是有益的,例如集成KV缓存中的压缩技术,或利用分组查询注意力。另一个有希望的探索方向是用于跨并发请求共享注意力状态的GPU原语,正如我们在§3.4中简要讨论的那样。这不仅可以减少TTFT延迟,还可以通过将更多请求打包到单个批次中来减少每个输出令牌的时间(TPOT)延迟。最后,Prompt Cache可以直接加速上下文内检索增强生成(RAG)方法,其中信息检索系统基本上充当提示模块的数据库。Prompt Cache对于实时问答和对话系统中的延迟敏感型RAG应用尤其有用。
💬 评论讨论
欢迎在这里分享您的想法和见解!