FlashTensor: Optimizing Tensor Programs by Leveraging Fine-grained Tensor Property
- 文章标题: FlashTensor: 通过利用细粒度张量属性优化张量程序
- 作者: Runxin Zhong, Yuyang Jin, Chen Zhang, Kinman Lei, Shuangyu Li, Jidong Zhai
- 机构: 清华大学, 清程.AI
A1 主要贡献
核心问题: 深度神经网络(DNNs)在处理长上下文任务时会产生巨大的中间张量,导致严重的内存开销。现有的DNN优化方法由于对张量属性的感知不足,无法有效进行内存优化,并可能导致计算效率低下。具体来说,当一个巨大的中间张量被多个具有不同归约(reduction)维度的操作符使用时,现有的融合策略会因复杂的依赖关系而导致并行度降低,性能严重下降。
研究目标: 本文旨在提出一个名为FlashTensor的DNN优化系统,通过利用细粒度的张量属性来减少内存开销并提高推理性能。
创新点与主要贡献:
* 总结了四种关键张量属性: 本文总结了四种对优化至关重要的张量属性,分别是归约依赖(reduce dependency)、广播性(broadcastability)、尺寸(size)和值(value)。
* 提出了张量属性识别器 (Tensor Property Identifier): 该模块系统地分析整个计算图,并捕获每个张量的细粒度属性。
* 提出了张量属性感知优化 (Tensor Property-Aware Optimization): 该模块基于属性感知的转换规则和核映射策略来搜索最优的计算核(kernels),以实现高计算效率和低内存访问开销。
* 设计并实现了FlashTensor系统: 这是一个利用细粒度张量属性优化张量程序的系统。实验表明,与八个最先进的工作相比,FlashTensor在H100上平均取得了1.50倍的端到端加速和3.24倍的核心模块性能加速(在A100上分别为1.86倍和3.70倍)。
A3 背景知识与关键观察
动机
近期的一个研究趋势是使用新的注意力变体来减少Vanilla Attention的巨大计算量。例如,像H2O【45,H2O: heavy-hitter oracle for efficient generative inference of large language models,2024,NIPS】,RoCo【28,On the Efficacy of Eviction Policy for Key-Value Constrained Generative Language Model Inference,2024,arXiv】和Keyformer【1,Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference,2024,MLSys】这样的模型,可以通过丢弃一些不重要的令牌来减少总计算量。以图2中展示的H2O为例,它在SoftMax(由Exp、Reduce 0和Div组成)之后使用了一个额外的Reduce算子(Reduce 1)来计算令牌的重要性。随后,一个TopK算子和一个Gather算子一起选择并缓存最重要的令牌,同时丢弃其余的。
在H2O的推理过程中,包含预填充(prefill)和解码(decode)两个阶段。预填充阶段逐块处理输入提示,选择并缓存重要令牌,尤其是在输入超过缓存容量时。解码阶段顺序生成输出令牌,并进一步更新缓存。因此,增加的算子会影响预填充和解码阶段的性能。然而,对于长上下文的文档摘要等任务,预填充阶段是主要的性能瓶颈。例如,在InfiniteBench【44,∞Bench: Extending Long Context Evaluation Beyond 100K Tokens,2024,ACL】中,提示可以达到442K个令牌,而生成的令牌只有0.7K个,相差631倍,导致H2O中预填充与解码的执行时间比为4.51。主要瓶颈来自于预填充阶段核心模块中创建的巨大张量,如图2所示。这个尺寸为$O(seqlen^2)$的张量由MatMul 0产生,带来了巨大的内存访问开销。即使经过TensorRT【36,NVIDIA TensorRT,2017,https://developer.nvidia.com/tensorrt】优化,H2O的核心模块仍占总预填充时间的约57.62%,但仅达到10.77 TFLOP/s(A100 F16 TensorCore峰值性能的3.45%)。
关键挑战在于两个归约操作,MatMul 1和Reduce 1,它们对Div的输出张量(称为DivOut)具有不同的归约维度。为了提高效率,DivOut的每一行必须分配给同一个并行单元以进行MatMul 1的计算,而DivOut的每一列也必须驻留在单个并行单元内。因此,除了使用单个并行单元外,没有其他可行的分区策略能够满足这些约束。结果是,像TensorRT这样的现有方法必须通过缓慢的全局内存跨多个核来处理这个大张量。此外,这两个具有不同归约维度的归约算子是H2O和Vanilla Attention之间的主要区别,这使得FlashAttention【10,FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning,2023,arXiv】【11,FLASHATTENTION: fast and memory-efficient exact attention with IO-awareness,2024,NIPS】【30,FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision,2024,arXiv】由于这些结构差异而难以优化。这种局限性源于它们缺乏对细粒度张量属性(如每个维度上的归约依赖)的感知,从而错失了将MatMul 1与前面的算子融合以避免访问大张量的机会。
关键观察
FlashTensor通过利用细粒度的张量属性来减少内存开销。在介绍我们的系统之前,我们首先介绍对关键张量属性的观察。这里我们列出一些细粒度属性,并使用H2O来说明它们为何重要。
- 归约依赖(Reduce dependency)。归约依赖指的是张量的某些维度是否被聚合或归约。分析算子的潜在并行机会非常重要。如前所述,在图2中,Div的输出张量具有来自两个方向的归约依赖。Reduce 1禁止在其行维度上进行SM间的并行化,而MatMul 1禁止沿其列维度进行并行化,使得整个计算图难以并行化。识别这种归约依赖对于避免次优的并行化模式和寻找最优策略至关重要。
- 广播(Broadcast)。广播指的是张量维度被扩展或广播以匹配另一个张量的形状。例如,图3中的Mul算子隐式地广播其右侧操作数的列维度。何时执行广播对总计算量很敏感。当重新排序后,示例中的乘法量从$S \times S' \times h$减少到$S \times S' \times h$。然而,这种重新排序并不总是能得到正确的结果。验证此类转换需要深入分析归约轴和广播轴之间的关系,这将在4.1节中进一步讨论。
图3. H2O中的广播示例。重新排序MatMul 0和Mul可以减小中间张量的尺寸。 - 尺寸(Size)。尺寸是张量的内在属性,表示其包含的元素总数。它与内存访问量密切相关,减少内存访问开销的关键策略之一就是最小化它。
- 值(Value)。值属性可以帮助减少不必要的计算和内存访问。例如,当意识到解码器-Transformer中通常使用的三角掩码,并结合其上算子的含义时,我们可以跳过大多数张量的一半计算。
系统概览
基于以上观察,我们提出了具备这些细粒度张量属性感知的FlashTensor。FlashTensor专注于优化预填充阶段,这是长上下文任务的瓶颈,如动机部分所讨论。FlashTensor的概览如图4所示,它由两个主要模块组成:张量属性识别器(Tensor Property Identifier)和张量属性感知优化(Tensor Property-Aware Optimization)。
1. 首先,张量属性识别器模块将一个以计算图(节点表示算子,边表示张量)形式表示的张量程序作为输入,并捕获计算图中每个张量的所有细粒度属性,包括归约依赖、广播、尺寸和值(第4节)。
2. 然后,张量属性感知优化模块基于图转换规则和核映射策略搜索最优计划。图转换规则旨在通过特定的限制来减小中间张量的尺寸,而核映射策略则通过考虑内存访问、计算强度以及并行性来选择高效的候选核。最终生成优化的张量程序,准备作为高效代码执行(第5节)。
A2 方法细节
张量属性识别器
在本节中,我们首先正式定义我们观察到的对后续分析和优化至关重要的属性,然后介绍FlashTensor如何在计算图中识别这些属性。
属性定义
如表1所示,张量属性被分为两大类:1) 逐维度属性 (Per-Dimension Property) 关注张量单个维度的特定属性,提供对归约依赖和广播性等方面的洞察。2) 整体张量属性 (Entire Tensor Property) 包含描述张量整体的属性,如其总尺寸和可能的常量值。
-
归约依赖 (Reduce Dependency)。这是一个多值属性,描述了张量维度与归约操作之间的相互依赖关系,对优化计算效率至关重要。它有三个值:
- NonPara:指由于归约操作的数据依赖性而无法进行分区并行执行的维度。这些维度需要由单个单元处理,从而阻止了并行化。例如,在图5(a)中,输入和输出的归约维度都被分类为NonPara,因为它们需要在整个维度上进行数据聚合。
- Reuse:指可以并行化同时通过分块(tiling)实现数据复用的维度。例如,在图5(b)中,MatMul中的行维度可以并行化,通过分块可以重用右侧操作数,从而提高内存效率。这种方法也适用于支持广播的操作,如Add和Mul,其中分块后并行化可以平衡计算效率和内存访问。
- Batch:代表可以完全并行化而无需数据复用的维度,例如图5(b)中批处理矩阵乘法(Batched Matrix Multiplication)的批处理维度,以及图5(a)中Reduce算子除归约维度之外的维度。
图5. 归约依赖属性
-
广播 (Broadcast)。这是一个属性,指示张量维度是否会被后续算子广播。如果张量是某个广播支持算子在该维度上的操作数,则该维度被认为是Broadcasted。图6(a)(b)展示了维度如何被扩展以满足算子形状要求的示例。
图6. 广播属性 -
尺寸 (Size)。它表示张量中的元素总数,计算为所有维度的乘积,如图7所示。
图7. 尺寸和值属性 -
值 (Value)。它指示张量在计算过程中是否保持一个常量值,如图7所示。与传统编译器将常量视为标量不同,FlashTensor在张量级别处理它,从而实现了更广泛的优化机会。具有常量值的张量可以被预计算或高效存储,从而最大限度地减少执行期间的动态更新。在FlashTensor中,常量值信息表示为一个枚举而不是实际值,以平衡信息的有效性和开销。
基于数据流的属性识别
一些基本属性,如尺寸,可以很容易地从张量中提取。然而,像归约依赖这样的复杂属性需要复杂的分析才能准确标注。例如,在单个算子中,输入和输出张量的归约依赖可以根据其计算语义进行初步标注,即使没有完全实现该算子,如图5所示。然而,在具有多个算子的计算图中,张量的归约依赖属性受到后续和先前算子的双重影响,这使得确定其最终值具有挑战性。
为了解决这个问题,我们提出了一种基于数据流的两阶段属性识别算法。
1. 属性传播 (Property Propagation):根据每个算子的计算语义,张量属性会进行前向和后向传播。
2. 属性聚合 (Property Aggregation):从不同算子传播来的属性会为每个张量进行聚合,以确定最终的属性值。
该算法利用了固有的单调优先级(例如,归约依赖类型:NonPara > Reuse > Batch)。当不同类型汇集到同一个张量维度上时,会保留优先级更高的值,确保在整个图上属性表示的准确性。这两个阶段会迭代执行,直到属性稳定为止。
图8提供了归约依赖识别的示例以供进一步说明。图8(b)说明了MatMul的输出如何通过后向传播将右侧操作数在维度N'上的归约依赖从Reuse更新为NonPara。图8(a)展示了属性如何在MatMul和Reduce之间进行聚合,其中中间张量从两个算子接收到不同的归约依赖值。
张量属性感知优化
FlashTensor基于已识别的张量属性进一步优化张量程序。优化基于两条规则:代数等价图变换和非凸核映射,以及一种轻量级计划搜索方法。代数等价图变换提供了一系列变换规则,允许在等价变换中改变中间张量的大小。非凸核映射提出了一种新的候选核类型以实现高效执行。轻量级计划搜索通过属性约束剪枝和低成本性能模型实现高效的计划生成。
代数等价图变换
中间张量的大小主导了内存访问开销。因此,我们提出了一种变换方法,主要通过在计算图上执行代数等价变换来关注中间张量大小的变化。这使得后续步骤能够在整个搜索空间中搜索具有最小中间张量大小的计算图,从而减少内存开销。广播是带来中间张量大小变化的本质。
-
广播属性感知的变换规则。我们提出了一系列不仅考虑张量大小,还考虑每个维度上广播属性的变换规则。具体来说,表2列出了所有的变换规则,可分为两大部分:
- 始终有效的变换。这些变换无论其操作数的广播属性如何,总是有效的。它们的有效性由不同算子及其组合的交换律、结合律和分配律这三个数学定律保证。
- 广播属性约束的变换。某些变换的有效性依赖于广播。在这些变换中,广播被视为确保正确性的基本属性。例如,如图9(a)所示,在Div的右操作数在维度N上广播的场景中,重排是无效的,因为MatMul中的归约会破坏计算语义并导致不正确的结果。然而,如图9(b)所示,如果Div的右侧操作数沿与MatMul的归约维度对齐的维度进行广播,则重排变得有效。
表2. 基于广播属性的变换规则。A, B, C是张量。+, ⊙, /是逐元素的加、乘、除。A@B表示A和B的矩阵乘法。除矩阵乘法外的所有算子都支持广播。
图9. MatMul和Div,其重排的有效性是有条件的 -
值属性感知的变换规则。我们利用值属性通过消除不必要的迭代来优化循环。具体来说,它专注于跳过For循环中的某些迭代而不影响最终输出。如图10所示,某些迭代可能会从Mask算子产生常量值,如-∞。由于这些常量值在For循环中传播时保持不变,我们可以识别它们并安全地跳过这些迭代。这不仅保持了输出的正确性,还减少了计算量。
图10. 值属性示例
非凸核映射
核映射涉及将计算图中的算子分配给GPU核,以便在现代硬件平台上执行【15,Optimal Kernel Orchestration for Tensor Programs with Korch,2024,ASPLOS】。对于给定的计算图,简单地将所有算子融合成一个单一的核通常会导致性能次优,因为存在复杂的归约依赖和有限的并行性。识别合适的核以实现高效率至关重要。
我们首先给出最先进工作【15】中核的正式定义。
定义1 (核)。对于一个计算图 $G = (V, E)$,一个节点集合 $S \subseteq V$ 形成一个核,如果不存在节点 $v_1, v_2 \in S$ 和另一个节点 $u \in V \setminus S$ 使得 $v_1 \rightarrow u$ 且 $u \rightarrow v_2$,其中 $x \rightarrow y$ 表示在G中存在从x到y的路径。
这个定义将计算图中的一个凸(Convex)算子子集视为一个核。凸意味着这种核不能包含通过外部算子依赖于自身的算子,如图11(a)所示。Korch【15】认为这种核由于循环依赖(即Div依赖于Reduce,而Reduce又依赖于Exp)而无法执行。然而,这种循环依赖可以通过其他核来解决。具体来说,一个核不必生成其所有输入;只要有另一个核提供这些输入就足够了。例如,为了解决图11(a)中的循环依赖,另一个包含Exp和Reduce的核可以提供Reduce的输出。
通过放宽凸的要求,即支持非凸(Non-Convex)的核,一个核不再需要包含内部算子路径上的任何算子,这带来了以下两个好处:
1. 更少的计算:与凸核相比,非凸核由于放宽了要求,能够包含更少的算子,直接避免了相应的计算。
2. 潜在的更宽松的归约依赖:排除一些算子可能会使子图的依赖关系变得更简单。如图11(a)和(b)所示,如果我们排除Reduce来形成一个包含Exp和Div的非凸核,其归约依赖要简单得多,因为该核中只剩下逐元素的算子。
下面我们给出FlashTensor中考虑的非凸核的正式定义。
定义2 (非凸核)。对于一个计算图 $G = (V, E)$,一个节点集合 $S \subseteq V$ 形成一个非凸核,如果存在节点 $v_1, v_2 \in S$ 和另一个节点 $u \in V \setminus S$ 使得 $v_1 \rightarrow u$ 且 $u \rightarrow v_2$。
基于此定义,我们可以讨论核的输入和输出张量。
1. 核输入。核的输入张量是根据核内所有算子的依赖关系确定的。具体来说,如果一个张量不是核内任何算子的输出,它就被分类为该核的输入。
2. 核输出。相反,识别核的输出张量不像推断输入那样直接。输出张量不能仅通过检查数据依赖来确定,因为核内的中间张量可能被指定为输出来解决其他核的循环依赖。因此,有必要明确指定哪些中间张量应被视为核的输出。
轻量级计划搜索
基于提出的规则,下一步是在巨大的搜索空间中高效地搜索一个高性能计划。我们主要在变换和核映射阶段采用两阶段搜索。下面我们分别介绍每个阶段的快速搜索方法。
阶段1:变换。此阶段的搜索目标是尽可能减小中间张量的大小,以减少内存访问开销。然而,由于搜索空间巨大,解决这个问题具有挑战性。一个直接的贪心搜索方法是不够的,因为它倾向于收敛到局部最优解。例如,一些变换可能不会立即影响输入和输出张量的大小,但可以为未来大幅减小张量大小铺平道路。这个限制导致贪心算法错失全局最优解。为了应对这一挑战,我们提出了一种基于模拟退火(Simulated Annealing)【3,Simulated annealing,1993,Statistical science】的张量大小最小化搜索算法。我们概率性地探索各种变换方案,偶尔接受较大或相同大小的张量以跳出局部最优,并随着算法的进行逐渐降低这种接受的概率。虽然该算法不保证找到最优计划,但它提供了对执行时间的控制,使其在实际应用中可行。
阶段2:核映射。由于FlashTensor引入了非凸核,核的搜索空间进一步增大。我们使用一个基于屋顶线(roofline)模型的性能模型来预测每个计划的性能,以便从剪枝后的空间中找到最优计划。为了快速识别高性能核,我们还基于属性约束进行了大量的剪枝,以消除性能差的候选核。
- 基于屋顶线模型的性能模型。为了识别所有候选者,包括凸核和非凸核,我们采用暴力搜索方法来探索计算图中所有可能的连通算子子集。一个连通子集可能对应多个候选者,这取决于子集内的中间张量是否被标记为输出。为了预测候选者的性能,我们提出了一个基于屋顶线模型的性能模型,主要考虑两个指标:
- 算术强度 (Arithmetic Intensity)。这个指标常用于屋顶线分析【42,Roofline: an insightful visual performance model for multicore architectures,2009,Commun. ACM】中评估单个算子,定义为计算量与内存访问量之比。对于包含多个算子的候选核,我们将其扩展如下:
算术强度公式
这里,$Comp(op)$表示核内算子op的计算量,$Size(t)$表示作为核输入或输出的张量t的大小。 - 并行度 (Parallelism)。该指标量化了可用于核并行化的并行执行单元数量,该数量从输入张量的归约依赖属性中推导得出。对于Batch维度,所有尺寸都贡献于并行度。对于Reuse维度,我们使用启发式方法根据维度大小确定最佳并行度。具体来说,如果维度太小,无法通过分块实现有效的数据重用,则所有尺寸都计入总并行度。反之,如果维度较大,我们应用预定义块大小的分块策略,并将分块数量计入并行度。
- 算术强度 (Arithmetic Intensity)。这个指标常用于屋顶线分析【42,Roofline: an insightful visual performance model for multicore architectures,2009,Commun. ACM】中评估单个算子,定义为计算量与内存访问量之比。对于包含多个算子的候选核,我们将其扩展如下:
- 属性约束剪枝。鉴于候选者数量极其庞大,我们提出了一种属性约束剪枝策略来减少搜索空间并专注于有希望的候选者。在迭代过程中,如果核的并行度低于可用的并行单元数量,或者其算术强度低于预定义的阈值,则该核将被剪枝。这有助于消除不满足性能预期的低效核,并提高搜索效率。
整个搜索算法如算法1所示。它首先确定可用的并行执行单元数量,如GPU流式多处理器(SMs)。然后,我们迭代计算图的连通子集,剪枝那些并行度或算术强度低于性能阈值的核。一旦候选核被剪枝,我们遵循先前的工作【15】将最优候选搜索形式化为一个二元线性规划(BLP)任务并求解。
A4 实验
FlashTensor基于MLIR【18,MLIR: A Compiler Infrastructure for the End of Moore’s Law,2020,CoRR】和Triton【37,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】实现,包含1万行C++代码和2千行Python代码。FlashTensor接受ONNX格式的张量程序,并将其转换为有效的MLIR代码。我们实现了两个MLIR Dialect:FT和FTTriton。FT定义了张量算子及相应的MLIR pass实现的变换。FTTriton作为从FT到Triton DSL的桥梁。所有优化应用后,最终的MLIR代码将被转换为有效的Triton DSL。
实验环境
- 硬件平台:
- 一块NVIDIA A100-PCIE-40GB GPU,配备两颗AMD EPYC 7742 64核CPU。
- 一块H100-PCIE-80GB GPU,配备一颗AMD EPYC 7453 28核CPU。
- 软件配置: Python 3.10, CUDA 12.1, GCC 12.2。
- 工作负载:
- 模型: 评估了七个支持长序列输入的模型,包括H2O【45】,RoCo【28】,Keyformer【1】,SnapKV【20】,Corm【9】,Vanilla Attention (V.A.)【40】,Gemma2【34】。
- 模型参数: 为公平比较,所有模型均使用Llama-2-7b【39】作为基础模型,权重为FP16,头数为32,头维度为128,唯一的区别在于它们的注意力模块。
表3. 评估模型的基本信息。
- 基线系统:
- 深度学习编译器: PyTorch 2.2.2【24】 (及其TorchInductor【2】), TensorRT 10.0.1【36】, TVM 0.16.0【5】 (使用MetaScheduler【6, 46】), Korch【15】, EinNet【47】。
- 算子库: FlashAttention 2.6.2【10, 11】, FlashInfer 0.1.0【12】。
实验结果
端到端与核心模块性能
如图12(a)所示,FlashTensor在A100上最高实现了2.22倍的加速,在H100上最高实现了1.62倍的加速,优于所有基线。性能提升主要源于对核心模块的高度优化。图12(b)展示了核心模块(如注意力变体)的性能,FlashTensor在A100上最高实现了4.52倍的加速,在H100上最高实现了5.43倍的加速。我们观察到,一些结构相似的模型在使用FlashTensor优化后性能大致相当,但经过SOTA编译器优化后性能差异巨大。例如,Gemma2对Vanilla Attention做了微小修改,在A100上,TensorRT处理Gemma2和V.A.的推理时间分别为5.13ms和2.25ms,因为TensorRT的预写规则(类FlashAttention)匹配了V.A.但未能匹配Gemma2。而FlashTensor处理两者的时间分别为1.14ms和1.20ms,性能相当且均优于TensorRT,因为它能自动搜索最优方案并利用值感知优化消除因果掩码带来的冗余计算。
不同序列长度下的可扩展性
图13展示了H2O在不同序列长度下的计算效率和内存占用。
* 计算效率: 随着序列长度增加,FlashTensor的FLOP/s更高,因为其内存访问开销更低。TorchInductor和TensorRT需要在其CUDA核中从全局内存读写大小为$O(seqlen^2)$的中间张量,导致高内存开销。
* 内存占用: FlashTensor的全局内存占用显著小于PyTorch和TensorRT,尤其是在序列长度增加时。PyTorch和TensorRT在慢速全局内存上分配了随序列长度二次方增长($O(seqlen^2)$)的大型中间张量,而FlashTensor旨在最小化这类张量的分配。
与算子库的比较
我们将FlashTensor和TensorRT与手动调优的算子库FlashAttention和FlashInfer进行了比较,这两者仅支持Vanilla Attention和Gemma2。Gemma2【34】在Vanilla Attention中引入了逻辑软上限(logic soft capping),如图14(a)所示。在图14(b)中,FlashTensor的性能与FlashInfer和FlashAttention相当,后两者均采用了领域专家手动调优的算子。虽然FlashTensor没有超越这些算子库,但它提供了无需手动调优即可在各种场景中泛化和优化的优势。
分解分析
我们进行了分解分析以展示FlashTensor各组件的性能影响,如图15所示。基线是PyTorch(加速比为1)。
1. Fission (算子分裂): 将Softmax等复杂算子分解为Exp、Reduce、Div等基本算子。这引入了更多小算子,增加了内存访问量,导致性能下降。
2. B.T. (广播属性感知变换): 最小化中间张量的大小,提升性能并为后续融合满足硬件约束创造条件。
3. K.M. w/o Fusion (无融合的核映射): 识别具有高并行效率的凸核和非凸核。由于识别出的核需要重计算一些算子来解决非凸核的循环依赖,在没有融合的情况下性能仍然会下降。
4. Fusion (融合): 带来最显著的性能提升,通过减少内存访问开销和利用已识别核的高并行效率,最高实现了14.2倍的加速。
5. V.T. (值属性感知变换): 利用因果掩码消除冗余计算,在所有模型上进一步实现了近2倍的额外加速。
- 分阶段分析:
- 变换阶段: 如表4所示,广播属性感知变换规则显著减少了中间张量的总大小约3GiB,这与图13中观察到的内存使用量减少和内存访问开销降低一致。该阶段的搜索成本非常小,在几百毫秒左右。
表4. 序列长度为4096时,基于广播的变换所改变的大小和执行时间 - 核映射阶段: FlashTensor因引入非凸核而面临巨大的搜索空间。如表5所示,Keyformer的注意力模块仅有24个算子,但整个搜索空间大小达到1.69亿。Korch和TVM需要数小时的调优时间仍只能找到次优策略,因为它们需要对每个候选核进行性能分析。相比之下,FlashTensor虽然候选核更多,但通过考虑张量属性有效剪枝了性能可能较差的候选核,显著缩小了候选集,从而能在秒级快速选出最优核。
表5. 搜索性能。#Op表示核心模块中的算子数量。#Candidate表示剪枝后的候选者数量,括号中是可能的候选者总数。
- 变换阶段: 如表4所示,广播属性感知变换规则显著减少了中间张量的总大小约3GiB,这与图13中观察到的内存使用量减少和内存访问开销降低一致。该阶段的搜索成本非常小,在几百毫秒左右。
案例研究: H2O
我们以H2O为例,在A100 GPU上展示FlashTensor如何利用多种关键属性进行优化。
- 归约依赖和尺寸属性: FlashTensor利用非凸核识别来最小化内存访问开销同时最大化并行性。如图16(a)和(b)所示,FlashTensor仅用2个核处理H2O,而TensorRT用了4个核,实现了7.7倍的加速。TensorRT的凸核虽然保证了高并行性,但需要巨大的中间张量,导致显著的DRAM访问(K1, K2, K3核均有数GB)。相比之下,FlashTensor识别的非凸核K5解决了涉及Reduce 0算子的循环依赖,允许在不牺牲并行效率的情况下融合MatMul 0和Reduce 1。K5的输入Q和K的归约依赖属性揭示了在K的S'维度上的并行化机会(Reuse),实现了高并行性,且输入输出张量大小仅为$O(seqlen)$,如表6中DRAM访问量减少所示。
- 广播和尺寸属性: 在K4中,FlashTensor重排了Div和MatMul 1(图16(c)),减小了中间张量的大小。这种重排之所以有效,是因为Div的广播维度与MatMul 1的归约维度对齐。通过此优化,中间张量大小减少了96.86%。
- 值属性: FlashTensor利用值属性,通过提前停止和延迟开始迭代来消除不必要的计算。如图16(d)所示,因果掩码在上三角部分引入了-∞,FlashTensor检测到这些迭代不影响最终结果,从而消除了近一半的计算。
A7 补充细节
相关工作
- 计算图优化。TensorRT【36】和XLA【43】通过预定义规则转换计算图。TASO【16】、PET【41】和EinNet【47】可以搜索新的转换规则。Korch将算子分解为细粒度算子并搜索最佳组合。然而,由于缺乏对张量属性的感知,它们只能生成像EinNet中eOperator这样的简单结构自定义算子,而将其他算子留给手动调优的库来优化,错失了融合复杂算子(如矩阵乘法)的机会。
- 属性感知的代码生成。基于融合的工作,包括Astitch【48】、DNNFusion【21】、SmartMem【22】和Welder【32】,通过考虑每个算子输入和输出元素之间的关系来实现融合。然而,由于没有利用每个维度上的细粒度张量属性,这些方法仅限于对齐良好的工作负载,并且在复杂情况下性能次优,例如涉及多个归约方向的操作。例如,DNNFusion【21】和SmartMem【22】将GEMM和Reduce等操作归类为高度复杂的操作(如“多对多”或“输入布局依赖且可变”)并避免融合它们,假设这种融合可能对寄存器和缓存效率产生负面影响。然而,在长上下文场景中,如果输入张量轴足够长以提供并行性,并且数据依赖关系得到仔细管理以避免沿归约维度进行并行化,那么这种融合既是可行的也是有益的。这需要对细粒度张量属性(如尺寸和归约依赖)进行分析,正如FlashTensor所提出的那样。
- 算子级代码生成。cuDNN【7】和cuBLAS【8】为GEMM和卷积等传统算子提供了高效的实现。FlashAttention【10, 11, 30】和Flashinfer【12】为注意力提供了快速的核,将所有算子融合成一个单一的核以最小化内存访问开销,这是FlashTensor所涵盖的一个特例。然而,对于像H2O这样更复杂的场景,简单地融合所有算子,如类FlashAttention方法,会因数据依赖导致的并行性差和硬件利用率不足而导致性能次优。FlashTensor通过采用非凸核映射和重计算克服了这些限制,通过其编译器驱动的方法实现了高并行性和减少的内存访问开销。对于像TVM【5】、Triton【37】、FreeTensor【33】和FlexAttention【13】这样的领域特定语言,它们支持自定义代码生成,但需要专门的编程知识。而FlashTensor可以自动支持新模型。
A5 结论
本文为解决长上下文带来的显著内存开销问题,提出了一个名为FlashTensor的DNN优化系统,该系统利用细粒度的张量属性来优化整体性能。我们总结了四个关键的张量属性,包括归约依赖、广播、尺寸和值,并从计算图中识别它们。FlashTensor进一步采用属性感知的变换和核映射来实现最优性能。实验结果表明,FlashTensor在典型的GPU平台上优于最先进的工作。
A6 附录
FlashTensor是一个开源项目,可公开访问:https://github.com/monellz/FlashTensor。复现论文中报告结果的说明可在以下地址找到:https://github.com/monellz/FlashTensor-AE 和 https://zenodo.org/records/14220175。
💬 评论讨论
欢迎在这里分享您的想法和见解!