StreamBP: Memory-Efficient Exact Backpropagation for Long Sequence Training of LLMs

作者/机构: Qijun Luo, Mengqi Li, Lei Zhao, Xiao Li; 香港中文大学(深圳),上海交通大学

A1 主要贡献

核心问题与研究目标
大型语言模型(LLMs)在处理如长链推理等复杂任务时,需要训练极长的输入序列。然而,随着序列长度的增加,在反向传播(BP)过程中存储激活值的内存开销变得巨大,即便使用了梯度检查点(gradient checkpointing)技术也难以解决。这种巨大的内存消耗主要源于在BP过程中需要存储中间激活值。本文旨在为这一内存问题提供一个高效的解决方案。

主要贡献
本文提出了一种名为StreamBP的内存高效且精确的反向传播算法,旨在解决在超长序列上训练LLMs(特别是推理模型)时遇到的内存瓶颈问题。其主要贡献如下:
1. (C.1) 提出StreamBP算法
* 核心思想:StreamBP基于链式法则的线性分解,以逐层(layerwise)的方式沿序列维度进行计算,从而显著减少了存储中间激活值和logits的内存开销。
* 广泛适用性:该方法与常见的训练目标兼容,包括监督微调(SFT)、组相对策略优化(GRPO)和直接偏好优化(DPO)。
* 计算与时间效率:通过利用语言模型的因果结构(causal structure),StreamBP相比标准BP减少了计算浮点运算次数(FLOPs),并实现了比梯度检查点基线更快的BP速度。
* 分布式支持:为了支持多GPU训练,本文还开发了一种通信高效的分布式StreamBP实现,特别优化了梯度和参数的通信,显著提高了训练效率并拓宽了其适用性。

  1. (C.2) 实验验证
    • 显著提升序列长度:在单GPU和分布式训练设置下,与梯度检查点基线相比,StreamBP将BP过程的最大序列长度提升了2.8至5.5倍,而BP时间相当甚至更短。
    • 可扩展性:StreamBP的内存成本与序列长度呈线性关系,因此其在序列长度上的扩展能力可以直接转化为批次大小(batch size)的扩展,从而加速训练。
    • 多目标与分布式验证:实验证明,StreamBP在SFT、GRPO和DPO等多种目标下均能有效增加最大序列长度。在Deepspeed ZeRO-2分布式训练方案下,StreamBP同样实现了比梯度检查点大5至5.6倍的序列长度。

A3 背景知识

2 反向传播过程中的峰值内存成本

本文使用PyTorch CUDA内存快照工具,记录了Qwen 3-4B模型在序列长度为8192并应用梯度检查点技术的情况下,进行2次前向和后向传播过程的详细GPU内存使用情况。结果如图1所示。分析中排除了优化器状态的内存成本,仅关注BP过程。

峰值内存成本。峰值内存成本出现在第二次前向传播结束时。除了分配给BF16参数和梯度的16GB外,大约14GB用于存储中间计算结果。这包括BF16格式的检查点层输入、BF16格式的logits、FP32格式的logits以及FP32 logits的梯度。由于logits的维度是序列长度(T)× 词汇表大小(C),其内存成本与模型大小无关,而取决于模型类别的词汇表大小C。

图1:使用PyTorch内存分析器可视化的Qwen 3-4B在梯度检查点下的反向传播内存剖面。序列长度设置为8192。由于我们关注BP过程,优化器状态被排除。
图1:使用PyTorch内存分析器可视化的Qwen 3-4B在梯度检查点下的反向传播内存剖面。序列长度设置为8192。由于我们关注BP过程,优化器状态被排除。

第二峰值内存成本。第二个内存峰值出现在第二次反向传播开始时,此时模型会重新对最后一个transformer层的检查点输入进行前向传播,以计算该层的激活值。这些激活值将被临时存储,用于计算该层参数和输入的梯度。

我们注意到,随着模型规模的增大,第二个内存峰值会增加并逐渐接近峰值内存。关于内存剖面的更详细解释,请参见附录C。

A2 方法细节

3 内存高效的精确流式反向传播

本节将介绍我们提出的算法设计,该算法在BP过程中以更低的内存成本和浮点运算次数(FLOPs)计算精确梯度。

3.1 主要思想

模型变换与梯度计算。考虑模型前向传播过程中的一个变换:

$$f_W(Z_{\text{in}}) = Z_{\text{out}},$$


其中 $W$ 是与该变换相关的权重,$f_W(\cdot)$ 可以是模型内部的任何映射,例如transformer层或语言建模头。根据链式法则,权重的梯度由 $\frac{\partial L}{\partial \text{vec}(W)} = (\frac{\partial \text{vec}(Z_{\text{out}})}{\partial \text{vec}(W)})^\top \frac{\partial L}{\partial \text{vec}(Z_{\text{out}})}$ 给出,其中 $L$ 是损失,$\text{vec}(\cdot)$ 是向量化算子,$\frac{\partial \text{vec}(Z_{\text{out}})}{\partial \text{vec}(W)}$ 表示雅可比矩阵。在BP过程中,当梯度 $\frac{\partial L}{\partial \text{vec}(Z_{\text{out}})}$ 准备就绪时,梯度检查点方法会通过 $f_W(\cdot)$ 重新前向传播 $Z_{\text{in}}$,然后计算并存储所有计算 $\frac{\partial \text{vec}(Z_{\text{out}})}{\partial \text{vec}(W)}$ 和 $\frac{\partial \text{vec}(Z_{\text{out}})}{\partial \text{vec}(Z_{\text{in}})}$ 所需的中间激活值,其中后者将用于计算前一层的梯度 $\frac{\partial L}{\partial \text{vec}(Z_{\text{in}})}$。如图1所示,存储这些中间激活值的内存成本可能非常巨大。

StreamBP的核心分解思想。为了减少这些中间值的内存成本,我们引入了流式反向传播(StreamBP)。设 $\text{vec}(Z_{\text{out}}) = [\text{vec}(Z_{\text{out}}^{(1)}), \text{vec}(Z_{\text{out}}^{(2)}), \dots, \text{vec}(Z_{\text{out}}^{(D)})]$ 是 $\text{vec}(Z_{\text{out}})$ 的任意分区。StreamBP基于以下线性分解:

$$\frac{\partial L}{\partial \operatorname{vec}(W)}=\left(\frac{\partial \operatorname{vec}\left(Z_{\text {out }}\right)}{\partial \operatorname{vec}(W)}\right)^{\top} \frac{\partial L}{\partial \operatorname{vec}\left(Z_{\text {out }}\right)}=\sum_{i=1}^D\left(\frac{\partial \operatorname{vec}\left(Z_{\text {out }}^{(i)}\right)}{\partial \operatorname{vec}(W)}\right)^{\top} \frac{\partial L}{\partial \operatorname{vec}\left(Z_{\text {out }}^{(i)}\right)} .$$


通过策略性地划分 $\text{vec}(Z_{\text{out}})$,存储计算 $\frac{\partial \text{vec}(Z_{\text{out}}^{(i)})}{\partial \text{vec}(W)}$ 所需的中间激活值的成本可以远低于存储计算 $\frac{\partial \text{vec}(Z_{\text{out}})}{\partial \text{vec}(W)}$ 所需的成本,并且通常与分区块的大小成正比。受此启发,StreamBP按顺序计算公式(1)中所有分区的分解部分,并将它们累加到一个运行总和中,从而得到精确的梯度。关于StreamBP如何通过一个线性变换示例显著降低内存成本的快速理解,请参阅附录A.1。

接下来,我们将详细阐述如何将StreamBP应用于一个具体的transformer大语言模型,这需要进行非凡且复杂的开发。

3.2 用于Transformer LLMs的StreamBP

本节中,我们将应用StreamBP来显著降低语言建模头和transformer层在BP过程中消耗的内存。此外,我们将讨论StreamBP相比于带梯度检查点的标准BP需要更少的计算FLOPs,从而可能比标准方法实现时间上的加速。

3.2.1 用于语言建模头的StreamBP:SFT、GRPO和DPO

语言建模头的线性变换。语言建模头执行以下线性变换:

$$HW_{\text{lm\_head}} = \text{logits.}$$


这里,$W_{\text{lm_head}} \in \mathbb{R}^{d \times C}$ 是语言建模头的权重,$H \in \mathbb{R}^{T \times d}$ 是最后一个transformer层输出的隐藏状态。logits $\in \mathbb{R}^{T \times C}$ 将用于计算目标函数。$C, d, T$ 分别是词汇表大小、隐藏维度和序列长度。如图1所示,由于词汇表大小和序列长度巨大,logits及其梯度会产生巨大的内存消耗。接下来,我们将逐一分析在SFT、GRPO和DPO机制下,如何使用StreamBP显著减少logits的内存。

监督微调(SFT)。SFT的(未归一化)目标函数如下:

$$L_{\text{SFT}}(\text{logits}, Y) := \sum\nolimits_{t=1}^{T-1} -\log \text{softmax}(\text{logits}_{t,:})_{Y_{t}},$$


其中 $Y \in \mathbb{R}^{T-1}$ 是标签向量。重要的是,每个位置的logits对目标的贡献是独立的。为了执行StreamBP,我们将logits和标签沿序列维度均匀划分为D个块,即 $\{(\text{logits}^{(i)}, Y^{(i)})|i=1,\dots,D\}$,其中 $\text{logits}^{(i)} \in \mathbb{R}^{((T-1)/D)\times C}$。然后,我们对所有分区 $i=1,\dots,D$ 顺序地累加梯度,如下所示:

$$g_{\mathrm{lm\_head}} += \frac{\partial L_{\mathrm{SFT}}(\text{logits}^{(i)}, Y^{(i)})}{\partial W_{\mathrm{lm\_head}}}, \quad g_H += \frac{\partial L_{\mathrm{SFT}}(\text{logits}^{(i)}, Y^{(i)})}{\partial H},$$
其中 $g_{\text{lm_head}}$ 和 $g_H$ 从零初始化。运算符 += 表示原地求和。$\text{logits}^{(i)}$ 及其梯度在公式(3)中使用后将从内存中清除。在所有分区累加完成后,$g_{\text{lm_head}}$ 和 $g_H$ 将分别是 $W_{\text{lm_head}}$ 和 $H$ 的精确梯度。在此计算过程中,StreamBP仅顺序存储每个 $i$ 的 $\text{logits}^{(i)}$ 及其梯度,与原始方法相比,内存成本仅为1/D。

组相对策略优化(GRPO)。GRPO的目标函数如下:
LGRPO(logits) := −E[q∼Dq,{oj}Gj=1∼πold(·|q)] [ 1G 1To G∑j=1 To∑t=1{min (πθ(j, t)πold(j, t) Aˆj,t, clip (πθ(j, t)πold(j, t) , 1 − ϵ, 1 + ϵ) Aˆj,t)− β log πθ(j, t)πref(j, t)} ] .
为简化表示,我们在不失一般性的情况下省略了GRPO的KL散度中的补偿项。符号的详细定义见附录A.2。这里,$\text{logits} \triangleq \{\text{logits}_{\pi_\theta}, \text{logits}_{\pi_{\text{old}}}, \text{logits}_{\pi_{\text{ref}}}\}$ 包含由目标策略、旧策略和参考策略生成的logits,其中 $\text{logits}_\pi \in \mathbb{R}^{G \times T_o \times C}$。策略的输出由logits决定,即:

$$\pi(j, t):=\pi\left(o_{j, t} \mid q, o_{j,<t}\right)=\operatorname{softmax}\left(\operatorname{logits}_{\pi, j, t,:}\right)_{o_{j, t}}$$ <p>
注意,公式(4)中的每个 $f(j,t)$ 对目标的贡献是独立的,并且只依赖于 $\text{logits}_{\pi,j,t,:}$,这使我们能够沿序列维度执行StreamBP。具体来说,我们将logits沿序列维度划分为 $\{\text{logits}^{(i)} := \{\text{logits}^{(i)}_{\pi_\theta}, \text{logits}^{(i)}_{\pi_{\text{old}}}, \text{logits}^{(i)}_{\pi_{\text{ref}}}\}|i=1,\dots,D\}$,其中 $\text{logits}^{(i)}_\pi \in \mathbb{R}^{G \times (T_o/D) \times C}$。定义序列分区的目标函数为:
L(i)GRPO(logits(i)) := −E[q∼Dq,{oj}Gj=1∼πold(·|q)] [ 1G 1To G∑j=1 ∑t∈Ti f(j, t) ] ,
其中 $\mathcal{T}_i := \{(i-1)T_o < t \le iT_o | t \in \mathbb{Z}\}$ 表示分区的序列范围。我们有 $L_{\text{GRPO}}(\text{logits}) = \sum_{i=1}^D L^{(i)}_{\text{GRPO}}(\text{logits}^{(i)})$。然后,与SFT类似,StreamBP对 $i=1, \dots, D$ 顺序执行以下累加:
$$g_{\text{lm\_head}} += \frac{\partial L_{\text{GRPO}}^{(i)}(\text{logits}^{(i)})}{\partial W_{\text{lm\_head}}}, \quad g_H += \frac{\partial L_{\text{GRPO}}^{(i)}(\text{logits}^{(i)})}{\partial H}.$$

直接偏好优化(DPO)。DPO的目标函数如下:
LDPO(logits) := −E [ log σ ( β T∑t=1 ( log πθ(yw,t|x, yw,<t)πref(yw,t|x, yw,<t) − log πθ(yl,t|x, yl,<t)πref(yl,t|x, yl,<t) )) ] .
这里,$\text{logits} := \{\text{logits}_{\pi_\theta}^w, \text{logits}_{\pi_\theta}^l\}$,其中 $\text{logits}_\pi \in \mathbb{R}^{T \times C}$,$\sigma(\cdot)$ 是sigmoid函数。T是响应y的序列长度。为简化表达,我们不失一般性地假设 $y_w$ 和 $y_l$ 具有相同的长度。策略的输出由logits决定:

$$\pi(y_{t}|x, y_{<t}) = \text{softmax}(\text{logits}_{t,:})_{y_{t}}$$ <p>
与SFT和GRPO不同,由于存在非线性的log-sigmoid变换,DPO的目标函数不能分解为各个分区损失的总和。幸运的是,其梯度保留了一个可分离的结构:
∂LDPO∂W = −E(x,yw,yl)∼D [ (1 − σ (β T∑t=1 ℓ(t))) β T∑t=1 ∂ℓ(t)∂W ] .
StreamBP将logits沿序列维度划分为 $\{\text{logits}^{(i)} := \{\text{logits}^{(i)w}_{\pi_\theta}, \text{logits}^{(i)l}_{\pi_\theta}\}|i=1,\dots,D\}$,其中 $\text{logits}^{(i)}_\pi \in \mathbb{R}^{(T/D) \times C}$。基于此分区,它对 $i=1,\dots,D$ 执行以下累加:
$$\ell += \sum_{t \in \mathcal{T}_i} \ell(t), \quad g_{\text{lm\_head}} += \beta \sum_{t \in \mathcal{T}_i} \frac{\partial \ell(t)}{\partial W_{\text{lm\_head}}}, \quad g_H += \beta \sum_{t \in \mathcal{T}_i} \frac{\partial \ell(t)}{\partial H},$$
其中 $\mathcal{T}_i := \{(i-1)T < t \le iT | t \in \mathbb{Z}\}$。完成上述累加后,StreamBP执行以下原地校正以计算精确梯度:
$$g_{\mathrm{lm\_head}} \leftarrow (\sigma(\beta \ell) - 1) g_{\mathrm{lm\_head}}, \quad g_H \leftarrow (\sigma(\beta \ell) - 1) g_H$$

3.2.2 用于Transformer层的StreamBP:Attention和MLP

Transformer层的结构。现在我们将StreamBP应用于transformer层。为简化表述,我们不失一般性地忽略了如归一化层、多头机制和残差连接等组件。一个transformer层由Attention和MLP两个连续的变换组成:

$$O = f_{\text{attn}}(H_{\text{in}}), \quad H_{\text{out}} = f_{\text{MLP}}(O),$$


其中 $H_{\text{in}} \in \mathbb{R}^{T \times d}$ 和 $H_{\text{out}} \in \mathbb{R}^{T \times d}$ 分别是transformer层的输入和输出。

图2:Transformer层的StreamBP(彩色视图最佳)。StreamBP存储的激活值以橙色突出显示。与梯度检查点方法重新前向传播Hin以计算Hout反向传播所需的所有激活值不同,StreamBP一次只为一个H(i)out分区计算激活值,从而大幅降低了内存成本。
图2:Transformer层的StreamBP(彩色视图最佳)。StreamBP存储的激活值以橙色突出显示。与梯度检查点方法重新前向传播Hin以计算Hout反向传播所需的所有激活值不同,StreamBP一次只为一个H(i)out分区计算激活值,从而大幅降低了内存成本。

标准BP的激活存储。计算 $\partial H_{\text{out}} / \partial H_{\text{in}}$ 需要存储 $Q, K, V, M, O, H_{\text{up}}, H_{\text{gate}}, H_{\text{out}}$,其中 $M \in \mathbb{R}^{T \times T}$ 是注意力掩码,在使用小批量训练或滑动窗口注意力等技术时需要。使用Flash Attention的切片方法可以避免存储 $S$ 和 $P$。

StreamBP的分区计算。定义量 $H \in \mathbb{R}^{T \times d}$ 沿序列维度的分区为 $\{H^{(i)} | i=1,\dots,D\}$,其中 $H^{(i)} \in \mathbb{R}^{(T/D) \times d}$。令 $H^{(:i)} \in \mathbb{R}^{(iT/D) \times d}$ 是 $\{H^{(j)}\}_{j=1}^i$ 沿序列维度的拼接。用于transformer层的StreamBP建立在以下观察之上:

  • 性质 3.1:计算 $\partial H^{(i)}_{\text{out}}/\partial W$ 仅依赖于 $O^{(i)}, Q^{(i)}, K^{(:i)}$ 和 $V^{(:i)}$。

分区计算流程。为证明上述性质,StreamBP对每个块 $i$ 顺序执行以下分区化的Attention和MLP计算:
* 分区Attention:
$Q^{(i)} = H^{(i)}_{\text{in}} W_q, K^{(:i)} = H^{(:i)}_{\text{in}} W_k, V^{(:i)} = H^{(:i)}_{\text{in}} W_v$
$S^{(i)} = Q^{(i)} K^{(:i)\top}, P^{(i)} = \text{softmax}(S^{(i)} \odot M^{(i)}), O^{(i)} = P^{(i)} V^{(:i)} W_o \in \mathbb{R}^{(T/D) \times d}$
* 分区MLP:
$H^{(i)}_{\text{up}} = O^{(i)} W_{\text{up}}, H^{(i)}_{\text{gate}} = O^{(i)} W_{\text{gate}}$
$H^{(i)}_{\text{out}} = \sigma(H^{(i)}_{\text{gate}}) \circ H^{(i)}_{\text{up}} W_{\text{down}} \in \mathbb{R}^{(T/D) \times d}$
当分区重前向传播完成后,计算 $\partial \text{vec}(H^{(i)}_{\text{out}})/\partial \text{vec}(W)$ 所需的激活值被存储在内存中。然后,我们对 $i=1,\dots,D$ 累加梯度:

$$ \text{vec}(g_W) += \frac{\partial\text{vec}(H_{\text{out}}^{(i)})}{\partial\text{vec}(W)}^\top \frac{\partial L}{\partial\text{vec}(H_{\text{out}}^{(i)})}, \quad \text{vec}(g_{H_{\text{in}}}) += \frac{\partial\text{vec}(H_{\text{out}}^{(i)})}{\partial\text{vec}(H_{\text{in}})}^\top \frac{\partial L}{\partial\text{vec}(H_{\text{out}}^{(i)})} . $$


当累加完成后,$g_W$ 和 $g_{H_{\text{in}}}$ 分别成为权重矩阵 $W := \{W_q, W_k, W_v, W_{\text{up}}, W_{\text{gate}}, W_{\text{down}}\}$ 和 $H_{\text{in}}$ 的精确梯度。注意,每个分区前向传播都需要 $K^{(:i)}$ 和 $V^{(:i)}$。因此,我们在当前层的StreamBP过程中一次性计算 $K$ 和 $V$ 并将其缓存。我们强调,分区Attention与Flash Attention是兼容的。我们在图2中展示了用于transformer层的StreamBP。

StreamBP的内存效率。StreamBP仅需存储以下激活值:$Q^{(i)}, K, V, M^{(i)}, O^{(i)}, H^{(i)}_{\text{up}}, H^{(i)}_{\text{gate}}$ 和 $H^{(i)}_{\text{out}}$。注意,当使用分组查询注意力(GQA)【2】且组大小为G时,K和V的内存仅为Q的1/G。因此,与标准BP相比,StreamBP在激活值上的内存成本大约为1/D。

StreamBP的计算效率。对于长序列训练,transformer层中计算最昂贵的操作是计算预注意力分数 $S$。StreamBP将该操作的FLOPs减少了大约一半。具体来说,标准实现 $S = QK^\top$ 的FLOPs为 $2T^2d^2$,而StreamBP在D个分区上计算 $S^{(i)} = Q^{(i)}K^{(:i)\top}$ 的总FLOPs为 $\frac{(1+D)T^2d^2}{D}$。FLOPs的减少是因为StreamBP利用了语言模型的因果结构,在计算 $S^{(i)}$ 时使用的是 $K^{(:i)}$ 而不是整个 $K$。对于所有其他操作,StreamBP在所有分区上的总FLOPs与带检查点的标准BP相同。注意,$K$ 和 $V$ 只计算一次并被缓存。

StreamBP的HBM开销。公式(1)中的每个分区梯度计算都需要将模型权重 $W$ 从高带宽内存(HBM)加载到寄存器中进行计算,这与标准BP相比会引入额外的开销。同时,StreamBP将注意力掩码的HBM吞吐量减少了大约一半。这个开销直接取决于分区的数量 $D$。在第4.4节中,我们将通过实验研究 $D$ 如何影响BP时间和内存成本。

3.3 分布式StreamBP

尽管StreamBP天然与现代分布式训练技术如分布式数据并行(DDP)和Deepspeed ZeRO【25】兼容,但直接将这些技术应用于StreamBP效率不高,因为存在冗余的梯度和参数通信。为此,我们还开发了分布式的StreamBP。主要设计包括梯度通信和参数通信。分布式StreamBP的详细通信设计和效率分析放在附录A.3中,并在第4.3节中通过实验研究其效率。

A4 实验环境

  • 模型:实验基于Qwen 3模型系列进行评估,包括4B、8B、14B和32B。结果适用于任何因果语言模型,如Llama、Mistral和Gemma。
  • 硬件配置

    • 所有实验均使用A800-80GB GPU进行。
    • 分布式训练在一个单节点服务器上进行,该服务器配备8个通过NvLink连接的A800-80GB GPU。
    • 附录中还提到了在单个RTX3090-24GB GPU上的实验。
  • 软件配置

    • 算法使用Hugging Face Transformers库【30】实现。
    • SFT、GRPO、DPO的实现基于Hugging Face TRL库【26】。
    • 分布式训练使用Deepspeed ZeRO-2方案。
  • 数据类型与精度

    • 模型权重和梯度存储采用BF16数据类型。
    • LoRA模型训练采用混合精度方案。
  • 实验设置细节

    • 反向传播成本测量:批次大小设为1。对于32B模型,由于单张A800 GPU无法同时存储模型和完整梯度,因此注入了rank-32的LoRA适配器,并只计算适配器的梯度。
    • 训练成本测量:批次大小设为1,但GRPO的组大小设为8。LoRA的rank设为32。在使用LoRA模式时,DPO和GRPO无需存储参考模型,因为可以通过禁用训练模型中的适配器来恢复参考模型。
    • StreamBP参数:语言建模头的分区大小在所有实验中均设为100。Transformer层的分区大小在测量最大序列长度时设为500,在测量时间时设为 T/3。

A5 实验结果

4.1 反向传播成本测量

  • 内存成本 (图3):在80GB内存限制下,与带梯度检查点(ckptr)的基线相比,StreamBP能将最大序列长度提升2.8-5.5倍;与不带梯度检查点的基线相比,提升了23.4-36.3倍。重要的是,StreamBP的内存成本与序列长度呈线性关系,这意味着其序列长度的扩展能力可以直接转化为批次大小的扩展以加速训练。
    图3:在不同序列长度下,Qwen 3-8B、14B和32B模型的峰值BP内存成本测量。在80GB内存限制下,与梯度检查点相比,StreamBP将最大序列长度扩大了2.8-5.5倍。
  • 时间成本 (表1):在各种序列长度下,StreamBP的BP时间始终比梯度检查点基线更快,并且比长序列训练基线MsT【20】快得更显著。随着序列长度增加,加速效果更加明显,这与3.2.2节中关于StreamBP减少注意力分数计算量的分析相符。
    表1:不同序列长度下的BP时间成本(秒)。结果基于Qwen 3-4B模型,并对50次独立试验取平均。随着序列长度的增长,StreamBP的加速效果愈发明显,证实了我们在3.2.2节的分析。
  • 批次大小扩展 (表2):StreamBP的内存效率允许使用更大的批次大小,从而进一步加速每个样本的BP速度。
    表2:在不同批次大小下,Qwen 3-4B模型的单样本BP时间成本。序列长度为9000。StreamBP通过使用显著更大的批次大小实现了进一步的加速。

4.2 训练成本测量

  • 序列长度扩展 (表3):对于SFT、GRPO和DPO等所有目标,StreamBP都显著增加了单张A800-80GB GPU上的最大序列长度。例如,对于8B模型的SFT,StreamBP可使用的批次大小是梯度检查点基线的约4.5倍(72/15.7)。
    表3:在单张A800-80GB GPU上的最大序列长度(单位:千)。
  • 与长序列训练基线的比较 (图4, 表1):与长序列SFT基线MsT相比,在训练Qwen 3-8B LoRA模型时,StreamBP实现了大约1.7倍的最大序列长度。在时间效率方面,StreamBP的BP时间显著少于MsT,且随着序列长度增加,加速效果更明显。
    图4:与MsT的内存成本比较。

4.3 分布式训练下的效率

  • 实验设置:在Deepspeed ZeRO-2 SFT方案下测量最大序列长度和BP时间,该方案中梯度和优化器状态被分区到不同GPU上。
  • 结果 (表4, 表5):与带梯度检查点的基线相比,分布式StreamBP将最大序列长度扩大了约5-5.6倍,并实现了明显更快的BP速度。
    表4:在ZeRO-2训练方案下,Qwen 3-8B的最大序列长度(单位:千)。
    表5:在ZeRO-2训练方案下,Qwen 3-8B的单样本BP时间成本(秒)。

4.4 消融研究与额外实验

  • 分区大小的影响 (图5):实验固定transformer层的分区大小为1k、2k、5k和T/3。结果显示,当序列长度较短(<10k)时,不同分区大小的BP时间相近。但随着序列长度增加,使用过小的分区大小会因额外的HBM权重加载和重复的内核启动而引入显著开销。幸运的是,较大的分区大小仅引入边际的额外内存成本。因此,对于长序列训练,可以使用相对较大的分区大小来最大化训练效率。
    图5:Qwen 3-8B上,不同分区大小(ps)下StreamBP的时间和内存成本。
  • 额外实验总结
    1. 梯度正确性验证:实验证明StreamBP与标准BP在数学上是严格等价的(详见附录B.1)。
    2. RTX3090-24GB上的序列扩展:结果显示StreamBP能将最大序列长度扩展到15k,约为梯度检查点的4.4倍(详见附录B.2)。
    3. StreamBP内存剖面:图8中的内存剖面展示了在层重前向传播过程中,logits和激活值的内存被大幅减少,这与第3节的分析一致(详见附录C.2)。

A6 结论

本文开发了一种名为StreamBP的内存高效且精确的反向传播方法。与梯度检查点基线相比,StreamBP通过利用LLM的因果结构,需要显著更少的内存成本,并享有更快的BP时间。StreamBP可用于任何transformer LLM的长序列训练,在训练推理模型等领域有广泛应用。我们还开发了一种通信高效的分布式StreamBP以支持多GPU训练。

局限性。目前,StreamBP不支持MoE或多模态模型。尽管如此,这些可以通过简单的实现扩展来解决,因为其基本原理保持不变。此外,StreamBP的分区大小对BP时间有明显影响。这种开销可以通过采用融合的反向传播算子来减少HBM吞吐量,从而得到缓解。我们将这些方向留作未来的工作。

A7 附录

A StreamBP的附加细节

A.1 StreamBP的线性示例

线性变换。给定输入 $X = [X_1, X_2, \dots, X_D]^\top \in \mathbb{R}^{D \times m}$ 和权重矩阵 $W_1 \in \mathbb{R}^{m \times n}, W_2 \in \mathbb{R}^{n \times k}$。考虑以下两个线性变换:
Y = [Y1, . . . , YD]⊤ := XW1 ∈ RD×n Z = [Z1, . . . , ZD]⊤ := YW2 ∈ RD×k.
为方便表达,定义 $dM := \frac{\partial L}{\partial M}$ 为量 $M$ 相对于目标 $L$ 的梯度。$Y$ 和权重矩阵的梯度由下式给出:

$$d Y=(d Z) W_{2}^{\top}, \quad d W_{1}=X^{\top}(d Y), \quad d W_{2}=Y^{\top}(d Z).$$

StreamBP分解。计算 $dW_1$ 和 $dW_2$ 的标准方法需要存储中间值 $Y$ 和 $dY$。基于公式(1),上述表达式可以写成:

$$d Y_{i}=\left(d Z_{i}\right) W_{2}^{\top}, \quad d W_{1}=\sum_{i=1}^{D} X_{i}^{\top}\left(d Y_{i}\right), \quad d W_{2}=\sum_{i=1}^{D} Y_{i}^{\top}\left(d Z_{i}\right).$$


因此,可以顺序计算 $Y_i$ 和 $dY_i$,累加 $X_i^\top dY_i$ 和 $Y_i^\top dZ_i$,然后丢弃 $Y_i$ 和 $dY_i$。与标准梯度计算相比,这种方法有效地将中间值的内存成本降低到1/D,而没有引入额外的计算成本。在实践中,为了更好地利用并行计算和减少HBM负载,可以按块(chunk-wise)而不是按样本(sample-wise)处理 $X$。

线性示例实验。我们通过具体的线性示例,凭经验测量了StreamBP的内存成本,并与标准BP进行了比较。具体来说,设 $X \in \mathbb{R}^{10^6 \times 16384}, W_1, W_2 \in \mathbb{R}^{16384 \times 16384}, L = \sum_{j,k} Z_{j,k}$。StreamBP将 $X$ 划分为 $\{X^{(i)} | i=1, \dots, D\}$,其中 $X^{(i)} \in \mathbb{R}^{(10^6/D) \times 16384}$。我们在表6中展示了标准BP和StreamBP的内存和时间成本。与标准BP相比,当D=20时,StreamBP将内存成本降低了63.2%,且几乎没有时间开销。随着D的增加,中间内存成本进一步降低。
表6:不同分区数下标准BP和StreamBP的内存和时间成本。

A.2 GRPO的符号表示

GRPO目标函数。我们在此重述GRPO的目标函数:
LGRPO(logits) := E[q∼Dq,{oj}Gj=1∼πold(·|q)] [ 1G 1To G∑j=1 To∑t=1{min (πθ(j, t)πold(j, t) Aˆj,t, clip (πθ(j, t)πold(j, t) , 1 − ϵ, 1 + ϵ) Aˆj,t)− β log πθ(j, t)πref(j, t)} ] .
符号定义。为简化表述,在上述方程中,我们不失一般性地省略了GRPO的KL散度项中的补偿项。StreamBP直接适用于GRPO的KL散度项。这里,$q$ 是提示序列,$D_q$ 是提示的数据集,$o_j$ 是相对于 $q$ 的第 $j$ 个响应序列。$G$ 和 $T_o$ 分别是组大小和响应的长度。为方便表达,我们不失一般性地假设所有响应长度相同。$\hat{A}_{j,t}$ 是估计的优势函数。$\pi_\theta, \pi_{\text{old}}$ 和 $\pi_{\text{ref}}$ 分别是目标策略、旧策略和参考策略。

图6:分布式StreamBP的设计。当对一个transformer层进行反向传播时,其参数会预先被收集(gather)。在该层的StreamBP期间,不触发任何梯度或参数通信。当该层的BP完成后,梯度将在各层间进行reduce操作,参数将被分片到各个GPU上。
图6:分布式StreamBP的设计。当对一个transformer层进行反向传播时,其参数会预先被收集(gather)。在该层的StreamBP期间,不触发任何梯度或参数通信。当该层的BP完成后,梯度将在各层间进行reduce操作,参数将被分片到各个GPU上。

A.3 分布式StreamBP

设计原则。StreamBP天然与Deepspeed ZeRO等分布式训练技术兼容。然而,为确保效率,需要定制梯度通信和参数通信。分布式StreamBP的设计如图6所示。

梯度通信设计。在反向传播过程中,当所有进程的本地梯度缓冲区准备就绪后,会执行梯度平均操作(如all-reduce或reduce-scatter)。如公式(1)所示,StreamBP中一个参数的梯度直到累加操作完成后才准备就绪。为避免冗余的梯度通信,分布式StreamBP的梯度平均操作在累加之后执行。这样,分布式StreamBP的梯度通信成本与标准BP相同。

参数通信设计。当使用模型并行时需要参数通信。例如,ZeRO-3将每个参数均匀地划分到不同GPU上,并在调用与其相关的算子时收集该参数。根据公式(1),StreamBP在反向传播期间需要D次访问未分区的参数。因此,朴素地为StreamBP实现模型并行会产生(D-1)次额外的all-gather操作。为降低通信成本,分布式StreamBP在本地缓存权重,直到其梯度累加完成,从而避免了与标准BP相比额外的参数通信成本。此设计会为每个GPU引入存储单个transformer层的额外内存成本。

分布式StreamBP的效率。批次大小的选择显著影响通信成本。例如,在使用Deepspeed ZeRO-2时,使用批次大小为1和梯度累积步数为K,其反向传播通信成本大约是使用批次大小为K和梯度累积步数为1的K倍。StreamBP能够使用比标准BP大得多的批次大小,从而显著降低了通信成本。我们注意到ZeRO-2通过重叠通信和计算来减少上述开销。

B 额外实验

B.1 StreamBP的正确性验证

浮点数精度问题。我们凭经验测量了StreamBP和基线方法的梯度差异,以验证我们实现的正确性。重要的是,浮点运算存在以下由数值精度引起的结合律问题:

$$(a \oplus b) \oplus c \neq a \oplus (b \oplus c)$$


其中 $\oplus$ 表示浮点加法。因此,考虑到计算顺序不同,梯度差异不可能为零。

验证方法。为了验证正确性,我们使用标准BP的纯FP32计算出的梯度作为基准(ground truth),记为base32。然后,我们分别计算它与StreamBP和标准BP的纯BF16计算得到的梯度之间的差异。此外,我们还进行了一次FP32精度的StreamBP运行,记为stream32,以验证其与base32在数值上的等价性。具体来说,我们定义绝对误差和相对误差为:

$$\text{Er}_{\text{abs}}(g) = \frac{1}{n} \sum\nolimits_i |g_i^{\text{base32}} - g_i|, \quad \text{Er}_{\text{rel}}(g) = \frac{1}{n} \sum\nolimits_i \frac{|g_i^{\text{base32}} - g_i|}{|g_i^{\text{base32}} + 10^{-10}|}.$$

结果分析 (表7)。表7揭示了两个关键发现:1) stream32的微小偏差($Er_{\text{abs}} \sim 10^{-9}$)验证了其与base32的数学等价性;2) stream16与base16相比,相对偏差≤0.04%,表明StreamBP与标准的bfloat16计算相比没有精度损失。
表7:梯度精度分析。加粗值证实了StreamBP在float32下的数值正确性。几乎相同的base16/stream16对表明StreamBP在bfloat16下保持了精度。Transformer层聚合了自注意力、前馈和归一化层的梯度。

B.2 在单个RTX3090-24GB上的序列扩展

实验结果 (图7)。我们在图7中展示了训练Qwen 3-8B LoRA模型的内存成本。在24GB的内存预算下,StreamBP允许序列长度扩展到15k,这比梯度检查点基线大了4.4倍。
图7:在单个RTX3090-24GB GPU上,Qwen 3-8B与LoRA(rank=32)的峰值BP内存成本测量。

C 内存剖面分析

C.1 梯度检查点的内存剖面解释

图1的详细解释。首先,“模型参数”部分存储了BF16模型参数,持续占用近8GB内存。然后,在第一次前向传播过程中,梯度检查点逐渐存储检查点层的输入,因此内存消耗逐渐增加。在第一次前向传播结束和第一次反向传播开始时,计算了FP32 logits及其梯度,由于Qwen 3的词汇表非常大,这会突然消耗大量内存。如图中棕色矩形所示,在计算完lm_head(与embedding层绑定)的梯度后,这部分巨大的内存占用被释放。在第一次反向传播期间,BP计算并存储所有参数的梯度,内存消耗持续增加。一个有趣的现象是,黄色三角形内的检查点层激活值的内存使用量在第一次反向传播过程中减少。这是因为一旦当前权重的梯度被计算出来,存储的相应激活值就会被删除,因为它们不再需要。另一个有趣的观察是,第一次反向传播期间存在一些三角形的凸起,我们在第二次反向传播过程中用一个子图来解释。我们明确展示了梯度检查点的重前向传播过程,并显示了重前向传播的层激活值在相应梯度计算后被删除,从而形成一个三角形的凸起。一个关于图1的补充说明是,当前“Transformers”包的实现在第二次反向传播过程中存储了一个额外的“BF16 logits copy”和一个额外的绑定embedding和“lm_head”的梯度,这些可能可以被优化,因为它们没有被使用。

C.2 StreamBP的内存剖面

StreamBP内存剖面 (图8)。StreamBP的内存剖面如图8所示。与图1中梯度检查点的剖面相比,StreamBP通过沿序列维度划分logits,极大地降低了峰值内存。在重前向传播激活值方面的第二个内存峰值也大大降低,现在仅略高于存储键(key)和值(value)状态的内存。
图8:在与图1相同设置下的StreamBP内存剖面。logits和transformer层的分区大小分别设置为100和500。

D 实验设置

实现。我们使用Hugging Face Transformers库【30】实现我们的算法。详细的实验设置如下。

反向传播成本。我们解耦了优化过程,纯粹关注BP过程中的时间和内存成本。该结果可作为在给定序列长度下训练语言模型的最低要求。具体来说,在批次大小为1的情况下,我们使用单个A800-80GB GPU测量了Qwen 3-8B、14B和32B模型的BP内存成本。对于32B模型,我们注入了rank-32的LoRA适配器,并只计算适配器的梯度,因为单个A800 GPU无法同时存储模型和完整的梯度。我们采用BF16数据类型来存储模型权重和梯度。语言建模头的分区大小在所有实验中均设为100。transformer层的分区大小在测量最大序列长度时设为500,在测量时间时设为T/3。

训练成本。我们分别使用SFT、GRPO和DPO的目标,测量了训练4B、8B、14B和32B模型的最大序列长度。我们的实现基于Hugging Face TRL库【26】。我们在全参数训练中采用纯BF16数据类型,在LoRA模型训练中采用混合精度方案。LoRA的rank设为32。重要的是,当使用LoRA模式时,DPO和GRPO中无需存储参考模型,因为可以通过禁用训练模型中的适配器来恢复参考模型。批次大小设为1,GRPO的组大小设为8。我们还比较了StreamBP与长序列训练基线方法MsT的内存成本,以展示StreamBP的有效性。语言建模头的分区大小在所有实验中均设为100。transformer层的分区大小在测量最大序列长度时设为500,在测量时间时设为T/3。所有实验均在单个A800-80GB GPU上进行。

分布式训练。在附录A.3的通信高效设计下,我们测量了分布式StreamBP在Deepspeed ZeRO-2训练范式下的最大序列长度和时间成本,其中梯度和优化器状态被划分到各个GPU上。我们的评估在一个单节点服务器上进行,该服务器配备8个通过NvLink连接的A800-80GB GPU。