The Evolution and Applications of CuTeDSL
The Evolution and Applications of CuTeDSL
Jie Fang (方杰) NVIDIA GPU加速计算专家团队、高级工程师
Yuxi Chi (池宇希) NVIDIA GPU加速计算专家团队、高级工程师
议程 (Agenda)
-
CuTeDSL
- 为何选择 CuTeDSL
- 特性 (Features)
- 更新 (Updates)
- CuTeDSL 的应用
- 路线图 (Roadmap)
-
CuTeDSL 作为 Tilelang 的代码生成目标 (CodeGen Target)
- 基于 Tile 的 DSL 概述
- Tilelang/TIR 作为原生代码生成器
- CuTeDSL 作为代码生成目标
为何选择 CuTeDSL
CUTLASS C++ 概述
CUTLASS (CUDA Templates for Linear Algebra Subroutines) 是一个高性能的 CUDA C++ 模板库,它提供了一套用于生产力和性能的抽象,适用于所有范围和规模。该项目是开源的,可在 https://github.com/NVIDIA/cutlass 获取。
其架构分层,从提供更多预调优方案(recipes)的上层到底层提供更多控制(control)的原子操作。
使用 C++ 的主要痛点
-
C++ 模板编译时间过长
- 前端对于我们的目的来说过于通用。
- 阻碍了快速迭代。
- 禁止了即时编译(JIT-ting)和大规模的暴力自动调优。
-
C++ 模板不方便
- 编写编译时逻辑会增加额外的脑力负担。
- 错误信息比小说还长。
-
深度学习领域全面拥抱 Python 生态系统
- 每个人都讨厌编写绑定代码。
- 依赖于 nvcc。
-
大语言模型(LLM)可能更擅长生成 Python 程序。
这些问题引出了一个核心疑问:我们是否必须忍受 C++ 的这些不便?
CuTe DSL 的关键优势
-
在接口和概念上与 CuTe C++ 保持一致
- 可以完全自由地设计新颖的核函数(kernels)。
-
告别 C++ 模板!
- 极快的编译速度。
- 更好的调试信息。
- 更快的原型设计循环。
- 更广阔的自动调优空间。
-
显著简化与 Python 框架的集成
- 通过缓存进行 JIT 编译,减少了开销。
- 无 nvcc 或 CUDA toolkit 依赖。
- 支持 DLPack 和 Torch 张量格式。
- 自动调优和基准测试代码也可以直接用 PyTorch 编写。
显著更快的编译速度
与 C++ 相比,CuTeDSL 的编译速度显著提升了30至100倍。下图展示了在 Blackwell 架构上,Cutlass C++ 与 CuTe DSL 在不同应用场景下的编译时间对比。
- FP16 I/O GEMM: C++ 耗时 28.00秒,而 DSL 仅需 0.24秒。
- FP16 I/O FA (无掩码, hd=64): C++ 耗时 41.19秒,DSL 仅需 0.73秒。
- FP16 I/O FA (有因果掩码, hd=64): C++ 耗时 40.68秒,DSL 仅需 1.23秒。
与 C++ 性能持平
尽管编译速度大幅提升,CuTeDSL 在数学计算效率上仍能与 C++ 实现相媲美。
下图比较了在 Blackwell 架构上,不同问题规模下 C++ 和 DSL 的数学计算效率(Math Efficiency %)。数据显示,在稠密GEMM、分组GEMM以及Flash Attention等多种场景下,DSL 的性能与 C++ 基本持平,甚至在某些情况下略有优势。
CuTeDSL 在不同工具中的定位
CuTeDSL 在生产力(productivity)和性能(performance)之间取得了独特的平衡。与 Torch 和 Triton 等更高层次的抽象相比,CuTeDSL 提供了接近 CUDA/PTX 的极致性能,同时保持了比直接编写底层代码更高的生产力。
下表总结了不同工具的性能和上手时间:
| 工具 | 性能 (内存密集型) | 性能 (计算密集型) | 上手时间 |
|---|---|---|---|
| Torch compile | ~90% | ~70-80% | 小时-天 (hours-days) |
| Triton | ~90% | ~80-90% | 天-周 (days-weeks) |
| Cute-DSL | 100% | 100% | 周-月 (weeks-months) |
该图表引用自 tri dao 的幻灯片。
特性 (Features)
支持 DLPack 协议
CuTeDSL 通过支持 DLPack 协议,能够与 PyTorch 等深度学习框架无缝集成。它支持从 torch tensor 的隐式转换,并提供了 mark_layout_dynamic 和 mark_compact_shape_dynamic 等接口来处理动态的布局和形状。
即时编译缓存 (JIT Caching)
CuTeDSL 支持 JIT 编译结果的缓存,允许用户定义自己的缓存机制,从而避免重复编译,进一步加快开发和执行速度。
打印与调试 (Print)
CuTeDSL 提供了强大的打印功能,用于调试。与 Python 的 print 函数在编译时打印静态值不同,cute.printf 可以在运行时同时显示静态和动态值,这对于调试 GPU kernel 至关重要。
软件流水线 (Software pipeline)
CuTeDSL 的编译器能够生成高效的软件流水线代码,其性能与专家手写的版本相当。下图比较了在不同问题规模和分块配置下,手写代码与编译器生成的代码在数学计算效率上的表现,两者几乎没有差异。
更新 (Updates)
示例的演进
从最初发布的12个示例,CuTeDSL 已经扩展到最新的31个示例。示例库的增长反映了其功能的丰富和对更多新架构(如 Blackwell, Hopper)和新模型(如 Mamba2, Flash Attention)的支持。
最新更新摘要
- 增加了对 aarch64 架构的支持。
- 支持更多的 Python 版本: Python 3.10, 3.11, 3.12, 和 3.13。
- 可调试性改进:
- 支持对 DSL API 的源代码位置跟踪。
- 支持导出 PTX 和 CUBIN 代码。
CuTeDSL 的应用
用户案例
CuTeDSL 已被业界广泛采用:
- QuACK: 通过 CuTe DSL 构建内存密集型操作(由 TogetherAI 的首席科学家 Tri Dao 创建)。
- Flash Attention v4: Tri Dao 在 HotChips 会议上宣布了该技术。
- 大量基于 CuTe DSL 的 FP4 稠密 GEMM 优化(PRs #1, PR#2)已合并到 TRTLLM GH。
- 多个 DSL 示例(FP4 稠密 GEMM, GEMM+AR)已集成到
flashinfer中。 - Sglang、Meta、Google、xAI、Torch、Megatron、Baseten、Cohere 和 MiPerf 团队正在探索使用 CuTe DSL 示例。
flash-attention 4
下图展示了 FlashAttention 4 在 Blackwell 架构上的性能。在 B200 GPU 上,对于前向传播(causal=True, headdim=128),FA4 的 TFLOPS 性能显著优于 cuDNN,尤其是在序列长度较长时。
- 环境: cuDNN 9.11.0, CTK 13.0
- 来源: 该图表来自 Tri Dao 的幻灯片。
- 代码: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/cute/flash_fwd_sm100.py
QuACK 内存密集型案例
CuTeDSL 使客户和框架能够轻松采用和自定义其他内核。下图展示了 QuACK 在内存密集型操作(RMSNorm, Softmax, Cross-Entropy)上的性能,并与 torch.compile、Liger Kernel 和 cuDNN 进行了比较。QuACK (基于 CuTe DSL 构建) 在 RMSNorm 和 Softmax 上表现出优越的性能。
- 测试平台: 1 NVIDIA H100 80GB HBM3, Intel Xeon Platinum 8468 CPU
- 软件版本: Pytorch 2.7.1; Liger Kernel v0.5.10; cuDNN v9.10.1
- 来源: 该图来自 QuACK。
路线图 (Roadmap)
未来的发展计划包括:
- 更多生产就绪的示例:
- Blackwell GEMM + COMM via NVSHMEM API
- Blackwell FP16/FP8 稀疏 GEMM
-
DSL 特性:
- 功能完备的预编译(Ahead-of-time, AOT)支持
- 支持 64 位步长以适应超大问题规模
-
框架集成:
- JAX
-
CUDA Dialect
- TVM-FFI
- 更多教学教程
CuTeDSL 作为 Tilelang 的代码生成目标 (CodeGen Target)
基于 Tile 的 DSL 概述
编程模型抽象
基于 Tile 的 DSL 采用分层编程模型来分解计算任务:
- Grid level (网格层): 代表整个工作负载。用户仅控制网格级别的操作。
- Block level (块层): 网格被划分为多个块(Tiles)。用户需要控制任务的划分以及对 Tiles 的操作。
- Thread level (线程层): 块被映射到线程。用户需要控制任务划分、线程映射以及线程级别的操作。
OpenAI Triton 与 Tilelang 对比
下表比较了 OpenAI Triton、Tilelang 和 CuTeDSL 在不同抽象层次上的控制粒度。CuTeDSL 提供了最显式和手动的底层控制,而 Triton 提供了最高层次的自动化。
| 特性 | OpenAI Triton | Tilelang | CuTeDSL |
|---|---|---|---|
| 内存分配 | 自动 | 自动 T.alloc_xxx |
GMEM, SMEM, TMEM, RF |
| 内存布局 | 自动 | 自动 T.annotate_layout |
cute.Layout |
| 内存移动 | 自动 | T.Copy |
cute.CopyAtom |
| 计算策略 | 自动 | T.Gemm |
cute.MmaAtom |
| 流水线 | stages | stages | 手动 |
Tilelang/TIR 作为原生代码生成器
Tilelang 作为原生代码生成器
Tilelang 是一个硬件无关的张量中间表示(Tensor Intermediate Representation, TIR),具有以下特点:
- 显式控制: 允许对内存分配和内存移动进行显式控制。
- Tile 索引: 提供在 Tile 内进行索引的能力。
- 代码注入: 支持 C++ 和 PTX 指令注入。
- 原生代码输出:
- 适用于 NVIDIA GPUs (Volta, Ampere, Hopper 等) 的 CUDA C++
- 适用于 AMD GPUs (MI300, MI250 等) 的 HIP C++
- 适用于 CPUs 的 C++/llvm
- 适用于其他加速器的原生代码
下图展示了从初学者到专家的编程抽象层次,从高层的 Tile 程序到底层针对特定硬件(NVIDIA/AMD GPUs)的可执行文件。
Tilelang 到 TIR 的降级 (Lowering)
Tilelang 提供了高层原语(如 T.alloc_shared, T.gemm),通过 tilelang.compile 将高级 Tilelang 代码降级为更详细的 TIR 代码。TIR 包含了如 T.vectorized, T.unroll 等更底层的原语。
以下是针对不同硬件架构的 TIR 代码生成示例:
- sm_80 (Ampere): 编译器会自动选择适用于 Ampere 架构的流水线(Pipeline)、
CopyAtom和MmaAtom,并生成相应的 TIR 代码。
- sm_90 (Hopper): 同样,对于 Hopper 架构,编译器会选择并生成利用其特有硬件功能(如 TMA 和 mbarrier)的 TIR 代码。
CUDA 代码生成: 从 TIR 到 CUDA
TIR 原语最终会被重写为 CUDA 代码。
- 内存分配和 Mbarrier 初始化: TIR 中的内存分配和屏障创建原语会映射到 CUDA 中的 __shared__ 内存声明和 mbarrier 初始化函数。
- CUDA API 包装器: https://github.com/tile-ai/tilelang/tree/main/src/tir/templates/cuda
- Warp 特化与计算: TIR 中的控制流和计算原语(如 TMA 操作、GEMM)会转换为 CUDA 中对应的硬件指令和内建函数(intrinsics),例如
__tma_load,cute::gemm等。
CuTeDSL 作为代码生成目标
Tilelang CuTeDSL 后端
Tilelang 使用 CuTeDSL 作为后端,以生成高效的设备代码。该后端的功能分解如下:
- Python 代码生成 (Host code):
- 生成主机端代码,包括 JIT 包装器、内核启动配置(grid, block, smem)和内核调用。
- 管理 TMA 描述符。
- 设备代码 (Device code):
- 利用 Ampere 特性(
cp.async,ld/st.matrix)和 Hopper 特性(TMA ld/st,mbarrier,Warp Specialization,wgmma.async)进行矢量化操作。
- 利用 Ampere 特性(
下图展示了 Python 代码如何调用后端,将 PyTorch 张量转换为 CuTe 张量,并启动编译后的 CUDA 内核。
示例:逐元素加法
下面是一个将 Tilelang 代码转换为 CuTeDSL 代码的具体示例:
- Tilelang 代码: 高度抽象,简洁地描述了逐元素加法操作。
- 生成的 CuTeDSL 代码: 更加具体和底层,显式处理线程/块索引,并使用 CuTe 的张量和布局来执行操作。
- 设备代码: CuTeDSL 使用 Tensor SSA 来实现矢量化的加载/存储和二元操作。
示例:Ampere GEMM
此页面展示了如何使用 TileLang 编写一个通用的矩阵乘法(GEMM)内核,并将其编译为针对 NVIDIA Ampere 架构优化的 CuTeDSL 代码。
-
左侧:TileLang 代码
- 定义了一个
gemm函数,接收张量 A, B, C 作为输入。 T.celldiv用于在线程块之间划分工作。- 使用
T.alloc_shared为 A 和 B 的分片分配共享内存 (A_shared,B_shared)。 - 使用
T.alloc_fragment为 C 的分片分配寄存器文件内存 (C_local)。 T.piplelined循环是实现软件流水线的主体,用于重叠计算和数据加载。T.copy(..., async=True): 从全局内存异步复制数据到共享内存(cp_async G2S)。T.gemm: 在共享内存中的数据上执行矩阵乘法(mma GEMM)。T.copy: 将计算结果从寄存器文件写回全局内存(STG)。
- 定义了一个
-
右侧:生成的 CuTeDSL 代码
- 这是由 TileLang 编译器为专用的 CuTeDSL 后端生成的 Python 代码。
- 导入了
cutlass.cute和cutlass.core库。 - 定义了一个
gemm_kernel内核。 - 代码中使用了 CuTeDSL 的底层抽象,如
tl.Thread,tl.make_fragment,tl.copy_async,tl.gemm等,精确地映射了 Ampere 架构的硬件特性。
-
Ampere 风格流水线 (Ampere style Pipeline)
- 序幕 (Prologue): 为前 #stages-1 个分片发出异步内存加载指令。
-
主循环 (Mainloop): k-Loop 从 0 到 #tiles - (#stages-1)。
- 发出第 #k+stages-1 个分片的异步内存加载。
- 等待第 #k 个分片的异步内存加载完成。
- 处理第 #k 个分片。
-
尾声 (Epilogue): 等待并处理最后 #stages-1 个分片。
示例:Hopper fp8 GEMM
此页面展示了针对 NVIDIA Hopper 架构的 FP8 GEMM 的一个更复杂的示例。
-
左侧:TileLang 代码
- 定义了一个
main函数,其参数更复杂,以支持 Hopper 架构的特性。 T.Prefetch: 使用张量内存加速器(TMA)进行数据预取。T.wgmma: 使用 Hopper 架构的 Warp Group Matrix-Multiply-Accumulate 指令进行计算。- 通过
T.Pipelined实现了双缓冲流水线,一个 warp 组负责数据加载(TMA warp),另一个 warp 组负责计算(GEMM warp)。
- 定义了一个
-
右侧:生成的 CuTeDSL 代码
- 同样是为 CuTeDSL 后端生成的代码。
- 使用了针对 Hopper 的特定 CuTeDSL API,如
tl.mbarrier_init(多线程屏障初始化),tl.prefetch_tma_descriptor(TMA 预取描述符),tl.dot(wgmma 的 CuTeDSL 封装)。 - 代码结构反映了 warp-specialized pipeline 的逻辑。
-
Hopper 风格的 Warp-Specialized 流水线 (Hopper style Warp-specialized Pipeline)
-
TMA warp: 循环处理 #stages 缓冲区。
- 等待 mbarrier 为空。
- 发出 TMA 加载指令。
- 在 mbarrier 上标记缓冲区已满。
-
GEMM warp: 循环处理 #stages 缓冲区。
- 等待 mbarrier 为满。
- 发出 wgmma GEMM 指令。
- 在 mbarrier 上标记缓冲区已空。
-
CuTeDSL 后端 vs CUDA 后端
此页面通过实验数据对比了 CuTeDSL 后端和传统的 CUDA 后端在编译时间和内核运行时间上的表现。
性能对比表
| 示例 | 后端 | 首次编译时间 | 内核运行时间 |
|---|---|---|---|
| Elementwise Add | CUDA | 11.81s | 9.622 ± 0.038 µs |
| CuTeDSL | 0.68s | 9.974 ± 0.040 µs | |
| Ampere Gemm Float16 M=N=K=8192 |
CUDA | 15.43s | 6.042 ± 0.234 ms |
| CuTeDSL | 0.59s | 5.498 ± 0.260 ms | |
| Hopper Gemm Float16 M=N=K=8192 |
CUDA | 15.38s | 3.006 ± 0.035 ms |
| CuTeDSL | 0.66s | 3.034 ± 0.032 ms |
-
关键结论:
- 编译时间: CuTeDSL 后端的首次编译时间远快于 CUDA 后端(例如,Ampere Gemm 从 15.43s 降至 0.59s)。这对于快速迭代和自动调优至关重要。
- 运行时间: CuTeDSL 后端生成的内核运行时间与 CUDA 后端相当,甚至在某些情况下(如 Ampere Gemm)表现更优。
-
备注:
- 首次编译时间包括 Tilelang lowering 和后端编译。
- Ampere Gemm 在 A100 GPU 上运行。
- Hopper Gemm 在 H100 GPU 上运行。
CuTeDSL 作为代码生成目标:总结
本页总结了将 CuTeDSL 作为 TileLang 代码生成目标的优势。
-
使用 Tilelang 作为前端 DSL
-
使用 CuTeDSL 后端
- 享受极快的编译速度
- 实现快速验证
- 支持自动调优
-
集成
- 可直接集成到 Python 框架中
- 能够像 CUDA 后端一样转储 cubin 文件
-
进一步加速
- 可以修改生成的 CuTeDSL 代码进行深度优化
- 可以切换到原生的 CuTeDSL 实现以获得最大灵活性