TileLang: A Composable Tiled Programming Model for AI Systems
文章标题:TileLang: 面向AI系统的可组合分块编程模型
作者/机构:LEI WANG§ (北京大学), YU CHENG§ (北京大学), YINING SHI§ (北京大学), ZHENGJU TANG (北京大学), ZHIWEN MO (帝国理工学院), WENHAO XIE (北京大学), LINGXIAO MA (微软亚洲研究院), YUQING XIA (微软亚洲研究院), JILONG XUE (微软亚洲研究院), FAN YANG (微软亚洲研究院), ZHI YANG (北京大学)
A1 主要贡献
现代AI工作负载严重依赖优化的计算核函数(kernel)进行训练和推理。尽管这些核函数遵循明确的数据流模式(如在DRAM和SRAM之间移动数据块并进行计算),但编写高性能核函数仍然是一项复杂的任务,需要进行细致的、以硬件为中心的优化。现有的领域特定编译器(DSL)虽然简化了这一过程,但在易用性和表达能力上仍有不足。例如,Triton虽然提供了直观的块级原语,但隐藏了线程行为和内存布局等细节,限制了专家开发者追求极致性能。
为了解决这些限制,本文提出了TileLang,一个通用的分块编程模型,旨在提高AI核函数编程的效率。TileLang的核心思想是将数据流与调度(线程绑定、内存布局、张量化和流水线)解耦。用户只需使用可组合的块(tile)操作符来定义数据流,而编译器负责探索和应用调度策略。当编译器的默认优化不足时,用户可以通过一组定制化的注解和原语在前端进行精确控制。
本文的主要贡献如下:
- 块级编程语言 (Tile-Level Programming Language):设计了一种块级编程语言,允许用户显式声明缓冲区在硬件内存层次结构中的位置。通过利用布局推断(Layout Inference)机制,系统在抽象化缓冲区操作并行化复杂性的同时,暴露了线程级的控制接口,使专家能够精确管理每个线程与缓冲区的交互。
- 带自动优化的编译器 (Compiler with Automated Optimization):为TileLang提供了一个配套的编译器,包含一系列自动化编译过程。这些过程包括通过布局推断机制实现的自动并行化、核函数库的动态参数简化、自动流水线推导以及针对动态形状的循环尾部拆分优化。
- 达到业界顶尖的性能 (State-of-the-Art Performance):在NVIDIA和AMD GPU上的真实AI核函数实验评估表明,TileLang的性能与专门的供应商库和其他基于DSL的方法(如Triton)相当,甚至在某些情况下超越了它们,证明了其统一的块-线程范式和透明的调度能力满足了现代AI系统开发的性能和灵活性需求。
A3 一个TileLang示例
现有的将调度与计算分离的机器学习编译器(如TVM)要求用户明确区分计算和调度,并且需要手动注册新的张量指令和指定缓冲区布局以获得最佳性能,但编写和理解调度程序仍然具有挑战性。现代框架如Triton虽然允许用户专注于块级编程,但其数据流表示不够清晰,并且需要使用掩码条件加载等变通方法或依赖硬件特定功能(如TMA)。而像ThunderKitten这样的框架,尽管将程序抽象为块粒度的操作组合,其数据流仍然不够透明,限制了用户的进一步优化。此外,随着基于Python的深度学习框架【3, PyTorch, https://pytorch.org/】【22, Huggingface’s transformers: State-of-the-art natural language processing, 2019, arXiv preprint arXiv:1910.03771】的广泛采用,手动将模型转换为C++进行优化是不切实际的。因此,在设计TileLang时,我们强调三个关键原则:(1) Pythonic设计,无缝集成Python生态系统,提供熟悉的编码体验并降低学习曲线;(2) 以数据流为中心,使用户能够主要关注数据流,同时抽象出底层的调度复杂性,将线程绑定、内存布局、张量化和流水线等调度方面与数据流解耦,并将其封装为一组可定制的注解和原语,以增强可编程性和可维护性;(3) 可组合性,确保核函数、原语和调度策略可以无缝组合以构建复杂的设计。
通用矩阵乘法(GEMM)示例。下面,我们用TileLang实现一个通用矩阵乘法(GEMM)核函数,以说明其基本语法并展示它如何提高生产力。如图1(a)所示,实现首先定义了GEMM核函数的输入和输出(第8行),指定它们的形状和数据类型。随后,我们初始化核函数上下文(第9-11行),它决定了网格大小和总线程数,接着是核函数主体(第12-27行),包括片上内存分配和数据流管理。由于TileLang是一种嵌入Python的编程语言,它支持Python的所有命令式结构(如if-else、for和while),关键区别在于用户必须为函数参数和变量声明提供显式类型注解。这是因为Python的动态类型不适合设备代码生成(如CUDA/HIP),其中静态数据类型对于确定精确的数据位宽至关重要。在TileLang中,类型注解明确定义了元素类型和张量形状,确保了正确性和高效的代码生成。此外,TileLang允许显式内存分配,从而更好地控制数据布局和访问模式。在给定的实现中,TileLang使用T.alloc_shared
将A和B的子矩阵存储在共享内存中,而使用T.alloc_fragments
在块级别为累加器分配寄存器文件。此外,使用流水线执行(T.Pipelined
)可以使内存传输与计算重叠,有效隐藏内存延迟并提高整体吞吐量。T.gemm
操作利用NVIDIA CUTLASS或手动编写的HIP代码来高效地执行块级矩阵计算。通过自动化底层调度和同步,TileLang允许开发人员专注于算法设计而非硬件特定的优化,从而在保持计算效率的同时提高生产力。
编译过程。最后,我们调用tilelang.compile
(第31行),将tilelang程序降级为中间表示(IR),如图1(b)所示。这个IR随后被进一步编译成可执行文件,生成最终的优化代码,如图1(c)所示。
A2 方法细节
3. Tile语言
本节我们介绍基于块(tile)的编程模型的基础,解释TileLang如何系统地高效管理AI核函数开发,并概述TileLang将数据流与其他调度空间分离的设计理念。
五阶段编译流水线。图2展示了TileLang的五阶段编译流水线。首先,开发者使用TileLang编写高级程序来描述计算逻辑和数据访问模式。在解析器(Parser)阶段,TileLang程序被解析为Python AST,随后转换为TileLang AST。接下来,IR构建器(IR Builder)将AST转换为TVM中间表示(IR),使我们能够利用TVM的语法树和相关基础设施。之后,优化(Optimization)阶段执行一系列图优化和调度变换以提高执行效率。最后,代码生成(Codegen)阶段将优化后的IR转换为后端代码,如LLVM IR、CUDA C/C++或HIP C/C++,以支持各种硬件平台。
数据流操作符与调度原语。表1展示了TileLang提供的数据流操作符和调度原语的一个代表性子集。Tile语言采用以数据为中心的编程范式,其中核心计算语义通过块级操作符(如T.copy
、T.gemm
和T.reduce
)来表达。作为这些操作符的补充,TileLang暴露了一组调度原语,允许开发者微调性能关键方面,如并行性、流水线和内存布局。我们将在以下各节中解释这两个组件的设计。
表1. TileLang支持的部分数据流操作符和调度原语。
3.1 基于块的编程模型
图1展示了一个TileLang中的简洁矩阵乘法(GEMM)示例,说明了开发者如何使用高级结构(如块、内存布局、流水线和操作符调用)来精细控制数据移动和计算。特别是,该代码片段展示了多级分块如何利用不同的内存层次(全局内存、共享内存和寄存器)来优化带宽利用率并减少延迟。总体而言,图3(原文此处引用Figure 11,但根据上下文应为Figure 3的示意图和Figure 1的代码)展示了TileLang的类Python语法如何让开发者在一个用户友好的编程模型中对性能关键的优化进行推理。
块声明。我们方法的核心是将块(tile)视为编程模型中的一等公民。一个块代表一部分成形的数据,可以由一个warp、线程块或等效的并行单元拥有和操作。在Matmul示例中,A和B缓冲区在核函数循环内以分块的形式(由block_M
, block_N
, block_K
决定)被读取。通过T.Kernel
,TileLang定义了执行上下文,包括线程块索引(bx
和by
)和线程数量。这些上下文可以帮助我们计算每个线程块的索引,并使TileLang更容易自动推断和优化内存访问与计算。此外,这些上下文也允许用户手动控制线程块内每个独立线程的行为。
显式硬件内存分配。TileLang的一个标志性特点是能够将这些块缓冲区显式地放置在硬件内存层次结构中。TileLang没有将其留给编译器不透明的优化过程,而是暴露了直接映射到物理内存空间或加速器特定结构的用户级内在函数。具体包括:
* T.alloc_shared: 在快速的片上存储空间中分配内存,对应于NVIDIA GPU上的共享内存。共享内存非常适合在计算过程中缓存中间数据,因为它比全局内存快得多,并允许同一线程块中的线程高效地共享数据。
* T.alloc_fragment: 在fragment内存中分配累加器,对应于NVIDIA GPU上的寄存器文件。通过将输入和部分和保留在寄存器或硬件级缓存中,可以进一步最小化延迟。值得注意的是,在这个块程序中,每个块都分配了与共享内存相同的局部缓冲区,这可能看起来违反直觉,因为寄存器文件有限。这是因为这里的分配指的是整个线程块的寄存器文件。TileLang在编译期间使用布局推断过程(Layout Inference Pass)来推导出一个T.Fragment
布局对象,该对象决定了如何为每个线程分配相应的寄存器文件。
数据传输与初始化。全局内存和硬件特定内存之间的数据传输可以使用T.copy
进行管理。此外,硬件特定缓冲区可以使用T.clear
或T.fill
进行初始化。对于数据赋值,操作也可以使用T.Parallel
并行执行,如图8所示。
3.2 以数据流为中心的块操作符
TileLang抽象了一组块操作符(Tile Operators),使开发者能够专注于数据流逻辑,而无需管理每个块操作的底层实现细节。图4展示了一个块操作符的接口以及几个代表性示例,包括GEMM、Copy和Parallel。每个块操作符都需要实现两个关键接口:Lower
和InferLayout
。Lower
接口定义了如何将高级的块操作符降级为更低级的IR,例如线程绑定或向量化的内存访问。例如,Copy可以被降级为一个带有显式线程绑定和向量化加载/存储的循环。InferLayout
接口负责确定与块操作符相关的内存和循环布局。这包括推断缓冲区布局(如swizzled内存)或循环级布局(如线程绑定)。例如,T.gemm
对其共享内存输入应用swizzled布局,并使用特定于矩阵的布局来写回MMA fragments。同样,T.Parallel
中的并行循环结构可以使用线程级绑定和向量化访问模式来表示,这两者都是通过布局推断得出的。
内置操作符。表1列出了TileLang操作符的一个子集,以简化基于块编程中的常见操作。这些内置操作符抽象了硬件内存访问和计算的底层细节,允许开发者从数据流的角度专注于高级算法设计,同时保持对性能关键方面的细粒度控制。下面描述几个关键操作符:
* copy: copy
操作是T.Parallel
与内存拷贝的语法糖,允许从/向不同作用域(fragment
对应寄存器,shared
对应静态共享内存,shared.dyn
对应动态共享内存,global
对应全局内存)进行拷贝。
* gemm: 内置的T.gemm
操作符是针对通用矩阵乘法的高度优化实现,支持多种内存访问模式(ss, sr, rs, rr),其中r
表示寄存器内存,s
表示共享内存。该操作符会根据核函数配置自动选择最优实现。对于CUDA后端,T.gemm
利用Nvidia的CUTLASS库【11, Cutlass: Cuda templates for linear algebra subroutines, 2024, NVIDIA Corporation】来高效利用Tensor Cores或CUDA Cores;对于AMD GPU,它同时使用可组合核函数和手写的HIP代码进行性能优化。用户还可以通过在Python中注册自定义原语来扩展T.gemm
。
* reduce: T.reduce
操作符为跨维度聚合数据提供了一个灵活高效的归约机制。它支持多种归约操作,如求和、求最小/最大值和乘积等。T.reduce
被实现为利用warp级和block级的并行性,以在CUDA和AMD后端上获得最佳性能。
* atomic: T.atomic
操作符为在并行上下文中安全更新共享或全局内存提供原子操作。支持常见的原子操作如add
, min
, max
。它利用NVIDIA和AMD GPU上的原生硬件原子指令,确保在并行执行中保持正确性的同时实现高性能。
3.3 调度注解和原语
为了对执行模式进行更细粒度的控制,TileLang提供了一套全面的调度原语,使开发者能够精确调整其应用程序的性能关键方面,如表1所示:
* Pipelined: T.Pipelined
原语允许对循环进行高效的流水线执行,通过重叠计算和内存操作来提高性能。在图1中,对k维度(归约维度)的循环使用了num_stages=3
进行流水线化,创建了一个3级流水线。
* Parallel: T.Parallel
原语通过将迭代映射到线程来实现循环的自动并行化。在图8中,将数据复制到A_shared
的操作使用T.Parallel(8, 32)
在8和32两个维度上进行并行化。
* annotate_layout: T.annotate_layout
原语允许您使用用户定义的内存布局为共享或全局内存指定内存布局优化。默认情况下,TileLang采用一种旨在最小化Nvidia和AMD GPU上bank冲突的优化内存布局。
* use_swizzle: T.use_swizzle
原语通过启用swizzled内存访问来提高L2缓存局部性,从而改善光栅化操作的数据重用。
4. 调度设计与自动化
本节我们讨论除数据流之外的四种调度空间及其在TileLang中的自动化设计。其中一些相对独立(如流水线和张量化),而另一些则耦合更紧密,如线程绑定和内存布局设计。
4.1 内存布局组合
Layout抽象。在TileLang中,我们支持使用如A[i, k]
这样的高级接口对多维数组进行索引。这种高级索引最终通过一系列软硬件抽象层转换为物理内存地址。为了对这个索引转换过程建模,我们引入了关键的抽象Layout
,它描述了数据在内存中的组织和映射方式。在物理地址层面,一个布局可以表示为一个线性地址表达式$`\sum_{d} i_d \times s_d`$,其中$`i_d`$是第d维的索引,$`s_d`$是该维度对总线性内存地址的贡献步长。给定一个布局$`L = \text{shape} : \text{strides} = (d_0, d_1, ..., d_{n-1}) : (s_0, s_1, ..., s_{n-1})`$,TileLang采用了一种受TVM【8, TVM: An automated end-to-end optimizing compiler for deep learning, 2018, 13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)】启发的设计,引入了基于IterVar
的可组合、可堆叠的布局函数抽象。由于IterVar
可以封装步长信息,布局表达式可以简化为IterVar
上的代数形式。因此,布局函数可以正式表示为一个映射$`f: \mathbb{K}^m \rightarrow \mathbb{K}^n`$,其中$`f`$编码了从高级索引到内存地址的转换。
Layout函数定义与组合。图5(a)展示了TileLang中Layout
的定义。其核心组件包括iter_vars
(可选择携带范围信息)和一组forward_index
表达式,这些表达式根据迭代变量计算内存位置。这些表达式共同定义了一个代数函数$`f: \mathbb{K}^m \rightarrow \mathbb{K}^n`$。如图5(b)所示,这允许表达一个2D到1D的布局转换。TileLang还支持非双射的布局转换,例如图5(c)展示了如何使用布局对缓冲区访问应用填充。这些布局转换是可组合的,TileLang包含几种内置布局策略,如布局swizzling,常用于减轻GPU上的共享内存bank冲突。
Fragment Layout扩展。此外,TileLang引入了Layout
抽象的扩展,称为Fragment
。与标准布局不同,Fragment
布局总是产生形式为$`f: \mathbb{K}^m \rightarrow \mathbb{K}^2`$的输出,其中两个输出维度分别代表线程在寄存器文件中的位置和在本地寄存器文件中的索引。例如,在图1中,核函数在块级别分配了一个寄存器文件C_local
。由于GPU寄存器文件必须在块内线程之间划分,Fragment
布局为这种划分方案提供了精确的描述。
Fragment Layout的构建。图6(a)展示了Fragment
布局的定义,TileLang提供了四种基本操作来帮助用户扩展现有的Fragment
布局。图6(b)展示了如何使用这些原语从用于m16k16
矩阵碎片的mma_ldmatrix
指令的基础布局中派生出一个完整的块级布局。在这里,base_layout
表示单个warp消耗一个m16k16
矩阵的布局。该布局通过repeat
原语扩展为warp_layout
,允许单个warp消耗一个m32k16
矩阵。图6(c)可视化了这一转换。然后,warp_layout
使用repeat_on_thread
和replicate
等原语进一步扩展,生成一个block_layout
,表示四个warp共同消耗一个m128k16
矩阵。
4.2 线程绑定
线程绑定问题与关键观察。基于Fragment
布局的抽象,一个关键挑战是如何在执行期间将这些布局映射到线程上。这引出了线程绑定(Thread Binding)问题,即如何确定在各个线程之间分配块级寄存器文件以及如何推断合适的fragment布局。此外,它还要求确定循环应如何正确并行化以匹配布局约束。尽管Fragment
布局有助于简化此过程,但为任意计算表达式确定所有缓冲区的合适fragment布局仍然很困难。我们有两个关键观察来指导这个过程:首先,由于多个块操作符经常共享相同的缓冲区,它们各自的布局和线程绑定策略是相互依赖的。其次,不同操作符对布局和线程绑定要求的严格程度不同。例如,在GPU上,GEMM操作符(利用Tensor Cores)对布局和线程绑定都有严格的约束,而逐元素操作符通常允许更大的灵活性。
基于优先级的推断方案。基于这些观察,我们提出了一种基于Layout
和Fragment
对象的推断方案来优化缓冲区布局和线程绑定。为了系统地管理缓冲区布局,我们维护一个LayoutMap
来记录所有缓冲区的布局信息。我们为块操作符布局定义了一个分层优先级系统,其中更高的优先级表示更严格的布局要求和更大的性能影响。TileLang以自顶向下的方式处理布局推断,从最高到最低优先级顺序推断布局。在每个优先级级别,TileLang尝试为所有未确定的缓冲区推断布局,直到无法取得进一步进展,然后才进入下一个较低的优先级。
推断示例。如图7所示,考虑一个场景,其中矩阵C是GEMM操作的结果,对应一个Fragment
对象,需要在GEMM计算后加上偏置D。鉴于GEMM在推断过程中具有最高优先级,其线程绑定配置是预先确定的,而D的线程绑定策略待定。输出矩阵C的维度为4×4,分布在8个线程上,每个线程负责2个元素。因此,偏置缓冲区D的布局必须与此配置对齐。由于张量C的每一行由2个线程处理,两个线程都需要访问D中相同的元素进行加法运算。因此,D必须被复制以确保每个线程都能访问相应的元素。D的布局可以用相同的方法推断出来。
多阶段自动线程绑定推断。图8展示了线程绑定推断过程的一个例子。具体来说,图8(a)展示了一个用于复制数据的简单代码片段,描述了子块从全局内存传输到共享内存的数据流。正确的线程绑定和向量化访问可以充分利用GPU的并行性并利用高性能内存访问指令。在图8(b)中,T.copy
操作被展开为多个循环轴。在应用布局推断过程后,如图8(c)所示,程序经历了自动向量化和并行化。最后,在图8(d)所示的阶段,应用了布局Swizzling。
4.3 利用高性能硬件指令
硬件指令的多样性。现代硬件架构通常支持多种指令路径来实现相同的计算操作。例如,在NVIDIA GPU上,一个8位乘法累加操作可以通过几种类型的指令实现。IMAD
指令执行标量融合乘加操作;DP4A
指令支持向量化点积操作;MMA
指令利用Tensor Cores执行更高吞吐量的矩阵计算。在NVIDIA RTX 3090 GPU上,这些指令的吞吐量分别约为17.8 TOPS、71.2 TOPS和284 TOPS。此外,MMA
指令在相同精度设置下支持多种形状。
在TileLang中调用硬件指令。如图10(a)和(b)所示,在TileLang中有两种方法调用硬件张量指令。第一种方法(图10(a))是使用C++源码注入,其中像dp4a
这样的指令通过C++模板手动包装,并通过T.import_source
和T.call_extern
注入到核函数中。第二种方法(图10(b)),TileLang提供了一个内置的T.ptx
原语,允许在核函数内部直接发出内联PTX指令(例如,mma.m16n8k32.row.col.s32.s8.s8.s32
)。
(注:原文中此部分引用Figure 9,但内容与Figure 10对应,此处修正)
与Tile库集成。为了简化指令选择过程,TileLang还支持与Tile库集成,如图10(c)所示。Tile库——如NVIDIA的cute
或AMD的composable kernel (ck)
——为GEMM等操作提供了高级、标准化的基于块的API(例如,tl::gemm_ss
)。这些库抽象了硬件特定细节,并允许底层实现为给定的输入配置自动选择最高效的指令。在TileLang中,开发者可以使用T.call_extern
以直接和一致的方式调用这些库。
两种方法的权衡。总的来说,TileLang提供了两种互补的方法来利用高性能指令。第一种是利用Tile库,它简化了集成并受益于供应商优化的性能,但高级抽象可能限制底层控制。例如,cute::gemm_ss
接口在共享内存输入上执行GEMM操作,但从共享内存到寄存器的数据流由cute
模板内部管理,无法从外部注解或覆盖内部布局,从而降低了灵活性。此外,由于大量使用模板,编译速度会显著变慢。第二种方法是TileLang允许通过T.gemm
直接使用tilelang本身实现指令。这避免了布局注解的限制并减少了编译时间,但要求用户在tilelang内为每个目标硬件指令实现完整的指令集。目前,TileLang支持这两种方法,默认使用基于Tile库的方法以促进对新硬件指令的快速支持。
4.4 软件定义流水线
自动化流水线推断。TileLang采用自动化的软件流水线推断机制,分析计算块(例如Copy和GEMM)之间的依赖关系,并生成一个结构化的流水线调度,以最大化并行性同时保持正确的执行顺序。该机制将Copy任务与其它计算密集型操作交错执行以减少空闲时间,并在检测到异步处理机会时,自动将这些任务映射到可用的硬件资源上进行并发执行。因此,TileLang可以只向用户暴露一个num_stages
接口,从而显著简化了流程。当然,如果需要,我们也允许用户显式提供关于顺序和阶段的信息。
Ampere架构支持。对于Ampere架构,TileLang使用cp.async
指令支持异步内存复制操作。TileLang通过分析循环结构并为符合条件的内存传输自动插入cp.async
指令来集成此功能。此外,TileLang确保正确使用cp.async.commit
和cp.async.wait
指令来处理同步,保证数据正确性。
Hopper架构支持。在Hopper架构中,引入了两个新特性:一是新的TMA单元,作为专门负责全局内存和共享内存之间数据复制的硬件单元;二是PTX指令集引入了新的wgmma
指令,它使一个warpgroup(由四个warp组成)能够执行矩阵乘法(MMA)操作以提高TensorCore利用率,并且wgmma.mma_async
指令是异步的。此外,Hopper架构的核函数优化通常采用warp specialization,即将线程分为生产者和消费者。生产者线程使用TMA移动数据,而消费者线程负责计算。在TileLang中,我们在降级过程中自动执行warp specialization优化。TileLang分析所有语句的缓冲区使用情况,确定其角色(生产者或消费者),并根据threadIdx
将它们划分到不同的执行路径。为确保计算正确性,TileLang利用Live Variable Analysis来确定适当的同步点并相应地插入内存屏障(mbarriers
)。
AMD CDNA架构支持。AMD CDNA架构中也提供了异步复制指令和DMA支持,TileLang通过HIP封装的Copy原语来利用这些功能。具体来说,TileLang利用s_waitcnt lgkmcnt
和buffer_load_dword lds
等指令来高效管理内存传输。
A4 实验环境
- 硬件平台:
- NVIDIA H100 (80 GB)
- NVIDIA A100 (80 GB)
- AMD Instinct MI300X (192 GB)
- 软件配置:
- NVIDIA H100: CUDA 12.4
- AMD MI300X: ROCm 6.1.0
- 操作系统: Ubuntu 20.04
- 算子工作负载:
- NVIDIA H100: 多头注意力 (MHA), 线性注意力, 通用矩阵乘法 (GEMM)。
- NVIDIA A100: 去量化GEMM核函数。
- AMD MI300X: GEMM 和 MHA。
- 基准 (Baselines):
- FlashAttention-3: 针对多头注意力优化的手工CUDA实现。
- Triton: 开源的GPU核函数高效生成框架。
- cuBLAS: NVIDIA的高性能稠密线性代数库。
- rocBLAS: AMD的BLAS库。
- PyTorch: 包含手写优化核函数(如GEMM和FlashAttention-2)。
- BitsandBytes: 专为
NF4
等量化格式提供高效核函数的库。 - Marlin: 针对
INT4
-FP16
计算的高度优化核函数库。
A4 实验结果
Flash Attention性能 (H100)。与FlashAttention-3、Triton和PyTorch相比,TileLang分别实现了1.36倍、1.41倍和1.70倍的加速。FlashAttention-3是一种手工方法,无法高效适应不同大小的工作负载,其固定的块大小导致在较短序列长度上性能不佳。对于较长序列(如8k),TileLang的性能接近FlashAttention-3。TileLang能够自动利用cp.async.bulk
和wgmma.mma_async
等指令,并自动应用warp specialization等优化。
线性注意力性能 (H100)。在线性注意力实验中,我们使用了Mamba-2中的chunk-scan
和chunk-state
函数。与Triton相比,TileLang分别取得了1.77倍和2.10倍的平均加速。
多头潜在注意力(MLA)性能 (H100 & MI300X)。如图14所示,在H100上,TileLang相比Torch实现了1075.9倍的加速,显著优于Triton和FlashInfer,并达到了手写优化FlashMLA性能的98%。同时,TileLang仅需约70行Python代码,易用性远超其他基准。在MI300X上,TileLang相比Torch实现了129.2倍的加速,在性能和代码紧凑性上均超越Triton,并达到了手写库AITER性能的95%。
矩阵乘法(Matmul)性能。如图13所示,在RTX 4090、A100、H100和MI300X上,TileLang相比供应商库分别实现了1.10倍、0.97倍、1.00倍和1.04倍的加速。与Triton相比,TileLang在相同GPU上实现了1.08倍、1.03倍、1.13倍和1.25倍的加速。TileLang使用简单的语法即可匹敌供应商优化库的性能,并通过Layout Swizzling确保在所有测试设备上无bank冲突。
去量化矩阵乘法性能 (A100)。我们将高性能混合精度计算库BitBLAS的后端替换为TileLang。与cuBLAS-FP16
-FP16
相比,BitBLAS-TileLang-INT2
-INT8
配置最高实现了7.65倍的加速。对于INT4
-FP16
格式,我们的方法比Marlin平均快1.04倍;对于NF4
-FP16
格式,比BitsandBytes平均快1.62倍。TileLang通过暴露线程级编程接口和对数据布局、流水线的控制,提供了更细粒度的优化能力,例如使用基于PTX的快速数值精度转换指令和利用Ladder实现更平滑的内存访问,这些优化在Triton中难以实现。
A5 结论
本文介绍了TileLang,一种类Python的领域特定语言(DSL),它通过将数据流与调度解耦,在易用性和细粒度控制之间取得了平衡。与Triton不同,TileLang允许用户在前端显式声明硬件内存层次中的缓冲区,并利用布局推断机制高效地并行化缓冲区操作。用户只需描述缓冲区的计算逻辑,而无需担心并行化的实现。同时,TileLang为专家提供了明确指定单个线程在缓冲区上确切行为的灵活性。相比ThunderKittens【4, ThunderKittens, https://github.com/HazyResearch/ThunderKittens】,TileLang简化了编程过程,默认抽象了流水线等优化细节,并为动态参数、动态形状等高级功能提供了强大支持。
未来工作展望:
1. 自托管Tile库:构建基于TileLang的自托管Tile库,消除对CUTLASS和手动包装CUDA/HIP代码的依赖。
2. 分布式支持:通过引入块级通信原语和调度策略,将TileLang扩展到分布式场景。
3. 成本模型设计:利用TileLang明确的线程映射、内存访问和计算行为,设计更有效的成本模型。
4. 动态形状优化:探索针对动态变化维度的程序选择最合适块配置的优化。
5. 多后端支持:探索通用设计方法,将支持扩展到CPU、NPU等多种硬件平台。
A6 附录
A. 基准测试中的算子形状
表2. 基准测试中的矩阵形状。
表3. 基准测试中的FlashAttention形状。
表4. 基准测试中的线性注意力形状。
B. 核函数实现
B.1 矩阵乘法 (Matmul)
B.2 去量化矩阵乘法
B.3 FlashMLA 实现
💬 评论讨论
欢迎在这里分享您的想法和见解!