Efficient Low Rank Attention for Long-Context Inference in Large Language Models
Efficient Low Rank Attention for Long-Context Inference in Large Language Models
文章标题: 在大型语言模型中通过高效的低秩注意力实现长上下文推理
作者: Guoxu Zhou, Tenghui Li, Yuning Qiu, Qibin Zhao, Xuyang Zhao
机构: 广东工业大学, RIKEN AIP, RIKEN iTHEMS, RIKEN IMS, 千叶大学
A1 主要贡献
随着输入文本长度的增加,大型语言模型(LLM)中的键值(KV)缓存带来了过高的GPU内存成本,限制了在资源受限设备上的长上下文推理。现有的方法,如KV量化和剪枝,虽然减少了内存使用,但会遭受数值精度损失或键值对保留不佳的问题。
核心问题:LLM在处理长上下文时,KV缓存会线性增长,导致内存消耗过大,成为系统瓶颈。现有方法在解决此问题时,存在数值精度损失、关键信息丢失或高延迟等弊端。
研究目标:提出一种能够平衡内存效率、数值精度和计算延迟的理想解决方案,以高效处理长上下文推理。
创新点:本文引入了低秩查询和键注意力(Low Rank Query and Key attention, LRQK),这是一个两阶段框架。
1. 联合低秩近似:与ShadowKV【16, ShadowKV: KV cache in shadows for high-throughput long-context LLM inference, 2025, ICML, https://openreview.net/forum?id=oa7MYAO6h6】对RoPE旋转位置编码前的键矩阵进行昂贵的SVD操作不同,LRQK在预填充(prefill)阶段将全精度的查询(Query)和键(Key)矩阵联合分解为紧凑的秩-r因子。这种方法共同优化了Q和K矩阵的低秩近似,降低了计算复杂性,同时保持了表示的准确性 。
2. 精度保持的注意力计算:在每个解码(decode)步骤中,使用这些低维投影以$O(lr)$的时间复杂度计算代理(proxy)注意力分数。通过仅选择得分最高的top-k个token和一小部分固定的近期token,LRQK利用这些低秩近似的K和V作为计算代理,用于轻量级的注意力分数估计。关键的是,后续的注意力层操作使用原始的查询、键和值向量,没有任何近似或重构,从而保证了数学上的保真度和模型性能。
3. 混合缓存管理:实现了一个复杂的GPU-CPU混合存储系统,该系统具有活跃token保留机制和命中/未命中(hit/miss)缓冲架构,最大限度地减少了跨设备的数据传输。此外,一个专门的缓冲区维护最近访问的键和值,经验分析证实这些键和值始终获得高注意力分数,从而进一步优化了操作条件下的缓存命中率。
通过这种方式,LRQK仅传输缺失的全精度KV对,从而在保留精确注意力输出的同时,减少了CPU-GPU的数据移动。
A3 背景知识与关键观察
3.1 查询和键的低秩性
查询与键矩阵的低秩结构。正如先前工作【16, ShadowKV: KV cache in shadows for high-throughput long-context LLM inference, 2025, ICML, https://openreview.net/forum?id=oa7MYAO6h6】和 【14, InfiniGen: Efficient generative inference of large language models with dynamic KV cache management, 2024, USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)】所证明的,解码器-仅(decoder-only)Transformer中的键矩阵表现出显著的低秩结构。基于这一观察,本文将全精度查询和键矩阵联合分解为秩-r的组件,这遵循了矩阵秩的不等式:
这个性质意味着,如果矩阵K的有效秩为r,那么它与任何查询矩阵Q的交互作用也接近一个r维子空间。因此,可以通过低秩分解$A_QA_K^T$来近似$QK^T$,且近似误差很小。
奇异值衰减的实验证据。图3展示了这一现象的一个示例,该图绘制了在Wikitext-2-v1测试集【25, Pointer sentinel mixture models, 2016, CoRR, http://arxiv.org/abs/1609.07843】上,Qwen2.5-7B 【2, Qwen2.5: A party of foundation models, 2024, https://qwenlm.github.io/blog/qwen2.5/】和LLaMA-3-8B-1M 【3, The Llama 3 herd of models, 2024, CoRR, https://doi.org/10.48550/arXiv.2407.21783】 【24, Llama 3 gradient: A series of long context models, 2024, https://gradient.ai/blog/scaling-rotational-embeddings-for-long-context-language-models】模型中每个注意力头的查询和键矩阵K的平均奇异值谱。在这两种情况下,奇异值在超过一个很小的秩之后迅速衰减,证实了K矩阵可以用很小的损失进行低秩近似 。
3.2 邻近当前Token的注意力分数
注意力权重的局部性。为了量化解码过程中自注意力的局部性,令$q_t$为步骤t的查询,ki为位置i的键,则注意力权重由以下公式给出:
本文使用Qwen2.5-7B和LLaMA-3-8B-1M在Wikitext-2-v1测试集上评估了每个头的平均分数。
近期偏置的经验性发现。如图4所示,在所有情况下,曲线在当前token及其近邻处都表现出显著更高的注意力分数,证实了强烈的近期偏置(recency bias)。这一经验性发现也被StreamingLLM【11, Efficient streaming language models with attention sinks, 2024, International Conference on Representation Learning】观察到,它启发了本文在GPU缓存中包含一个紧凑的近期缓冲区(recency buffer),该缓冲区持续保留最近的token,从而以最小的开销最大化缓存命中率。
A2 方法细节
LLM的推理过程可分为两个阶段:预填充(prefill)和解码(decode)。当用户提供提示时,LLM首先对长提示进行一次预填充,然后以自回归方式解码token。这两个阶段需要特定的缓存机制,因此将在4.2节和4.3节分别介绍用于预填充和解码的LRQK方法。
符号说明。首先澄清方法中使用的初步知识和公式。在实际实现中,Transformer模型通常处理形状为(批量大小,头数,序列长度,隐藏大小)的输入张量。然而,当使用GQA【26, GQA: Training generalized multi-query transformer models from multi-head checkpoints, 2023, Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing】等方法时,查询张量的注意力头数通常与键和值张量的不同。因此,本文使用重复键/值(repeat KV)机制,确保查询和键/值张量的头数对齐。此外,为简化符号,本文也省略了‘批量大小’和‘头数’维度,因为不同批次和头之间的计算是独立的。因此,查询、键和值矩阵表示为$Q, K, V \in \mathbb{R}^{l \times d}$,其中l表示序列长度,d表示每个头的隐藏大小。此外,由于预填充和解码过程在所有Transformer层中都是相同的,为了清晰起见,本文在方法中省略了层索引。
4.1 混合缓存
混合缓存策略与命中-未命中机制。为了解决内存限制并优化运行时效率,我们采用了一种混合缓存策略,利用了速度快但容量有限的GPU KV缓存和容量大但速度较慢的CPU缓存。为了进一步减少GPU和CPU之间的数据传输开销,我们引入了一种命中-未命中(hit-and-miss)缓存机制。在此方案中,如果所需的token已存在于GPU缓存中(即缓存命中),则直接访问它们,不会产生额外的传输成本。如果所需的缓存不在GPU中,则只从CPU缓存中获取缺失的token,从而避免了冗余的数据移动,提高了整体系统效率。
GPU缓存结构。认识到最近的token更有可能获得高注意力分数(如图4所示),我们进一步整合了一个专用缓冲区来存储最近的token。GPU KV缓存的结构为[活跃token | 精简token],其中‘活跃token’对应于由我们的查询和键矩阵的低秩近似选择出的、具有最高近似注意力分数的top-k个token集合$\Omega_k$;而‘精简token’(lite tokens)则对应于一个固定数量的相邻近期token集合$\Omega_l$。因此,GPU KV缓存的总窗口大小是活跃token和精简token数量之和,从而在缓存构成中实现了相关性和近期性的平衡。
4.2 预填充
输入与矩阵表示。在此阶段,输入是一个长提示,表示为矩阵$X \in \mathbb{R}^{l \times d}$,其中l是序列长度,d是隐藏维度。通过对X应用线性投影,我们得到查询、键和值矩阵:$Q, K, V \in \mathbb{R}^{l \times d}$。
联合低秩分解。根据公式(1),注意力分数矩阵通常是低秩的,因此可以使用联合低秩分解来近似:
其中$A_Q, A_K \in \mathbb{R}^{l \times r}$且$r \ll d$表示矩阵的秩。
查询与键矩阵的近似。除了联合低秩近似,为了提高内存效率,我们还对查询和键矩阵进行近似:
其中$B_Q, B_K \in \mathbb{R}^{r \times d}$。注意,注意力分数$QK^T$和$Q, K$的近似共享由$A_Q$和$A_K$张成的相同列/行空间。
优化问题。对于求解公式(2)和(3),我们可以将其表述为:
拉格朗日松弛。为了求解约束优化问题(4),我们通过推导其拉格朗日形式来放松它:
其中$\lambda_{pQ}$和$\lambda_{pK}$是两个缩放因子,分别表示查询和键矩阵低秩近似的重要性。
求解过程。通过求解拉格朗日函数(5),将偏导数设为零,我们可以得到查询和键矩阵的解(详见附录A.1):
计算成本分析。从公式(6)中可以看出,需要进行矩阵求逆操作。矩阵求逆的开销是$O(r^3)$。然而,由于秩r是一个小数,并且$A_K^TA_K$, $A_Q^TA_Q$, $B_KB_K^T$, $B_QB_Q^T$都是形状为(r, r)的小矩阵,因此矩阵求逆的计算成本相对较低。此外,在求解$A_Q$和$A_K$时,可以通过改变乘法顺序来进一步降低内存成本。例如,将操作顺序从$(QK^T)A_K$更改为$Q(K^TA_K)$,我们将计算复杂度从$O(l^2d)$降低到$O(rld)$。
迭代过程。算法1中的迭代过程可以解释为块坐标下降(Block Coordinate Descent, BCD)方法,其中每一步都针对变量的一个子集优化拉格朗日函数,同时保持其他变量固定。
算法1 LRQK的预填充阶段,用于AQ, AK, BQ, BK的交替更新
需要: $Q, K$, 秩 $r$, $\lambda_{pQ}$, $\lambda_{pK}$
1: 初始化 $A_Q, A_K \sim N(0, 1)$
2: while $i < \text{max\_iteration}$ 或未收敛 do
3: 更新 $B_Q \leftarrow (A_Q^T A_Q)^{-1} A_Q^T Q$
4: 更新 $B_K \leftarrow (A_K^T A_K)^{-1} A_K^T K$
5: 更新 $A_K \leftarrow K(Q^T A_Q + \lambda_{pK} B_K^T) (A_Q^T A_Q + \lambda_{pK} B_K B_K^T)^{-1}$
6: 更新 $A_Q \leftarrow Q(K^T A_K + \lambda_{pQ} B_Q^T) (A_K^T A_K + \lambda_{pQ} B_Q B_Q^T)^{-1}$
7: end while
8: return $A_Q, A_K, B_Q, B_K$
4.3 解码
解码阶段的目标。在预填充阶段之后,语言模型按顺序生成剩余的token,遵循自回归解码过程,其中每个token的预测都基于先前生成的上下文。假设当前阶段是t,输入是一个行向量$x_t \in \mathbb{R}^{1 \times d}$,其投影为$q_t \in \mathbb{R}^{1 \times d}$和$k_t \in \mathbb{R}^{1 \times d}$。在解码阶段,目标是将$q_t$和$k_t$压缩为更小的$\hat{q}_t \in \mathbb{R}^{1 \times r}$和$\hat{k}_t \in \mathbb{R}^{1 \times r}$,从而大大减少KV缓存的内存占用。
解码阶段的优化问题。从前一阶段,我们获得了查询和键矩阵的低秩近似,形式为左右因子:$A_Q, A_K \in \mathbb{R}^{l \times r}$,和$B_Q, B_K \in \mathbb{R}^{r \times d}$。为了在新token的建模设置中保持一致性,我们为估计近似的查询和键向量$\hat{q}_t$和$\hat{k}_t$制定了以下优化问题:
矩阵$A_{K,\Omega,t-1}$表示$A_K$的一个子矩阵,仅包含由$\Omega$索引的行。类似地,$K_{\Omega,t-1}$是由选择$\Omega$索引的行形成的$K$的相应子矩阵。这里,$\Omega$表示在时间步$t-1$时保留在GPU内存中的KV缓存的索引集。
拉格朗日公式。优化问题(7)的拉格朗日函数写为:
其中$\lambda_{d1}$和$\lambda_{d2}$是正则化参数。更多细节请参考附录A.2。
更新规则。通过求解拉格朗日函数(8),我们获得了$\hat{q}_t$和$\hat{k}_t$的更新规则:
与预填充阶段类似,由于需要求逆的矩阵尺寸为(r, r),这是一个小矩阵,所以计算成本相对较小。
B_Q,t和B_K,t的更新。对于$B_{Q,t}$和$B_{K,t}$,应用一步梯度下降来优化目标函数。更新规则由下式给出:
其中$\eta_Q$和$\eta_K$分别是$B_{Q,t-1}$和$B_{K,t-1}$的学习率。
学习率的计算。两个学习率$\eta_Q^*$和$\eta_K^*$的计算公式如下:
解码算法总结。算法2总结了$\hat{q}_t, \hat{k}_t, B_{Q,t}$和$B_{K,t}$的更新规则,代理注意力的计算,以及所选键和值的更新。LRQK的完整解码过程的图形表示如图1所示。为了避免随机初始化,我们通过设置$\lambda_{d1}=0$来使用$\hat{k}_t$的初始猜测。算法迭代直到收敛或达到最大迭代次数。注意力是用选定的键和值计算的,$Attention(q_t, K_{\Omega,t}, V_{\Omega,t})$。
算法2
需要: $q_t, k_t, v_t, B_{Q,t-1}, B_{K,t-1}, \lambda_{d1}, \lambda_{d2}, A_{K,\Omega,t-1}, K_{\Omega,t-1}, V_{\Omega,t-1}$
1: 初始猜测 $\hat{k}_t \leftarrow k_t B_{K,t-1}^T (B_{K,t-1} B_{K,t-1}^T)^{-1}$
2: while $i < \text{max\_iteration}$ 或未收敛 do
3: 更新 $\hat{q}_t \leftarrow m_{lq,t} M_{rq,t}^{-1}$
4: 更新 $\hat{k}_t \leftarrow (k_t B_{K,t-1}^T + \lambda_{d1} q_t k_t^T \hat{q}_t) (B_{K,t-1} B_{K,t-1}^T + \lambda_{d1} \hat{q}_t^T \hat{q}_t)^{-1}$
5: end while
6: $B_{Q,t} \leftarrow B_{Q,t-1} - \eta_Q^* \nabla B_{Q,t-1}$,其中$\eta_Q^*$如公式(11)
7: $B_{K,t} \leftarrow B_{K,t-1} - \eta_K^* \nabla B_{K,t-1}$,其中$\eta_K^*$如公式(12)
8: $A_{K,t} \leftarrow [A_{K,t-1}; \hat{k}_t]$ ▷ 拼接
9: top-k索引 $\Omega_k \leftarrow \text{top}(A_{K,t} \hat{q}_t^T, k)$ ▷ 计算代理注意力
10: 根据$\Omega_k$从CPU K, V中获取缺失的$\{k_i\}, \{v_i\}$
11: $K'_{\Omega,t-1}, V'_{\Omega,t-1} \leftarrow$ 将获取的$\{k_i\}, \{v_i\}$与GPU $K_{\Omega,t-1}, V_{\Omega,t-1}$合并
12: $K_{\Omega,t} \leftarrow [K'_{\Omega,t-1}; k_t]$, $V_{\Omega,t} \leftarrow [V'_{\Omega,t-1}; v_t]$
13: 异步将$k_t, v_t$传输到CPU
14: return $K_{\Omega,t}, V_{\Omega,t}$.
A4 实验环境与结果
实验环境
-
硬件配置:
- GPU: NVIDIA A100 GPU (80GB 和 40GB 内存版本), NVIDIA A6000 (48GB), NVIDIA GeForce RTX 3090 (24GB)。
- CPU: AMD EPYC 7742 64核处理器。
-
软件配置:
- 评估框架: OpenCompass 【27, Opencompass: A universal evaluation platform for foundation models, 2023, https://github.com/open-compass/opencompass】 。
- 推理框架: Hugging Face transformers 库。
-
模型:
- LLaMA-3-8B-1M 【24, Llama 3 gradient: A series of long context models, 2024, https://gradient.ai/blog/scaling-rotational-embeddings-for-long-context-language-models 】
- Qwen2.5-7B-Instruct 【2, Qwen2.5: A party of foundation models, 2024, https://qwenlm.github.io/blog/qwen2.5/ 】
- Mistral-7B-Instruct-v0.3
- Phi-3-mini-128k-instruct
- Qwen2.5-14B-Instruct
- Qwen2.5-32B-Instruct
- Llama-3.1-8B-Instruct
-
数据集:
- RULER 【28, Ruler: What’s the real context size of your long-context language models?, 2024, arXiv preprint arXiv:2404.06654】: 用于评估长上下文(4K到128K tokens)的准确性。
- LongBench: 用于评估长上下文任务的准确性。
- Wikitext-2-v1 【25, Pointer sentinel mixture models, 2016, CoRR, http://arxiv.org/abs/1609.07843 】: 用于评估缓存未命中率的摘要任务。
-
实验参数:
- 算法迭代: 最大迭代次数为2,容忍度为0.01。
- 缩放参数: $\lambda_{pQ} = \lambda_{pK} = \lambda_{d1} = \lambda_{d2} = 1$。
- LRQK参数: 秩r、top-k和lite token数量根据不同实验进行设置。
实验结果
5.1 准确性评估
长上下文基准测试。在RULER 128K和LongBench数据集上,将LRQK与四种最新的动态稀疏注意力基线方法(Loki, InfiniGen, Quest, ShadowKV)进行了比较。如表1所示,LRQK在多个任务上表现出竞争力。在RULER基准上,LRQK在S2任务上达到100%准确率,在QA-1和QA-2任务上超越所有基线,甚至在MQ和MV任务上超过了原始的LLaMA-3-8B-1M模型。在LongBench上,LRQK在PRetr和LCC任务上取得了最佳性能,显示了其处理长距离依赖关系的能力。
表1:不同模型和方法在RULER(左)和LongBench(右)上的比较。
跨模型性能。在RULER-4K子集上,对LLaMA-3-8B-1M和Qwen2.5-7B-Instruct两个模型进行了评估(rank r=16, top-k=256, lite tokens=16)。如表2所示,LRQK在两个模型上都保持了良好的性能。
表2:两个模型在RULER-4K子集上的结果
5.2 秩r和Top-k选择的影响
超参数敏感性分析。通过在RULER-4K子集上进行实验,探究了秩r和top-k数量对模型性能的影响。
- 秩(Rank)的影响:如表3所示,增加秩r通常能提高QA-1和VT任务的性能,但QA-2在较低的秩r=16时表现最佳,这表明某些任务可能受益于更紧凑的表示。
- Top-k的影响:如表4所示,更大的top-k值能持续提升所有任务的准确率。k=1024的配置得分最高,但更大的k会增加计算和内存成本,需要在性能和资源之间进行权衡。
表3:不同秩值下的准确率(top-k=256,lite tokens=16)
表4:不同top-k值下的准确率(r=8,lite tokens=16)
5.3 缓存未命中率
内存管理效率评估。通过计算未命中率(miss rate),即必须从CPU传输到GPU的KV缓存行数与所有选定索引总数的比率($c_{miss}/c_{total}$),来评估内存管理策略的效率。实验在wikitext-2-v1测试集的摘要任务上进行。如图5所示,未命中率的分布近似于高斯分布,平均未命中率约为0.40。这意味着,所提出的命中-未命中机制平均减少了约60%的CPU到GPU的数据传输量。
A5 结论
本文提出了一种名为低秩查询和键(LRQK)注意力的两阶段推理算法,该算法通过将查询和键矩阵的联合低秩分解与混合GPU-CPU缓存管理相结合,实现了长上下文处理。通过在紧凑的秩-r因子上计算代理注意力分数,并有选择地仅获取top-k活跃token和最近token的全精度键值对,LRQK在保留精确注意力输出的同时,减少了CPU-GPU的数据传输。
在RULER(最高128K tokens)和LongBench基准上的广泛评估表明,LRQK在各种长上下文任务中达到或超过了当前最先进的稀疏注意力方法。该方法实现了显著的内存节省,使得处理那些否则会导致内存不足错误的上下文成为可能,且没有明显的准确性下降。消融研究证实了每个组件的有效性:低秩近似、活跃token选择和近期token保留都对该方法的稳健性能做出了贡献。
局限性:尽管LRQK显著减少了数据传输开销,但进一步观察发现,CPU端的索引操作构成了主要的性能瓶颈,而非PCIe带宽限制,希望能在未来的工作中解决。此外,LRQK的超参数需要针对特定任务进行调整。虽然我们提供了实用的指导方针和默认配置,但最佳设置仍需在新领域进行经验验证。
社会影响:LRQK降低的推理成本可以使更广泛的人群接触到长上下文语言模型,从而有益于研究和教育。然而,这种效率也可能降低部署有害应用(如大规模虚假信息宣传)的门槛。我们相信,在有适当的保障措施和负责任的部署实践的情况下,普及先进AI能力将带来净效益。
A6 附录
A. 预填充和解码的导数细节
A.1 预填充阶段的导数
拉格朗日函数。如正文公式(5)所示,预填充阶段的拉格朗日函数为:
偏导数计算。计算$L_{pre}$关于$A_Q, A_K, B_Q, B_K$的偏导数:
求解。将偏导数设为零:
求解这些方程,可以得到$A_Q, A_K, B_Q, B_K$的解析解:
A.2 解码阶段的导数
拉格朗日函数。如正文公式(8)所示,解码阶段的拉格朗日函数为:
偏导数计算。计算$L_{dec}$关于$\hat{q}_t$和$\hat{k}_t$的偏导数:
求解。将偏导数设为零,可以推导出$\hat{k}_t$的更新公式:
BQ,t-1 和 BK,t-1 的更新。首先计算关于$B_{Q,t-1}$和$B_{K,t-1}$的偏导数:
由于行向量的外积$\hat{q}_t^T\hat{q}_t \in \mathbb{R}^{r \times r}$是一个秩为1的矩阵,不可逆,因此$B_{Q,t-1}$和$B_{K,t-1}$没有闭式解。
梯度下降更新。我们使用梯度下降来更新$B_{Q,t-1}$和$B_{K,t-1}$。$L_{dec}$关于它们的梯度为:
梯度下降的更新规则为:
B. 附加实验结果
运行细节。由于PyTorch中的求逆计算以float32进行,因此在执行算法1和算法2之前,参数会临时转换为float32。模型在推理时使用bfloat16。CPU为AMD EPYC 7742 64核处理器。RULER 128K实验在单张NVIDIA A100 80GB GPU上运行,LongBench实验在单张A100 40GB GPU上运行,批处理大小为1。更多结果见表5。
表5:不同模型和方法在RULER 128K(左;top-2048, lite tokens=64)和LongBench(右;rank r = 16, lite tokens=64)上的比较。
在Llama和Qwen上的更多结果。在Llama-3-8B-1M和Qwen2.5-7B两个不同模型上,使用LRQK(rank=16, top-256 active tokens, 16 lite tokens)的更多结果如表6所示。实验在单张A100 40G GPU上运行。
表6:两个模型在RULER-4K/8K/16K子集上的更多结果。
在Mistral和Phi-3上的结果。在另外两个模型'mistralai/Mistral-7B-Instruct-v0.3'(Mistral)和'microsoft/Phi-3-mini-128k-instruct'(Phi-3)上,使用默认LRQK超参数(rank=32, top-2048, 64 lite tokens)进行了评估。由于NVIDIA A6000 GPU的48GB内存限制,Mistral在RULER 32K上评估,Phi-3-mini在RULER 16K上评估。表7总结了结果,表明LRQK能有效推广到不同架构的模型。
表7:Mistral和Phi-3在不同上下文长度的RULER基准上的评估结果。
在更大型号上的结果。为了评估LRQK对更大型号的可扩展性,在'Qwen/Qwen2.5-14B-Instruct'(Qwen 14B,64K上下文)和'Qwen/Qwen2.5-32B-Instruct'(Qwen 32B,16K上下文)上进行了额外实验,均在A100 80G GPU上进行。所有实验使用默认LRQK超参数。表8显示,对于Qwen 14B,LRQK在检索任务上有显著提升;对于Qwen 32B,基线模型已接近完美,LRQK保持了相当的结果,表明该方法不会降低具有强大原生能力的模型的性能。
表8:在RULER基准上,Qwen2.5-14B-Instruct(64K上下文)和Qwen2.5-32B-Instruct(16K上下文)使用和不使用LRQK的性能比较。'+LRQK'列表示使用LRQK的模型性能。
与KVQuant的结合结果。为了研究LRQK是否与基于量化的压缩方法兼容,我们评估了LRQK与KVQuant【6, Kvquant: Towards 10 million context length LLM inference with KVcache quantization, 2024, Advances in Neural Information Processing Systems】的组合。实验在Llama-3.1-8B和Qwen2.5-7B上进行,使用RULER 32K基准。LRQK超参数设置为r=32, top-k=2048, 64 lite tokens。表9显示,LRQK可以与量化方法有效结合,而不会导致灾难性的性能损失,表明两者在KV缓存压缩方面是互补的。
表9:在RULER 32K上,单独应用KVQuant与结合LRQK的性能比较。'+KVQuant'列表示仅应用KVQuant,'++LRQK'列表示同时应用KVQuant和LRQK。
初始化策略的影响。本文研究了低秩因子$A_Q \in \mathbb{R}^{l \times r}$和$A_K \in \mathbb{R}^{l \times r}$的三种初始化策略。
1. 随机高斯初始化 (randn): $A_Q, A_K \sim N(0,1)$。
2. 独立Top-r选择 (top): 独立计算Q和K中每个维度的L1范数作为重要性分数,并选择最重要的r个维度来初始化$A_Q$和$A_K$。
3. 联合Top-r选择 (topcol): 基于Q和K的联合重要性选择维度,确保$A_Q$和$A_K$共享相同的列索引。
实验结果。在RULER 16K基准上对这三种策略进行了评估。如表10所示,所有三种初始化策略在大多数任务上都取得了几乎相同的性能,表明算法1对低秩因子的初始化具有鲁棒性。鉴于此,推荐使用计算效率更高的randn初始化作为默认选项。
表10:AQ和AK初始化策略在RULER 16K上的性能比较。尽管设计理念不同,但所有三种方法都取得了相当的结果。
C. 运行时性能分析
C.1 吞吐量结果
性能对比实验。在Llama-3.1-8B-Instruct-1M上,将LRQK与标准GPU-only和CPU offloading方法进行运行时性能比较。实验在单张NVIDIA A100 40GB GPU上进行,批处理大小为1,上下文长度从4K到64K。如表11所示,GPU-only方法在短上下文(4K-32K)中吞吐量最高,但在64K时出现内存不足(OOM)错误。CPU offload虽然能处理更长上下文,但在解码阶段吞吐量严重下降。相比之下,LRQK default在所有上下文长度上保持了相对稳定的解码吞吐量,这得益于其选择性传输机制。LRQK no hit/miss的性能下降则凸显了hit/miss缓冲优化的重要性。
表11:'meta-llama/Llama-3.1-8B-Instruct-1M'在RULER QA-2上不同上下文长度的吞吐量比较(tokens/s)。P:预填充阶段;D:解码阶段。LRQK default在64K tokens时保持了稳定的解码性能并避免了OOM。
C.2 与基线方法的比较
与ShadowKV的对比。在一个文本摘要任务上(LongBench 32K上下文),将LRQK与vanilla attention和ShadowKV[16]进行了比较。实验在NVIDIA GeForce RTX 3090 (24 GB) GPU上进行。如表12所示,Vanilla attention吞吐量最高但需要2个GPU。LRQK变体的吞吐量在489到646 tokens/s之间,实现了内存效率和速度的权衡。LRQK的GPU功耗低于Vanilla和ShadowKV,表明该方法未充分利用GPU资源,因为需要等待CPU的数据索引,这是该方法的一个局限性。
表12:在NVIDIA GeForce RTX 3090(24 GB,250W)上进行文本摘要(LongBench,32K上下文)的性能比较。Vanilla attention使用2个GPU;其他方法使用1个GPU。Tokens/s通过(所有token数/总时间)计算。
D. 超参数选择指南
默认配置。我们推荐从以下默认配置开始,该配置在多种设置下均表现出色:秩$r=32$,活跃token $top-k=2048$,lite tokens 64,迭代次数2,容忍度$10^{-2}$。
秩 r。秩r控制每个注意力头的低秩因子$A_Q, A_K$的维度。在实践中,我们发现$r \in \{8, 16, 32, 48\}$可以在近似质量和计算成本之间取得良好平衡。对于大多数应用,$r=32$是一个有效的折衷。
Top-k 活跃 Tokens。top-k参数决定了保留在GPU内存中的高注意力token数量。最优的k随上下文长度扩展:
- 短上下文 (≤ 4K): $k=256$
- 中等上下文 (8K-16K): $k=512$ 至 $1024$
- 长上下文 (≥ 32K): $k=2048$
Lite Tokens。该参数指定了始终保留在GPU内存中的近期token数量,以捕获局部上下文。推荐选择$\{16, 32, 64\}$,对整体计算成本影响很小。我们推荐64作为保守的默认值。
收敛参数。
- 迭代次数: 通常2或3次迭代即可收敛。
- 容忍度: 收敛阈值,定义为连续低秩矩阵之间的均方误差。$10^{-2}$或$10^{-3}$在实践中效果良好。
调优策略。
1. 从默认配置开始。
2. 如果内存有限,按上下文长度比例减少k或降低r。
3. 如果准确性不足,增加k(直至硬件限制)或r。
实现与集成开销。相对于标准注意力,LRQK引入了适度的计算开销:
- 预填充阶段: 计算低秩投影的额外成本。
- 解码阶段: 低秩矩阵的增量更新增加了计算,但对更少token的注意力计算降低了总成本。
- 缓存管理: CPU-GPU传输和缓存查找操作增加了延迟,但通过仅传输k个token而非整个KV缓存来缓解。
💬 评论讨论
欢迎在这里分享您的想法和见解!