Mirage: A Multi-Level Superoptimizer for Tensor Programs
作者/机构: Mengdi Wu, Xinhao Cheng, Shengyu Liu†, Chunan Shi†, Jianan Ji, Man Kit Ao, Praveen Velliengiri‡, Xupeng Miao♯, Oded Padon⋄, Zhihao Jia (卡内基梅隆大学, 北京大学†, 宾夕法尼亚州立大学‡, 普渡大学♯, 魏茨曼科学研究所⋄)
A1 主要贡献
本文介绍了一种名为Mirage的多级张量程序超级优化器。当前用于深度神经网络(DNN)的框架,如PyTorch和TensorFlow,通常依赖手动设计的规则将张量程序映射到专家编写的GPU核函数,这不仅需要大量的工程努力,还可能错失优化机会。为解决此问题,自动化方法应运而生,主要分为两类:第一类是基于调度(schedule)的优化器(如Halide, TVM, Ansor),它们在固定算法下搜索最优的执行方案,但无法探索数学上等价的不同算法;第二类是基于代数变换(algebraic transformation)的优化器(如TASO, PET),它们利用数学等价性在算法层面重组或融合算子,但受限于预先提供的核函数实现性能。
现有方法都要求程序员手动指定一组核函数,并在代数或调度变换的空间内进行探索。然而,一些高级性能优化(如FlashAttention)需要跨越GPU计算层次结构(核、线程块、线程级别)进行协同变换,并涉及生成全新的自定义核函数,这些优化超出了现有自动化方法的搜索范围。
Mirage是首个能够自动发现并验证此类复杂优化的多级超级优化器,其核心贡献包括:
1. µGraphs (微图):提出了一种分层的图表示法,能够统一描述GPU的核、线程块和线程三个计算层级的张量程序。这种统一表示使得Mirage能够同时探索代数变换、调度变换,并能通过优化µGraph来发现新的自定义核函数。
2. 表达式引导的µGraph生成器:为了应对µGraphs带来的巨大搜索空间,Mirage引入了一种基于抽象表达式的剪枝技术。该技术通过将µGraph抽象为逻辑表达式,有效减少了需要考虑的µGraph数量,同时为发现最优解提供了一定的理论保证。
3. 概率性等价验证器:为了确保优化后的µGraph与输入程序功能等价,Mirage引入了一种概率性验证程序。该程序利用了多项式恒等式检验(PIT)的理论,将其推广到包含线性、除法和有限指数运算的LAX(Linear, Division, and an Exponentiation)程序片段,通过在有限域上进行随机测试来提供强大的正确性保证。
4. µGraph优化器:对通过验证的µGraph进行进一步优化,包括张量布局、算子执行顺序调度和内存分配规划。
评估结果显示,Mirage能够自动发现FlashAttention等手动设计的优化,并能找到性能更优的方案(在某些场景下高达2.2倍)。对于广泛使用且经过高度优化的DNN,Mirage的性能比现有方法高出3.3倍。
A2 方法细节
多级图表示 (Multi-Level Graph Representation)
µGraph表示法。Mirage使用µGraph来指定张量程序在GPU上的执行。一个µGraph包含多个层级的层次化图,分别代表核(kernel)、线程块(block)和线程(thread)级别的计算。本节将首先描述GPU的层次结构,并以图3为例介绍µGraph的关键组成部分。
GPU层次结构。图2展示了现代GPU的层次结构。GPU上的计算被组织成核函数(kernels),每个核函数以单程序多数据(SPMD)的方式在多个GPU核心上同时执行。一个核函数包含一个线程块网格(grid of thread blocks),每个线程块在一个GPU流式多处理器(SM)上执行,并包含多个线程来对单个数据元素进行计算。每个线程都有独立的寄存器文件,同一线程块内的所有线程可以访问共享内存以实现集体操作。最终,一个核函数的所有输入和输出都存储在GPU设备内存中。
核图 (Kernel graph)。每个张量程序对应一个核图,其中每个节点代表在整个GPU上运行的一个核函数,每条边是核函数之间共享的张量。核图中的所有张量都存储在GPU设备内存中,因为不同的核函数不能在寄存器文件或共享内存中共享数据。核图中的节点可以是预定义的核算子,例如由cuDNN【索引15,cudnn: Efficient primitives for deep learning,2014】支持的卷积和由cuBLAS【索引16,Dense Linear Algebra on GPUs,2016】支持的矩阵乘法。此外,为了实现细粒度的核间优化(如核融合),节点也可以是图定义核算子,其语义和行为由一个更低层级(即块级)的图定义。例如,图3b中的核算子就是一个由块图指定的图定义算子。
块图 (Block graph)。块图指定了与一个线程块相关的计算,其中每个节点表示块内的计算(块算子),每条边(图3b中的蓝色箭头)是块算子之间共享的张量。Mirage将块图内的所有中间张量存储在GPU共享内存中,主要有两个考虑:首先,GPU共享内存提供了比设备内存高得多的带宽,这种设计允许Mirage通过最大化地在共享内存中保存中间结果来减少设备内存访问。其次,对于尺寸超过共享内存容量必须存储在设备内存中的张量,Mirage使用这些张量将计算分割成多个块图,每个块图只包含共享内存中的张量,这种分离不会引入额外的设备内存访问。
块图的执行属性。每个块图还关联了一些指定其执行方式的属性,具体如下。
网格维度 (Grid dimensions)。一个核函数内的所有线程块被组织成一个最多3维的网格,维度标识为x、y和z。块图关联了最多三个网格维度,指定了沿x、y、z三个维度上的线程块数量。图3b中的块图启动了128个线程块。
输入张量分区 (imap)。首先,对于图定义核算子的每个输入张量(例如图3b核图中的X, G, W),其关联的块图包含一个imap
,用于指定输入张量如何被分区为供单个线程块使用的子张量。对于每个网格维度(即x, y, z),imap
将其映射到(1) 输入张量的一个数据维度 或 (2) 一个特殊的复制维度φ。对于(1),被映射的数据维度会沿着该网格维度在线程块间被均等划分。对于(2),输入张量会在这些线程块间被复制。例如,图3b中的块图有三个输入X, G, W。对于W,其imap={x ↔d}
表示张量W的d维度被划分为128个大小相等的块。因此,每个线程块接收到的W形状为[h=1024, d=32]。
输出张量拼接 (omap)。其次,对于块图的每个输出张量(例如图3b中的Z),块图包含一个omap
,用于指定所有线程块的输出如何被拼接以构建核算子的最终输出。在omap
中,每个网格维度必须映射到输出张的一个数据维度,因为不同的线程块必须在设备内存中存储不相交的张量。对于图3b中形状为[b=16, d=32]的Z,其omap={x ↔d}
表示具有相同x索引的块的输出沿d维度拼接,最终形成形状为[b=16, d=4096]的张量Z。
For循环体 (For-loop body)。为了将大型输入张量装入共享内存并使设备内存数据加载与计算重叠,块图可以包含一个for循环体,该循环体被多次执行以完成一个核函数。通常,核函数中的for循环之后会进行一些后处理。例如,在计算平均值时,for循环会执行n个值的求和,而后处理会除以n。Mirage使用输入迭代器、for循环累加器以及它们之间的所有算子来指定块图的for循环体,如图3b中的橙色框所示。块图的每个输入张量首先通过一个输入迭代器,该迭代器将部分张量(如X, G, W)从设备内存加载到共享内存。每个输入迭代器关联一个fmap
,以指定在每次迭代中加载输入张量的哪一部分。fmap
将每个for循环维度映射到(1) 输入张量的一个数据维度 或 (2) 复制维度φ。与imap
类似,对于(1),张量沿该维度被均等划分;对于(2),则进行复制。图4展示了如何使用不同的imap
和fmap
在一个输入矩阵上对线程块和for循环迭代进行分区。
For循环累加器。每个块图还关联一个for循环维度,它决定了为完成核函数而执行for循环体的迭代次数。此外,Mirage使用for循环累加器(例如图3b中的两个Accum算子)来累加每次迭代计算的中间结果(使用标准累加器,如求和和最大值),并将累加结果存储在共享内存中。一旦for循环体完成,Mirage会直接在累加结果上执行循环体外的剩余算子。最后,一个输出保存器将最终结果从共享内存保存回设备内存。
线程图 (Thread graph)。线程图进一步将计算范围从一个线程块缩小到单个线程。与块图类似,每个线程图也关联了块维度(指定块内线程的组织)和for循环维度(定义完成计算所需的总迭代次数)。每个线程图包含输入迭代器,它们将输入张量(如图3b中的A和B)从共享内存加载到寄存器文件;以及输出保存器,它们将输出张量(如C)从寄存器文件存回共享内存。线程图是µGraph中最低层级的图,只包含预定义的线程算子。
张量布局 (Tensor layout)。核图、块图或线程图中的每个张量都关联一个张量布局(为简化在图3中省略),它指定了张量在内存中如何线性化。值得注意的是,张量布局只影响µGraph的性能,对其输出的正确性没有影响。
µGraph有效性定义 (Definition 2.1)。一个µGraph G是有效的,当且仅当:(1) 对于G中的每个核、块和线程算子o,其输入和输出张量与o的规范匹配;(2) 每个核图、块图和线程图中的所有张量可以分别驻留在GPU设备内存、共享内存和寄存器文件中;(3) 对于每个具有for循环体的块图和线程图,任何从输入到输出的路径都恰好经过一个输入迭代器、一个for循环累加器和一个输出保存器。
与先前工作的比较。先前的工作分别考虑代数变换【索引25,TASO: Optimizing deep learning computation with automatic generation of graph substitutions,2019,SOSP;索引46,PET: Optimizing tensor programs with partially equivalent transformations and automated corrections,2021,OSDI】或调度变换【索引13,TVM: end-to-end optimization stack for deep learning,2018,CoRR;索引31,Automatically scheduling halide image processing pipelines,2016,ACM Trans. Graph.;索引35,Halide: A language and compiler for optimizing parallelism, locality, and recomputation in image processing pipelines,2013,PLDI】,而µGraphs能以统一的方式表示这两种变换。具体来说,网格和for循环维度及其到张量维度的相应映射(即imap、omap和fmap)构成了图定义算子所有可能调度的全面搜索空间。跨越核、块和线程级别的层次化图允许Mirage在这些级别上探索代数变换。
表达式引导的µGraph生成器
本节介绍Mirage µGraph生成器,它能为输入张量程序自动发现潜在的µGraph。为了生成能在核、块和线程级别上捕捉优化的µGraph,Mirage必须探索比现有仅考虑核级别优化的超级优化器大得多的搜索空间。Mirage采用两种关键技术来应对这一挑战。首先,基于核级别和块级别的优化对性能的贡献远大于线程级别(因为访问设备和共享内存比访问寄存器文件昂贵几个数量级)的观察,Mirage的µGraph生成器采用混合方法:它在核和块级别上详尽地考虑所有达到一定大小的可能图,并使用基于规则的策略来构建线程级别的图。这种方法在保留大多数性能关键优化的同时减少了搜索空间。其次,为了进一步剪枝搜索空间,Mirage引入了一种基于µGraph抽象(称为抽象表达式)的剪枝技术,该技术在大幅减少Mirage必须考虑的µGraph数量的同时,为发现的µGraph的最优性提供了某种理论保证。我们在§4.1和§4.2介绍混合µGraph生成算法,在§4.3介绍表达式引导的剪枝技术。
核图与块图的生成
增量式图生成。Mirage以增量方式生成核图和块图,并利用多种剪枝技术来减少搜索空间,如图5第二部分所示。具体来说,Mirage维护一个有效µGraph的前缀,并迭代地用新算子扩展它。对于一个图G = (V, E)
,如果G' = (V', E')
是G
的子图,并且对于所有u ∈ V'
,所有(v, u) ∈ E
都有v ∈ V'
,那么我们称G'
是G
的前缀。
算法流程。为了在核图中生成下一个算子,Mirage会枚举核算子类型t
和输入张量集I
。如果t
代表图定义算子类型,Mirage会通过以下方式生成定义其核计算的关联块图:(1) 枚举网格和for循环维度(在§2中介绍),这使得Mirage能够计算块图的输入张量形状;(2) 执行一个与核级别类似的嵌套生成过程,但不考虑图定义算子。算法1的第6-16行和第17-24行分别展示了Mirage如何生成核算子和块算子。在添加一个算子之前,Mirage会检查张量形状(第28行)和内存使用情况(第29行),以确保生成有效的前缀。
规范形式以避免重复。为确保相同的µGraph只生成一次,Mirage定义了µGraph的规范形式。给定一个µGraph G,其算子按拓扑顺序为o1, ..., on,oi的第j个输出的索引定义为一个元组(i, j)。G中的每个算子oi被赋予一个秩(inputi, typei),其中inputi是oi的输入张量索引列表,typei是算子类型。如果一个µGraph的算子按秩的升序排列,则它处于规范形式。Mirage通过要求按秩的升序添加算子(第9行和22行)来只生成规范形式的µGraph。这种方法不会剪掉任何有效解,因为每个µGraph都可以通过重排算子转换为规范形式。
抽象表达式剪枝。此外,Mirage利用抽象表达式技术来剪掉不满足特定约束的前缀,这将在§4.3中介绍。
线程图的构建
基于变换的构建方法。虽然类似的嵌套生成策略可以应用于线程图,但Mirage转而使用基于变换的方法来构建它们(见图5的第三部分和算法1的第4-5行),以减少搜索空间。Mirage在构建线程图时应用算子融合,通过尽可能在寄存器文件中重用张量来减少对共享内存的访问。例如,Mirage将图3b中的三个逐元素算子(Mul, Sqrt, Div)融合成一个线程图,避免了将中间结果保存到共享内存,并将这些算子的整个计算过程保留在寄存器文件中。虽然我们当前的实现侧重于算子融合,但也可以使用其他基于规则的变换来构建线程图。
通过抽象表达式进行剪枝
剪枝的直觉。在搜索可能的µGraph空间时,我们的目标是避免那些中间结果无法对期望计算做出贡献的µGraph前缀。例如,对于输入程序X · Z + Y · Z
,我们可以剪掉一个计算X · Y
的前缀,但不应剪掉计算X + Y
的前缀,因为(X + Y) · Z
等价于输入程序。但是,我们如何在搜索某个计算的同时,判断一个前缀是否能对这个期望的计算做出贡献呢?下面,我们通过抽象来解决这个“鸡生蛋,蛋生鸡”的问题,并发展出一种由这个直觉驱动的剪枝技术。我们首先介绍这种抽象——抽象表达式,然后解释如何用它来剪枝。最后,我们提供一个理论保证,在某些条件下,这种剪枝不会排除最优的µGraph。
抽象表达式的定义。回想一下,µGraph中的一条边对应一个以输入张量为变量的张量值函数。直观地,抽象表达式通过忽略同一输入张量中元素之间的差异来抽象这些函数。形式上,抽象表达式是整数理论和未解释函数上的一阶逻辑项。在µGraph中,每条边的抽象表达式,记为E(·),定义在表1中。计算µGraph的抽象表达式时,所有图定义算子都会被“内联”。具体来说,为图定义算子的输入计算出的表达式会传递到其低层图中,该低层图产生的输出表达式成为图定义算子的输出表达式。图6展示了注意力机制一个子图的抽象表达式。
表格1:Mirage支持的算子。第二列显示了支持每个算子的图级别(K、B和T分别代表核、块和线程图)。最后一列定义了每个算子输出的抽象表达式,其中E将张量映射到它们的抽象表达式。
抽象表达式的作用。虽然抽象表达式捕捉了每条边上计算的函数的一些信息,但它们也抽象掉了许多细节。例如,如果X是一个k × k的矩阵,对行求和与对列求和都会得到相同的抽象表达式——sum(k, E(X))。但是,将k作为抽象表达式的一部分对于有效的剪枝至关重要。
抽象子表达式与剪枝。我们使用抽象表达式来剪枝µGraph的搜索空间,方法是形式化抽象表达式上的两个关系:等价性和抽象子表达式。具体来说,我们剪掉任何其抽象表达式不是某个与输入程序抽象表达式等价的抽象表达式的子表达式的µGraph前缀。我们将抽象表达式形式化为整数算术和未解释函数理论上的一阶逻辑中的未解释函数,并使用一个SMT求解器基于表2中的两组公理:Aeq和Asub来对它们进行推理。
公理系统Aeq和Asub。首先,Aeq公理化了抽象表达式之间的等价关系。需要明确的是,这些公理不必是健全的——不要求具有等价抽象表达式的µGraph在功能上等价,因为不等价的µGraph可以有相同的抽象表达式。其次,Asub公理化了抽象表达式之间的子表达式关系。Asub的一个关键性质是,只要一个µGraph G1是G2的前缀——意味着G2可以通过在G1上添加额外算子来构建——E(G1)就是E(G2)的抽象子表达式;形式上,Asub |= subexpr(E(G1), E(G2))
,其中|=
表示在整数算术和未解释函数理论下的蕴含关系。
剪枝过程。在搜索过程中,算法1首先计算输入LAX程序的抽象表达式,记为EO
,并剪掉任何µGraph前缀G,如果Aeq ∪ Asub ̸|= subexpr(E(G), EO)
。也就是说,如果一个图的抽象表达式不是EO
的子表达式,那么它就会被剪掉。这个检查是使用一个SMT求解器(Z3【索引18,Z3: An efficient smt solver,2008,TACAS'08/ETAPS'08】)来完成的。作为一个优化,这些检查的结果会被缓存和重用,因为Mirage在搜索过程中可能会遇到多个具有相同抽象表达式的µGraph。
理论保证与剪枝-最优性权衡。直观地说,我们的剪枝会保留任何可能导向一个其抽象表达式与输入LAX程序的抽象表达式(根据Aeq)等价的µGraph的前缀。形式上:
定理1 (通过抽象表达式进行剪枝)。对于一个输入µGraph G0,以及一个与G0等价的µGraph G,如果Aeq |= E(G0)= E(G)
,那么G将被算法1生成。
证明思路。根据表1和表2,我们证明对于任何算子op
,如果Y =op(X1, . . . , Xn)
,那么Asub |= subexpr(E(Xi), E(Y))
对于1 ≤ i ≤ n
成立。也就是说,op
的每个输入的抽象表达式总是其输出的子表达式。鉴于Asub包含自反性和传递性公理,因此对于任何G的前缀G',Asub |= subexpr(E(G'), E(G))
成立。再加上Aeq |= E(G0) = E(G)
的假设,我们有Aeq ∪ Asub |= subexpr(E(G'), E(G0))
。因此,G的任何前缀都不会被剪枝,Mirage将会生成G。
表格2:用于剪枝的抽象表达式公理化。Mirage通过查询SMT求解器来检查抽象表达式E1是否是E2的子表达式,即检查subexpr(E1, E2)
是否可以由这些公理推出。这些公理中的所有变量都是全称量化的。
剪枝与最优性的权衡。该定理强调了抽象表达式在解决上述“鸡生蛋,蛋生鸡”问题中的作用。为了判断一个前缀µGraph是否有用,我们在抽象层面上推理它是否是一个有用计算的前缀。抽象的选择和公理Aeq
代表了最优性和剪枝效果之间的权衡。如定理1所示,我们只保证能找到其抽象表达式在Aeq
下与输入程序等价的最优µGraph。更强的公理会扩大定理覆盖的µGraph集合,但会降低剪枝效果,因为更多的前缀会通过子表达式测试。特别地,请注意Aeq
不包括抵消规则(例如,div(mul(x, y), y)=y
)。因此,Mirage可能会错过一些等价的µGraph。然而,包含这样的公理会使任何东西都成为任何东西的子表达式,从而使期望的剪枝失效。如我们的评估所示,所选的Aeq
在剪枝和最优性之间取得了良好的平衡。
概率性等价验证器
核心思想。Mirage的概率性等价验证器用于检查一个候选µGraph是否与期望的LAX程序等价。其关键思想是在两个有限域上用随机输入来评估两者。使用有限域代替浮点数不仅避免了浮点误差,还提供了一个强大的理论保证:接受一个不等价µGraph的概率可以被做得任意低。
LAX程序的优势。对于通用程序,随机测试几乎无法提供任何正确性保证。然而,我们证明了对于LAX程序(定义如下),随机测试提供了概率性的正确性保证,并且重复测试可以将错误概率降低到任意小的阈值。先前的工作【索引46,PET: Optimizing tensor programs with partially equivalent transformations and automated corrections,2021,OSDI】已将类似技术应用于只包含线性算子(如矩阵乘法、卷积)的张量程序等价性检查。我们开发了一种随机测试技术,它还支持除法和指数运算,这对于许多DNN优化(例如§3中的RMSNorm示例)是必需的。
LAX µGraph的定义。Mirage在LAX µGraph(线性、除法和指数运算)之间验证等价性,定义如下。我们在§5.1介绍主要的理论结果,并在§5.2展示Mirage的验证方法。
定义5.1 (LAX µGraph)。一个µGraph G是一个LAX µGraph,如果(1) G只包含多线性算子、除法和指数运算,并且(2) G中从输入到输出的每条路径最多包含一个指数运算。
理论基础
LAX µGraph的数学形式。不失一般性,我们假设一个LAX µGraph G接受n个输入张量并产生一个输出张量。我们的理论结果可以直接推广到具有多个输出的LAX µGraph。由于每个LAX µGraph包含线性算子、除法和每条路径最多一个指数运算,输出张量中每个元素的计算可以表示为以下形式(通过使用标准恒等式,如a/b * c/d = ac/bd, a/b + c/d = (ad+bc)/bd, e^x * e^y = e^(x+y)):
其中fi, gi, hi, f'j, g'j
和h'j
(1 ≤ i ≤ k, 1 ≤ j ≤ k')是关于输入张量元素的各项的多项式。
核心理论定理。我们随机化等价验证的主要理论基础是以下定理,它将有限域上的多项式恒等式检验(PIT)【索引37,Fast probabilistic algorithms for verification of polynomial identities,1980,J. ACM;索引54,Probabilistic algorithm for testing polynomial identities,1979,Springer】推广到LAX µGraph。注意,两个LAX µGraph的差也具有方程(2)的形式。因此,两个LAX µGraph的恒等性检验简化为测试一个该形式的表达式是否为零。由于指数运算的存在,我们使用两个有限域而不是一个。
表格3:用于随机测试的算术运算。Mirage选择两个素数p和q,使得q整除p-1。xp和xq分别是来自有限域Zp和Zq的值。符号x^-1和√x分别表示在相应有限域中的乘法逆元和平方根。具体来说,xx^-1 mod p = 1 且√x * √x mod p = x。
定理2。设P
是方程(2)描述形式的函数,其中fi, gi, hi, f', g', h'
是次数最多为d的非零多项式,其整数系数在[-w, w]
之间。设p, q
是素数,使得q | p-1
且q > 2w
。设G
是Zp
中q
次单位根的集合。如果P
不是一个零函数,那么【索引27,Identity testing for circuits with exponentiation gates,2025】:
在有限域上的随机测试
验证流程。Mirage利用定理2,通过在定理2定义的有限域Zp和Zq上进行随机测试,来概率性地验证两个µGraph的等价性。为了检查两个µGraph的等价性,Mirage首先生成输入张量,其中每个元素从Zp × Zq中均匀采样。Mirage还从Zp中的q次单位根集合中均匀采样ω,用于指数运算。然后,Mirage使用表3中定义的运算在这些输入上评估这两个µGraph。
运算细节。如§5.1所述,Zp和Zq分别用于指数运算外部和内部的计算。除指数运算外的所有运算都通过在Zp和Zq中独立进行模算术来实现。对于指数运算,Mirage使用来自Zq的值xq,并计算ω^xq mod p
以获得一个在Zp中的结果。
重复测试与错误率。注意,在LAX µGraph中,指数运算在每条路径上最多执行一次。最后,Mirage检查两个µGraph是否产生相同的输出。这个过程会重复多次,如果两个µGraph通过了所有随机测试,则认为它们是等价的。以下由定理2得出的定理表明,这个过程可以产生任意低的错误率。
定理3。等价的µGraph总能通过µGraph验证。对于两个不等价的µGraph和一个给定的概率阈值0 < δ ≤ 1,这两个µGraph通过所有Ω( (k^2/ln(q)) * ln(1/δ) )
次随机测试的概率最多为δ。
数值稳定性。虽然该定理联系了有限域和实数计算,但实数计算和浮点运算之间可能出现差异,特别是在涉及因中间值过大而导致的溢出或下溢时。Mirage采用浮点测试来过滤掉具有显著数值误差的µGraph。
µGraph优化器
后验证优化。对于每个通过验证的µGraph,Mirage的µGraph优化器通过进一步执行布局优化、算子调度和内存规划来最大化其性能,如图1所示。Mirage将这些µGraph优化推迟到验证之后进行,原因有二。首先,这些优化不影响生成的µGraph的正确性;在生成µGraph时忽略它们可以减少Mirage必须考虑的搜索空间,因为具有相同图拓扑但张量布局、算子顺序或内存分配方案不同的µGraph被µGraph生成器视为相同的。其次,在验证后应用这些优化也减少了这些优化的搜索空间,因为µGraph优化器只需要优化功能上与输入等价的µGraph。
张量布局。µGraph优化器探索核、块和线程级别上所有中间张量的可能数据布局,并选择最佳组合以最大化性能。我们将布局选择形式化为一个约束优化问题,并使用整数线性规划(ILP)算法来最优地解决它。具体来说,对于每个张量t和t的每个可能布局l,我们引入一个布尔变量Bt,l来表示张量t是否使用布局l。核、块和线程级别的算子可能会对张量布局施加各种约束。例如,要使用cuBLAS库【索引16,Dense Linear Algebra on GPUs,2016】进行矩阵乘法,两个输入张量的最内层维度必须是最后两个维度之一。这些限制被转换为关于Bt,l的一系列线性约束。不同的张量布局可能导致不同的性能。例如,某些输入张量布局支持从设备内存到共享内存的批量复制,而其他则不支持。Mirage引入一个成本函数来模拟每个算子在不同布局选择下的性能。Mirage使用一个现成的ILP求解器(即Z3【索引18,Z3: An efficient smt solver,2008,TACAS'08/ETAPS'08】)来找到一个满足所有布局约束同时最小化成本的最优布局策略。
算子调度。在µGraph中,有多种拓扑顺序来执行算子,不同的顺序可能会产生不同的性能。对于给定的输入µGraph,µGraph优化器通过最小化每个线程块内的线程级同步(即CUDA中的__syncthreads()
)来确定一个高效的算子调度策略。为实现这一目标,Mirage用一个深度来标记每个节点,深度定义为从任何输入算子到该节点的最长路径的长度。Mirage使用动态规划算法计算每个节点的深度,并按深度的升序调度所有算子。这种方法最小化了生成的CUDA核中所需的线程级同步次数,因为Mirage只需要在具有不同深度的算子之间插入同步点。
内存规划。第三类后验证优化是内存规划,它为核、块和线程级别的所有中间张量确定内存偏移。Mirage将内存规划形式化为一个动态存储分配问题,并详尽地枚举所有可能的分配方案以发现一个最优策略。
A3 背景知识与设计原则
案例研究: RMSNorm
本节我们使用均方根层归一化(Root Mean Square Layer Normalization, RMSNorm)【索引50,Root mean square layer normalization,2019】作为案例研究,来展示µGraph表示和Mirage超级优化方法的优势。RMSNorm是近期大型语言模型中广泛使用的一种归一化技术【索引41,The llama 3 herd of models,2024】。形式上,RMSNorm接收两个张量X和G作为输入,并根据均方根对其逐元素乘积进行归一化:
其中d是X的隐藏维度大小。
传统实现方式。RMSNorm通常后接一个矩阵乘法(MatMul)。图3a展示了一个RMSNorm后接一个MatMul算子的计算图,其中X是输入张量,G和W表示两个权重张量。现有的机器学习编译器通常为RMSNorm和MatMul的计算启动两个独立的核函数,因为这两个操作内部都对一个输入维度进行归约,这使得将它们的计算融合到单个核函数中变得具有挑战性。这种方法需要将中间结果(即Y)存储在设备内存中,因为不同的核函数无法在共享内存或寄存器文件中共享数据。
Mirage发现的优化µGraph。图3b展示了Mirage为在单个核函数中计算RMSNorm和MatMul而自动发现的最佳µGraph。该计算被融合在一个单一的图定义核算子中,以避免将中间结果(即Y)保存到设备内存并减少核函数启动开销。
Mirage的关键优化。我们强调Mirage发现的µGraph与原始µGraph之间的关键差异。这些差异涉及发现新的自定义核函数以及结合代数和调度变换,使得单独考虑代数和调度变换无法发现最终的µGraph。首先,Mirage利用矩阵乘法和逐元素除法的交换律重新排序了MatMul和RMSNorm的除法(代数变换)。其次,Mirage并行执行了均方根中的累加(即 Ai = ∑j X^2ij
)和矩阵乘法中的累加(即 Bik = ∑j XijGjWjk
)(调度变换),避免了将累加结果写入设备内存。接下来,Mirage实例化一个线程图来执行一系列逐元素算子,同时将所有中间结果保留在寄存器文件中(调度变换)。最后,发现的最佳µGraph使用一个新的自定义核函数来融合RMSNorm和MatMul的计算,减少了设备内存访问和核启动开销。这个µGraph在NVIDIA A100和H100 GPU上分别比现有系统中的手写核函数快1.5倍和1.9倍。
A4 实验
实验环境
- 硬件配置:实验在NVIDIA A100和H100 GPU上进行,每块GPU均配备40GB内存。对于GQA(用于LLaMA-2-70B),由于其规模较大,采用张量模型并行【索引39,Megatron-lm: Training multi-billion parameter language models using model parallelism,2019】策略,在四块GPU上并行执行。
- 软件配置:Mirage由3万行C++, CUDA和Python代码实现。核算子使用cuDNN【索引15,cudnn: Efficient primitives for deep learning,2014】和cuBLAS【索引16,Dense Linear Algebra on GPUs,2016】库实现,块和线程算子使用cuTLASS【索引2,Nvidia/cutlass: Cuda templates for linear algebra subroutines,2019】和CUDA PTX实现。SMT和ILP求解器使用Z3 4.12.6【索引18,Z3: An efficient smt solver,2008,TACAS'08/ETAPS'08】。Mirage支持即时(JIT)编译和部署,生成的核函数可直接集成到PyTorch程序中。
- 基准测试:
- 微基准:评估集中在6个常见的DNN LAX程序上,包括大型语言模型(LLM)的主要构建块GQA、RMSNorm、GatedMLP,以及用于增强模型收敛性的QKNorm【索引40,Chameleon: Mixed-modal early-fusion foundation models,2024】和用于微调的LoRA。具体配置见下表4。
- 端到端模型:评估了Mirage生成的核函数对四个完整DNN模型的端到端性能提升,包括Chameleon【索引40,Chameleon: Mixed-modal early-fusion foundation models,2024】,nGPT【索引28,nGPT: Normalized transformer with representation learning on the hypersphere,2024】,LLaMA-3【索引41,The llama 3 herd of models,2024】和LoRA【索引22,Lora: Low-rank adaptation of large language models,2021】。
- 对比系统:包括TASO/PET【索引25,TASO: Optimizing deep learning computation with automatic generation of graph substitutions,2019,SOSP;索引46,PET: Optimizing tensor programs with partially equivalent transformations and automated corrections,2021,OSDI】(代数优化器)、PyTorch【索引34,Tensors and Dynamic neural networks in Python with strong GPU acceleration,2017】(启用torch.compile和FlashAttention)、TensorRT/TensorRT-LLM【索引42,NVIDIA TensorRT: Programmable inference accelerator,2017】(手动优化核)、FlashAttention/FlashDecoding【索引17,Flash-decoding for long-context inference,2023;索引21,Flashdecoding++: Faster large language model inference on gpus,2024】(手动优化核)和Triton【索引43,Triton: an intermediate language and compiler for tiled neural network computations,2019,MAPL】(调度优化器)。所有基线都使用CUDA Graphs以最小化核函数启动开销。
表格4:我们评估中使用的DNN基准。
实验结果
微基准性能 (图7)
与现有最佳方法相比,Mirage通过结合代数变换、调度变换以及生成新的自定义核函数,将这些基准的性能提升了高达3.3倍。
- GQA (分组查询注意力): Mirage不仅发现了专家设计的FlashAttention/FlashDecoding核,还找到了性能更优的µGraph,最高提速2.2倍。这得益于两个额外优化:1) 自动搜索最佳网格维度以充分利用所有SM,而现有方法依赖固定的次优启发式策略。2) 自动选择最高效的并行化策略,根据不同注意力场景定制µGraph,将设备内存访问减少高达7倍。
- QKNorm (查询-键归一化): 如图8所示,Mirage自动发现了一个将QKNorm和注意力计算融合到单个自定义核中的µGraph,避免了中间结果写入设备内存,将核执行时间减少了高达1.4倍。
- LoRA (低秩适应): 如图9所示,Mirage利用代数变换W×X + B×A×X = (W∥B)×(X∥(A×X))
,将三个矩阵乘法和一个加法融合到单个核中,通过在共享内存中更新张量偏移来执行拼接,将LoRA的执行成本降低了1.1-2.4倍。
- GatedMLP (门控多层感知机): 如图10所示,Mirage发现的µGraph在同一块图中并行执行两个矩阵乘法,并将剩余的SiLU和Mul操作作为后处理步骤进行融合。这种方法在A100上实现了1.5倍的加速,在H100上实现了2.7-3.3倍的加速。
- nTrans (归一化Transformer): Mirage将计算融合成单个核,但性能略低于TensorRT。原因在于对于计算量小的核,Mirage将数据从全局内存加载到共享内存再写回的开销较大。未来计划通过支持绕过共享内存来缓解此问题。
端到端性能 (图11)
将Mirage生成的核函数集成到PyTorch程序中,只需少量代码更改即可将四个常用DNN模型的端到端延迟降低0.9-1.9倍。
搜索时间 (表5)
Mirage优化一个LAX程序的耗时最长为4小时,这是一次性的部署前成本。对RMSNorm的消融研究显示,多线程显著减少了搜索时间,而基于抽象表达式的剪枝对于Mirage的可扩展性至关重要,它使Mirage能探索包含多达11个算子的块图,而禁用该技术在10小时内最多只能处理6个算子的块图。
表格5:关于Mirage加速µGraph生成技术的消融研究。我们评估了多线程和抽象表达式对RMSNorm搜索时间的影响。
优化消融研究 (图12)
在A100上对GQA(批大小为1)进行的消融研究表明,独立禁用Mirage的任何一项优化(线程图构建、布局优化、算子调度、内存规划)都会导致5%到70%的性能下降,证明了每个优化组件的重要性。
A5 结论
本文提出了Mirage,首个针对张量程序的多级超级优化器。Mirage引入了一种层次化图表示法,用于描述GPU执行层次结构中的核、线程块和线程级别的张量程序,并采用一种基于抽象的新颖剪枝技术,在显著减少搜索空间的同时,为找到最优解提供了一定的保证。实验证明,即使对于被广泛使用和深度优化的DNN,Mirage的性能也比现有的张量程序优化器高出3.3倍。
A6 方法细节中的引用汇总
- [2] Nvidia/cutlass: Cuda templates for linear algebra subroutines, 2019, https://github.com/NVIDIA/cutlass: 在“实现”章节中提到,用于实现块和线程算子。
- [13] Tianqi Chen et al., TVM: end-to-end optimization stack for deep learning, 2018, CoRR: 在“引言”和“多级图表示”章节中被引用,作为现有调度优化器的一个例子,与Mirage的统一优化方法进行对比。
- [15] Sharan Chetlur et al., cudnn: Efficient primitives for deep learning, 2014, CoRR: 在“多级图表示”和“实现”章节中提到,作为提供预定义核算子(如卷积)的库。
- [16] Dense Linear Algebra on GPUs, 2016, https://developer.nvidia.com/cublas: 在“多级图表示”、“µGraph优化器”和“实现”章节中提到,作为提供预定义核算子(如矩阵乘法)的库,并指出其对张量布局有特定约束。
- [18] Leonardo De Moura and Nikolaj Bjørner, Z3: An efficient smt solver, 2008, TACAS’08/ETAPS’08: 在“通过抽象表达式进行剪枝”、“µGraph优化器”和“实现”章节中被引用,作为Mirage用于抽象表达式推理和整数线性规划的SMT/ILP求解器。
- [25] Zhihao Jia et al., Taso: Optimizing deep learning computation with automatic generation of graph substitutions, 2019, SOSP: 在“引言”和“多级图表示”章节中被引用,作为现有仅考虑代数变换的超级优化器的一个例子。
- [27] Jiatu Li and Mengdi Wu, Identity testing for circuits with exponentiation gates, 2025: 在“理论基础”章节中被引用,为定理2提供了理论依据,该定理将多项式恒等式检验推广到LAX µGraphs。
- [31] Ravi Teja Mullapudi et al., Automatically scheduling halide image processing pipelines, 2016, ACM Trans. Graph.: 在“多级图表示”章节中被引用,作为调度变换方法的参考文献。
- [35] Jonathan Ragan-Kelley et al., Halide: A language and compiler for optimizing parallelism, locality, and recomputation in image processing pipelines, 2013, PLDI: 在“引言”和“多级图表示”章节中被引用,作为引入算法与调度分离思想的代表性工作。
- [37] J. T. Schwartz, Fast probabilistic algorithms for verification of polynomial identities, 1980, J. ACM: 在“引言”和“理论基础”章节中被引用,作为Mirage概率性等价验证器所基于的多项式恒等式检验(PIT)的开创性工作。
- [46] Haojie Wang et al., PET: Optimizing tensor programs with partially equivalent transformations and automated corrections, 2021, OSDI: 在“引言”、“多级图表示”和“概率性等价验证器”章节中被引用,作为代数变换优化器的例子,并且其随机测试技术被Mirage扩展以支持更多算子。
- [54] Richard Zippel, Probabilistic algorithm for testing polynomial identities, 1979, Springer: 在“引言”和“理论基础”章节中被引用,与[37]一同作为多项式恒等式检验(PIT)的基础理论来源。
💬 评论讨论
欢迎在这里分享您的想法和见解!