PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation
PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation
作者/机构: Jason Ansel, Edward Yang, Horace He, Natalia Gimelshein, Animesh Jain, Michael Voznesensky, Bin Bao, Peter Bell, David Berard, Evgeni Burovski, Geeta Chauhan, Anjali Chourdia, Will Constable, Alban Desmaison, Zachary DeVito, Elias Ellison, Will Feng, Jiong Gong, Michael Gschwind, Brian Hirsh, Sherlock Huang, Kshiteej Kalambarkar, Laurent Kirsch, Michael Lazos, Mario Lezcano, Yanbo Liang, Jason Liang, Yinghai Lu, CK Luk, Bert Maher, Yunjie Pan, Christian Puhrsch, Matthias Reso, Mark Saroufim, Marcos Yukio Siraichi, Helen Suk, Michael Suo, Phil Tillet, Eikan Wang, Xiaodong Wang, William Wen, Shunting Zhang, Xu Zhao, Keren Zhou, Richard Zou, Ajit Mathews, Gregory Chanan, Peng Wu, and Soumith Chintala. (Meta, OpenAI, QuantCo, Intel, University of Michigan, George Mason University)
A1 主要贡献
本文介绍了对流行的 PyTorch 机器学习框架的两个扩展:TorchDynamo 和 TorchInductor,它们共同实现了 PyTorch 2 中发布的 torch.compile 功能。
-
核心问题: 现代机器学习框架分为 Eager 模式(如 PyTorch【32, Pytorch: an imperative style, high-performance deep learning library, 2019, Proceedings of the 33rd International Conference on Neural Information Processing Systems】)和 Graph 模式(如 TensorFlow【1, TensorFlow: A system for large-scale machine learning, 2016, 12th USENIX Symposium on Operating Systems Design and Implementation (OSDI 16)】)。Eager 模式因其易于理解和调试而备受用户青睐,但它使得通过编译器应用跨算子边界的图级别优化变得困难。以往在 PyTorch 中捕获计算图的尝试(如 record/replay【17, TorchScript, 2018】, 【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】, Python 解析【17, TorchScript, 2018】和延迟执行【39, LazyTensor: combining eager execution with domain-specific compilers, 2021】)都牺牲了 PyTorch 的易用性,存在不健全、功能不完整或高开销等问题。
-
研究目标与创新点:
- TorchDynamo: 提出一个 Python 级别的即时(JIT)编译器,旨在在不牺牲 Python 完整灵活性的前提下,实现 PyTorch 程序的图编译。其核心创新在于:
- 动态字节码修改: 通过挂钩 CPython 的帧求值 API【9, PEP 523: adding a frame evaluation API to CPython, 2016】,在 Python 字节码执行前动态地修改它。
- 稳健的图提取: 从字节码中提取 PyTorch 操作序列到一个 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】中,然后使用可扩展的后端进行 JIT 编译。
- 混合执行: 设计用于生成较小的图片段,这些片段可以与原生 Python 执行混合,从而在易用性和性能之间取得平衡,能够优雅地处理图中断(graph break)。
- TorchInductor: 提出一个新的编译器后端,作为 TorchDynamo 的默认后端。其核心创新在于:
- 多目标代码生成: 将 PyTorch 程序转换为 OpenAI 的 Triton【46, Triton: an intermediate language and compiler for tiled neural network computations, 2019, (MAPL 2019)】(用于 GPU)和 C++/OpenMP【15, OpenMP: an industry standard API for shared-memory programming, 1998, Computational Science & Engineering, IEEE】(用于 CPU)。
- Define-by-Run IR: 引入一种新的“边运行边定义”(define-by-run)的循环级中间表示(IR),使得添加新的算子底层实现(lowering)变得容易。
- Python 实现: 完全用 Python 实现,便于 PyTorch 用户理解、扩展和修改。
- TorchDynamo: 提出一个 Python 级别的即时(JIT)编译器,旨在在不牺牲 Python 完整灵活性的前提下,实现 PyTorch 程序的图编译。其核心创新在于:
-
主要成果:
- 实验表明,TorchDynamo 相比以往方法能更稳健地捕获计算图,且增加的开销极小。
- 在 NVIDIA A100 GPU 上,TorchInductor 在超过 180 个真实世界模型上实现了 2.27 倍的推理和 1.41 倍的训练几何平均加速,性能优于其他六个编译器。
- 这些扩展为像 PyTorch 这样的 Eager 模式框架提供了一种通过编译器应用优化的新方法。
A3 先前在 PyTorch 中捕获图的尝试
PyTorch 图捕获的挑战。与图模式框架【1, TensorFlow: A system for large-scale machine learning, 2016, 12th USENIX Symposium on Operating Systems Design and Implementation (OSDI 16)】,【25, Caffe: Convolutional Architecture for Fast Feature Embedding, 2014, CoRR】,【5, Theano: A Python framework for fast computation of mathematical expressions, 2016, CoRR】,【37, CNTK: microsoft’s open-source deep-learning toolkit, 2016, Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’16)】不同,PyTorch 允许用户在模型中嵌入任意代码,包括非 PyTorch 库。这导致了 PyTorch 张量与 Python 类型之间的频繁转换、对外部库(如 numpy、logging)的使用,以及对 Python 结构(类、闭包、异常、控制流等)的使用,这些都很难映射到固定的图抽象中。因此,以往的图捕获尝试都需要对用户体验施加限制。
2.1 torch.jit.trace
torch.jit.trace 的原理与缺陷。torch.jit.trace 使用“记录/回放”机制,通过示例输入来生成一个 TorchScript【17, TorchScript, 2018】图。记录发生在 PyTorch 的 C++ 调度器层面,该层面用于将算子分派到特定设备的核函数并进行自动求导。由于记录在 C++ 中进行,torch.jit.trace 无法捕获 Python 中的任何控制流。例如,对于一个依赖于输入数据的 if 条件分支,torch.jit.trace 只会记录示例输入所走的路径,导致当输入改变时程序行为不正确。此外,任何非 PyTorch 的操作(如外部库调用、打印、日志记录、副作用等)都会被忽略。
2.2 torch.jit.script
torch.jit.script 的原理与挑战。torch.jit.script 同样构建一个 TorchScript【17, TorchScript, 2018】图,但它是通过解析 Python 的抽象语法树(AST)并进行静态分析来实现的。它能够正确捕获控制流,并且是一种健全的方法,不会产生不正确的结果。然而,torch.jit.script 面临的主要挑战是它试图将整个 Python 重实现为一个静态语言。这种方法是“全有或全无”的:一旦遇到一个未实现的 Python 组件,整个程序就无法被捕获。在实践中,torch.jit.script 只支持 Python 的一个子集。实验结果显示,它在 TorchBench 基准测试套件中的真实模型上只有大约一半的成功率,将大型模型“torchscript化”可能需要数周或数月,用户体验非常差。
2.3 延迟张量(Lazy Tensors)
Lazy Tensors 的工作机制与缺点。Lazy Tensors 技术最初在 PyTorch/XLA【42, PyTorch/XLA, 2023】,【39, LazyTensor: combining eager execution with domain-specific compilers, 2021】项目中引入,主要用于支持 Google TPU【26, In-datacenter performance analysis of a tensor processing unit, 2017, Proceedings of the 44th Annual International Symposium on Computer Architecture (ISCA ’17)】。它是一种 C++ 级别的图捕获技术,在每次迭代中延迟算子的执行以累积一个计算图,然后将该图发送给 XLA【45, XLA - Tensorflow, compiled, 2017】编译器。虽然这种方法有效且健全,但存在几个主要缺点:
- 更高的开销: 除了运行与 Eager 模式相同的 Python 代码和调度器栈外,还必须维护额外的图数据结构,增加了运行时成本。
- 引入延迟: Eager 模式在第一个算子处就发出第一个核函数,而 Lazy Tensors 直到模型代码执行完毕才发出,这增加了延迟并可能导致主机执行与 GPU/加速器利用率串行化。
- 重编译: 每当捕获到的图的哈希值发生变化时,Lazy Tensors 就必须重新编译,可能导致频繁的重编译。
PyTorch/XLA 项目【10, PyTorch/XLA 2022 Q4 dev update, 2022】已经与 TorchDynamo 集成,通过混合使用两种技术来隐藏 Lazy Tensors 的开销。
2.4 torch.fx.symbolic_trace
torch.fx.symbolic_trace 的特点与局限性。torch.fx.symbolic_trace【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】是这些系统中最年轻的,它引入了 TorchDynamo 也使用的 FX 图格式。它采用类似 torch.jit.trace 的“记录/回放”方法,但在 Python 级别进行追踪。通过使用 Proxy Python 对象来记录用户代码的行为,并利用 PyTorch 中的 __torch_function__【3, Improving subclassing Tensor by propagating subclass instances, 2020】扩展点。由于在更高的 Python 级别记录,symbolic_trace 能够捕获许多 torch.jit.trace 无法捕获的操作,并能检测到更多不正确的捕获情况。然而,它仍然是不健全的,可能产生错误结果。例如,它可能会将随机数生成的结果固化为常量,或者完全忽略与 Proxy 对象无关的副作用(如修改全局变量)。几乎所有机器学习的图格式都不支持 Python 全局变量等概念。
2.5 torch.onnx.export
ONNX 导出与图捕获的关系。ONNX【31, ONNX, 2023】导出并非一种图捕获机制,但常被误解。其内部使用 torch.jit.trace 和 torch.jit.script,因此面临着与这些系统相同的限制。此外,从 TorchScript 转换为 ONNX 格式可能会因为 ONNX 不支持所有 PyTorch 算子而失败。ONNX 团队正在开发与 TorchDynamo 的直接集成,以提高其支持的模型数量。
2.6 与 JAX 中图捕获的比较
JAX 与 PyTorch 在图捕获上的根本差异。JAX【8, JAX: composable transformations of Python+NumPy programs version 0.3.13, 2018】基本上没有面临 TorchDynamo 正在解决的这些挑战。JAX 的初始设计就与 XLA【45, XLA - Tensorflow, compiled, 2017】紧密耦合,这使得 JAX 程序必须遵循 XLA 带来的约束,例如 jax.jit 不支持数据依赖的 Python 控制流,并要求用户代码是函数式纯净的。相比之下,PyTorch 最初是一个纯 Eager 框架,没有内置任何面向编译器的约束,因此在其之上发展起来的大量模型都没有考虑过捕获和编译的难度。在实现层面,JAX 的捕获机制类似于 torch.fx.symbolic_trace【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】,但更简单,因为 JAX 程序是纯函数式的,无需担心状态问题。
A2 方法细节
3 TorchDynamo 设计与实现
TorchDynamo 采用了一种与以往 PyTorch 图捕获系统根本不同的方法。它不试图移除或替换 Python,而是通过即时(JIT)编译 Python 字节码来与 CPython 协同工作。TorchDynamo 是一个 Python 字节码到 Python 字节码的转换器,它从原始字节码中提取 PyTorch 操作,并用调用已编译的、融合了多个 PyTorch 操作的构件来替换它们。图 1 概述了 TorchDynamo 的工作方式。
3.1 使用 API
torch.compile API 介绍。本文引入的主要 API 是 torch.compile。它可以作为函数装饰器使用,也可以直接调用在一个 PyTorch 模块上。它包含以下关键字选项:
- backend: 允许用户提供一个自定义的编译函数,该函数接收一个 torch.fx.Graph 和一个示例输入列表,并返回一个 Python 可调用对象。默认是 TorchInductor,但也可以设置为内置的多个后端或用户定义的后端。
- options: 一个可选的字典,包含后端特定的配置标志。
- mode: 预定义选项集的简写字符串:“default”、“reduce-overhead” 或 “max-autotune”。
当使用 torch.compile 运行一个模块时,该模块会以图 1 所示的被修改的 CPython 行为来执行。具体来说,一个自定义的 CPython 帧求值钩子会重写每个正在执行的 Python 函数的字节码,以提取和编译 PyTorch 操作序列。
3.2 CPython 帧求值钩子
利用 CPython 帧求值 API 进行 JIT 编译。PEP 523【9, PEP 523: adding a frame evaluation API to CPython, 2016】在 CPython 解释器中引入了帧求值 API,这是 TorchDynamo 使用的主要扩展点。它允许覆盖 CPython 中用于解释单个函数调用的核心函数。TorchDynamo 修改 eval_frame 函数指针,用一个执行 Python 帧 JIT 编译的循环替换标准的 CPython 解释器循环。
自定义帧求值函数的执行流程。TorchDynamo 安装的自定义 eval_frame 函数执行以下操作:
- 检查是否跳过:首先,检查是否应跳过当前帧,原因可能包括文件名排除(如 Python 标准库和 numpy)、先前分析失败标记或超出缓存大小限制。如果跳过,则调用原始的 _PyEval_EvalFrameDefault 来执行原始字节码。
- 缓存检查与守卫执行:其次,检查该帧是否已被编译并缓存。如果是,则执行缓存中每个条目对应的守卫函数(guard function)。如果守卫函数返回 True,则运行匹配的已缓存编译字节码并返回。
- 符号分析:接着,对函数字节码进行逐指令的符号分析,以提取一个 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】、守卫(guards)和副作用。如果遇到不支持的操作,分析可能会在函数中途停止。
- 图编译:然后,使用用户通过 backend= 参数指定的编译器函数来编译提取出的 FX 图。
- 守卫函数生成:生成并编译一个单一的 Python 函数,用于检查所有的守卫条件。如果守卫通过,它返回 True,表示可以重用现有的编译构件。
- 续体函数生成:如果分析没有到达函数末尾,则生成 resume_at_XX 续体函数(continuation functions),用于在新的帧中运行函数的剩余部分。
- 新字节码生成:生成新的 Python 字节码。这个新字节码会:1) 调用已编译的 FX 图;2) 存储和重建局部/栈状态;3) 执行原始函数应有的副作用;4) 返回或通过回退到原始字节码并调用生成的续体函数来实现图中断(graph break)。
- 缓存与执行:最后,将生成的 Python 字节码和守卫函数安装到缓存中,使用 _PyEval_EvalFrameDefault 运行生成的字节码,然后返回。
3.3 守卫(Guards)
守卫机制的作用。守卫是 TorchDynamo 用来重新检查 JIT 编译时所依赖的动态属性的机制,以确定一个缓存的编译结果是否可以被重用。TorchDynamo 为每个转换后的 PyCodeObject 生成一个守卫函数,如果重用已编译的构件是安全的,该函数返回 True。守卫和转换后的代码都使用 PEP 523【9, PEP 523: adding a frame evaluation API to CPython, 2016】中引入的 _PyCode_SetExtra 扩展点进行存储。守卫在分析过程中累积,可以指向来自全局/局部变量或嵌套在 Python 数据结构中的变量。目前有 30 种不同类型的守卫,包括检查 torch.Tensor 的多种属性、Python 类型、常量特化、属性、字典/列表/元组、nn.Module 实例以及全局 PyTorch 状态。守卫系统贯穿 TorchDynamo、AOTAutograd 和 TorchInductor,任何层都可以引入守卫来保护特化。
3.4 符号求值
符号求值器的核心功能。TorchDynamo 的一个基本组成部分是符号 Python 字节码求值器,它负责分析 Python 字节码并建模每条指令的效果。符号求值器维护着跟踪以下内容的数据结构:1) 栈状态;2) 局部变量;3) 异常上下文;4) 累积的 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】;5) 累积的守卫;以及 6) 副作用。该算法逐个处理 Python 字节码,并为每种 Python 字节码指令类型都包含一个相应的处理函数。
符号求值的过程。在符号求值开始时,函数参数被检查并转换为符号表示 VariableTracker。如果字节码访问类属性或全局变量等数据结构,会懒加载地为这些构造添加新的符号表示。求值器从函数的第一个字节码指令开始,逐个处理。例如,LOAD_FAST 指令的处理函数会将局部变量的符号表示推送到符号栈上;BINARY_ADD 的处理函数会从栈上弹出两个符号变量,计算它们的结果,然后将结果推回栈上。如果操作涉及 PyTorch 张量,则会在 FX 图中添加一个新节点,并创建一个指向该结果节点的新符号张量。
3.5 建模 Python 数据结构
使用 VariableTracker 层次结构对 Python 数据进行建模。为了分析 Python 的行为,TorchDynamo 拥有一个类层次结构 VariableTracker,用于建模不同数据类型的常见行为。重要的 VariableTracker 子类包括:
- TensorVariable: 代表一个 torch.Tensor,它不存储实际的张量值,而是存储一个指向部分构建的 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】的 fx.Proxy,以及一个代表张量元数据(无实际数据)的“伪”张量(fake tensor)。
- ConstDictVariable 和 DataClassVariable: 用于表示键为常量字符串的键值对。
- ListVariable 和 TupleVariable: 代表列表/元组,可以包含任何其他类型的符号变量。
- UserFunctionVariable 和 UserMethodVariable: 代表可以被内联的用户定义函数,并支持动态构建的包含闭包的函数。
- UserDefinedClassVariable 和 UserDefinedObjectVariable: 代表用户定义的类和实例。当访问它们的属性时,会进行懒加载特化,并跟踪它们的突变。
每个 VariableTracker 实例除了包含类型特定的数据外,还包含一组守卫,并在操作中通过并集传播。
3.6 内联、控制流和闭包
处理函数调用、控制流和闭包的策略。为了收集更大的计算图,TorchDynamo 会尝试内联函数调用并扁平化程序。当遇到函数调用时,它会创建一个当前符号状态的检查点,然后递归地尝试符号求值被调用的函数。如果递归分析遇到会导致图中断(graph break)的情况,TorchDynamo 会回滚到函数调用前的状态,并在该函数调用处生成一个图中断。
控制流处理。大多数 Python 字节码中的控制流通过特化(specialization)被优化掉。例如,遍历 torch.nn.Module 列表时,TorchDynamo 会添加守卫确保列表不变,并展开循环。对于基于张量类型、大小和形状的控制流,TorchDynamo 会在这些属性上添加守卫并移除控制流。对于无法移除的控制流(如基于张量值的-分支),则会生成一个图中断。
闭包处理。TorchDynamo 必须处理多种闭包情况:
- 在捕获区域外创建的单元变量(cell variables): 通过生成 LOAD_DEREF 和 STORE_DEREF 字节码来访问。内联时,则直接读写内联函数的单元内容。
- 在捕获区域内创建并销毁的单元变量: 这是最常见且最容易处理的情况,TorchDynamo 会静态地优化掉闭包。
- 在捕获区域内创建但逃逸出该帧的单元变量: 这是最难处理的情况。TorchDynamo 会优化掉捕获区域内所有对闭包的使用,然后在生成的字节码末尾创建所需的单元和 Python 函数对象以供返回。
3.7 突变和副作用
延迟处理副作用。Python 函数有时会产生副作用。TorchDynamo 通过将副作用推迟到 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】被调用之后来处理它们,并在输出的字节码中生成代码在最后应用所有副作用。TorchDynamo 的副作用数据结构会跟踪所有原始代码会产生的副作用。如果代码尝试读取一个会被待处理的副作用所修改的值,它会直接读取那个待处理的值。图生成后,一个垃圾回收过程会移除未逃逸出分析上下文的副作用。这种处理方式使得对同一值的多次写入被合并为单次写入。支持的副作用类型包括:
- 全局变量写入: 生成 STORE_GLOBAL 字节码。
- 属性写入: 映射到 STORE_ATTR 字节码。
- 单元/闭包写入: 以多种方式进行跟踪和处理。
- 类构造: 通过创建占位符符号对象、内联 __init__ 方法并跟踪属性突变来处理。
- 字典和列表突变: 如果字典/列表是输入或从全局/属性加载的,则会产生副作用。VariableTracker 会在这些对象的初始符号状态上添加守卫,然后符号化地跟踪所有变化。
3.8 图中断和续体函数
图中断机制。当 TorchDynamo 遇到它无法处理的 Python 字节码时(例如调用外部库),它会生成一个所谓的“图中断”(graph break),将被分析的字节码分割成多个部分。任何待处理的部分 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】都会被编译。在输出代码中,当部分图被调用后,不支持的字节码将被执行,然后 TorchDynamo 会递归地分析函数的剩余部分。
续体函数的实现。为了触发这种递归分析,TorchDynamo 会生成一个或多个续体函数(continuation functions),其形式如下:
def resume_at_X(... livevars ...):
... restore try/except/stack state ...
JUMP_ABSOLUTE X
... original function bytecode ...
这个续体函数与原始函数非常相似,但参数被更改为反映图中断处的所有活动变量,并添加了恢复栈/异常状态的前缀,以及一个 JUMP_ABSOLUTE 指令以从函数中间恢复执行。将续体构造为 Python 函数的优势在于,它会通过帧求值 API 递归地触发 TorchDynamo。
3.9 AOTAutograd
AOTAutograd 的作用。AOTAutograd 是 PyTorch 中的一个可重用组件,被许多 PyTorch 编译器后端调用以添加训练支持和使用共享的算子分解。TorchDynamo 捕获模型的前向传播,而为了支持训练,还需要生成反向传播过程。在 PyTorch Eager 模式中,反向图是使用基于磁带(tape-based)的 autograd【32, Pytorch: an imperative style, high-performance deep learning library, 2019, Proceedings of the 33rd International Conference on Neural Information Processing Systems】动态生成的。AOTAutograd 将前向图转换为一个同时包含前向和反向传播的图,并且支持部分程序图。
AOTAutograd 的工作原理。AOTAutograd 的工作方式是在伪张量(fake tensor)输入上运行 PyTorch Eager 模式的 autograd 引擎,并记录一个联合的前向和反向图。由于数据依赖的操作无法与伪张量一起工作,TorchDynamo 会在这些操作上产生图中断。然后,AOTAutograd 使用最小割(min-cut)算法【55, Transcending runtime-memory tradeoffs in checkpointing by being fusion aware, 2023, Proceedings of Machine Learning and Systems】将这个联合图分割成独立的前向和反向图,以优化内存使用。作为这个过程的一部分,它还会应用后端特定的优化来重新计算(rematerialize)反向图中一些计算成本较低的激活值。此外,AOTAutograd 还应用了其他调度器级别的转换,如算子分解(将一些 PyTorch 算子映射为更基础的算子集合)和函数化(移除执行突变的操作,并用其函数式等价物替换)。
4 TorchInductor 设计与实现
TorchInductor 是与 TorchDynamo 配对的后端编译器,它接收捕获的 FX 图【34, Torch.fx: practical program capture and transformation for deep learning in python, 2022】并生成快速代码。它被设计为一个通用的参考编译器后端。
4.1 设计原则与关键技术
TorchInductor 的设计理念。TorchInductor 的设计遵循以下原则和技术:
- PyTorch 原生 (PyTorch Native): TorchInductor 共享与 PyTorch Eager 相似的抽象,如支持可操作的 strides、别名视图(aliasing views)以及原地数据和元数据突变,以便能够支持 PyTorch 的所有功能。
- Python 优先 (Python First): TorchInductor 在 Python 中实现,使其易于 PyTorch 用户理解、修改和贡献。
- 广度优先 (Breadth First): 早期就注重支持广泛的算子、硬件和优化,使其成为一个可扩展到多种场景的通用编译器,并优先关注更具挑战性的训练场景。
- 重用最先进的语言 (Reuse State-Of-The-Art Languages): TorchInductor 生成 OpenAI Triton【46, Triton: an intermediate language and compiler for tiled neural network computations, 2019, (MAPL 2019)】(用于 GPU)和 C++/OpenMP【15, OpenMP: an industry standard API for shared-memory programming, 1998, Computational Science & Engineering, IEEE】(用于 CPU)作为输出代码,以利用这些项目的技术并生成用户可理解的代码。
4.2 分解(Decompositions)
算子分解机制。TorchInductor 不为 PyTorch 中的所有算子实现到底层 IR 的转换,而是将许多算子分解为一组更简单、更易于处理的算子。这些分解通过 AOTAutograd(见 3.9 节)进行,TorchInductor 向其传递一个包含所需分解的字典。分解被编写为使用其他 PyTorch 算子实现的 Python 函数。例如,log2 可以被分解为 log 和 mul。这个分解过程会递归地进行,直到达到一个不动点。目前,TorchInductor 使用了 191 种分解(包括重载共 387 种),其中大部分是通用的,可通过 torch._decomp 模块供其他后端使用。
4.3 底层转换和“边运行边定义”的循环级 IR
从 FX 图到底层 IR 的转换。编译的下一阶段是将 PyTorch 操作的 FX 图转换为 TorchInductor 的“边运行边定义”(define-by-run)IR。这意味着 IR 使用可执行的 Python 代码来定义循环体,从而赋予了 IR 几乎完整的 Python 功能,并允许简洁地编写底层转换(lowering)。
IR 示例与结构。图 2 展示了一个 IR 示例。inner_fn_buf0 是一个 Python 函数,它定义了如何使用 TorchInductor 的原始算子(在 ops.* 命名空间中)来计算张量 buf0 的单个元素。该函数接收代表元素符号坐标的 SymPy【28, SymPy: symbolic computing in Python, 2017, PeerJ Computer Science】符号(i0 和 i1)。TensorBox 和 StorageBox 是与 PyTorch 的 torch.Tensor 和 torch.Storage 对象相匹配的抽象,用于处理视图、别名和突变。ComputedBuffer 表示一个将通过生成代码计算的张量,而 Pointwise 表示这是一个数据并行的点操作。
def inner_fn_buf0(index):
i0, i1 = index
tmp0 = ops.load("arg0_1", i0 * s1 + i1)
tmp1 = ops.log(tmp0)
tmp2 = ops.constant(1.4426950408889634, torch.float32)
tmp3 = ops.mul(tmp1, tmp2)
return tmp3
buf0_ir = TensorBox(StorageBox(ComputedBuffer(
name='buf0',
layout=FixedLayout('cuda', torch.float32,
size=[s0, s1], stride=[s1, 1]),
data=Pointwise(inner_fn=inner_fn_buf0, ranges=[s0, s1], ...))))
IR 的优势与分析。这个 IR 的关键优势在于它易于构建,因为它拥有 Python 的全部能力。可以通过虚拟化的 ops.* 命名空间来编译和分析这个 IR。为了进行分析,可以将 ops 指向一个分析过程(pass),该过程可以记录内存访问等信息。为了生成代码,可以将 ops 指向一个写出 Triton 或 C++ 代码的东西。目前,TorchInductor 的循环级 IR 包含 54 个原始算子,包括 ops.load、ops.store、ops.reduction、ops.masked 等。
4.4 调度(Scheduling)
调度阶段的功能。TorchInductor 的调度阶段决定了哪些算子被融合、核函数的运行顺序,并进行内存规划以移除或重用缓冲区。调度首先将 IR 中的每个缓冲区转换为 BaseSchedulerNode 的子类,如代表标准核函数的 SchedulerNode,代表外部库调用的 ExternKernelSchedulerNode 等。
融合算法。调度器将每个核函数的内存读写集转换为节点间的依赖边。融合由两个关键函数控制:
- Scheduler.can_fuse(node1, node2): 判断两个节点是否可以融合。它检查依赖关系和许多其他属性以确保融合的正确性。
- Scheduler.score_fusion(node1, node2): 用于对不同的融合可能性进行排序。TorchInductor 选择得分最高的融合。
TorchInductor 在一个循环中执行一个贪心算法:1) 找到所有融合机会;2) 对它们进行评分和排序;3) 依次应用仍然合法的融合。
4.5 Triton 代码生成
将 IR 映射到 Triton。Triton 代码生成负责将 TorchInductor 的 IR 映射到输出的 Triton【46, Triton: an intermediate language and compiler for tiled neural network computations, 2019, (MAPL 2019)】核函数。图 3 展示了为上述 log2 示例生成的代码。该核函数一次处理一个 XBLOCK 元素的块。代码生成还负责公共子表达式消除(CSE)。@pointwise 装饰器封装了用于促进块大小启发式、自动调优和预编译的样板代码。对于归约(reduction)核函数,TorchInductor 有两种生成模式:对于小型归约,生成一个持久化归约(persistent reduction);对于大型归约,则生成一个循环。对于更复杂的操作(如矩阵乘法和卷积),TorchInductor 使用一个模板系统来生成混合手写和自动生成的 Triton 代码。
4.6 C++ 代码生成
CPU 后端的代码生成。对于 CPU 后端,TorchInductor 生成带有 OpenMP【15, OpenMP: an industry standard API for shared-memory programming, 1998, Computational Science & Engineering, IEEE】的 C++ 代码。它有两种变体:一种是向量化变体,它执行分块(tiling)并将大多数操作映射到 PyTorch 源代码中的 at::vec::Vectorized 类;另一种是非向量化变体,它生成标准的 C++ 代码。两种变体都使用 #pragma omp for 进行并行化。
4.7 包装器代码生成
生成调用核函数的代码。包装器代码生成负责生成调用来自 Triton、C++ 和外部源的核函数的代码。它还进行张量大小计算和内存分配/释放。有两种实现方式,一种生成 Python 代码(更灵活),另一种生成 C++ 代码(开销更低)。当启用 mode="reduce-overhead" 时,TorchInductor 使用 CUDA Graphs【20, Getting started with CUDA graphs, 2019】来完全消除包装器代码的开销。
4.8 相关的深度学习编译器
与其他编译器的关系。本文简要回顾了深度学习编译器领域的工作。选择 Triton【46, Triton: an intermediate language and compiler for tiled neural network computations, 2019, (MAPL 2019)】作为输出目标,是因为它已被证明能生成比手写库更快的核函数。许多编译器(如 TVM【12, TVM: an automated End-to-End optimizing compiler for deep learning, 2018, 13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)】, nvFuser【36, Introducing nvFuser, a deep learning compiler for PyTorch, 2022】, NNC【60, NNC walkthrough: how PyTorch ops get fused, 2021】)的设计灵感来自 Halide【33, Halide: a language and compiler for optimizing parallelism, locality, and recomputation in image processing pipelines, 2013, Proceedings of the 34th ACM SIGPLAN Conference on Programming Language Design and Implementation (PLDI ’13)】。XLA【45, XLA - Tensorflow, compiled, 2017】是 TensorFlow 和 JAX 背后的编译器。许多新兴编译器出现在 MLIR【27, MLIR: scaling compiler infrastructure for domain specific computation, 2021, 2021 IEEE/ACM International Symposium on Code Generation and Optimization (CGO)】生态系统中。
5 动态形状(Dynamic Shapes)
支持动态形状的必要性。深度学习编译器通常只适用于静态形状,即为单一输入形状配置生成编译程序,如果输入形状改变则需重新编译。然而,在某些情况下这是不够的,例如:
- 批量大小或序列长度可能变化。
- 某些模型(如检测模型)会产生数据依赖的输出形状。
- 处理稀疏表示(如稀疏张量、图神经网络)时,数据量依赖于问题的稀疏结构。
本文支持动态形状,但不支持动态秩(rank)的程序,因为后者在现实世界中很少出现。
5.1 符号化形状守卫
动态形状下的追踪策略。在处理动态形状时,本文延续了重用现有代码的理念:系统总是选择条件分支中的一个,并在假设该追踪路径将被重用的前提下进行特化。为此,系统为每个符号化尺寸(symbolic size)维护一个尺寸提示(size hint),即首次触发 JIT 编译时的具体值。当对张量形状执行条件判断时,系统会查阅这个提示来决定走哪个分支,并添加一个守卫。
元函数(Meta functions)的作用。这种方法简化了生成的符号化形状公式,但需要一个更复杂的守卫管理系统。为了确定一个中间张量的尺寸(例如,torch.cat([x, y]) 的尺寸),系统需要将其尺寸表示为输入的函数(即 x.size(0) + y.size(0))。这是通过为 PyTorch 中的所有算子编写元函数来实现的。元函数可以在不实际执行计算的情况下,将尺寸信息传播到张量的输出。
5.2 优化动态形状的推理
降低符号形状推理开销的策略。为了减少符号形状推理的性能影响,本文采用了多种策略:
- 自动推断动态性: 默认 API 假定所有输入都可能是动态的,而模型权重是静态的,然后通过分析模型来推断真正的动态性。
- 0/1 特化: 如果一个输入尺寸是 0 或 1,系统会将其视为常量并添加相应的守卫,而不是分配一个符号变量。这对于捕捉广播语义和性能优化很重要。
- 增量简化: 随着从守卫中学习到更多事实,系统会增量地简化符号表达式。
- 减少不必要的守卫: 通过修改 PyTorch 避免不必要地重新计算张量属性(如是否连续),将预计算属性变为惰性计算,以及提供 constrain_range API 让用户指定尺寸的界限,来减少生成的守卫数量。
5.3 无提示(无后备)的符号整数
处理数据依赖的尺寸。当一个尺寸变量来自数据依赖的操作(如 .nonzero() 或 .item())时,其具体值是未知的,这被称为“无后备”(unbacked)的符号整数。对这些符号整数执行控制流是非法的,因此必须在这些操作上产生图中断。为了减少因此产生的过多图中断,本文做出了一些增强,例如在张量创建时预计算属性、使某些属性的计算变得懒惰,并提供 API 允许用户约束整数张量中数据的范围。
A4 实验环境
- 基准测试套件:
- TorchBench【14, TorchBench: a collection of open source benchmarks for PyTorch performance and usability evaluation, 2020】:包含从开源仓库中挑选的多种模型,这些项目根据 Papers with Code【35, Papers with Code 2021: a year in review, 2021】的引用排名选出。
- HuggingFace【53, Transformers: State-of-the-Art Natural Language Processing, 2020, Association for Computational Linguistics】:一个流行的 Transformer【49, Attention is all you need, 2017, Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS’17)】模型库。
- TIMM【52, PyTorch image models, 2019】:一个流行的 PyTorch 视觉模型库。
- 硬件配置:
- GPU: NVIDIA A100
- CPU: Intel Xeon 8275CL
- 软件配置:
- CUDA 11.6
- PyTorch nightly build (2023年8月30日)
- 其他编译器版本: nvFuser 2.0; NNC 2.0; Hidet 0.2.2; TVM 0.11.1; ONNX Runtime (ONNXRT) 1.14.1; PyTorch/XLA 2.1
- 实验设置:
- 实验重复 100 次,预热 3 次。
- 每个模型超时时间为 30 分钟。
- TorchInductor 运行在
max-autotune、freezing和cudagraphs启用的模式下。 - 训练实验测量单步前向和反向传播的时间,不包括优化器。
A4 实验结果
6.1 TorchDynamo 捕获图的能力
- 实验内容: 比较 TorchDynamo 和 TorchScript【17, TorchScript, 2018】在三个基准测试套件上捕获模型的能力。
- 实验结果 (表 1):
- 在 TorchBench 上,TorchDynamo 的成功率(93%)是 TorchScript(45%)的两倍多。
- 在 HuggingFace 上,TorchScript 失败于所有模型,而 TorchDynamo 全部成功。
- 在 TIMM 上,两者成功率都很高(98% vs 100%),因为 TIMM 维护者主动适配了 TorchScript。
- 分析结论 (表 1):
- TorchDynamo 相比以往的“全有或全无”系统,能够捕获部分程序和多个图。
- 在大多数情况下,TorchDynamo 能够捕获单个完整的程序图。即使有图中断,平均每个图也包含数百个算子。
- 最常见的图中断原因是:使用非 PyTorch 库(如 numpy【21, Array programming with NumPy, 2020, Nature】)、转换为 Python 类型(如
tolist())以及数据依赖的控制流。
6.2 图捕获的开销
- 实验内容: 测量 TorchDynamo 和 Lazy Tensors 引入的运行时开销。实验使用与 PyTorch Eager 相同的核函数,因此开销仅来自图捕获。
- 实验结果 (表 2):
- TorchDynamo 的开销低于 5%。
- Lazy Tensors 增加了大量开销(推理 38%,训练 90%)。
- 即使使用跨迭代流水线,Lazy Tensors 的开销仍然很高(推理 31%,训练 86%)。
- 分析结论: Lazy Tensors 的开销在不同模型间不均匀,对于小模型或操作数多的模型,其捕获速度不足以饱和 GPU。此外,代码中常见的将张量值转换为 Python 类型的操作(如
loss.item())会破坏跨迭代流水线,导致 CPU 等待 GPU,从而放大开销。
6.3 TorchInductor 加速效果
- 实验内容: 比较 TorchInductor 和其他六个 TorchDynamo 后端(nvFuser, NNC, PyTorch/XLA, ONNXRT, TVM, Hidet)相对于 PyTorch Eager 的几何平均加速比。
- 实验结果 (表 3 和 图 4):
- 在各种配置(训练/推理、CPU/GPU、float32/float16)和基准测试套件中,TorchInductor 在大多数情况下都比其他后端快。
- nvFuser【36, Introducing nvFuser, a deep learning compiler for PyTorch, 2022】和 NNC【60, NNC walkthrough: how PyTorch ops get fused, 2021】的加速比集中在 1 倍左右,因为它们主要使用 Eager 核函数。
- PyTorch/XLA【42, PyTorch/XLA, 2023】性能波动较大,有时有大加速,有时有大减速。
- ONNX Runtime【16, ONNX runtime, 2021】、TVM【12, TVM: an automated End-to-End optimizing compiler for deep learning, 2018, 13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)】和 Hidet【18, Hidet: task-mapping programming paradigm for deep learning tensor programs, 2023, Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS 2023)】是仅推理的,且由于算子实现缺失等问题,在许多模型上运行失败。
- 分析结论: TorchInductor 提供了最稳定和最高的性能提升。
6.4 TorchInductor 加速来源
- 实验内容: 通过逐一禁用 TorchInductor 的优化来衡量其对 HuggingFace 模型几何平均加速比的影响(消融研究)。
- 实验结果 (表 4):
- 最大的加速来自将点操作(pointwise)、归约(reduction)和散布(scatter)核函数融合成更少数量的核函数,这发生在底层转换的内联(Inlining)和调度阶段的融合(Fusion)中。禁用这两项后,TorchInductor 甚至会导致减速。
- 其他重要的优化包括:Cudagraphs(减少核函数启动开销)、Matmul 模板(使用 Triton 模板替代 cuBLAS/cuDNN)、模式匹配(图级别的窥孔优化)和循环/布局重排。
- 分析结论: 算子融合和内联是 TorchInductor 性能提升的核心。这是因为 TorchInductor 的分解过程会将大的优化算子分解为许多小的原始算子,必须依靠融合将它们重新组合以恢复性能并实现超越。
A5 结论
本文介绍了 PyTorch 的两个扩展:TorchDynamo 和 TorchInductor。它们通过在 PyTorch 程序中实现图编译来提供加速,同时保留了 PyTorch 赖以成名的 Eager 编程模型的灵活性和易用性。通过在 PyTorch 程序中启用图编译,作者希望能让研究人员和从业者以更高的效率和灵活性来解决更大、更复杂的机器学习问题。
A6 附录
A.1 摘要
代码与复现。这项工作的源代码包含在 PyTorch 中,可在 https://github.com/pytorch/pytorch/ 获取。TorchDynamo 位于 torch/_dynamo 目录,TorchInductor 位于 torch/_inductor 目录。用于复现论文结果的基准测试代码位于 https://github.com/pytorch/pytorch/tree/main/benchmarks/dynamo。由于实验数量庞大,这里的说明将重点复 现 TorchInductor 在 HuggingFace 上的 GPU 结果。
A.2 构件清单(元信息)
- 二进制文件: 可在 https://pytorch.org/ 获取
- 硬件: NVIDIA A100 GPU,Intel Xeon 8275CL CPU
- 指标: 相对于 PyTorch Eager 模式的几何平均加速比
- 所需磁盘空间: 约 50 GB
- 准备工作流所需时间: 约 1 小时
- 完成实验所需时间: 大多数实验每个后端、每个配置 < 1 天
- 是否公开: 是
- 代码许可证: BSD-3
A.3 描述
访问方式。源代码、基准测试代码和 PyTorch 二进制文件的访问链接已在 A.1 和 A.3.1 中提供。
硬件依赖。为匹配论文配置,需要 NVIDIA A100 GPU 和 Intel Xeon 8275CL CPU。基准测试可在具有 SM80+ 和 40GB+ 显存的 NVIDIA GPU 上运行。
软件依赖。需要一个近期的 Linux 发行版、NVIDIA 内核驱动、兼容的 CUDA 版本、gcc/g++、Miniconda 以及一些 Python 包(pandas, scipy, psutil, tqdm)。
A.4 安装
安装步骤。首先,使用 conda 创建一个新的 Python 3.10 环境并激活。然后,安装基准测试代码所需的依赖项。接着,使用 conda 安装 PyTorch 的发布版本。最后,克隆 PyTorch 的源代码仓库以获取基准测试脚本。
A.5 实验工作流
复现命令示例。要复现 TorchInductor 在 HuggingFace 上、float16、GPU、推理场景下的加速效果,运行以下命令:
TORCHINDUCTOR_MAX_AUTOTUNE=1 ./benchmarks/dynamo/huggingface.py \
--performance --no-skip \
-dcuda --float16 --inference \
--inductor --freezing \
--output=`pwd`/results.csv
此命令会下载 HuggingFace 模型,并分别在使用和不使用 TorchDynamo 的情况下运行它们,以计算相对于 PyTorch Eager 模式的加速比。结果会写入 results.csv 文件。
A.6 评估和预期结果
结果验证。输出文件 results.csv 应包含 46 个条目,显示每个模型的加速比等指标。所有模型都应该能正常工作,所有加速比的几何平均值应与论文中报告的加速比相似。
A.7 实验定制
自定义实验参数。上述命令可以通过修改参数进行定制:
- 基准套件: 可将 huggingface.py 替换为 timm_models.py 或 torchbench.py。
- 设备: -dcuda 可替换为 -dcpu。
- 精度: --float16 可替换为 --float32 或 --amp。
- 模式: --inference 可替换为 --training。
- 后端: --inductor 可替换为其他后端,如 --backend=nvfuser, --xla 等。
- 优化: 可移除 TORCHINDUCTOR_MAX_AUTOTUNE=1 和 --freezing 来禁用这些优化。
A.8 注意事项
- 版本差异: PyTorch 新版本的性能和模型覆盖率可能优于论文中的结果。
- 环境敏感性: 性能结果对硬件和 CUDA 版本等环境设置敏感,预计会有微小差异。
- 额外步骤: TorchBench 和非 TorchInductor 后端需要额外的安装步骤。
- 性能仪表盘: 一个基于这些脚本的性能仪表盘可在 https://hud.pytorch.org/benchmark/compilers 获取。
💬 评论讨论
欢迎在这里分享您的想法和见解!