TORCH.FX: PRACTICAL PROGRAM CAPTURE AND TRANSFORMATION FOR DEEP LEARNING IN PYTHON

文章标题:TORCH.FX:Python中深度学习的实用程序捕获与转换
作者/机构:James K Reed, Zachary DeVito, Horace He, Ansley Ussery, Jason Ansel (均为 Facebook AI)

A1 主要贡献

本文旨在解决现代深度学习框架中的一个核心矛盾:即时执行(eager execution)模式虽然提升了开发效率和用户体验,但却牺牲了对程序结构的访问能力,而这种能力对于性能优化、可视化、分析和硬件集成等高级转换至关重要。为了在即时模式框架(如PyTorch)中重新获得这种能力,需要一种程序捕获机制。然而,现有的系统(如TorchScript)为了完全忠实地模拟Python的复杂语义(包括可变状态、控制流、复杂数据类型),其捕获技术和生成的中间表示(IR)都变得异常复杂,给转换(transform)的编写带来了巨大困难。

本文提出,可以通过专注于深度学习的典型用例(大多数神经网络模型的高层有向无环图(DAG)结构)而非长尾的复杂情况,来设计一个更简单、更高效的程序捕获与转换框架。基于这一理念,本文介绍了torch.fx,一个完全用Python编写的、为PyTorch设计的程序捕获与转换库,其核心目标是为机器学习从业者提供极高的开发生产力。

本文的主要贡献如下:

  1. 实用性分析:对深度学习程序中重要的程序捕获与转换特性进行了实用性分析。
  2. 纯Python程序捕获库:实现了一个纯Python的程序捕获库,该库可被定制以捕获不同层次的程序细节。
  3. 简单的6指令IR:提出了一种仅包含6个指令的简单中间表示(IR),其设计重点在于易于理解和进行静态分析。
  4. 代码生成系统:构建了一个代码生成系统,能够将转换后的代码无缝地返回到宿主语言(Python)的生态系统中。
  5. 案例研究:展示了torch.fx在实践中如何被用于性能优化、程序分析、设备适配(device lowering)等场景,实现了PyTorch生态系统中以前难以完成的工作流。

A3 背景知识与设计原则

背景知识

程序捕获、特化与IR设计的权衡。无论是即时模式还是图模式框架,在捕获和转换程序时都必须在程序结构的捕获、程序的特化(specialization)以及中间表示(IR)的设计之间做出选择。这些选择共同决定了框架能表示的程序范围、编写转换的难易程度以及转换后程序的性能。通常,为了支持更多程序并实现高性能,需要更复杂的捕获框架和IR,这反过来又使得转换的编写更加困难。

2.1 捕获程序结构

  • 执行追踪(Tracing)。从Python程序中捕获结构有多种方法。最简单的是基于示例输入的追踪,即给定一些示例输入,执行模型并记录发生的操作,PyTorch的jit.trace (DeVito et al.【TorchScript, 2018】) 采用了此方法。一种稍微复杂的方式是符号追踪,即使用抽象值而非具体示例输入进行追踪。MXNet的Gluon (Chen et al.【Mxnet: A flexible and efficient machine learning library for heterogeneous distributed systems, 2015】) 和TensorFlow的tf.function (Moldovan et al.【AutoGraph: Imperative-style Coding with Graphbased Performance, 2018】) 实现了这种方法。符号追踪的优点是不需要用户提供示例输入,并且能暴露输入值依赖的Python控制流位置。追踪通常只记录张量和少数其他数据结构(如张量列表)上的操作,因此其可见性有限,但对于深度学习计算(通常是扁平的张量操作序列,即基本块程序)来说已经足够。
  • 扩展追踪能力。一些追踪系统通过重写标准Python代码的执行行为来捕获更多程序结构,例如控制流,但这增加了复杂性。例如,tf.function通过一个轻量级模块化分阶段(Lightweight Modular Staging)系统 (Rompf & Odersky【Lightweight modular staging: a pragmatic approach to runtime code generation and compiled DSLs, 2010】) 来增强符号追踪,该系统使用Python AST变换将命令式控制流结构转换为可被追踪的高阶Python函数。
  • 嵌入式编程语言。另一种方法是让用户直接在Python内部的嵌入式编程语言中编写模型。最简单的是提供类似TensorFlow的图构建API。为了更好地表示控制流,PyTorch的TorchScript (DeVito et al.【TorchScript, 2018】) 使用传统的词法-语法-编译器工具链直接从Python源码中提取程序。TorchScript能够完全保真地检查源码语法,理解结构化控制流、集合类型等。与可能静默失败的追踪相比,嵌入式语言方法可以在编译时报告不支持的构造。但其实现非常复杂,且在实践中仍无法支持完整的Python语言。
  • 原生宿主语言集成。Zygote.jl (Innes【Don’t Unroll Adjoint: Differentiating SSA-Form Programs, 2018】) 和Julia生态中的TPU集成 (Fischer & Saba【Automatic Full Compilation of Julia Programs and ML Models to Cloud TPUs, 2018】) 以及Swift for TensorFlow (Saeta et al.【Swift for TensorFlow: A portable, flexible platform for deep learning, 2021】) 通过与非Python宿主语言集成来提供程序转换接口。主要缺点是要求用户离开拥有庞大库和社区动力的Python生态系统。

2.2 程序特化

  • 特化的必要性与程度。Python表达式(如a + b)非常抽象。ML框架在捕获程序时,通常会对其进行特化,使其仅对特定类型或张量形状有效。特化程度越高,适用输入范围越窄。不同方法在特化程度、时机(预先/即时)和安全性上有所不同。
  • 基于示例的特化。PyTorch的torch.jit.trace (DeVito et al.【TorchScript, 2018】) 会根据示例输入的形状进行特化。这种捕获方式是非侵入式的,但可能导致追踪到的表示形式是形状特化的,即仅对追踪时使用的值形状有效,对其他形状可能失败。
  • 即时(JIT)特化。为避免特化失败,DyNet (Neubig et al.【Dynet: The dynamic neural network toolkit, 2017】) 和LazyTensor (Suhan et al.【LazyTensor: combining eager execution with domain-specific compilers, 2021】) 等系统采用即时追踪,为每次调用捕获特化的程序表示。它们在运行时推迟张量操作的执行,累积程序轨迹,当需要具体值时才应用转换并执行。这会带来额外开销,但LazyTensor通过缓存系统来缓解。
  • JIT特化的改进与挑战。JAX的jit组合子 (Frostig et al.【Compiling machine learning programs via high-level tracing, 2018】) 通过要求输入是纯函数来改进JIT特化,这强制了非张量计算(如形状表达式)的引用透明性,从而可以根据输入形状判断是否需要重新捕获。JIT特化的缺点是代码执行更难推理(例如printpdb语句只在重新追踪时执行),且重新追踪和转换可能导致难以预测的性能抖动。

2.3 中间表示(IR)设计

  • 语言。许多框架使用跨语言格式定义IR,如Caffe和TensorFlow使用Protocol Buffers (Xiao et al.【Protocolbuffers/Protobuf: Protocol buffers - google’s data interchange format, 2008】)。PyTorch的JIT和MXNet使用C++数据结构。原生表示性能更好,但给Python程序员带来了学习曲线。
  • 控制流。大多数神经网络可表示为无控制流的扁平张量操作序列,即基本块程序(basic block program),通常用有向无环图(DAG)表示。MLP、CNN (He et al.【Deep Residual Learning for Image Recognition, 2015】)、推荐模型 (Naumov et al.【Deep Learning Recommendation Model for Personalization and Recommendation Systems, 2019】) 和Transformer (Vaswani et al.【Attention Is All You Need, 2017】) 都可以这样表示。RNN (Elman【Finding structure in time, 1990】; Hochreiter & Schmidhuber【Long short-term memory, 1997】; Cho et al.【Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, 2014】) 虽然有循环结构,但通常被封装为单个张量操作,因此整个网络应用也常表现为基本块程序。尽管如此,许多框架(如TorchScript、JAX、MLIR (Lattner et al.【MLIR: A Compiler Infrastructure for the End of Moore’s Law, 2020】))仍在IR中支持控制流,但这增加了IR和转换实现的复杂性。
  • 状态。深度学习模型包含可训练权重等状态。PyTorch允许值被修改,张量可以是彼此的视图(view),例如x[i] = y会修改x。由于PyTorch支持这些别名和突变语义,程序修改必须依赖于安全性分析(如别名分析 (Andersen【Program analysis and specialization for the C programming language, 1994】))。TorchScript实现了别名分析,但代价高昂:所有操作都需标注别名和突变行为,保守的假设会阻碍优化。而JAX的函数式方法将状态管理的负担转移到框架之外,模型必须是纯函数,参数作为输入传递。这使得需要同时修改状态和代码的转换(如批量归一化折叠)变得更加复杂。

设计原则

现有框架的设计大多倾向于支持更广泛的深度学习程序,但牺牲了实现的简洁性。当捕获的程序是运行的唯一方式时,高保真度至关重要。但PyTorch主要作为即时执行框架使用,程序捕获仅用于特定转换,无需对整个程序都有效。此外,目标用户是机器学习从业者,他们更习惯使用Python而非编译器设计。

通过为典型的深度学习模型而非长尾用例进行设计,可以创建一个更易于使用和实现的框架。torch.fx的设计原则体现了这一理念:

  • 优先考虑典型模型的易用性:宁可牺牲对所有可能程序的支持,也要让典型模型的程序捕获和转换变得简单。避免为支持长尾、晦涩的用例而增加复杂性。
  • 使用熟悉的工具和概念:与机器学习从业者已经熟悉的工具和概念(如Python数据结构和PyTorch公开的算子)协同工作。
  • 高度可配置的捕获过程:使用户能够为长尾用例实现自己的解决方案。允许用户进行一次性配置比处理通用情况更简单。

A2 方法细节

TORCH.FX 概述

torch.fx采用符号追踪来捕获程序,使用一个简单的包含6个指令且基于Python的IR来表示它们,并从IR重新生成Python代码来执行。为避免JIT特化带来的重捕获复杂性,torch.fx本身不尝试特化程序,而是依赖于转换过程来决定需要执行何种特化。符号追踪过程是可配置的,用户可以定制以处理更特殊的用例。

  • 代码捕获与转换示例。图1展示了使用torch.fx捕获代码的示例。symbolic_trace函数接收一个函数或torch.nn.Module,并将其结构捕获到一个Graph对象中。该Graph对象与模块参数结合成一个GraphModule,这是一个torch.nn.Module的子类,其forward方法运行被捕获的图。打印图的节点可以看到捕获的IR:placeholder节点代表输入,output节点代表结果,call_function节点直接引用要调用的Python函数,call_method节点调用其第一个参数的方法。图2展示了一个简单的变换示例,该变换将代码中所有的relu激活函数替换为gelu
from torch.fx import Graph
def replace_activation(g: Graph, old, new):
    for n in g.nodes:
        if n.op == 'call_function' and n.target == old:
            # create IR to call new activate
            with g.inserting_after(n):
                new_n = g.call_function(new, n.args)
            n.replace_all_uses_with(new_n)
            g.erase_node(n)
            # or for this simplified case: 'n.target = new'

replace_activation(traced.graph, torch.relu, torch.nn.functional.gelu)
traced.recompile()
图1. torch.fx使用符号追踪将程序捕获到一个简单的IR中,并从该IR生成Python代码。图2. 变换,如此处替换激活函数的变换,是直接用Python编写的。

4.1 程序捕获

  • 符号追踪机制torch.fx的符号追踪机制使用一个Proxy数据结构来记录流经程序的值上的操作。Proxy是一个鸭子类型(duck-typed)的Python类,它记录对其的属性访问和方法调用,充当具体程序值的抽象替代品。Proxy利用__torch_function__协议 (Abbasi et al.【Improving subclassing Tensor by propagating subclass instances, 2020】) 来拦截并记录PyTorch算子(它们是自由函数)的派发。此外,torch.fx重写了PyTorch的Module抽象,以记录对使用代理值(proxied values)的Module的调用。整个符号追踪过程可以通过一个Tracer类进行配置,用户可以重写其方法来控制哪些值应保持为Proxy,哪些值在追踪期间应被部分求值。

4.2 中间表示

  • IR结构torch.fx在一个基于DAG的IR中表示程序,这适用于深度学习中常见的基本块程序。程序表示为一个Graph对象,其中包含一系列线性的Node对象,每个Node代表一个操作。Node具有以下属性:
    • opcode (字符串):描述节点代表的操作类型(具体语义见附录A.1)。
    • target:对于调用节点(call_modulecall_functioncall_method),这是调用的目标。
    • argskwargs:共同表示在追踪期间观察到的Python调用约定中的参数(各操作码的具体语义见附录A.2)。
  • 数据依赖与简洁性。节点之间的数据依赖通过argskwargs中对其他节点的引用来表示。为简化IR,torch.fx的IR没有用于建模数据结构构造或突变的原始操作。然而,argskwargs支持立即值(immediate values):Python内置类型(如intfloat)和递归集合类型(如tuplelist)可以作为节点参数出现,而无需单独的对象构造节点。这使得IR非常干净,节点与张量操作近似一一对应。
  • 状态管理torch.fx将程序的状态存储在GraphModule类中。GraphModule是转换后程序的容器,它暴露了转换后生成的代码,并提供了nn.Module中熟悉的参数管理API。GraphModule可以像普通的nn.Module一样在任何地方使用,确保了转换后的代码与PyTorch生态系统的其他部分具有互操作性。torch.fx的IR提供了两个操作码来访问模块层次结构中的状态:call_module(调用子模块的forward方法)和get_attr(从模块中获取参数)。这在可变参数和与之交互的函数式Graph之间提供了自然的分离,同时将它们保留在单个对象中以便于同时对两者进行转换。

4.3 源码到源码的转换

  • 代码生成torch.fx转换流程的最后阶段是代码生成。torch.fx不退出Python生态系统进入一个定制的运行时,而是从转换后的IR生成有效的Python源代码。这些转换后的代码随后被加载到Python中,生成一个可调用的Python对象,并被安装为GraphModule实例的forward方法。使用代码生成允许torch.fx转换的结果被安装在模型中,并用于进一步的转换。如图3所示,可以将一个程序的追踪结果安装为一个新模块的激活函数,然后对结果进行符号追踪以进行进一步的转换。

    图3. torch.fx生成Python代码作为其输出,因此可以被重用于进一步的捕捕获和转换步骤。

设计决策

torch.fx融合并扩展了先前工作中的方法,提供了一个易于使用、实现简单且可配置的库。

5.1 符号追踪

  • 选择符号追踪的原因torch.fx使用带有Proxy对象的符号追踪,而不是嵌入式语言技术,因为前者使用Python灵活的对象模型更容易直接在Python中实现。其实现足够简单,用户在追踪行为异常时可以阅读和单步调试源代码。此外,追踪有助于消除模型中不依赖于输入的控制流,例如torch.nn.Sequential中对顺序模块的循环,这对于穿透各种抽象以获取实际运行的算子至关重要。符号追踪对常见模型效果很好,代价是无法捕获那些真正包含输入依赖控制流的长尾模型,但这一限制通过使追踪过程可定制来弥补。

5.2 可配置的程序捕获

  • 定制追踪行为torch.fx的符号追踪是可定制的。Tracer类控制着fx.symbolic_trace的行为,其方法可以被重写以改变追踪过程。
    • is_leaf_module方法可以被重写,以指定哪些PyTorch Module实例在追踪期间应被视为不透明调用。默认情况下,torch.fx会保留PyTorch内置模块(如nn.Conv2d)的完整性,同时追踪用户自定义模块,以创建由标准、可理解的原语组成的轨迹。定制此行为可以屏蔽模型中包含不支持语言特性的部分,或修改用于转换的表示级别。
    • create_proxy方法可以被重写,以自定义在图中创建节点及关联的运行时Proxy值的行为。例如,这可以用于在节点上安装自定义元数据以用于转换,或支持将自定义数据结构作为可追踪值。

5.3 预先(AoT)捕获而不进行特化

  • 预先捕获的优势。虽然预先(ahead-of-time)追踪限制了可以捕获的程序空间(例如,不支持任意控制流),但它提供了一个更可预测、更可观察的捕获、转换和代码生成过程,这与PyTorch的开发者体验相符,并在实践中效果良好。与基于示例的追踪不同,符号追踪不会偶然地特化程序流程,因为在追踪时不存在做出数据依赖控制流决策所需的信息。控制决策中常用的张量属性(如shapendim)在符号追踪期间作为Proxy值返回,对这些值的操作可以被记录下来。当这些Proxy对象被用于不可追踪的操作(如转换为Python内置类型intbool)时,用户会收到一个错误消息和指向问题位置的堆栈跟踪。

5.4 基于Python的IR和变换

  • Python生态系统集成torch.fx的IR完全在Python中表示和实现,而不是使用如Protocol Buffers之类的跨语言格式。用户可以轻松地调用、阅读或重写它,无需理解C++或Protocol Buffers。变换也用Python编写。此外,变换的结果也是Python代码,这使得它易于检查正确性、用pdb调试、提供给其他库,并传递给进一步的变换。转换后的代码被封装在一个GraphModule中,可以像其他nn.Module一样在PyTorch中使用,例如,用户可以将其用TorchScript编译以进行部署,或在PyTorch的DistributedDataParallel库中使用。这种方式将torch.fx进一步整合到Python生态系统中,而不是将转换后的代码隔离到一个定制的、更难使用的运行时中。

5.5 IR内部无控制流

  • 简化分析。随着Transformer (Vaswani et al.【Attention Is All You Need, 2017】) 逐渐取代序列递归神经网络,深度学习中宿主语言控制流的使用变得越来越少。许多模型可以在没有控制流的情况下表达。然而,无论特定模型是否使用它,IR中存在控制流都会显著增加复杂性。大多数对IR的分析必须表示为程序上的不动点数据流分析 (Kildall【Global expression optimization during compilation, 1972】),而不是简单的正向传播。这要求作者定义分析属性的格(lattice)、转移函数和连接函数,并证明其单调性和有限性,这对于不熟悉编译器的ML变换编写者来说容易出错。
  • 不动点分析的复杂性示例。以形状传播为例:形状可以轻易地通过基本块程序向前传播。但是,当加入控制流时,形状传播不满足有限性属性——跨循环迭代传递的值可能会呈现无限多种形状,如图4所示。在这种情况下,分析通常会得到一个“动态”值,从而阻碍需要具体形状信息的后续转换。此外,一些ML社区提出的转换(如第6.2.1节中描述的量化转换)在存在控制流的情况下没有明确定义。
def loop_shapes(x, itr):
  # x is an input tensor of size [1, N]
  for _ in range(itr):
    x = torch.cat((x, x), dim=0)

  # Depending on the number of loop iterations, x may have an
  # arbitrary leading dimension i.e. x \in [*dynamic*, N]
  return x

IR本身不包含控制流,并不妨碍变换在更大模型中的基本块子图上工作;如何组合这些子图的细节留给变换的编写者或用户来决定。

5.6 函数式图与有状态模块

  • 避免可变性分析。语言中的别名和可变性语义可能需要复杂的分析来证明程序转换的合法性。torch.fx省略了此类分析,而是将可变操作定义为未定义行为,并可在追踪期间捕获时选择性地引发错误。在IR中避免可变性极大地简化了深度学习程序的分析和转换。大多数模型不受此限制,因为大多数可变性都局限于模型的参数。
  • 保留模块层次结构torch.fx仍然保留了PyTorch的分层nn.Module结构,并可以表示从此结构中的模块调用和属性获取。像torch.nn.Conv2d这样的模块对用户来说是易于理解的,有详细记录的参数,并且将参数的有状态使用隐藏在模块内部,因此保留这些对象使编写转换更容易。

A4 实验环境

  • 模型架构
    • IR复杂度对比:ResNet50
    • 量化:DeepRecommender模型 (Kuchaiev & Ginsburg【Training deep autoencoders for collaborative filtering, 2017】)
    • 算子融合:PyTorch ResNet50模型
    • 设备导出:PyTorch ResNet50模型 和 LearningToPaint模型 (Huang et al.【Learning to paint with model-based deep reinforcement learning, 2019】)
  • 硬件配置
    • CPU:服务器级 Intel Xeon Gold 6138 @ 2.00GHz (用于量化和融合实验)
    • GPU:NVIDIA Tesla V100-SXM2 16GB (用于融合和TensorRT导出实验)
  • 软件配置
    • 核心库:PyTorch, torch.fx
    • 依赖库
      • 量化实验使用 FBGEMM (Khudia et al.【FBGEMM: Enabling HighPerformance Low-Precision Deep Learning Inference, 2021】) 作为量化算子后端。
      • 融合和TensorRT实验使用 CUDA 11.0。
      • 设备导出实验使用 NVIDIA TensorRT (NVIDIA【Abstract】)。
  • 数据集
    • 论文未明确指定用于评估的数据集。在量化实验的校准阶段,描述为“用户将多批数据送入网络”,但未命名具体数据集。

A5 实验结果

6.1 IR 复杂度

  • 实验内容:比较了torch.fxtorch.jit.scripttorch.jit.trace为 canonical ResNet50 模型生成的IR的复杂性(以操作数量衡量)。
  • 实验结果torch.fx IR包含445个操作,而torch.jit.trace为860个,torch.jit.script为2614个。
  • 分析结论torch.fx通过追踪并展开与输入无关的控制流,并将简单的常量和数据结构内联为节点参数,显著简化了典型模型的IR,使其比torch.jit.trace小近一半,比torch.jit.script小一个数量级。简化的IR降低了编写和维护程序转换的复杂性。

    图5. torch.fx会追踪不变化的控制流,并可将常量作为参数嵌入其节点中。这大大简化了典型模型的IR。对于一个标准的ResNet50模型,torch.fx IR包含445个操作,而torch.jit.script为2614个,torch.jit.trace为860个。

6.2 性能优化

  • 6.2.1 量化 (Quantization)
    • 实验内容:对DeepRecommender模型应用训练后量化(Post-Training Quantization),并在CPU上评估其推理性能。
    • 实验结果:与浮点模型相比,基于torch.fx的量化带来了高达3.3倍的运行时性能提升,且性能方差很小,显示了预先转换带来的可预测性。
    • 分析结论torch.fx不仅提供了预期的性能提升,其开发效率也比基于TorchScript的实现高出一个数量级。这得益于其简化的表示、Python API以及与原生PyTorch生态的融合。

      图6. 基于torch.fx的量化的归一化推理运行时间(越低越好)。
  • 6.2.2 融合优化 (Fusion Optimizations)
    • 实验内容:对PyTorch ResNet50模型应用卷积-批量归一化(Convolution-BatchNorm)融合,并在GPU和CPU(多线程/单线程)上评估推理延迟。
    • 实验结果:在GPU上延迟降低约6%,在默认并行度的CPU上降低40%,在禁用并行度的CPU上降低18%
    • 分析结论torch.fx的预先、基于图的特性为这种非局部程序转换提供了必要的上下文和状态修改能力。整个转换和测试工具仅用不到150行Python代码实现,展示了其API在实现简洁、快速开发的程序转换方面的强大能力。

      图7. 基于torch.fx的卷积/批量归一化融合的归一化推理运行时间(越低越好)。
  • 6.2.3 程序调度 (Program Scheduling)
    • 实验内容:对于包含从参数服务器获取值的阻塞式远程过程调用的模型,使用torch.fx将阻塞调用替换为非阻塞调用和单独的等待调用,并将非阻塞调用尽可能提前。
    • 实验结果:在大型分布式训练任务中,此优化可将QPS(每秒查询次数)提高多达9%
    • 分析结论torch.fx可用于实现复杂的程序调度优化,以重叠网络调用和本地计算。

6.3 程序分析

torch.fx已被应用于多种程序分析场景:

  • 大规模推理模拟:用于实现一个深度学习推理模拟框架,通过估算FLOPs、内存带宽使用和数据大小来预测程序在各种硬件设备上的运行时间和内存消耗。
  • 形状分析fx.passes.shape_prop包提供了一个通过解释图来记录形状的朴素实现。其他更高级的形状传播系统(如通过符号表达式或渐进类型语义)也正在开发中。
  • 可视化fx.graph_drawer包使用户能够使用Graphviz可视化torch.fx图,提供了一种直观理解深度学习程序DAG结构的方式。

6.4 设备和运行时导出/编译

  • 实验内容:使用一个实验性的torch.fx-to-TensorRT转换系统,将PyTorch ResNet50模型和LearningToPaint模型下沉(lower)到NVIDIA TensorRT,并在V100 GPU上进行评估。
  • 实验结果:与基线PyTorch相比,ResNet50获得了3.7倍的运行时加速,LearningToPaint获得了1.54倍的加速。
  • 分析结论torch.fx为编译器栈(如TensorRT)与PyTorch的集成提供了一条高效路径。该项目的开发效率很高,利用torch.fx的Python API可以快速构建转换层、模型自动分割等功能,最终为用户提供了易于使用、检查和调试的API。

    图8. 基于torch.fx的TensorRT下沉的归一化推理运行时间(越低越好)。

A6 结论

本文介绍了torch.fx,一个纯Python的系统,用于捕获和转换PyTorch程序。通过分析相关系统(如控制流、可变性、数据模型)的复杂性来源,本文展示了torch.fx如何通过专注于常见用例和提供可定制性来避免这些复杂性。通过对优化、分析和设备下沉等多个用例的研究,本文证明了torch.fx的API设计如何成功地实现了这些功能。

A7 附录

A. TORCH.FX 节点语义

A.1 操作码(Opcode)含义

下表描述了torch.fx中每个Nodeopcode的含义。

A.2 args/kwargs 行为

下表描述了不同opcodeargskwargs字段的预期行为。

B 量化评估数值数据

下表为第6.2.1节量化实验的详细运行时间数据(单位:秒)。

C 融合评估数值数据

下表为第6.2.2节融合实验的详细运行时间数据(单位:秒)。

D TensorRT评估数值数据

下表为第6.4节TensorRT实验的详细运行时间数据(单位:秒)。