Enable Tensor Core Programming in Python with CUTLASS 4.0

Albert Di, Vincent Zhang | 2025-05-30

目录

  1. Python DSL 概述

    • 为何转向 Python
    • DSL 基础架构
  2. Hopper GEMM 实现

    • GEMM 概述
    • MMA/Copy 原子操作
    • Hopper GEMM 代码演示
    • 性能比较
  3. 结论

Python DSL 概述

为何转向 Python

使用 C++ 的主要痛点

C++ 模板及其带来的不良后果:

  • C++ 模板编译时间过长

    • 前端对于我们的目的来说过于通用
    • 阻碍了快速迭代
    • 禁止大规模的即时编译(JIT-ting)和暴力自动调优
  • C++ 模板使用不便

    • 编写编译时逻辑时需要额外的脑力负担
    • 错误信息比小说还长
  • 深度学习领域完全拥抱 Python 生态系统

    • 每个人都讨厌编写绑定代码
    • 依赖于 nvcc
  • LLM 可能更擅长生成 Python 程序

Page 4
Page 4

推出 CUTLASS 4.0

在 Python 中进行张量核心(Tensor core)编程。

Page 5
Page 5

CUTLASS in Python

初始版本包含 CuTe。

CUTLASS in Python 架构 (Page 6)
CUTLASS in Python 架构 (Page 6)

架构特点:

  • 层级结构类似 CUTLASS C++

    • 旨在成为 CUTLASS C++ 的 1:1 对应版本。
    • 为生产力和峰值性能而设计。
    • 通过简洁的编程模型实现底层控制。
    • 支持从 Ampere 开始的所有架构。
  • JIT 编译

    • 在运行时使用 MLIR 基础架构编译成 CUDA 设备代码。
  • 未来发展

    • 将支持更高层次的抽象。

为什么选择 Python DSL?

CuTe DSL 的主要优势。

使用 Python DSL 的优势 (Page 7)
使用 Python DSL 的优势 (Page 7)
  • 与 CuTe C++ 在接口和概念上保持一致

    • 可以完全自由地设计新颖的内核。
  • 告别 C++ 模板!

    • 极快的编译速度。
    • 更好的调试信息。
    • 更快的原型设计循环。
    • 更广阔的自动调优空间。
  • 显著简化与 Python 框架的集成

    • 使用缓存进行 JIT 编译,减少了开销。
    • 无 NVCC 或 CUDA 工具包依赖。
    • 支持 DLPack 和 Torch 张量格式。
    • 在框架内进行自动调优和基准测试。
    • 测试代码也可以直接用 PyTorch 编写!

从上图可以看出,Python 的编译时间(241 ms)远低于 C++(27997 ms),提升超过100倍。

CUTLASS in Python 入门

如何开始使用:

CUTLASS in Python 入门示例 (Page 8)
CUTLASS in Python 入门示例 (Page 8)
  1. 通过 pip install nvidia-cutlass-dsl 安装。
  2. 编写 Python 内核代码,使用 @cute.kernel@cute.jit 装饰器。
  3. 通过 python3 hello_world.py 运行,代码将被即时编译(JIT)并执行。
import cutlass
import cutlass.cute as cute

@cute.kernel
def kernel():
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.printf("Hello world")

@cute.jit
def hello_world():
    cutlass.cuda.initialize_cuda_context()

    # Launch kernel
    kernel().launch(
        grid=(1, 1, 1),
        block=(32, 1, 1)
    )

# Just-In-Time (JIT) compilation
print("Running hello_world()...")
hello_world()

DSL 基础架构

CUTLASS Python 架构

CUTLASS Python 架构 (Page 10)
CUTLASS Python 架构 (Page 10)

上图展示了从 Python 端代码到最终在 GPU 上执行的 CUBIN 文件的编译流程:
1. 用户使用 @cute.kernel 编写 Python 内核。
2. 代码通过 CUTLASS DSL 栈(CuTe DSL, DSL 编译器)转换成中间表示(IR program)。
3. IR 程序进入 CUDA 编译器栈,依次通过 NVVM/LLVM、PTX 和 SASS 编译。
4. 最终生成的 CUBIN 由 JIT Executor 加载并执行。

在 Python 中编写内核

通过 @cute.jit@cute.kernel 装饰器,可以将 C++ 中复杂的模板元编程内核,用更简洁的 Python 代码来表达。

在 Python 中编写内核 (Page 11)
在 Python 中编写内核 (Page 11)

左侧是 C++ CUTLASS 内核的定义,使用了大量模板参数。右侧是等效的 Python 实现,通过函数参数和类型注解来定义,更加清晰直观。

与 PyTorch 的轻松集成

支持 DLPack 协议

可以直接将 torch.tensor 作为输入传递给 JIT 编译的 CUTLASS 内核函数,无需手动数据转换。

import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import torch

@cute.kernel
def jit_kernel(A: cute.Tensor):
    ...

@cute.jit
def jit_func(A: cute.Tensor):
    jit_kernel(
        A, config=cutlass.LaunchConfig(grid=[1, 1, 1], block=[1, 1, 1])
    )

# Create a torch tensor
A_tensor = torch.tensor([0, 0], dtype=torch.int32).cuda()

# Call the JIT function with the torch tensor
jit_func(A_tensor)

# Or explicitly convert using from_dlpack
# Or jit_func(from_dlpack(A_tensor).mark_layout_dynamic())

简化测试流程

与 Python 生态(如 PyTorch)的深度集成,使得验证和测试变得非常简单。可以直接使用 PyTorch 实现一个参考内核,并用 torch.testing.assert_close 来验证 CUTLASS 内核的正确性。

与 PyTorch 集成进行测试 (Page 13)
与 PyTorch 集成进行测试 (Page 13)

上图左侧是 C++ 中复杂的测试流程,包括设备内存分配、内核启动、同步、数据拷贝回主机以及手动比较。右侧则展示了等效的 Python 测试流程,仅需两行代码即可完成计算和验证。

从静态布局到动态布局

静态布局 (Static Layout):

当 JIT 函数接收到不同形状(layout)的张量时,会为每种形状编译一个专门的内核。

静态布局处理 (Page 14)
静态布局处理 (Page 14)

如上图所示,当 jit_func 分别接收形状为 (3:1)A_tensor 和形状为 (5:1)B_tensor 时,会编译两个独立的 JIT 函数,每个函数对应一种静态布局。

  • cute.Tensor 类型会有一个布局 (3:1),与输入 A_tensor 的大小一致。
  • cute.Tensor 类型会有一个布局 (5:1),与输入 B_tensor 的大小一致。
  • 编译了两个 JIT 函数集。
  • 静态布局在不同的代码中产生结果。

动态布局 (Dynamic Layout):

通过使用 mark_layout_dynamic(),可以生成一个通用的内核,处理动态布局的张量,从而避免为每种形状都重新编译。

动态布局处理 (Page 15)
动态布局处理 (Page 15)

如上图所示,通过 mark_layout_dynamic(mode=0),即使输入不同形状的张量,也只会编译一套 JIT 函数。

  • cute.Tensor 类型会有一个动态布局 (?:1)
  • 只为这两种情况编译一个 JIT 函数集。
  • 使用动态布局来生成通用的代码。

与 Pytorch 的轻松集成:LLaMA 8b 集成示例

  • 连接门/上/下投影层与自定义线性模块
    • 在标准的 LLaMA 模型实现中,MLP(多层感知机)部分包含 gate_proj, up_proj, 和 down_proj 等线性层,它们通常使用 nn.Linear
    • 为了集成自定义的 CUTLASS 内核,可以将这些 nn.Linear 层替换为自定义的线性模块,例如 MyCutlassLinear
    • 这种替换是直接的,只需在模型初始化时将标准线性层实例化替换为自定义层的实例化即可。
Page 16: LLaMA MLP模块代码修改示例
Page 16: LLaMA MLP模块代码修改示例
  • 自定义线性模块的实现
    • 在自定义的线性模块(MyCutlassLinear)内部,可以配置并调用由 CUTLASS Python API 实现的定制化核函数。
    • 如代码所示,MyGemmKernel 是一个通过 CUTLASS Python API 定义的 GEMM(通用矩阵乘法)核函数。它在模块初始化时被实例化。
    • 在前向传播(forward)过程中,调用该核函数。
    • 利用隐式的 from_dlpack 转换,可以无缝地在 PyTorch 张量和 CUTLASS 所需的数据结构之间传递数据,无需显式的数据拷贝或格式转换。
Page 17: 自定义CUTLASS线性模块实现代码
Page 17: 自定义CUTLASS线性模块实现代码

更好的代码表达性和可读性

通过 CuTe,可以实现更简洁和可读性更高的代码,尤其是在进行操作融合(如 GEMM 与激活函数融合)时。

  • 激活函数融合的演进:

    • 传统方式: 需要显式地编写循环来遍历张量元素,并在循环体内应用激活函数(如 if a > 0 实现 ReLU)。
    • CuTe 方式: 可以通过 epilogue_op 参数传递一个 lambda 函数 lambda x: x 或更复杂的表达式,如 cute.where(x > 0, x, cute.full_like(x, 0)),直接在 gemm 操作的末尾阶段(epilogue)完成激活函数的计算,从而实现融合。
  • TensorSSA:

    • 核心理念: 为 CuTe 张量提供线程局部数据建模,具有值语义和不可变性。
    • 实现方式:
      • 基于嵌套的 CuTe 形状支持。
      • 将张量元素作为向量加载/存储到张量中。
      • 为向量化操作提供运算符重载。
Page 18: 代码可读性对比与TensorSSA介绍
Page 18: 代码可读性对比与TensorSSA介绍

自定义类 C 结构体的数据类型

@cute.struct 装饰器允许开发者像在 C 语言中一样定义具有精确内存布局的数据结构。

  • @cute.struct 装饰器功能:

    • 将一个 Python 类转换为内存映射的结构体,可以精确控制内存布局、对齐和偏移量。
    • 支持标量(scalar)、MemRange 或嵌套结构体作为其成员。
    • 允许自定义数据对齐,这对于实现高性能至关重要。
  • 示例:

    • 下图展示了如何定义一个复数结构体 complex 和一个包含嵌套结构体及内存范围的自定义存储结构 MyStorage
    • 表格清晰地展示了每个成员变量的偏移量(offset)和对齐(alignment)方式,包括数据类型本身的自然对齐和用户指定的对齐。
Page 19: @cute.struct 示例与内存布局
Page 19: @cute.struct 示例与内存布局

以面向对象编程(OOP)方式编写核函数

@cute.struct 同样有助于以更结构化和面向对象的方式组织核函数代码,特别是共享内存的管理。

  • 传统方式:

    • 手动计算共享内存的大小和偏移量。
    • 通过指针和偏移量直接访问共享内存的不同部分。
    • 代码可读性较差,容易出错。
  • OOP 方式:

    • 使用 @cute.struct 定义一个 SharedStorage 结构体来封装所有共享内存的分配。
    • 通过结构体成员(如 storage.ab_full_mbar_ptr.data_ptr())来访问共享内存,代码更清晰、更易于维护。
    • 在启动核函数时,通过 cutlass.LaunchConfig 传递结构化配置,如共享内存大小。
Page 20: 传统方式与OOP方式编写核函数的对比
Page 20: 传统方式与OOP方式编写核函数的对比

通过缓存降低核函数启动延迟

在不使用缓存的情况下,即时编译(JIT)会带来显著的开销。

  • 无缓存的问题:
    • 在每次前向传播(forward pass)中,都需要调用核函数的 JIT 编译。
    • 这个编译过程会引入显著的运行时开销,尤其是在迭代次数很多的情况下。
    • 如代码所示,self.gemm(...)forward 方法中被直接调用,每次调用都可能触发编译。
Page 21: 无缓存时JIT编译带来的开销
Page 21: 无缓存时JIT编译带来的开销

通过缓存降低核函数启动延迟:零编译方案

通过缓存已编译的 CUBIN(CUDA Binary),可以实现 "零编译"(Zero Compile),从而消除 JIT 开销。

  • 实现机制:
    • 自定义键进行核函数缓存: 使用输入张量的形状(input.shape)等作为键,在字典(self.cached_kernels)中查找或存储已编译的核函数。
    • 带 CUBIN 缓存的 JIT 执行器: 如果缓存中不存在对应键的核函数,则调用 cute.compile() 进行编译并存入缓存。后续调用将直接使用缓存中的 CUBIN,避免了重复编译。
    • 序列化与反序列化: JIT 执行器支持将编译结果序列化到文件,并在需要时反序列化,实现了持久化缓存。
Page 22: 使用缓存避免重复编译的代码实现
Page 22: 使用缓存避免重复编译的代码实现

生成、编译和启动开销对比

下图展示了在 Blackwell B100 GPU 上运行 FP16 GEMM (M=N=K=8K) 时,不同缓存策略下的开销对比。

  • 无缓存 (Without cache):

    • 完整地执行从 IR 生成、编译到核函数启动的整个流程。
    • 总开销巨大,包括 IR 生成(73,000 us)和编译(241,500 us)。
  • 稳定缓存 (Stable cache):

    • 默认开启。
    • 总是生成代码以检查功能正确性,但如果生成的代码与缓存中的一致,则跳过编译步骤。
    • 显著减少了开销,只剩下 IR 生成的成本(73,000 us)。
  • 零编译 (Zero Compile):

    • 通过用户管理的 JIT 执行器直接启动核函数。
    • 无需 IR 生成和编译,将这部分开销降至零。
    • 实现了最低的核函数启动延迟。
Page 23: 不同缓存策略的开销对比图
Page 23: 不同缓存策略的开销对比图

如果缺少功能怎么办?无缝集成原生 Op 构建器

当上层 API 缺少特定功能时,可以直接使用底层的原生操作(Op)构建器。

  • 功能:

    • 暴露 LLVM + NVVM Op 构建器: 允许开发者直接访问和使用 LLVM 和 NVVM 的操作。
    • 直接访问 NVVM 操作: 可以调用底层的 PTX 指令,实现对硬件的精细控制。
    • 直接内联 PTX: 支持将手写的 PTX 汇编代码直接内联到核函数中。
    • 不受上游破坏性变更影响: 由于直接与底层接口交互,因此可以减少对上层库版本更新的依赖。
  • 代码示例:

    • thread_idx 函数直接读取 PTX 特殊寄存器 %tid.x
    • exp2 函数直接内联 PTX 指令 ex2.approx.ftz.f32 来实现快速的指数运算。
Page 24: 集成原生Op构建器的代码示例
Page 24: 集成原生Op构建器的代码示例

Hopper GEMM 实现

GEMM 概述

  • GEMM 分块 (Tiling)
    • 在矩阵乘法中,每个 CTA(Cooperative Thread Array,线程协作阵列)被分配一个输出矩阵 C 的分块(tile)进行计算和累加。
    • 如代码 C(m,n) += A(m,k) * B(n,k) 所示,计算过程沿着 K 维度进行迭代,将 A 和 B 的相应分块相乘并累加到 C 的分块中。
Page 26: GEMM分块示意图
Page 26: GEMM分块示意图

Hopper GEMM 流水线:单个分块的视角

对于 C 矩阵的每个分块,GEMM 的计算遵循一个三阶段流水线:

  1. 序言 (Prologue): 使用 TMA (Tensor Memory Accelerator) Load 指令,将 A 和 B 矩阵的第一个分块从全局内存(GMEM)预取到共享内存(SMEM)中。
  2. 主循环 (Mainloop): 使用 Tensor Core 计算 C = A*B + C,同时使用 TMA Load 指令加载下一个 K 维度的 A/B 矩阵分块,实现了计算与数据加载的重叠。
  3. 尾声 (Epilogue): 使用 TMA Store 指令,将最终计算完成的 C 矩阵分块从共享内存(SMEM)写回到全局内存(GMEM)。
Page 27: Hopper GEMM 单分块流水线示意图
Page 27: Hopper GEMM 单分块流水线示意图

MMA/Copy 原子操作

CuTe 的设计理念

对于已经使用 CUTLASS-C++ 的开发者来说,CuTe 的概念会非常熟悉。

  • CuTe 布局 (Layouts) 和张量 (Tensors):

    • 提供了完整的代数功能。
    • 完全支持混合静态/动态布局。
  • 稳健的张量编程模型: CuTe 的编程模型与 CUTLASS-C++ 的用户所熟悉的模型非常相似。

  • 抽象层次:

    • CuTe 布局和代数 (CuTe Layouts and Algebra): 构成了整个库的基础。
    • MMA/拷贝原子操作 (MMA/Copy Atoms): 使用布局来封装带有元数据的 PTX 指令。
    • 分块 MMA/拷贝 (Tiled MMA/Copy): 提供稳健的通用分区接口。
    • cute.{gemm|copy} 算法: 自动分派到请求的 PTX 指令。
Page 29: CuTe的抽象层次和设计理念
Page 29: CuTe的抽象层次和设计理念

cute.gemm

cute.gemm 是执行矩阵乘法的高级接口。

  • MMA Op:

    • MMA (Matrix Multiply-Accumulate) 操作是 PTX 指令的抽象。
    • SMEM 描述符定义了 A 和 B 操作数。
    • 指令描述符定义了形状和数据类型。
  • 使用流程:

    1. 用户选择一个 MMA Op (例如 MmaBF168F16Op) 来定义计算的核心操作。
    2. 使用 cute.make_tiled_mma 和选择的 MMA Op 来创建一个 TiledMMA 对象,它定义了如何在线程块级别上进行分块计算。
    3. 调用 cute.gemm 函数,并传入 TiledMMA 对象以及作为操作数的累加器、A 张量和 B 张量,以执行分块矩阵乘法。
Page 30: cute.gemm 使用流程和图示
Page 30: cute.gemm 使用流程和图示

下图展示了使用 cute.copycute.gemm 的通用矩阵乘法(GEMM)的数据流。

  • 使用 cute.copy 将全局内存中的 gAgB 瓦片(tiles)复制到共享内存 sAsB 中。
  • cute.gemm 操作利用共享内存中的数据 sAsB,在寄存器 rC 中执行矩阵乘法计算。
Page 32
Page 32

cute.copy

CuTe TMA Atoms 为 TMA (Tensor Memory Accelerator) 提供 PTX 和元数据。

  • TMA Op -- PTX for TMA

    • CopyBulkTensorTileG2S0p
    • CopyBulkTensorTileG2SMulticast0p
    • CopyBulkTensorTileS2G0p
  • CopyAtom

    • 使用元数据封装 PTX。

TMA Copy 代码示例:

op = cute.nvgpu.cpasync.CopyBulkTensorTileG2S0p()

tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tma_tile_atom(
    op,
    gmem_tensor,
    smem_layout,
    cta_tiler,
)

cute.copy(
    tma_atom_a,
    gmem_tensor,
    smem_tensor,
    tma_bar_ptr,
    mcast_mask,
)
Page 31
Page 31

Hopper GEMM 代码演示

Hopper 中的流水线 (Pipeline)

采用非 Warp 专用核(Non Warp Specialized Kernel)风格。

生产者:TMA 加载

  • PipelineTmaAsync 操作: producer_acquire, producer_commit, consumer_wait, consumer_release
  • 流程:
    1. 等待 smem_free 屏障。
    2. 发出带有 smem_ready 屏障的 TMA 指令。
    3. 更新传输字节数并在 smem_ready 屏障上抵达(arrive)。

消费者:数学计算

  • 流程:
    1. 等待 smem_ready 屏障。
    2. 发出 GMMA (Global Memory Matrix-multiply Accumulate) 指令,等待工作完成。
    3. smem_free 屏障上抵达。
Page 34
Page 34

加载 A/B (从 GMEM 复制到 SMEM)

  • 等待空闲屏障(empty barrier)就绪。
  • 使用 TMA LOAD 将 A/B 瓦片从 GMEM 复制到 SMEM。
  • 更新屏障状态。
加载 A/B 的数据流和代码示例
加载 A/B 的数据流和代码示例

代码逻辑:
在 K 维度上循环,每个循环:
1. mainloop_pipeline.producer_acquire: 等待一个空的共享内存缓冲区。
2. cute.copy: 使用 TMA atom 将数据从全局内存复制到共享内存。
3. mainloop_pipeline.producer_commit: 提交写入操作。
4. mainloop_pipeline.producer_state.advance(): 更新流水线状态。

计算 C = A * B (GEMM)

  • 等待 TMA 复制操作完成。
  • 使用 Hopper Tensor Core 计算 C = A * B。
  • 更新屏障状态。
计算 C = A * B 的数据流和代码示例
计算 C = A * B 的数据流和代码示例

代码逻辑:
在 K 维度上循环,每个循环:
1. mainloop_pipeline.consumer_wait: 等待 TMA 加载数据完成。
2. cute.nvgpu.wgmma.fence(): WGMMA (Warp Group MMA) 栅栏同步。
3. 在 k_blocks 上循环,调用 cute.gemm() 执行矩阵乘法。
4. cute.nvgpu.wgmma.commit_group()wait_group(): 提交并等待 WGMMA 操作完成。
5. mainloop_pipeline.consumer_release: 释放共享内存缓冲区。
6. mainloop_pipeline.consumer_read_state.advance(): 更新流水线状态。

存储 C (从 SMEM 复制到 GMEM)

  • 复制累加器(Accumulator)中的数据。
  • 进行类型转换。
  • 将 C 瓦片从 RMEM (Register Memory) 复制到 SMEM。
  • 使用 TMA Store 将 C 瓦片从 SMEM 复制到 GMEM。
存储 C 的数据流和代码示例
存储 C 的数据流和代码示例

代码逻辑:
在 Epilogue 阶段:
1. 将累加器中的数据复制到寄存器 tRS_rD
2. cute.make_fragment_like: 进行类型转换。
3. cute.copy: 将寄存器数据复制到共享内存 tRS_sD
4. cute.arch.barrier: 同步。
5. cute.copy: 使用 TMA atom 将共享内存中的数据异步写回全局内存。
6. cute.arch.cp_async_bulk_commit_group()wait_group(): 提交并等待异步复制操作完成。

性能比较

Hopper 性能:Python vs. C++ (FP16 I/O GEMM,M=N=8192)

下图比较了在 M=N=8192 的 GEMM-K 尺寸下,Python 和 C++ 实现的数学计算吞吐量(Math SOL%)。结果显示,C++ 的性能略高于 Python。

  • 性能差距 (Perf gap)
    • Python 示例不包含持久化(persistent)优化,因此在小 GEMM-K 情况下,CTA (Cooperative Thread Array) wave 之间的开销更大。

测试规格: H100* 80GB HBM3, 132SM GPC-1500MHz/DRAM 2619MHz 700W, 128x256x64 cooperative size, Swizzle size = 8.

Page 39
Page 39

Hopper 性能:Python vs. C++ (FP16 I/O GEMM,M=N=2048)

下图比较了在 M=N=2048 的 GEMM-K 尺寸下,Python 和 C++ 实现的性能。

  • 性能差距 (Perf gap)
    • 对于小的 GEMM-K 情况,这是一个单 wave (single wave) 场景。
    • C++ 因其 warp 专用的持久化调度器(warp-specialized persistent scheduler)在瓦片调度计算上产生了更多开销,而 Python 示例中未使用该调度器。

测试规格: H100* 80GB HBM3, 132SM GPC-1500MHz/DRAM 2619MHz 700W, 128x256x64 cooperative size, Swizzle size = 8.

Page 40
Page 40

结论

CuTe DSL:一种基于 Python 的编程语言,使用 CuTe 语义对 Tensor Core 进行编程,以实现最佳性能。

  • 在 Python 中编写核函数:而不仅仅是调用 CUTLASS 已有的核。
  • 易于集成:可与 Pytorch 等流行的 Python 框架轻松集成。
  • 精确的硬件建模:可完全控制性能,实现与 C++ 相媲美的性能。
  • 基于 MLIR 框架:利用 MLIR 生态系统的强大功能。
  • 极快的编译速度
Page 41
Page 41