Balancing the Compute Throughput & Latency in Async Programming
Balancing the Compute Throughput & Latency in Async Programming
Petrick Liu, Jiang Shao, NVIDIA DevTech Team | AI Open Day / 2025.05.30
目录
- 快速回顾:Hopper之前的异步编程
- 从Ampere到Hopper架构的演进
- Mbarrier:异步编程的关键促成因素
- Hopper Warp Specialized GEMM
- Hopper Tensor Core 基础概念
- 最大化Hopper Tensor Core吞吐量
- Warp 组执行排序
- 从 Hopper 到 Blackwell 的演进
- 案例研究:W4A8 Hopper Grouped Gemm Kernel
快速回顾:Hopper之前的异步编程
在Hopper架构之前,典型的分块通用矩阵乘法(Blocked GEMM)通过将数据在不同层级内存(全局内存、共享内存、寄存器文件)之间移动,并利用CUDA核心/张量核心(Tensor Cores)进行计算。其基本流程是通过多级分块(Thread Block Tile, Warp Tile, Thread Tile)来管理数据。
如上图所示,基本的计算循环包含:
1. 从共享内存加载数据A (load_A_tile) 和数据B (load_B_tile) 到寄存器文件。
2. 线程同步 (__syncthreads())。
3. 从寄存器文件加载片段 (load_A_frag, load_B_frag)。
4. 执行矩阵乘加运算 (mma)。
5. 再次线程同步。
这种模式中,数据加载(LDGSTS,从全局内存到共享内存)和计算(MMA)交替进行,导致张量核心(TC Active)存在空闲时间,因为计算必须等待数据加载完成。
为了建立一个稳定的计算流水线,需要一个“序言”(Prologue)阶段来预加载数据。如下图所示,在计算开始前,需要执行一系列的全局内存加载指令(LDGSTS Ktile0 至 Ktile4)来填充数据缓冲区。这个过程会引入启动延迟(Gmem Latency),在此期间张量核心处于非活动状态。
为了隐藏和重叠异步操作的延迟,需要精心设计的流水线。关键的优化思路是:
- 采用高度流水线化的设计来最大化吞吐量。
- "序言"阶段是实现高吞吐量所需付出的代价。
- 在主循环(Mainloop)中,通过RF双缓冲(RF double buffer)等技术,实现全局内存加载(Gmem Loading)、共享内存加载(Smem Loading)与张量核心计算(TC computing)的完全重叠。
从Ampere到Hopper架构的演进
不同架构在处理计算任务时采用了不同的调度和资源分配策略,这直接影响了效率和延迟隐藏。
-
Ampere架构:
- 当每个SM(Streaming Multiprocessor)上只有一个CTA(Cooperative Thread Array)时,序言(prolog)和尾声(epilog)的开销会暴露出来,无法被有效隐藏。
- 增加SM上的CTA数量(例如2 CTA/SM)可以帮助隐藏一部分延迟,但如果一个CTA正在运行而共享内存(SMEM)利用率较低,会导致效率下降。
-
Hopper架构:
- Hopper引入了Warp专业化(warp specialization),将不同的任务分配给不同类型的Warp。
- DMA Warps 专门负责数据预取(prolog),将数据从全局内存搬运至共享内存。
- Math Warps 专门负责计算。
- 这种设计允许在一个Warp组进行计算的同时,另一个Warp组的尾声(epilog)可以与数学计算重叠。例如,为第二个数据块(tile)启动数据预取(prolog)的同时,第一个数据块正在进行计算。
- Hopper通过持久化GEMM(persistent gemm)和Warp专业化,实现了仅用1个CTA/SM的占用率就能高效隐藏延迟,显著提升了计算单元的利用率。
Mbarrier:异步编程的关键促成因素
Hopper架构引入了Mbarrier,这是一种用于Warp间通信和同步的强大机制,是实现高效异步编程的关键。它支持创建生产者-消费者(Producer-Consumer)模型,特别是在Warp专业化的GEMM中。
Hopper Warp专业化GEMM流程:
-
生产者Warp (TMA Warps):
- 负责数据加载,使用
CollectiveMma::load_a等接口。 - 等待
Smem_empty屏障,确保消费者已使用完上一批数据。 - 发出TMA(Tensor Memory Accelerator)指令,将矩阵A和B的数据块从全局内存异步加载到共享内存,并更新
Smem_full屏障。 - 到达
Smem_full屏障,通知消费者数据已准备好。
- 负责数据加载,使用
-
消费者Warp (TC Warps):
- 负责计算,使用
CollectiveMma::mma等接口,并且是持久化的。 - 等待
Smem_full屏障,直到生产者准备好数据。 - 发出WGMM (Warp Group Matrix Multiply) 指令进行计算。
- 计算完成后,到达
Smem_empty屏障,通知生产者可以加载新数据。 - 将计算结果写回寄存器文件,并最终通过TMA指令写回全局内存。
- 负责计算,使用
共享内存作为数据缓冲区,通过Mbarrier对象(包含Smem_empty和Smem_full状态)进行同步,实现了数据加载和计算的高度流水化。
Mbarrier与TMA编程模式详解
以下通过一个逐步示例解释Mbarrier的工作机制。
1. 初始化
- TMA Warp和TC WarpGroups共享一个
Mbarrier对象。 - 调用
Init_mbarrier(&bar, 1)进行初始化,设置预期到达数量为1。 - TC WarpGroups进入
while (try_wait(&bar, phase)) {}循环,等待phase翻转。 - 初始状态:
Phase = 0,Expect Arrv_Cnt = 1,Actual Arrv_Cnt = 0,Expect Trans_Bytes = 0,Actual Trans_Bytes = 0。
2. TMA发出加载和到达指令
- TMA Warp中的一个线程 (if(tma_thread)) 发出TMA_bulk_load指令,请求加载16KB数据。
- 接着发出mbarrier_arrive_expect(&bar, 16KB),通知Mbarrier本次事务预期传输16KB数据。
- Mbarrier状态更新:Expect Trans_Bytes变为16KB。Actual Arrv Cnt根据执行线程数更新。
3. TMA数据传输
- TMA开始将数据从全局内存加载到共享内存。
- Mbarrier会追踪实际已传输的字节数。图中显示已传输1KB,Actual Trans_Bytes更新为1KB。TC Warp仍然在try_wait处阻塞。
- 数据继续传输,
Actual Trans_Bytes更新为4KB。
4. 事务完成
- 当全部16KB数据传输完成时,Mbarrier记录整个事务完成。Actual Trans_Bytes变为16KB。
5. Phase翻转与消费者唤醒
- 由于事务完成,Mbarrier的Phase从0翻转到1。
- TC Warp的try_wait(&bar, phase)条件满足,跳出循环("Pass here!")。
- TC Warp开始使用共享内存中的数据进行WGMMA计算。
- Mbarrier状态被重置,为下一轮数据传输做准备(例如,Expect Trans_Bytes重置为0)。
Mbarrier 与 TMA 编程模式
Mbarrier 状态转换
该编程模式涉及一个TMA Warp(生产者)和一个或多个TC WarpGroup(消费者),通过两个mbarrier对象(bar_full 和 bar_empty)进行同步。
初始状态与第一阶段 (Page 16)
- 初始化:
- Init_mbarrier(&bar_full, 1): 初始化bar_full,期望到达数为1(来自TMA Warp)。
- Init_mbarrier(&bar_empty, 128): 初始化bar_empty,期望到达数为128(来自TC WarpGroups)。
- mbarrier_fence(): 确保初始化完成。
-
TMA Warp (生产者):
issue_TMA_bulk_load(...): 发出异步数据加载指令,目标是SMEM。mbarrier_arrive_expect(...): 到达bar_full屏障,表示数据正在传输中。
-
TC WarpGroups (消费者):
while (try_wait(&bar_full, phase)) {}: 等待bar_full的阶段(phase)翻转。此时会阻塞,因为TMA的数据加载尚未完成。
-
Mbarrier状态:
Mbarrier Smem_Full: 阶段为1,期望到达数为1,实际到达数为0。Mbarrier Smem_Empty: 阶段为0,期望到达数为128,实际到达数为0。
数据就绪与消费 (Page 17)
- TMA Warp: TMA加载完成,
bar_full的实际到达数变为1,满足期望值,于是bar_full的phase翻转。 -
TC WarpGroups: 检测到
bar_full的phase翻转,try_wait成功,循环退出。WGMMA(...): 执行矩阵乘累加计算,消费SMEM中的数据。WAIT_WGMMAs(): 等待WGMMA计算完成。mbarrier_arrive(&bar_empty): 所有消费者Warp完成计算后,到达bar_empty屏障。
-
Mbarrier状态:
Mbarrier Smem_Full:phase翻转,实际到达数与期望数均为1。Mbarrier Smem_Empty: TC WarpGroups到达后,实际到达数变为128。
SMEM释放与循环 (Page 18)
- TC WarpGroups: 当所有128个线程到达
bar_empty后,满足其期望到达数,bar_empty的phase翻转。 -
TMA Warp:
while (try_wait(&bar_empty, empty_phase)) {}: 等待bar_empty的phase翻转。- 当
bar_empty的phase翻转后,try_wait成功,TMA Warp被唤醒。这意味着SMEM中的数据已被消费,可以安全地加载新数据。 - 循环回到开头,准备发出下一次TMA加载指令。
-
Mbarrier状态:
Mbarrier Smem_Empty:phase翻转,实际到达数与期望数均为128。
第二次迭代 (Page 19)
- 这是循环的第二次迭代。
- TMA Warp和TC WarpGroups中的
phase变量递增(例如,int empty_phase = 1;)。 - 整个过程重复:TMA加载数据 ->
bar_full翻转 -> TC消费数据 ->bar_empty翻转。
Hopper Warp Specialized GEMM
TMA-TC 生产者-消费者模型
这是一种利用Hopper架构特性(如TMA、WGMMA、Mbarrier)实现的高效GEMM计算模型,其核心思想是生产者和消费者的解耦。
-
生产者 (TMA Warps):
- 等待
Smeme_empty屏障,确保有空的缓冲区。 - 发出TMA指令加载A和B矩阵块,并更新
Smeme_full屏障。 - 更新传输字节数,并到达
Smeme_full屏障。 - 循环处理K维度上的所有分块。
- 等待
-
消费者 (TC Warps):
- 等待
Smeme_full屏障,确保数据已加载到SMEM。 - 发出WGMMA指令进行计算,等待上一批计算完成。
- 计算完成后,到达
Smeme_empty屏障,表示该缓冲区已可重用。 - 循环处理K维度。
- 将RF中的计算结果写回到SMEM,再由TMA写回全局内存。
- 等待
-
关键特性:
- 生产者和消费者完全解耦。
- 生产者可以更新任何一个空的阶段(缓冲区)。
- 完成状态在线程块(block)级别可见,而不仅仅是线程内。
- 可以异步、非阻塞地检查完成状态。
-
思考题: 如何初始化屏障,使得生产者可以跳过第一次对
empty缓冲区的检查?- 这通常通过初始化
empty屏障的阶段来实现,使其看起来好像消费者已经“完成”了第一轮消费,从而立即释放第一个缓冲区给生产者。
- 这通常通过初始化
CUTLASS中的流水线实现
CUTLASS库提供了用于实现这种复杂流水线模型的原生组件。
cutlass/pipeline/sm90_pipeline.hpp:PipelineState结构体管理流水线的状态,包括当前阶段索引index_、阶段标志phase_和计数器count_。- 通过重载
++运算符来推进流水线阶段,当索引绕回时翻转phase_。这与前面mbarrier的phase机制相对应。
- TMA-TC 生产者-消费者模型原语:
- CUTLASS为生产者和消费者提供了清晰的API,用于管理流水线状态。
- 生产者API:
producer_try_acquire,producer_acquire(获取空闲缓冲区),producer_commit(提交数据)。 - 消费者API:
consumer_try_wait,consumer_wait(等待数据就绪),consumer_release(释放已用缓冲区)。 - 这些高级抽象封装了底层的
mbarrier操作,简化了异步流水线的编程。
Hopper Tensor Core 基础概念
Hopper WGMMA 指令
Hopper架构引入了Warp Group MMA (WGMMA),以 warp group(128个线程)为单位进行矩阵乘法。
- Hopper Tensor Core 主要特性:
- Warp Group协作: 128个线程(一个Warp Group)协作执行矩阵乘法。
- 指令形态:
64xNx256bit,其中N在[8, 256]之间,步长为8。4个Warp分布在M维度,每个Warp执行16xNx256bit的计算。 - 操作数: 操作数B在SMEM中,由4个warp共享。操作数A可以来自寄存器文件(RF)或SMEM。
- Smem描述符: 用于描述SMEM中操作数的布局。支持
NO_SWIZZLE,SWIZZLE_32B,SWIZZLE_64B,SWIZZLE_128B等多种内存排布模式,且与TMA的排布模式兼容。 - 异步执行: WGMMA是异步指令,使用
Group Commit和Wait来跟踪完成情况。
典型的WGMMA指令序列
一个典型的WGMMA计算流程如下:
-
wgmma.fence.sync.aligned;- 栅栏指令,确保所有线程的SMEM和RFs都已就绪。
-
wgmma.mma_async.aligned.m64n128k16.f32.f16.f16;...- 发出若干条异步的
wgmma指令。这些指令被分组执行。
- 发出若干条异步的
-
wgmma.commit_group.sync.aligned;- 提交上述指令组,使其开始执行。
-
wgmma.wait_group.sync.aligned 0;- 等待组号为0的指令组完成计算。
最大化Hopper Tensor Core吞吐量
- Hopper Tensor Core的WGMMA是异步指令,使用
wgmma.wait_group来同步和跟踪完成状态。 - 对于计算密集型工作负载,为了充分利用Tensor Core,必须保持其持续繁忙。
延迟问题
- 如果不使用流水线,Tensor Core在等待数据从全局内存加载到SMEM期间会处于空闲状态,这导致了性能瓶颈。
- 下图展示了Warp调度器在等待SMEM数据到达(
Wait Smem0 Arrv)时,Tensor Core单元处于空闲(Idle)状态。
使用流水线隐藏延迟
- 为了解决空闲问题,可以采用流水线技术,例如软件流水线(或双缓冲)。
- 当Tensor Core正在处理当前阶段(Stage N)的数据时,TMA可以异步地将下一阶段(Stage N+1)的数据加载到另一块SMEM缓冲区中。
- 这样,数据加载的延迟就被计算的执行时间所隐藏,从而使Tensor Core保持繁忙。
- 下图展示了一个二阶流水线(Stage2 example)的例子,其中对
Smem1的等待(Wait Smem1 Arrv)与前一个WGMMA执行重叠,消除了空闲周期。
CUTLASS中的WGMMA流水线实现
CUTLASS提供了WGMMA多级流水线的实现。代码逻辑分为两个主要部分:
-
MMA多级流水线序言 (Prologue):
- 在主循环开始前,预取多个阶段的数据,以填满流水线。
- 这部分代码会等待数据可用 (
pipeline.consumer_wait),然后执行WGMMA计算。
-
MMA多级流水线主循环 (Mainloop):
- 在主循环的每次迭代中,对当前k分块进行计算,同时异步获取下一个k分块的数据。
- 核心逻辑是:等待下一份数据 (
consumer_wait) -> 对当前数据进行计算 (tiled_mma) -> 释放之前用过的缓冲区 (consumer_release) -> 推进流水线 (smem_pipe.release)。 - 通过这种方式,只在流水线的末尾释放SMEM,使数据在需要时一直保留在SMEM中(
Keep MMAs in flight)。
- 思考题: 为什么WGMMA的prologue代码被分成两部分?
- 这种分割通常是为了处理流水线的启动阶段。第一部分可能用于初始化累加器(例如设置为0),而第二部分则在循环中执行实际的预取和计算,以填满流水线。
Warp 组执行排序
为了错开两个Warp组的执行,可以使用 OrderedSequenceBarrier。以下代码片段展示了如何命令两个数学Warp组(Math WG)的MMA(矩阵乘法累加)操作,这有助于隐藏尾声(epilogue)的开销。
代码逻辑如下:
- 主循环:当有工作瓦片(work tile)时,计算M、N、K坐标。
- 内存分配:为(M,N)块形状的累加器分配内存。
- 排序MMA:使用
math_wg_order.barrier.wait()来命令两个数学Warp组的MMA操作。 - 主循环流水线:执行
collective_mainloop.mma,处理主循环流水线、消费者状态、累加器等。 - 启动下一个MMA:使用
math_wg_order.barrier.arrive()为下一个数学Warp组的MMA做准备。 - 同步:确保所有指令完成且自由缓冲区清空,然后进入主循环。
- 更新流水线:更新主循环流水线的消费者状态。
- 尾声与存储:在循环外,执行尾声(epilogue)并存储结果。这包括加载流水线、消费者状态、存储流水线、问题形状、块坐标、累加器等。
- 更新加载/存储状态:为下一个瓦片更新加载/存储流水线的状态。
- 等待:等待所有TMA存储完成。
- 调度:获取下一个工作瓦片并调度。
从 Hopper 到 Blackwell 的演进
新特性/关键特性
SM100 (Hopper)
-
Tensor Core:
tcgen05.mma家族- 1个线程执行CTA/Cluster范围的TC
- 新数据类型支持: MXFP8/6/4, NVFP4
- 支持块缩放 (Block scaling)
- 引入Tensor Memory,累加器在Tmem中
- 纯异步执行 (Pure Asynchronous execution)
-
TMA:
- 新的 scatter,
im2col::w模式加载/存储 - 新的 cluster arrive 模式,为
tcgen05mma 设计
- 新的 scatter,
-
Persistent:
- 新的 Cluster launch control 功能
-
所有以上功能均基于 mBarrier 编程
SM120 (Blackwell)
-
Tensor Core: Ampere 风格的 mma 家族
- 32线程协作执行
- 新数据类型支持: MXFP8/6/4, NVFP4
- 支持块缩放 (Block scaling)
- 累加器仍在RF(寄存器文件)中
- 纯同步执行 (Pure Synchronous execution)
-
TMA:
- 新的 scatter,
im2col::w模式加载/存储
- 新的 scatter,
-
Persistent:
- 新的 Cluster launch control 功能
-
除Tensor Cores外,所有以上功能均基于 mBarrier 编程
轻松最大化 Blackwell Tensor Core 吞吐量
- Blackwell Tensor Core WGMMA 是一条异步指令。
- 我们使用
tcgen05.commit指令将 Tensor Core 的完成状态与mBarrier连接起来。 - 将完成状态从线程级别(thread level)移至CTA级别(CTA level)。
- 将累加器从线程级别移至CTA级别。
下图展示了Warp调度器如何处理 tcgen05.mma 指令。硬件(HW)会异步跟踪TC(Tensor Core)的完成情况,并通过 mBarrier 进行更新。Wait Smem0 表示等待共享内存数据,Commit 表示提交任务,Issue 表示分发MMA指令。
静态分块调度 (Static Tile Scheduling)
-
CUTLASS Hopper Persistent Scheduling:
- 持久化CTA的数量等于SM的数量。
- 输出分块(tile)到SM的映射是固定的,在内核启动时决定。
-
静态调度的问题:
- 如果一个SM的上下文切换到另一个内核,其输出分块不会被重新映射,导致执行尾部拖得很长(long execution tail)。
如下图所示,当SM 2被另一个网格(Other Grid)占用时,原先分配给它的分块(如Tile 102, 202, 302)必须等待SM 2空闲后才能被处理,这造成了整体执行时间的延长。
动态分块调度 (Dynamic Tile Scheduling)
- Blackwell允许用户通过
clusterlaunchcontrolPTX指令以编程方式在SM上获取线程块/集群(Thread Block/Clusters)。 - CUTLASS Dynamic Persistent Scheduler:
- 利用
clusterlaunchcontrol实现动态持久化调度器。 - 输出分块到SM的映射完全基于SM的进度动态进行。
- 动态持久化调度器在CUTLASS on Blackwell上默认开启。
- 利用
如下图所示,当SM 2被占用时,动态调度器会将原本分配给它的任务(如Tile 301)重新分配给其他空闲的SM(如SM 1),从而避免了执行延迟,优化了资源利用率。
Warp专用持久化内核:Hopper vs Blackwell 对比
Hopper Ping-pong Warp Specialization
- 累加器位于RMEM(寄存器文件)中,尾声(Epilogue)和MMA在相同的线程上执行。
- 通过在两个MMA组之间进行“乒乓操作”(ping-pong)来重叠尾声开销。
- 如图所示,数据加载(TMA Tile 1)、MMA计算(Tensor Core Tile 1)和尾声(Epilogue Tile 1)是交错进行的,需要两个MMA+epilogue warp组(#1和#2)来隐藏延迟。
Blackwell Warp Specialization
- 累加器位于TMEM(Tensor Memory)中,尾声和MMA在不同的线程上执行。
- 因此,不再需要两个MMA线程来重叠尾声。
- 如图所示,调度(Scheduling warp)、数据加载(Data loading warps)、MMA计算(MMA warp)和尾声(Epilogue warps)由专门的warp组负责,流水线更加清晰高效。
Blackwell 持久化内核高层视图
这是一个SM100内核的视图,展示了不同类型的Warp(线程束)如何协作:
- Sch Warp (WarpId = 3):调度Warp,负责管理工作负载(Workld)。当工作负载为空(Workld_Empty)时,需要从外部获取输入偏移(input offset);当工作负载满(Workld_Full)时,通知工作负载流水线消费者。它也负责在需要时停止其他Warp。
- TC Warp (WarpId = 2):Tensor Core Warp,执行核心计算任务。它是TMA主循环流水线的消费者和尾声流水线的生产者。
- TMA Warps (WarpId = 0, 1):负责数据加载。它们是TMA主循环流水线的生产者。
- EpilogueWarps (WarpId = 4,5,6,7):负责尾声处理。它们是尾声流水线的消费者。
这些Warp通过共享内存(SMEM)和张量内存(TMEM)的状态(Full/Empty)以及工作负载队列进行同步和通信,形成一个高效的生产者-消费者流水线模型。
从Hopper到Blackwell(SM_120)的性能演进
此图比较了不同架构和调度策略下的执行时间线,其中不同的Warp被分配了不同的任务(prolog, mainloop, epilog)。
- Ampere: 采用1个CTA/SM的占用率(共6个CTA),显示了基准性能。
- Blackwell Cooperative Dynamic Persistent: 使用2个数学Warp组,通过协作和动态持久化调度,相比Ampere实现了显著的加速(Speedup)。
- Blackwell Pingpong Dynamic Persistent: 同样使用2个数学Warp组,采用乒乓机制和动态持久化调度,也获得了显著的加速(Speedup)。
Blackwell架构通过动态调度和专门的Warp任务分配,有效缩短了计算时间。
案例研究:W4A8 Hopper Grouped Gemm Kernel
混合数据类型分组GEMM (Mixed Data Type Grouped GEMM)
- 应用场景: 混合专家模型(Mixture-of-Experts, MoE)语言模型,如DeepSeek-V3/R1, Qwen3, Mixtral。
-
GEMM 输入:
- 组(Experts): 多个独立的GEMM运算。
- 矩阵 A (Activations): FP8, 张量级对称量化, K-major布局。
- A的缩放因子: FP32。
- 矩阵 B (Weights): INT4, 沿K维度组级对称量化, K-major布局。
- B的缩放因子: FP16。
-
计算流程:
- 方法一(效率低): 由于CUDA核心不支持FP8操作,需要将INT4矩阵B转换为FP16,然后进行缩放,再转换为FP8。这个过程涉及多次转换和缩放。
- 方法二(效率高): 直接将INT4矩阵B转换为FP8,与FP8矩阵A进行MMA运算,得到FP32中间累加结果。在累加后,进行
scaling_B & add操作,最后再进行scaling_A & convert操作得到FP16的输出矩阵C。这种方法显著减少了缩放操作的次数。
混合数据类型分组GEMM - CUTLASS 实现
-
CUTLASS 示例 69: Hopper Mixed Dtype Grouped GEMM
- 示例代码路径:
https://github.com/NVIDIA/cutlass/tree/v3.9.2/examples/69_hopper_mixed_dtype_grouped_gemm
- 示例代码路径:
-
GEMM 输入 (与前页相同):
- 组 (Experts)
- 矩阵 A (Activations): FP8
- 矩阵 B (Weights): INT4
- A的缩放因子: FP32
- B的缩放因子: FP8 (通过
lookup_table_convert获得)
-
核心代码:
cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hppCUTLASS_DEVICE void mma(...)
下图展示了将INT4的矩阵B先convert为FP16,再进行scaling_B,最后convert为FP8,然后与FP8的矩阵A进行MMA运算。
混合数据类型分组GEMM - 优化实现
此页展示了对前一页流程的优化。通过使用查找表转换(lookup_table_convert),将INT4矩阵B的转换和缩放操作合并为一步(convert & scaling_B),直接生成FP8格式的数据。这减少了中间步骤和数据移动,提高了效率。
混合数据类型分组GEMM - 工作分配
- 工作被分配到CUDA线程块(thread blocks)中,通过使用
TileShape对每个组的输出矩阵C进行分区。
下图展示了如何将两个组(Group 0和Group 1)的计算任务进行划分。每个组的输出矩阵C被划分为多个Block(Block 0-3等),每个Block由一个CUDA线程块负责计算。输入矩阵A和B也相应地按TileShape进行划分。
混合数据类型分组GEMM - 块内执行
- 每个线程块沿K维度进行遍历。
- 执行流程包含一个外层循环(
for (stage=0; ...))用于多级流水线(multistage),以及一个内层循环(for (k_block=0; ...))。 - 内层循环执行SMEM加载和WGMMA(Warp Group MMA)计算。
下图左侧展示了线程块如何分阶段(Stage 0, Stage 1)处理K维度的不同块(k_block 0, k_block 1)。右侧图则展示了用于计算的TileShape。
实现细节与优化
流水线设计
为了高效执行混合精度矩阵乘法,设计了一个流水线(Pipeline)机制。其核心思想是重叠数据加载、转换和计算操作。
- 文件路径:
cutlass/include/cutlass/gemm/collective/sm90_mma_array_tma_gemma_rs_warpspecialized_mixed_input.hpp - 输入数据: 假定输入数据已在共享内存(SMEM)中准备就绪。MMA(Matrix Multiply-Accumulate)操作在寄存器中执行并存储结果。
- 特定配置: 以
TileShape::K = 128和GMMA_K = 32为例,每个K Tile被划分为4个kblocks。
下图展示了一个流水线执行流程。矩阵B(INT4)通过lookup_table_convert进行转换和缩放,变为FP8格式。然后与FP8格式的矩阵A进行MMA操作,累加结果为FP32。最后,对累加结果进行缩放和转换,得到FP16格式的输出矩阵C。整个过程在多个kblock上以流水线方式执行,并使用双缓冲(double buffer)来隐藏数据传输延迟。
该流程分为两个主要阶段(Stage 0 和 Stage 1),并由copy_tensors_MK、dequantize_A_kblock、cute::gemm等核函数以及warpgroup_wait同步原语协调。
Hopper Tensor Core MMA 指令
Hopper架构引入了特定的Tensor Core指令 wgmma.mma_async,用于实现异步的矩阵乘法与累加操作。
- 参考链接: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions
- 核心特性:
- warpgroup级别: 操作由一个warpgroup(一组4个连续的warp)执行。
- 异步执行:
mma_async允许计算与数据移动重叠。 - 操作数位置: 操作数A可以来自寄存器文件(RF)或共享内存(SMEM)。操作数B在SMEM中,并由4个warp共享。
- 数据类型: 支持多精度数据类型,如FP8(e4m3, e5m2)和稀疏计算。
- 指令形状 (Shape): 一个典型的例子是
.m64n16k32。
在混合专家(MoE)模型推理中,有两个关键点:
- INT4数据的转换需要在寄存器文件(RF)中进行。
- 每个专家处理的token数量通常小于64。
下图右侧展示了.m64n16k32指令如何将128个线程(T0-T127)组织成4个warp来执行计算。
使用 CuTe 进行数据布局
CuTe是一个基于C++的库,用于描述和操作张量在GPU内存中的布局,是CUTLASS 3.x的核心组件。
- CuTe入门: https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/cute
- 调试工具:
cute::print函数可以打印出几乎所有CuTe类型的布局,包括指针、整数、步长(Strides)、布局(Layouts)和张量(Tensors)。 -
示例定义:
TileShape->(_:128, _:16, _:128)- GMMA 指令形状 ->
.m64n16k32
-
共享内存布局:
-
SmemLayoutA(Tensor sA):- Shape:
(BLK_M, BLK_K, STAGE)->((_8, _16), (_128, _1), (_19)) - Stride:
((_128, _1024), (_1, _0), (_0, _16384))
- Shape:
-
SmemLayoutB(Tensor sB):- Shape:
(BLK_N, BLK_K, STAGE)->((_8, _2), (_128, _1), (_19)) - Stride:
((_128, _1024), (_1, _0), (_0, _2048))
- Shape:
-
下图直观地展示了A和B矩阵在逻辑上的Tile划分。
分区策略 (Partitioning Strategy)
使用CuTe,可以将全局的Tile划分为每个线程和warpgroup负责处理的数据分片。
- 代码实现:
// 获取当前线程/warpgroup在MMA操作中的数据分片
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
auto mma_warpgroup_slice = tiled_mma.get_warpgroup_slice(warp_group_idx);
// 为A和B分配张量片段和描述符
Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // MMA_M, MMA_N, MMA_K, PIPE
Tensor tCsA = mma_thread_slice.partition_A(sA);
...
Tensor tCrB_mma = mma_warpgroup_slice.make_fragment_B(tCsB); // MMA_M, MMA_N, MMA_K, PIPE
Tensor tCsB = mma_warpgroup_slice.partition_B(sB);
- GMMA RS 特性:
- 操作数A来自RF(线程本地)。
- 操作数B位于SMEM中,由一个warpgroup(4个warp)共享。
- 命名约定解读:
tCsA:tC代表“分区模式”,sA表示该模式应用于张量sA。sA,sB-> SMEM中的张量。rA,rF,rB-> 只是张量sB的一个视图(GmmatDescriptor)。
数据布局详解
-
tCsA(SMEM, 块局部):- 这是一个
mma_thread_slice。 - MMA Shape (32 bits, M, K):
((_4, _2, _2), _1, (_2, _2), _1, _19) - Stride:
((_1, _1024, _16), _0, (32, 64), _0, _16384))
- 这是一个
-
tCrA_mma(RF, 线程局部):- Shape:
((_4, _2, _2), _1, (_2, _2)) - Stride:
((_1, _4, _8), _0, (_16, _32))
- Shape:
-
tCrA_load(加载形状):- Shape:
((_4, _2, _2), _1, (_2, _2)) - Stride:
((_1, _4, _8), _0, (_16, _32))
- Shape:
下图展示了数据如何从32位格式(包含4个FP8值)加载到Warp0中各个线程的寄存器中。
下图更直观地展示了数据从SMEM中的A Tile到线程寄存器(RF)的映射过程。编号的箭头表示数据加载和处理的逻辑步骤。
累加器 (Accumulator)
GEMM操作的核心是cute::gemm,它将A和B的分片相乘并累加到累加器(accum)中。
cute::gemm(tiled_mma, tCrA_mma, tCrB_(_:_,_:k_block), accum);
tCrA_mma: 位于RF中,是线程本地的。accum: 同样位于RF中,是线程本地的,精度为FP32。- MMA Shape (64 bits, M, N):
((_2, _2, _2), _1, _1) - Shape:
((_2, _2, _2), _1) - Stride:
((_1, _2, _4), _0)
- MMA Shape (64 bits, M, N):
下图展示了Warp0的累加器寄存器布局,其中每个64位寄存器存储2个FP32值。
MoE 模型中的应用
- 实现方案:
- 朴素方案: 由于CUDA核心不直接支持FP8操作,一种方法是将INT4的矩阵B先转换为FP16,再转换为FP8,然后与FP8的矩阵A进行MMA。这种方法涉及多次转换和缩放。
- 优化方案: 更优的方法是先将INT4的矩阵B转换为FP8,与FP8的矩阵A进行MMA得到FP32中间结果。然后,将B的缩放因子应用到这个FP32累加结果上(
scaling_B & add)。这种方法显著减少了缩放操作的次数。
- 权重缩放:
Weight Scaling Group Size设置为128,这等于4 x GMMA_K。- 这意味着沿K维度的4次GMMA操作的结果会累加到同一个累加器上。
- 如下图所示,不同的权重缩放因子(
scale 0,scale 1)可以应用于累加器中的不同区域。
优化的流水线
基于上述优化方案,流水线被重新设计。
- 新流程:
- 矩阵B (INT4) 转换为 FP8。
- 与矩阵A (FP8) 进行MMA,得到FP32中间累加值。
- 新增
Scaling阶段: 在流水线中显式加入一个阶段,对FP32中间结果应用B的缩放因子并累加。 - 对累加结果应用A的缩放因子并转换为FP16输出。
这个新的流水线通过在计算过程中间引入缩放步骤,提高了效率。warpgroup_wait的依赖关系也相应调整为3 -> 0。