RLAX: Large-Scale, Distributed Reinforcement Learning for Large Language Models on TPUs

作者/机构: Runlong Zhouw, Lefan Zhang, Shang-Chen Wu, Kelvin Zou, Hanzhi Zhou, Ke Ye, Yihao Feng, Dong Yin, Alex Guillen Garcia, Dmytro Babych, Rohit Chatterjee, Matthew Hopkins, Xiang Kong, Chang Lan, Lezhi Li, Yiping Ma, Daniele Molinari, Senyu Tong, Yanchao Sun, Thomas Voice, Jianyu Wang, Chong Wang, Simon Wang, Floris Weers, Yechen Xud, Guolin Yin, Muyang Yu, Yi Zhang, Zheng Zhou, Danyang Zhuod, Ruoming Pang, Cheng Leong; Apple, University of Washington, Duke University.

主要贡献

核心问题与研究目标: 强化学习(RL)已成为提升大型语言模型(LLM)推理能力和与人类偏好对齐的关键后训练步骤。然而,在大规模分布式集群上高效执行RL算法,同时支持灵活的训练范式(on-policy/off-policy)、随时抢占以及解决训练和推理间的数值不一致性问题,是当前面临的重大挑战。

解决方案 RLAX: 本文介绍了RLAX,一个专为在大型TPU集群上运行而设计的可扩展、分布式强化学习框架。RLAX旨在解决大规模LLM后训练中的四大关键挑战:
1. 灵活支持多种RL算法:构建一个能灵活支持各种现有和新兴RL算法的训练框架。
2. 高效处理不同训练范式:同时支持on-policy和off-policy RL,并允许用户在其间灵活选择。
3. 大规模扩展与抢占支持:将RL训练扩展到大型分布式集群,并实现无缝的、随时可进行的抢占和恢复。
4. 维持训练器与推理工作者间的数值对齐:解决因实现差异导致的数值偏差,确保RL训练的稳定性。

创新点与核心设计:
* 解耦式架构: RLAX采用参数服务器(parameter-server)架构,将训练器(trainer)、推理工作者(inference worker)和验证器(verifier)逻辑上分离。这种解耦设计使得可以独立且灵活地为各组件分配计算资源。
* 统一的训练范式接口: 通过引入两个可配置的陈旧度参数(inference reload stalenesstrainer acceptance staleness),RLAX为on-policy和off-policy RL提供了统一的接口,用户可以灵活调整陈旧度以在训练稳定性和吞吐量之间取得平衡。
* 强大的抢占与恢复机制: RLAX的设计哲学是“所有组件被动,仅控制器主动”,且控制器所有操作都是幂等的。系统状态通过模型检查点(checkpoint)自然捕获,当发生抢占时,控制器可以从检查点重建状态并恢复,无需保存和恢复庞大的内存状态(如KV缓存)。
* 数值对齐策略: 为解决训练和推理间的数值偏差问题(尤其是在低精度浮点数下),RLAX通过重新计算对数概率(log-probability)并使用与推理匹配的训练图来强制对齐计算过程,显著减少了数值漂移,保证了off-policy RL的稳定性。
* 模块化算法设计: RLAX将RL算法的目标函数分解为聚合权重、重要性采样权重、优势估计器和梯度项等模块化组件,使得实现、调试和迭代新算法变得简单高效。
* 可扩展的验证服务: 针对计算密集型的代码验证任务,设计了OUBLIETTE,一个基于AWS Lambda的远程代码执行服务,实现了高效、可扩展且安全的验证流程。

核心成果: 在对QwQ-32B模型的评估中,使用1024个v5p TPU,RLAX仅用12小时48分钟就将其在Codeforces数据集上的pass@8准确率提升了12.8%。同时,实验证明了RLAX在训练过程中能有效抵御抢占事件,保持了强大的鲁棒性。

背景知识

LLM后训练中的强化学习。强化学习(RL)已成为LLM后训练的关键环节。特别是,通过人类反馈进行强化学习(RLHF)【【索引28,Training language models to follow instructions with human feedback,2022,NeurIPS】,【索引31,Proximal policy optimization algorithms,2017】】可以更好地使模型与人类偏好对齐,例如遵循指令和生成安全内容。评分函数也可以是数学函数或程序编译器/解释器,这使得RL能够提升LLM的数学和编程能力【【索引14,Deepseek-r1: Incentivizing reasoning capability in llms via reinforcement learning,2025】】。如今,所有AI编程基准测试的领先者都是通过RL训练的LLM【【索引19,Livecodebench: Holistic and contamination free evaluation of large language models for code,2024】,【索引23,Spider 2.0: Evaluating language models on real-world enterprise text-to-sql workflows,2024】,【索引10,Chatbot arena: An open platform for evaluating llms by human preference,2024】】。一个典型的RL训练迭代过程如下:LLM根据一组用户提示(prompt)通过LLM推理工作者采样生成token序列,例如“编写一个Python程序,输出从1到100所有整数的和”。输出的token序列是一个Python程序。这些生成的Python程序随后通过一个验证器(即Python解释器)进行验证,检查结果是否正确(在本例中,正确结果是5050)。然后,训练器更新模型权重,以增加LLM生成正确程序的概率。

问题 #1: 支持各种RL算法。存在大量用于LLM的RL算法,因此,一个RL训练系统若能全面支持这些多样化且快速演进的算法,同时保持易于维护和调试,将是一个非常理想的特性。为了介绍这些算法,我们用 $q$ 表示问题(用户提示),用 ${o_i}_{i=1}^G$ 表示为 $q$ 生成的一组 $G$ 个输出,用 ${\hat{A}_i}$ 可以是先前的检查点或者是当前策略的一个分离版本。为了补偿可能出现的分布偏移,通常会使用重要性采样权重来使损失函数变为“on-policy”:$\{\hat{A}_i\}_{i=1}^G$ 表示给定 $q$ 时每个输出的优势估计。这里 $o_i$ 是使用 $\pi_{\theta_{\text{old}}}$ 生成的,$\pi_{\theta_{\text{old}}}$ 可以是先前的检查点或者是当前策略的一个分离版本。为了补偿可能出现的分布偏移,通常会使用重要性采样权重来使损失函数变为“on-policy”:

$$r_{i, t}(\theta) = \frac{\pi_{\theta}\left(o_{i, t} | q, o_{i,\lt t}\right)}{\pi_{\theta_{\mathrm{old}}}\left(o_{i, t} | q, o_{i,\lt t}\right)}.$$


我们还用 $r_i(\theta) = \prod_{t=1}^{|o_i|} r_{i,t}(\theta)$ 表示整个生成过程的重要性权重。一类算法利用 $\log \pi_{\theta}$ 作为梯度来源,这得益于对数导数技巧。这包括REINFORCE【【索引39,Simple statistical gradient-following algorithms for connectionist reinforcement learning,1992,Machine learning】,【索引36,Policy Gradient Methods for Reinforcement Learning with Function Approximation,1999,NeurIPS】】,它对以下代理目标函数执行梯度上升:

$$\begin{aligned} \mathcal{J}_{\text{REINFORCE}}(\theta) = \mathop{\mathbb{E}}_{\substack{q \sim D \\ \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)}} \left[ \frac{1}{G} \sum_{i=1}^G \text{sg}[r_i(\theta)] \hat{A}_i \log \pi_\theta(o_i|q) \right] \end{aligned}$$
这里的 $\text{sg}(\cdot)$ 是停止梯度算子。这类算法的其他著名实例包括CISPO【【索引9,Minimax-m1: Scaling test-time compute efficiently with lightning attention,2025,arXiv】】和LlamaRL【【索引40,Llamarl: A distributed asynchronous reinforcement learning framework for efficient large-scale llm training,2025】】中的单边裁剪目标函数。近端策略优化(PPO)【【索引31,Proximal policy optimization algorithms,2017】】使用重要性比率 $r_{i,t}(\theta)$ 作为梯度来源。在前面一类算法中,尽管 $r_{i,t}(\theta)$ 可能出现在目标函数中,但它的梯度被停止了。这里我们以PPO的裁剪代理目标函数为例:
$$\begin{aligned} \mathcal{J}_{\mathrm{PPO}}(\theta) = \underset{\substack{q \sim \mathcal{D} \\ \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)}}{\mathbb{E}} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \left\{ r_{i,t}(\theta) \widehat{A}_i, \operatorname{clip}(r_{i,t}(\theta), 1-\epsilon, 1+\epsilon) \widehat{A}_i \right\} \right]. \end{aligned}$$
$\hat{A}_{pi}$ 是使用一个学习到的价值网络计算的。PPO中的裁剪机制约束了重要性比率,防止了可能破坏训练稳定性的过大策略更新。当优势 $\hat{A}$ 为正时,裁剪会阻止新策略在增加该动作概率的方向上偏离旧策略太远($\gt 1 + \epsilon$)。反之,当优势为负时,它会限制概率可以被降低的程度。这种保守的更新策略在探索和稳定、增量的策略改进之间取得了平衡。该类别的其他变体包括GRPO【【索引32,Deepseekmath: Pushing the limits of mathematical reasoning in open language models,2024,arXiv】】、DAPO【【索引46,Dapo: An open-source llm reinforcement learning system at scale,2025,arXiv】】、GSPO【【索引47,Group sequence policy optimization,2025】】和GFPO【【索引34,Sample more to think less: Group filtered policy optimization for concise reasoning,2025,arXiv】】。

问题 #2: 支持不同的训练范式。RL有多种分布式训练范式。我们希望能够支持所有这些范式,以便模型工程师能够探索不同的选择。最直接的训练范式是on-policy RL,其中rollout生成和训练器不重叠。推理工作者总是使用最新的模型权重,并且控制器在每个训练迭代中必须等待所有推理工作者完成它们的rollout。On-policy RL为我们的模型工程师在训练过程中提供了精确的控制和可预测性。然而,这也意味着当训练器在训练神经网络时,推理工作者是空闲的,反之亦然。为了使训练基础设施实现高利用率,off-policy RL允许训练器接受来自陈旧模型权重的训练数据(即来自推理工作者的生成rollout)。这意味着训练器和推理工作者可以并发运行。然而,陈旧度可能会对训练收敛产生不利影响。一些RL系统(例如StreamRL【【索引48,Streamrl: Scalable, heterogeneous, and elastic rl for llms with disaggregated stream generation,2025】】)专门针对给定陈旧度约束下的训练吞吐量进行了优化。

问题 #3: 扩展和支持抢占。我们利用大规模TPU基础设施来满足RL日益增长的计算需求,这些需求由模型规模增大、上下文长度扩展以及思维链(CoT)的整合所驱动。我们希望我们的RL系统是可扩展的,可以轻松配置训练集群的大小,并独立地为推理工作者和训练器分配资源。此外,由于我们在大规模共享基础设施上运行,当更高优先级的任务(如推理)需要回收资源时,我们必须支持抢占。然而,RL系统有多个组件,以黑盒方式完全地检查点和恢复所有组件的状态是低效的。例如,没有必要检查点推理工作者中的kv缓存状态。

问题 #4: 推理系统和训练器之间的数值一致性。RL中的一个关键技术是重要性采样(IS),它使用一个比率 $r(\theta)$ 来纠正策略之间的差异。这个比率定义为:(参见公式(1))

$$r(\theta) = \frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)}.$$


这里,$\pi_{\theta}$ 代表当前正在优化的策略,而 $\pi_{\theta_{\text{old}}}$ 是用于生成训练轨迹的行为策略。在off-policy设置中,根据定义,这些策略是不同的。IS比率 $r(\theta)$ 对于重新加权观察到的奖励和梯度、纠正数据与策略不匹配至关重要。在严格的on-policy设置中,行为策略和当前策略是相同的($\pi_{\theta} = \pi_{\theta_{\text{old}}}$)。因此,理论上重要性采样比率应始终为 $r(\theta) = 1.0$。然而,我们在实践中观察到,即使在on-policy实现中,$r(\theta) \neq 1.0$。这种现象在其他工作中也有报道,例如【【索引17,Defeating nondeterminism in llm inference,2025,Thinking Machines Lab: Connectionism】】和【【索引29,Defeating the training-inference mismatch via fp16,2025】】。这种差异的根本原因在于浮点数算术的非结合性,尤其是在像bfloat16这样的低精度格式中。浮点数加法不具有结合性,意味着由于中间的舍入误差,$(a + b) + c$ 可能不等于 $a + (b + c)$。当计算以不同顺序执行时,最终累加的值可能会出现分歧。这种数值不稳定性可能来自两个来源:计算非确定性或实现差异。计算非确定性是GPU等加速器上的一个常见问题。在像atomic-add这样的并发操作中,累加顺序不被保证,这可能导致即使输入完全相同,结果也不同。在GPU上的解决方法是通过强制顺序累加来避免在归约维度上的并发,从而使结果具有确定性【【索引17,Defeating nondeterminism in llm inference,2025,Thinking Machines Lab: Connectionism】】。相比之下,TPU的脉动阵列架构通过其设计在本质上解决了这个问题,因为它为数值块提供了固定的、确定性的累加顺序。由于计算非确定性在TPU上不是问题,RLAX剩下的问题是训练器和推理服务之间的实现差异。这种差异由两个主要不同点驱动:(1)训练器和推理服务中采用的不同并行策略,以及(2)编译器驱动的核融合(kernel fusion)的变化,这两者都改变了操作的最终顺序。

系统设计

整体架构。RLAX是一个可扩展、可抢占的RL框架,用于在TPU上进行LLM后训练。图1展示了其整体系统架构。RLAX采用基于参数服务器的设计,并配有一个单一的中心化控制器。它包含五个关键组件:一个推理服务(由一组推理工作者组成)、一个控制器、一个训练器、一个参数服务器和一个验证服务(包括编码和数学验证器)。控制器根据所选的RL算法来协调训练器、推理工作者和验证器的运行时间。
图1:RLAX的系统图。蓝色部分代表RLAX的核心软件组件。黄色部分代表我们使用的第三方组件。RLAX使用AXLearn作为训练器和推理工作者。

参数服务器。RLAX参数服务器构建于TensorStore【【索引16,TensorStore,2024,https://google.github.io/tensorstore/】】之上,它扩展了TensorStore的原生分片方案,以支持内存持久化和自定义版本管理。我们将参数服务器的分区与模型的原生权重分片对齐。例如,在一个数据并行度为4、专家并行度为8、模型并行度为8(跨8个TPU核心的主机)的配置中,RLAX会部署四个参数服务器副本,每个副本跨越八台主机,以匹配模型并行度的需求。推理工作者被均匀地映射到这些副本上。为了支持有界陈旧度 的off-policy RL,每个参数服务器在主机内存中最多维护N个权重快照。只有在所有副本上确认全局进度已达到步骤t后,才会对步骤t之前的检查点进行垃圾回收。此外,我们扩展了TensorStore客户端,以支持基于gRPC的通信,并丰富了额外的元数据。通过为权重传输启用RDMA,可以进一步加快加载时间。

训练器与推理服务。RLAX使用AXLearn【【索引22,AXLearn: Modular Large Model Training on Heterogeneous Infrastructure,2025】】作为训练器。在每个RL迭代中,训练器使用来自控制器的训练数据更新参数服务器中的模型权重。RL算法是通过扩展AXLearn模块来实现的,我们将其统称为AXLearn中的RL插件。推理服务由多个推理工作者组成,每个工作者包含一组能够独立运行模型推理的TPU。每个推理工作者都是一个使用其推理模式的AXLearn实例。推理工作者从参数服务器拉取最新的模型权重,并根据控制器提供的提示(prompts)生成rollout。

验证服务与检查点。控制器将这些rollout转发给验证服务。验证服务包括一个代码执行环境池,为训练数据分布中出现的每种编程语言(如Python, C++, Rust, Go)提供一个环境。这种设计支持包含多种编程语言代码编写任务的训练批次。Rollout和模型权重都会被定期检查点到持久存储。这些检查点确保了在发生抢占或系统故障时的容错性和可复现性。在本节的其余部分,我们将描述RLAX的系统设计如何支持不同的训练范式和处理抢占。

对不同训练范式的支持

通过陈旧度参数实现灵活性。RLAX旨在灵活支持广泛的训练范式。用户可以通过配置两个陈旧度参数来选择on-policy和off-policy RL:推理重载陈旧度(inference reload staleness)和训练器接受陈旧度(trainer acceptance staleness)。推理重载陈旧度 $j$ 决定了推理工作者从参数服务器获取更新模型权重的频率。换句话说,$j$ 意味着工作者只在每 $j$ 个模型版本时重新加载一次。训练器接受陈旧度 $k$ 则限制了训练器消费的rollout可以有多旧:每个rollout都用其生成时所用的模型版本进行标记,只有当rollout的模型版本与训练器当前模型版本的差距最多为 $k$ 步时,训练器才被允许使用它。

统一多种RL训练模式。这个接口为几种常用的RL训练模式提供了一个统一的表述:
* On-policy RL ($j=k=1$):训练器只消费由最新模型版本产生的rollout,确保了推理、验证和训练之间的严格同步(lockstep)调度。这对应于传统的同步PPO风格的流水线。
* 单步Off-policy RL ($j=1, k=2$):训练器可以使用落后一步的模型生成的rollout,允许rollout生成和模型更新之间有有限的重叠。这种模式与StreamRL【【索引48,Streamrl: Scalable, heterogeneous, and elastic rl for llms with disaggregated stream generation,2025】】中使用的单步off-policy设计一致。
* Off-policy RL ($j \ge 1, k = \infty$):训练器不施加陈旧度约束,可以从任意旧的rollout进行训练。这最大化了吞吐量,但牺牲了训练稳定性,可能需要额外的保障措施,如off-policy校正。

实际应用与权衡。这种统一的抽象使得RLAX能够以最少的用户代码或基础设施修改来支持一系列RL算法和系统层面的权衡。训练器总是倾向于使用更新的rollout,并丢弃未使用的。我们凭经验发现,有界陈旧度的RL,当 $j$ 和 $k$ 较小时(例如,$(j, k) = (16, 16)$ 或 $(16, 32)$),通常能提供一个有利的权衡:它显著减少了推理工作者的空闲时间,同时避免了生成过旧的、会被训练器后续丢弃的rollout。在我们的实验中,除非另有说明,我们通常设置 $(j, k) = (16, 32)$,因为这个值在提供强大经验性能的同时,也使得大量的硬件并行成为可能。

对抢占的支持

设计原则。抢占是在共享或弹性集群上运行大规模RL工作负载时的一个基本考虑因素。RLAX被设计为能够容忍抢占,而不会损失正确性或可复现性。我们的方法以控制器驱动、基于检查点的恢复机制为中心,遵循三个核心设计原则:
* 原则 #1: 以控制器为中心,组件被动。控制器之外的所有组件(即推理工作者、参数服务器副本、验证器、训练器)都作为被动执行者运行。它们仅响应来自控制器的明确指令来执行工作。这种架构确保了系统进度完全通过控制器进行序列化,消除了会使恢复复杂化的隐藏状态转换。因为没有组件会自主推进其自身状态,重启的控制器可以可靠地重建全局系统状态。
* 原则 #2: 持久化、自包含的快照。在每个权重重载边界(通常是每个训练步骤一次),控制器将其操作状态的完整快照写入持久存储,如Google Cloud Storage (GCS) 或 Amazon S3。这些快照包括模型检查点步骤、提示混合(prompt mixture)的进度、随机数生成器(RNG)流的状态,以及确定性地重新生成rollout所需的任何其他元数据。重要的是,这些快照是恢复的唯一真实来源,确保系统总能从一个明确定义的、一致的点恢复。
* 原则 #3: 控制器执行的幂等性。控制器的设计使其在任意两个快照之间的行为是确定性的。从一个保存的快照开始,它将重新加载相同的模型权重,发出相同的提示序列,并驱动系统执行相同的操作。这种幂等性保证了快照之间的失败不会在训练轨迹中引入偏差。

恢复流程。在正常操作期间,训练器和推理工作者从内存中的参数服务器读取模型权重以获得最大吞吐量。同时,RLAX持续地将模型权重检查点到持久对象存储中。这种双存储策略确保了即使训练器进程、参数服务器或其主机完全丢失,训练也总能从持久存储中恢复。所有用于提示混合采样的随机性都由可复现的RNG流控制,其状态包含在每个控制器快照中。控制器还记录了每个训练步骤消耗的提示数量。在恢复时,RLAX会恢复RNG状态和消耗计数器,从而实现比特级别的相同提示选择,并确保恢复后生成的rollout遵循与无抢占运行完全相同的轨迹。当检测到非正常关机或抢占时,控制器会启动一个确定性的恢复程序:(1)控制器从持久存储加载最新的状态快照,包括模型版本、提示混合进度和RNG状态。(2)控制器恢复其内部簿记结构,并重新连接到参数服务器,或在需要时从持久存储重新加载权重。(3)控制器精确地从记录的检查点和提示位置恢复rollout生成和训练,确保训练轨迹和日志元数据的连续性。

对RL算法的支持

模块化设计。RLAX采用模块化设计来计算给定一批rollout结果的损失函数。这种设计为算法的消融实验提供了极大的灵活性。在每次迭代中,我们首先从数据加载器中采样B个问题。对于每个问题q,我们使用$\pi_{\theta_{\text{old}}}$(在on-policy情况下,$\pi_{\theta_{\text{old}}}$与当前策略$\pi_{\theta}$分离)采样一组G个rollout $\{o_i\}_{i=1}^G$。最后,我们接收到奖励${R_i}^G$。在一个统一的视角下,现代RL算法的目标可以设计为:

$$\begin{aligned} \mathcal{J}_{\text{unify}}(\theta) = \mathop{\mathbb{E}}_{\substack{q \sim \mathcal{D} \\ \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)}} \left[ \sum_{i=1}^G \sum_{t=1}^{|o_i|} \text{sg}[\text{Agg}_{i,t}^{\mathcal{X}} \cdot \text{IS}_{i,t}^{\mathcal{X}}] \cdot \left( \text{sg}[\text{Adv}_{i,t}^{\mathcal{X}}] \cdot \text{GradTerm1}_{i,t}^{\mathcal{X}} + \text{GradTerm2}_{i,t}^{\mathcal{X}} \right) \right], \end{aligned}$$


其中$\text{sg}(\cdot)$是停止梯度算子,而$\mathcal{X} = (\pi_{\theta}, \pi_{\theta_{\text{old}}}, q, {o_i, R_i}_{i=1}^G)$是当前批次的完整信息包。我们如下解释每个组件:

  • 聚合权重 (Aggregation Weight)。$Agg$是每个token的聚合权重,以使目标函数有合理的界限。常见的选择是单个轨迹平均($\frac{1}{G|o_i|}$)、组级别平均($\frac{1}{\sum_{i'=1}^G |o_{i'}|}$)和最大长度平均($\frac{1}{GL_{\text{max}}}$,其中$L_{\text{max}}$是最大生成长度)。

  • 重要性采样权重 (Importance Sampling Weight)。$IS$用于解释当前策略$\pi_{\theta}$和采样策略$\pi_{\theta_{\text{old}}}$之间的分布偏移。从损失函数的角度看,当$\pi_{\theta_{\text{old}}}$在异步训练中变得陈旧时,该项将一个off-policy损失校正为一个on-policy损失。从数值角度看,由于推理引擎中的近似计算【【索引44,Your efficient rl framework secretly brings you off-policy rl training,2025】】,真实的$\pi_{\theta_{\text{old}}}$与实际采样分布之间存在不匹配。即使在完全同步的训练中($\pi_{\theta_{\text{old}}} = \text{sg}(\pi_{\theta})$),这个问题也会出现。更多细节请参考第5节。为了概念和工程上的严谨性,我们要求$IS$不贡献梯度。虽然类PPO算法使用重要性权重作为梯度来源,但我们将这些梯度归因于GradTerm1

  • 优势估计器 (Advantage Estimator)。$Adv$区分了“好”方向和“坏”方向的响应。在RLAX的设计中,可以轻松地插入不同的优势估计器(基本的蒙特卡洛估计器、GAE【【索引30,High-dimensional continuous control using generalized advantage estimation,2015,arXiv】】等)。

  • 梯度项 (Gradient Terms)GradTerm1是主要的梯度来源,因为它由$Adv$引导。GradTerm2用于容纳正则化。通过显式地用0屏蔽掉梯度项,我们可以恢复类PPO算法中的裁剪技巧,这也为目标函数的设计提供了更清晰的视角。

模块化设计的益处

工程方面。我们发现,用于LLM的现代RL(策略优化)算法可以很容易地通过这种模块化设计来实例化。它支持每个算法组件的插件式实现,使得算法调整像修改几行代码一样简单。对任何组件的任何错误修复或功能更改(例如,提高重要性权重的数值稳定性的技巧)都会自动在所有训练配置中生效,而无需重新检查其出现位置。另一方面,我们注意到,虽然优势估计器的模块化设计在现代RL库(verl【【索引33,Hybridflow: A flexible and efficient rlhf framework,2024,arXiv】】、AReaL【【索引15,Areal: A large-scale asynchronous reinforcement learning system for language reasoning,2025】】、slime【【索引49,slime: An llm post-training framework for rl scaling,2025,GitHub repository】】、TorchRL【【索引6,Torchrl: A data-driven decision-making library for pytorch, 2023】】、SkyRL【【索引8,Skyrl-v0: Train real-world long-horizon agents via reinforcement learning, 2025】】)中很普遍,但RL目标中的其他实现细节要么是冗余的,要么由于继承层次结构而可读性差:在这些库中,贡献者要么为特定的RL目标完全重写一个类,导致冗余,要么从现有的目标类继承并重写其方法,导致算法设计和超参数设置的可读性都很低。在AXLearn【【索引22,AXLearn: Modular Large Model Training on Heterogeneous Infrastructure,2025】】中也观察到了类似的现象,即许多训练框架由于使用子类型化而在神经网络层实现中存在冗余。

认知方面。明确区分梯度项和非梯度项,为每个组件的功能提供了更清晰的视角。例如,尽管GRPO(公式(5))和CISPO(公式(7))都对重要性采样权重进行裁剪,但它们对梯度的影响截然不同。对于GRPO,被裁剪的项没有梯度,这意味着它的裁剪实际上是一个梯度掩码。对于CISPO,被裁剪的项仍然有梯度,但权重被限制在1的邻域内,这意味着它的裁剪实际上是一个正则化器。

现代RL算法的实例化

实例化示例。我们现在展示一个使用$\mathcal{J}_{\text{unify}}(\theta)$(公式(3))实例化现代RL算法的例子。为方便表示,我们使用与【【索引9,Minimax-m1: Scaling test-time compute efficiently with lightning attention,2025,arXiv】】中公式(7)相同的定义来表示由于类PPO裁剪而产生的梯度掩码:

$$\begin{aligned} M_{i,t}=\left\{\begin{array}{ll} 0, & \text { if } \hat{A}_{i, t}\gt 0 \text { and } r_{i, t}(\theta)\gt 1+\epsilon_{\text {high }}, \\ 0, & \text { if } \hat{A}_{i, t}\lt 0 \text { and } r_{i, t}(\theta)\lt 1-\epsilon_{\text {low }}, \\ 1, & \text { otherwise. }\end{array}\right. \end{aligned}$$


对于具有对称裁剪的算法,$\epsilon_{\text{high}} = \epsilon_{\text{low}} = \epsilon$

GRPO。根据【【索引32,Deepseekmath: Pushing the limits of mathematical reasoning in open language models,2024,arXiv】】中的公式(3),原始的GRPO目标是:

$$\begin{aligned} \mathcal{J}_{\mathrm{GRPO}}(\theta)=\underset{\substack{q \sim \mathcal{D} \\ \{o_i\}_{i=1}^{G} \sim \pi_{\theta_{\text{old}}}(\cdot | q)}}{\mathbb{E}}\left[\frac{1}{G} \sum_{i=1}^{G} \frac{1}{|o_i|} \sum_{t=1}^{|o_i|}\left(\min \left\{r_{i, t}(\theta) \widehat{A}_{i}, \operatorname{clip}\left(r_{i, t}(\theta), 1-\epsilon, 1+\epsilon\right) \widehat{A}_{i}\right\}\right.\right. \\ \left.\left.-\beta \widehat{\mathrm{KL}}_{t}(\pi_{\theta} \| \pi_{\mathrm{ref}})\right)\right] \end{aligned}$$


其中

$$\hat{A}_i = \frac{R_i - \text{mean}(\{R_{i'}\}_{i'=1}^G)}{\text{std}(\{R_{i'}\}_{i'=1}^G)},$$
且$\widehat{KL}_t(\pi_{\theta}||\pi_{\text{ref}})$是 $KL(\pi_{\theta}(\cdot|q, o_{i,\lt t})||\pi_{\text{ref}}(\cdot|q, o_{i,\lt t}))$ 的K3估计器(参见【【索引32,Deepseekmath: Pushing the limits of mathematical reasoning in open language models,2024,arXiv】】的公式(4)): $$\widehat{\mathrm{KL}}_t\left(\pi_{\theta} \| \pi_{\text {ref }}\right)=r_{i, t}(\theta)-\log r_{i, t}(\theta)-1.$$ 因此,我们可以用以下方式实例化公式(3):
$$\begin{aligned} \begin{aligned} & \operatorname{Agg}_{i,t}^{\mathcal{X}}=\frac{1}{G\left|o_{i}\right|}, \\ & \operatorname{IS}_{i,t}^{\mathcal{X}}=1, \\ & \operatorname{Adv}_{i,t}^{\mathcal{X}}=\frac{R_{i}-\operatorname{mean}\left(\left\{R_{i'}\right\}_{i'=1}^{G}\right)}{\operatorname{std}\left(\left\{R_{i'}\right\}_{i'=1}^{G}\right)}, \\ & \operatorname{GradTerm1}_{i,t}^{\mathcal{X}}=M_{i,t} \cdot r_{i,t}(\theta), \\ & \operatorname{GradTerm2}_{i,t}^{\mathcal{X}}=-\beta\left(r_{i,t}(\theta)-\log r_{i,t}(\theta)-1\right). \end{aligned} \end{aligned}$$
DAPO【【索引46,Dapo: An open-source llm reinforcement learning system at scale,2025,arXiv】】可以通过设置$\text{Agg}_{i,t}^{\mathcal{X}} = \frac{1}{\sum_{i'=1}^G |o_{i'}|}$,移除GradTerm2,并开启§6.2中的动态采样开关来实例化。Dr. GRPO【【索引24,Understanding r1-zero-like training: A critical perspective,2025,arXiv】】可以通过设置$\text{Agg}_{i,t}^{\mathcal{X}} = \frac{1}{GL_{\text{max}}}$是最大生成长度),禁用批归一化,并移除GradTerm2来实例化。REINFORCE++【【索引18,Reinforce++: An efficient rlhf algorithm with robustness to both prompt and reward models,2025,arXiv】】可以通过将Adv替换为【【索引18,Reinforce++: An efficient rlhf algorithm with robustness to both prompt and reward models,2025,arXiv】】中的公式(11)和(12),并将GradTerm2替换为K2估计器来实例化。我们在代码清单1中实现了【【索引32,Deepseekmath: Pushing the limits of mathematical reasoning in open language models,2024,arXiv】】中的原始版GRPO。

loss_func = RLObjective.default_config().set(
    agg=IndividualAggregation.default_config(),
    # DAPO: GroupAggregation.default_config(),
    # Dr. GRPO: MaxLengthAggregation.default_config()
    IS=None,
    adv=AdvantageEstimator.default_config().set(
        batch_mean=True,
        batch_norm=True, # Dr. GRPO: False
    ),
    grad_1=MaskedISGrad.default_config().set(
        mask_func=PPOClipMask.default_config().set(
            eps_high=0.2, # DAPO: 0.28
            eps_low=0.2,
        ),
    ),
    grad_2=K3Grad.default_config().set( # DAPO & Dr. GRPO: None
        beta=0.04,
    ),
).instantiate()

对于其他算法,请参阅附录A。

解决推理工作者和训练器之间的数值偏差问题

导致数值偏差的来源。存在两个导致训练器和推理工作者实现不一致并可能引发数值偏差的来源。

来源 #1: 不同的并行策略。用于训练和推理的并行策略通常不同,以优化它们各自的目标。
* 推理: 通常使用张量并行(Tensor Parallelism, TP)来最大化有效内存带宽并减少跨多个设备的延迟。
* 训练: 通常采用更复杂的混合并行方案,如数据并行、FSDP和上下文并行的组合,以更好地将计算与通信重叠,并扩展到更大的模型。
这些不同的策略内在地改变了操作的顺序。例如,一个并行求和 $\sum(x_1, ..., x_n)$ 在一种设置(如TP)中可能被计算为 $(x_1, ..., x_{n/2}) + (x_{n/2+1}, ..., x_n)$,而在另一种设置中则可能作为单个顺序求和。由于bfloat16算术对操作顺序敏感,这种变化会导致可测量的数值差异。

来源 #2: JAX核融合的差异。JAX编译过程也导致了训练和推理使用不同的融合核(fused kernels)。
* 推理: 一个纯粹的推理调用只触发前向函数(primal function)。编译器有充分的自由度来积极地融合操作以获得最大速度。
* 训练: 一个训练步骤会同时触发前向和后向函数。前向传播必须输出后向传播所需的中间值(残差),这限制了编译器融合操作的能力。
一个具体的例子是RMSNorm。在推理模式下,JAX可能会将RMSNorm与后续操作(例如,一个残差加法)融合,在一个高度优化的核中计算这个组合块。然而,在训练期间,前向传播必须明确输出RMSNorm的输出以供后向传播使用。这一要求阻止了融合,导致了与推理路径不同的核分解,从而产生了不同的计算顺序。这些细微的、低级别的计算顺序差异足以在bfloat16中引起数值偏差,最终导致了$r(\theta) \neq 1.0$这一意外观察。

缓解策略

采用对数概率重计算和图对齐。获得高精度的对数概率 ($\pi_{\theta_{\text{old}}}(a|s)$) 对于RL训练的收敛至关重要,尤其是在严重依赖重要性采样(IS)的off-policy设置中。为了缓解数值偏差,我们采用了对数概率重计算(log-prob recomputation)并结合训练-推理图对齐的方法,这与【【索引44,Your efficient rl framework secretly brings you off-policy rl training,2025】】中讨论的方法类似。为了确保最大程度的对齐,我们不依赖于来自推理工作者的缓存对数概率。相反,我们维护一个旧策略的副本,该副本精确镜像了训练器的模型并行配置,并重新运行一个预填充(prefill)步骤来计算$\pi_{\theta_{\text{old}}}$。

强制计算图匹配。至关重要的是,我们进一步通过在训练器中禁用大多数激活保存(除了必要的层输出)来强制训练模型的前向路径与推理模型的计算图(例如,核融合模式)相匹配。这有效地迫使训练的前向传播行为与推理工作者的推理传播行为一致。虽然这种重度的重物质化("nothing-saveable")结合重计算可能会给训练过程带来一定的减速,但它对端到端性能的影响可以忽略不计(小于10%)。这是因为,在我们的RL工作负载中,主要的性能瓶颈是推理工作者上的rollout生成,而不是训练器。因此,我们认为这种性能权衡对于实现所需的数值一致性是必要的。

混合专家模型(MoE)的未来工作。对于混合专家(MoE)模型,我们观察到了一个额外的偏差来源:新旧策略模型激活的专家集合可能会有显著不同,这一现象在【【索引47,Group sequence policy optimization,2025】】中也有提及。我们计划在未来的工作中通过强制路由对齐来解决这个问题,确保新策略在特定的训练阶段遵循旧策略的专家路由决策。

其他系统设计与优化

简介。接下来,我们将讨论我们的验证器服务实现和数据整理方法。

验证器

代码验证器的挑战。与通常只需几十毫秒且可以在推理工作者本地单线程运行的只验证答案的数学验证器不同,代码验证器是资源密集型的。例如,一个高质量的编程竞赛问题通常要求代码通过10到100个测试用例,每个用例的时间限制在1到10秒之间。通常情况下,从大量的测试用例中进行子采样很容易导致假阳性结果。对于这类问题,本地代码验证器速度太慢。更糟糕的是,如果执行上下文没有被安全隔离,它很容易损坏系统。

OUBLIETTE:可扩展的远程代码执行服务。为了使代码验证具有可扩展性,我们设计了OUBLIETTE,一个利用AWS Lambda服务的远程代码执行服务。OUBLIETTE支持两个基本操作:执行代码(必要时进行编译)和直接执行二进制可执行文件,两者都支持标准或文件输入/输出。OUBLIETTE在验证官方Codeforces数据集方面发挥着至关重要的作用。在去重并移除没有英文描述或交互式的问题【【索引25,Interactive Problems: Guide for Participants,2025】】后,官方Codeforces数据集包含9096个问题和449797个测试用例,未压缩总大小超过500 GB。除了其极大的规模,Codeforces还因广泛使用针对每个问题的特殊检查器(special checkers)【【索引3,Checkers with testlib.h,2025】】而区别于其他编程竞赛数据集。概括地说,特殊检查器用于处理有多个可行解的问题,其本身也是一个可执行程序。

代码验证流程。我们设计了以下代码验证流程:
* 训练前:处理所有数据(包括将检查器编译成可执行二进制文件以减少服务端编译开销)并将其上传到一个持久化的OUBLIETTE存储库。我们获取每个测试用例的链接。这一步只需要做一次,只要数据本身不改变。
* 训练中:当代码需要验证时,客户端仅将代码和指向输入的链接发送给OUBLIETTE。OUBLIETTE随后使用指定链接中的输入执行代码,将输出存储到一个临时的OUBLIETTE存储库,并返回指向输出的链接。接下来,客户端将指向检查器二进制文件、输入和答案(都存储在同一存储库中)的链接,以及指向输出的链接发送给OUBLIETTE。最后,OUBLIETTE执行检查器二进制文件,并将检查器生成的裁决(verdicts,通常很短)直接发送给客户端。

效率与可扩展性。我们在此描述OUBLIETTE在我们用例中的效率和可扩展性。一个典型的OUBLIETTE请求包含2000个待并行执行的唯一 (代码, 输入) 对。每对的时间限制设置为10秒。平均验证时间,包括两轮客户端-服务器交互(代码和检查器),每个请求少于30秒。此外,由于AWS Lambda是无服务器函数服务,OUBLIETTE可以处理的代码运行最大并发数没有上限。因此,即使在第一个批次中因达到最大生成长度而产生大量rollout时,也不会出现瓶颈。由于持久化存储库存储了测试数据,网络传输开销也大大降低。假设平均代码大小为3 KB,平均链接大小为50 B,与向任何远程代码执行服务器进行完整数据传输相比,在整个Codeforces数据集上运行一个epoch的数据量从约 $449797 \times 31024 \times 1024 + 500 \approx 500$ GB 减少到 $449797 \times (3 + 3 \times 50) / (1024 \times 1024) \approx 1.35$ GB。
图2:使用RLAX在Codeforces数据集上训练QwQ-32B时的训练奖励和测试集准确率随步骤的变化。

数据整理

数据质量的重要性。数据质量是强化学习,特别是对于组相对策略优化(GRPO)算法的一个根本瓶颈。对于通常只有稀疏的0/1奖励信号的RL任务来说,数据质量极其重要。

核心挑战:避免同质化奖励。核心挑战在于避免一个组内出现同质化的奖励,因为这无法提供学习信号。当编程问题要么过于简单,要么过于困难时,从单个问题提示生成的所有样本往往会获得相同的奖励——要么全部正确(奖励=1),要么全部错误(奖励=0)。这种均匀分布消除了任何使用一组样本来估计优势的算法(包括GRPO等)计算有意义的策略梯度所需的奖励方差。理想的训练场景需要一个平衡的奖励分布,即从单个提示生成的n个样本中,一些解决方案是正确的,而另一些则包含错误,从而创造出有效学习所必需的奖励多样性。

数据整理策略。为了应对这一挑战,RLAX有一个扩展,可以选择性地添加一个数据过滤器,使用以下两种数据整理策略之一:
* 预过滤方法(Pre-filtering Approach):该方法对训练数据集进行初步评估运行,只保留模型能取得部分成功率的问题。虽然这最大化了数据集的效率,并通过在适当阶段部署不同难度的问题来实现课程学习,但它引入了手动开销,可能会减慢训练流程。
* 动态过滤方法(Dynamic Filtering Approach):该策略在训练期间实现实时过滤,只有当批次中包含来自同一问题的正确和错误样本时才转发轨迹(类似于DAPO【【索引46,Dapo: An open-source llm reinforcement learning system at scale,2025,arXiv】】)。尽管这通过自动化过滤消除了手动评估的开销,但它通过丢弃部分轨迹加剧了计算瓶颈。

实验评估

实验概述。我们首先展示主要结果,证明RLAX如何提高QwQ-32B【【索引37,Qwq-32b: Embracing the power of reinforcement learning,2025】】的准确率。然后,我们提供关于陈旧度界限选择和数值对齐方法的消融研究。最后,我们评估RLAX被抢占和恢复的能力。

主要结果

模型质量提升。RLAX能够训练并交付高质量模型,包括内部和开源变体。在这里,我们通过一个在QwQ-32B上进行RL训练以完成编程竞赛任务的案例研究,来展示RLAX的有效性。我们使用Codeforces 2013-2024年数据进行训练,2025年数据进行评估,并使用OUBLIETTE作为代码执行以对抗测试用例的奖励模型。在此实验中,我们为训练和rollout都部署了v5p-512集群。图2显示了训练奖励和测试集准确率随训练步骤的变化。使用RLAX,我们在奖励、pass@8和pass@32方面分别实现了6.7%、12.8%和12.0%的测试集准确率提升。由于QwQ-32B的知识截止日期是2024年11月28日【【索引35,Qwen3-30b-a3b vs qwq-32b model comparison,https://llm-stats.com/models/compare/qwen3-30b-a3b-vs-qwq-32b】】,这些 在Codeforces 2025上的提升代表了编码能力的真实增长。
a) 推理服务吞吐量和训练步时。
b) 不同墙上时间下的训练步数。
图3:TPU可扩展性:Rollout吞吐量和训练步时与核心数的关系。
图4:RLAX在不同陈旧度设置下的训练奖励。

系统可扩展性。我们系统的一个关键目标是随着推理能力的增加保持近线性的吞吐量扩展。为了评估这一点,我们在大小分别为128、256、512和1024的TPU v5p推理集群上进行相同的RL工作负载,每个集群都托管相同的模型并暴露相同的服务接口。我们使用相同的v5p-512集群作为训练器,以比较不同设置下的生成吞吐量和训练步时。图3a报告了每种配置的稳态推理服务吞吐量和训练步骤延迟。如图3a中的蓝线所示,推理吞吐量随着推理服务规模的增长大致呈线性增长,直到v5p-1024,与理想扩展的偏差极小。系统从使用v5p-128到v5p-1024作为推理服务规模时,实现了8.0倍的提升,这表明控制器即使在大规模下也只引入了非常低的开销。随着我们扩展推理服务的规模,整体训练性能也随之提升。图3a中的绿线显示了训练步时延迟。当推理服务的吞吐量提高8.0倍时,训练步时延迟减少了3.6倍。这是因为随着推理吞吐量的增加,系统变得受训练限制。图3b显示了在每种设置下前256步的步时与墙上时间的关系:RLAX的训练步时延迟在训练步骤中保持一致。

消融研究

陈旧度对收敛和训练速度的影响。我们研究了不同的陈旧度设置如何影响收敛和训练性能。在我们的主要运行中,我们将推理重载陈旧度设为16,训练器接受陈旧度设为32。这意味着$(j, k) = (16, 32)$。我们还测试了一种配置,其中训练器接受陈旧度等于推理重载陈旧度$(j, k) = (16, 16)$,这实际上导致训练器和推理工作者更频繁地同步,从而更经常地互相等待。我们比较了这两种设置下的训练奖励曲线。图4显示了我们实验设置(Codeforces数据集,QwQ-32B,16k上下文长度)的结果。在图中,stale-16表示$k = 16$,stale-32表示$k = 32$。两种配置每步都实现了相似的奖励增益。然而,具有更大训练器接受陈旧度的配置(训练器接受陈旧度 $\gt$ 推理重载陈旧度)将平均步时从75秒减少到45秒,从而实现了单位墙上时间内更高的奖励增益。然而,高陈旧度可能会减慢每步的奖励收敛速度,尽管在我们的主要实验中并非如此。在其他使用较小模型(3B)和较少样本(10个,过拟合场景)的实验中,高陈旧度既减慢了收敛速度,又增加了达到平台期的时间。陈旧度界限如何改变RL中的准确率和性能是依赖于具体工作负载的,需要在进行完整实验前进行验证运行。

数值对齐。为了量化训练器策略和行为策略之间的差异,我们测量了它们之间每个token对数概率的绝对差。设$\pi_{\beta}$为生成轨迹的行为策略,$\pi_{\theta}$为当前正在优化的训练器策略。对于每个动作token,我们计算$|\log \pi_{\beta}(a_t|s_t) - \log \pi_{\theta}(a_t|s_t)|$的掩码均值。图5报告了在使用QwQ-32B模型进行off-policy学习的两个受控实验的结果,比较了使用和不使用我们的数值对数概率重计算方法(§5.1)的训练运行。未经校正训练的模型表现出明显更不稳定的对数概率差异,包括与奖励学习不稳定性相关的尖锐峰值。相比之下,启用我们的方法可以持续抑制这些波动,并产生更平滑、更稳定的策略更新。在数量上,我们使用512个训练步骤中每步对数概率差异的第95百分位数来总结稳定性:没有重计算的基线运行达到了0.044348的第95百分位数,而我们的重计算方法将其降低到0.019943。
图5:训练器和Rollout之间的对数概率绝对差

对抢占的支持

鲁棒性评估。我们通过检查RLAX在真实抢占期间的行为来评估其在抢占下的鲁棒性。在我们的陈旧度消融实验(§7.2)期间,自然发生了一些与资源相关的抢占,这为我们提供了代表性的压力测试,而无需进行人工故障注入。图6展示了RLAX被抢占的一个例子。RLAX在抢占后可靠地从最近持久化的检查点恢复训练,学习曲线或系统行为没有可观察到的偏差。
图6:RLAX的训练奖励随时间变化及一次抢占事件。
这一结果证实了RLAX的检查点-恢复机制在实践中是有效的,并且系统可以容忍抢占而不会影响训练进度。

相关工作

Off-policy RL。实现大型语言模型的off-policy强化学习(RL)有多种可能的方法。RLAX目前采用一种直接的、rollout级别的off-policy设计:每个rollout完全使用单一模型版本生成,并且训练器被允许消费任何其来源模型最多落后k步的rollout。在这种方案下,推理工作者在从参数服务器刷新其权重时必须丢弃部分完成的rollout。任何续写都会产生模型来源无法清晰归因于单一版本的rollout。最近的工作已经开始探索放宽这一约束的机制,即允许单个rollout使用多个模型版本生成。例如,Magistral【【索引26,Magistral,2025】】使用更新后的模型权重继续生成,同时重用由旧权重产生的KV缓存,从而挽救了否则会被丢弃的部分计算。LlamaRL【【索引40,Llamarl: A distributed asynchronous reinforcement learning framework for efficient large-scale llm training,2025】】同样通过混合版本的rollout提高了数据效率,其中陈旧的rollout会加上一个使用最新权重计算的新鲜prefill作为前缀。这些方法提高了硬件利用率并减少了rollout的浪费,可以作为RLAX的扩展。

RL on TPUs。RLAX是专为TPU集群设计的全栈RL后训练系统,但它并非该领域唯一的近期成果。谷歌的Tunix库【【索引5,Tunix,2025,https://github.com/google/tunix】】提供了一个基于TPU的RL微调框架,专注于简单性和可扩展性。然而,Tunix目前缺乏多主机分布式训练能力,并依赖于标准的注意力实现,这使其不太适合长上下文的LLM。相比之下,RLAX专为大规模、多主机的部署而设计,并利用了如FlashAttention【【索引13 ,FlashAttention: Fast and memoryefficient exact attention with IO-awareness,2022,NeurIPS】,【索引12,FlashAttention-2: Faster attention with better parallelism and work partitioning,2024,ICLR】】等优化的核实现,以支持长上下文和高容量模型的有效训练。

ML系统中的检查点-恢复机制。检查点-恢复机制长期以来一直是大规模机器学习系统的标准实践,为分布式训练提供了容错和抢占支持。先前的工作通常关注于具有紧密同步迭代的监督式训练工作负载【【索引42,Gandiva: Introspective cluster scheduling for deep learning,2018,OSDI】,【索引38,Bamboo: Making preemptible instances resilient for affordable training of large DNNs,2023,NSDI】】。相比之下,RL后训练引入了独特的挑战:rollout依赖于提示混合采样、验证器结果和模型版本来源,所有这些都必须在重启后保持一致。简单地检查点所有状态,包括像KV缓存或中间验证器状态这样的大型临时产物,将导致不必要地大且加载缓慢的检查点。通过紧密集成确定性的控制器逻辑、可复现的RNG流以及模型权重的内存/持久化双存储,RLAX将传统的检查点-恢复技术扩展到了RL设置。这使得RLAX能够在频繁抢占下保证正确性,这一特性对于在共享或弹性计算集群上运行的大规模RL工作负载至关重要。

结论

我们介绍了RLAX,一个为大型语言模型(LLM)设计的可扩展、分布式强化学习框架,它在一个单一、灵活的系统架构中统一了广泛的RL算法和训练范式。RLAX专为在TPU上进行实际的大规模部署而设计,提供了高效的rollout生成、高吞吐量的训练以及内置的抢占和恢复支持。在大规模评估中,RLAX在1024个v5p TPU上仅用12小时48分钟就将QwQ-32B的pass@8准确率提高了12.8%,并且在整个训练过程中对多次意外抢占保持了鲁棒性。

附录

更多现代RL算法的实例化

REINFORCE。最早由【【索引39,Simple statistical gradient-following algorithms for connectionist reinforcement learning,1992,Machine learning】】引入,REINFORCE有许多变体。一个on-policy的REINFORCE使用公式(6)中的目标函数执行单步梯度更新,其中$\hat{A}{pi} = R_i - \text{mean}({R}_{i'=1}^G)$是无偏的优势估计器。因此,我们可以用以下方式实例化公式(3):

$$\begin{aligned} \text{Agg}_{i,t}^{\mathcal{X}} = \mathbf{1}[t = |o_i|] \cdot \frac{1}{G}, \\ \text{IS}_{i,t}^{\mathcal{X}} = r_t(\theta), \\ \text{Adv}_{i,t}^{\mathcal{X}} = R_i - \text{mean}(\{R_{i'}\}_{i'=1}^G), \\ \text{GradTerm1}_{i,t}^{\mathcal{X}} = \mathbf{1}[t = |o_i|] \cdot \sum_{t'=1}^{|o_i|} \log \pi_{\theta}(o_{i,t'} | q, o_{i,\lt t'}), \\ \text{GradTerm2}_{i,t}^{\mathcal{X}} = 0, \end{aligned}$$


一个token级别的REINFORCE

$$\begin{aligned} \mathcal{J}_{\text{token-REINFORCE}}(\theta) = \mathop{\mathbb{E}}_{\substack{q \sim \mathcal{D} \\ \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)}} \left[ \frac{1}{GL_{\max}} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \text{sg}[r_{i,t}(\theta)] \hat{A}_i \log \pi_\theta(o_{i,t}|q, o_{i,\lt t}) \right]. \end{aligned}$$
可以用以下方式实例化:
$$\begin{aligned} \begin{array}{l} \mathrm{Agg}_{i, t}^{\mathcal{X}}=\frac{1}{G L_{\max }}, \\ \mathrm{IS}_{i, t}^{\mathcal{X}}=r_{i, t}(\theta), \\ \mathrm{Adv}_{i, t}^{\mathcal{X}}=R_{i}-\operatorname{mean}\left(\left\{R_{i^{\prime}}\right\}_{i^{\prime}=1}^{G}\right), \\ \mathrm{GradTerm} 1_{i, t}^{\mathcal{X}}=\log \pi_{\theta}\left(o_{i, t} | q, o_{i,\lt t}\right), \\ \mathrm{GradTerm} 2_{i, t}^{\mathcal{X}}=0, \end{array} \end{aligned}$$
其中$L_{\text{max}}$是最大生成长度。留一法优势估计器【【索引20,Buy 4 REINFORCE samples, get a baseline for free!,2019】】可以用$\text{Adv}_{i,t}^{\mathcal{X}} = \frac{G-1}{G} (R_i - \frac{1}{G-1} \sum_{j \neq i} R_j)$来实例化。正则化(例如【【索引2,On the theory of policy gradient methods: Optimality, approximation, and distribution shift,2021,Journal of Machine Learning Research】】中的公式(12))可以用$\text{GradTerm2}_{i,t}^{\mathcal{X}} = \min\{r_{i,t}(\theta), \rho\}$来实例化。

CISPO。【【索引9,Minimax-m1: Scaling test-time compute efficiently with lightning attention,2025,arXiv】】中的公式(4):

$$\begin{aligned} \mathcal{J}_{\mathrm{CISPO}}(\theta) = \underset{\substack{q \sim \mathcal{D} \\ \{o_i\}_{i=1}^G \sim \pi_{\theta_{\text{old}}}(\cdot|q)}}{\mathbb{E}}\left[\frac{1}{\sum_{i'=1}^G |o_{i'}|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \operatorname{sg}[\hat{r}_{i,t}(\theta)] \hat{A}_i \log \pi_\theta(o_{i,t}|q, o_{i, \lt t})\right], \end{aligned}$$

;
其中
$$\hat{r}_{i, t}(\theta)=\operatorname{clip}\left(r_{i, t}(\theta), 1-\epsilon_{\text {low }}^{\mathrm{IS}}, 1+\epsilon_{\mathrm{high}}^{\mathrm{IS}}\right), \quad \hat{A}_{i}=\frac{R_{i}-\operatorname{mean}\left(\left\{R_{i^{\prime}}\right\}_{i^{\prime}=1}^{G}\right)}{\operatorname{std}\left(\left\{R_{i^{\prime}}\right\}_{i^{\prime}=1}^{G}\right)}.$$
因此,我们可以用以下方式实例化公式(3): $$ \mathrm{Agg}_{i,t}^{\mathcal{X}} = \frac{1}{\sum_{i'=1}^{G} |o_{i'}|}, $$ $$ \mathrm{IS}_{i,t}^{\mathcal{X}} = \mathrm{clip}(r_{i,t}(\theta), 1 - \epsilon_{\mathrm{low}}^{\mathrm{IS}}, 1 + \epsilon_{\mathrm{high}}^{\mathrm{IS}}), $$ $$ \mathrm{Adv}_{i,t}^{\mathcal{X}} = \frac{R_i - \mathrm{mean}(\{R_{i'}\}_{i'=1}^G)}{\mathrm{std}(\{R_{i'}\}_{i'=1}^G)}, $$ $$ \mathrm{GradTerm1}_{i,t}^{\mathcal{X}} = \log \pi_{\theta}(o_{i,t}|q, o_{i,\lt t}), $$ $$ \mathrm{GradTerm2}_{i,t}^{\mathcal{X}} = 0. $$