Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention
作者: Haiquan Qiu, Quanming Yao
机构: Department of Electronic Engineering, Tsinghua University
A1 主要贡献
本文针对一个长期存在且未解决的Transformer训练失败案例——在低精度设置下使用Flash Attention进行训练时会导致灾难性的损失爆炸——提供了首个机理层面的解释。
-
核心问题:尽管低精度数值格式(如BF16)在提升计算效率方面潜力巨大,但在训练Transformer模型,特别是使用Flash Attention时,常常会遇到训练不稳定的问题,具体表现为损失突然爆炸。这一问题在社区中被多次报告,但其根本原因一直不明确,导致解决方案多为经验性的临时修补,缺乏理论指导。
-
研究目标:本文旨在深入剖析这一低精度训练失败的根本原因,建立从数值误差到最终损失爆炸的清晰因果链,并基于此提出一种有理论依据的解决方案。
-
创新点与发现:
-
揭示了失败的双重根源:本文通过深入分析发现,训练失败并非随机伪影,而是由两个相互交织的现象共同导致的:
- 相似低秩表示的出现:在训练过程中,不同训练步骤和不同token的注意力机制内部会涌现出结构相似的低秩表示。
- 有偏舍入误差的累积效应:低精度算术(特别是BF16)固有的舍入误差存在系统性偏差,导致误差朝一个方向持续累积。
-
阐明了误差累积机制:本文论证了有偏的舍入误差如何充当低秩表示的系数,使得这些低秩表示作为一种有偏的梯度更新累加到模型权重上。这种累积效应会异常地增大全连接层权重和激活值的谱范数,最终破坏训练动态,导致损失爆炸。
- 提出了验证性解决方案:为了验证上述分析,本文对Flash Attention算法进行了一个微小的修改,旨在减轻舍入误差的偏差。这个简单的改动成功地稳定了训练过程,不仅证实了本文分析的正确性,也为解决这个顽固问题提供了一个实用的方案。
-
下图展示了本文的分析流程,从追踪训练失败的因果链(蓝色框)逆向识别其根本原因。
A3 背景知识
2.1 低精度训练
低精度训练在现代深度学习中的核心地位。低精度训练是现代深度学习的基石,它通过使用比标准32位单精度(FP32)更少比特的数值格式来表示权重、激活和梯度,从而减少内存使用并加速计算,使得开发越来越大的模型成为可能【索引20,Deepseek-v3 technical report+2024,【索引39,Training llms with mxfp4+2025】,【索引12,Low-precision training of large language models: Methods, challenges, and opportunities+2025】。混合精度训练【索引22,Mixed precision training+2017】在实践中被广泛采用,它将FP16或bfloat16(BF16)等16位格式用于大多数计算,同时保留FP32的主权重副本以维持准确性。尽管FP16提供更高的精度,但其有限的动态范围常导致梯度下溢,需要使用损失缩放等技术。相比之下,最初为谷歌TPU开发并现已得到广泛支持的BF16,提供了与FP32相同的动态范围,使其对下溢问题更具鲁棒性,并成为训练大型语言模型的首选【索引17,A study of bfloat16 for deep learning training+2019】,【索引42,Bfloat16: The secret to high performance on cloud tpus+2019】。然而,BF16降低的精度仍可能引入数值误差,导致训练失败,这也是本文的关注焦点。
Bfloat16 (BF16) 格式及其舍入误差。bfloat16 (BF16) 格式是一种16位浮点表示法,包含1个符号位、8个指数位和7个有效数字位。它与32位单精度(FP32)的动态范围相匹配,但精度较低,使其成为在深度学习中平衡计算和数值范围的热门选择。两个BF16数相加涉及对齐它们的指数、相加有效数字、归一化总和,并将结果四舍五入以适应7位小数。这最后一步的舍入,通常是“四舍五入到最近的偶数”,是误差的主要来源之一。虽然这种舍入方法被设计为对随机数据是无偏的,但对具有特定分布的数据进行一系列操作可能会导致有偏的舍入误差。这种误差在一个方向上的累积是导致在低精度设置中观察到的训练失败的关键因素。关于BF16加法的详细信息见附录B。
2.2 FLASH ATTENTION
Flash Attention的核心思想与内存优化。Flash Attention (FA) 【索引7,FlashAttention: Fast and memory-efficient exact attention with IO-awareness+2022】,【索引6,FlashAttention-2: Faster attention with better parallelism and work partitioning+2024】,【索引37,Flashattention-3: Fast and accurate attention with asynchrony and low-precision+2024】是一种I/O感知的精确注意力算法,旨在克服标准注意力的内存瓶颈。标准注意力定义为 $O = \text{softmax}(\alpha QK^\top)V$,需要实例化 $N \times N$ 的注意力得分矩阵 $S = \alpha QK^\top$,导致相对于序列长度N的内存复杂度为 $O(N^2)$。Flash Attention通过将输入矩阵 $Q, K, V \in \mathbb{R}^{N \times d}$ 划分为块并迭代处理它们,将内存复杂度降低到 $O(N)$。这些块从高带宽内存(HBM)加载到快速的片上SRAM中,从而最小化了昂贵的内存传输。
Flash Attention 2的前向与后向传播机制。在本文中,我们重点分析Flash Attention 2。FA的前向传播使用在线softmax方法(见附录算法2)计算输出O和对数-和-指数统计量L。它通过Q的块(外循环)和K、V的块(内循环)进行迭代。对于每个查询块Qi,它维护运行统计量:最大得分$m_i$和归一化因子$l_i$。在每个内循环步骤中,它计算未归一化的注意力得分$\bar{P}_i^{(j)} = \exp(S_i^{(j)} - m_i^{(j)})$,并通过累加乘积$\bar{P}^{(j)}V$来更新一个未归一化的输出。这个累加过程是我们分析的关键焦点。在遍历所有键/值块之后,最终的输出块Oi被正确归一化。这种分块策略避免了实例化完整的 $N \times N$ 得分矩阵。反向传播(见附录算法3)利用了相同的分块策略。它首先计算一个关键的中间项 $\delta = \text{rowsum}(dO \circ O)$,这是我们研究的核心。然后,它动态地重新计算注意力得分,以计算得分的梯度,$dS_i^{(j)} = P_i^{(j)} \circ (dP_i^{(j)} - \delta_i)$,其中$dP_i^{(j)} = dO_iV_j^\top$。最终的梯度dQ、dK、dV是分块累加的。这种方法在两个传播过程中都保持了I/O效率。
A2 方法细节
3 导致FLASH ATTENTION不稳定的根本原因
我们首先在3.1节介绍失败案例。在3.2节,我们将数值误差的来源缩小到Flash Attention内部。在3.3节的进一步分析揭示,失败源于两个因素的结合:低秩表示的出现和BF16算术固有的有偏舍入误差的累积。这种从根本原因到损失爆炸的完整过程如图1所示。
3.1 低精度FLASH ATTENTION的失败案例
复现一个长期存在的低精度Flash Attention训练失败案例。我们的研究目标是一个有据可查且持续存在的失败案例:在使用Flash Attention以BF16精度训练生成式预训练Transformer 2(GPT-2)模型时发生的灾难性损失爆炸【索引25,Discussion on nanogpt issue #303: Gradient explosion when training with bfloat16+2023】,【索引26,Discussion on nanogpt issue #524: Training loss becomes nan when using bfloat16+2024】,【索引27,Discussion on nanogpt issue #554: Loss diverges when using bfloat16 with flash attention+2024】。这个被报告了两年多的失败案例,表现为在几千个训练步骤后损失突然爆炸(见图8)。虽然像切换回标准注意力或使用更高精度(FP32)这样的经验性解决方法可以稳定训练,但它们是以牺牲效率为代价的。这种不稳定性并非孤立事件;更广泛的社区在训练大型语言模型时也观察到了类似的失败【索引18,Kimi k2: Open agentic intelligence+2025】,【索引32,Qwen3 technical report+2025】。这些失败通常在经验上与权重谱范数过大、激活值过大【索引46,A spectral condition for feature learning+2023】,【索引36,Methods of improving llm training stability+2024】和注意力下沉【索引44,Efficient streaming language models with attention sinks+2023】等现象联系在一起,从而催生了一系列修复方法,包括QK归一化【索引14,Query-key normalization for transformers+2020】,【索引32,Qwen3 technical report+2025】、QK裁剪【索引18,Kimi k2: Open agentic intelligence+2025】和门控注意力【索引31,Gated attention for large language models: Non-linearity, sparsity, and attention-sink-free+2025】,【索引32,Qwen3 technical report+2025】。尽管有这些干预措施,但对根本原因的基本理解仍然难以捉摸。从数值误差到损失爆炸之间缺乏清晰的因果链,使得社区依赖于临时补丁而非原则性解决方案,这阻碍了鲁棒低精度训练的进展。本文通过复现失败、剖析其根本原因并提出一个实用、有原则的解决方案,首次提供了机理性的解释。
实验设置:复现失败案例。为了复现这一失败,我们采用了一个GPT-2架构,包含12个层、12个注意力头、768的嵌入维度和1024的上下文长度。该模型在OpenWebText数据集上进行预训练【索引10,Openwebtext corpus+2019】。为了实现确定性的可复现性,我们没有使用标准的随机数据加载器,而是记录并重用了一次导致失败的初始运行中的数据批次的确切序列。这确保了所有后续实验都以相同的顺序处理相同的数据,从而将失败与数据相关的随机性隔离开来。
训练超参数。我们使用AdamW优化器训练模型,其中$\beta_1 = 0.9$,$\beta_2 = 0.95$,权重衰减为零。学习率遵循余弦退火策略,包含2000次迭代的线性预热期,峰值达到$1 \times 10^{-3}$,然后衰减到$1 \times 10^{-5}$。我们应用了全局梯度裁剪,最大范数为1.0。训练在4块NVIDIA A100 (80GB) GPU上进行,使用了PyTorch的分布式数据并行(DDP)模块。我们使用自动混合精度,前向传播使用BF16,后向传播使用FP32。每个GPU处理的微批次大小为32,经过4步梯度累积,每个优化步骤的有效全局批次大小为524,288个token。
3.2 在FLASH ATTENTION内部隔离失败源头
定位失败源头的实验设计。为了精确定位Flash Attention内部的失败源头,我们进行了一系列有针对性的实验。我们系统地修改了该算法——禁用了分块(tiling)、选择性地用标准实现替换Flash Attention,并在高精度下执行关键计算——以缩小不稳定性的潜在原因范围。为了加速这一分析过程,我们监控了如权重谱范数等经验性指标,以快速识别失败的配置。
分块策略并非失败根源。为了确定Flash Attention中的分块处理是否是失败的原因,我们进行了一项实验,通过将块大小设置为等于序列长度来禁用分块。这迫使算法一次性处理整个矩阵。训练过程仍然失败,导致了同样的损失爆炸。这一发现排除了分块策略是问题原因的可能性。因此,在所有后续实验中,我们都使用这种非分块的设置来简化分析,并专注于核心的数值计算。
失败源于单一层。我们首先分析了所有层权重的谱范数【索引46,A spectral condition for feature learning+2023】,【索引36,Methods of improving llm training stability+2024】。分析显示,第二层注意力内部出现了异常的峰值(见图9)。我们通过两个有针对性的实验证实了这一发现:(1)仅在第2层使用Flash Attention足以复现训练失败;(2)在第2层用标准注意力替换Flash Attention,同时在所有其他层保留它,则恢复了训练的稳定性。这些结果确切地指明第2层的Flash Attention是失败的根源。因此,后续分析将集中于该模块以剖析失败机制。
失败与$\delta$的计算有关。Flash Attention的反向传播过程为了计算效率使用了$\delta = \text{rowsum}(dO \circ O) \in \mathbb{R}^N$这一项。一个数学上等价的替代公式是计算$\delta = \text{rowsum}(dP \circ P)$,其中$dP = dOV^\top$。我们发现,用这个替代公式替换高效的计算方式后,训练恢复了稳定。这个实验表明,在BF16中计算O时引入的数值误差很可能是失败的主要来源,因为避免了训练失败的替代公式等价于使用在FP32中计算的O。
O中的数值误差是失败的根源。基于$\delta$的计算至关重要的发现,我们进一步将误差源头隔离到低精度输出矩阵$O_{lp}$上,即$\delta_{lp} = \text{rowsum}(dO \circ O_{lp})$。我们进行了两个关键实验。首先,我们没有使用前向传播中的低精度O来计算$\delta$,而是在反向传播中将其重新计算为FP32下的$PV$;这一改变稳定了训练。其次,我们发现在前向传播中以高精度(FP32)计算O,即$\delta_{hp} = \text{rowsum}(dO \circ O_{hp})$,同时保持所有其他操作为BF16,也恢复了稳定性。这些证据确凿地证明,在O的BF16计算过程中引入的数值误差是导致失败的直接原因(参考声明1)。
失败定位于特定的注意力头。为了进一步缩小失败的来源,我们通过追踪查询投影矩阵($W_Q$)的谱范数【索引46,A spectral condition for feature learning+2023】,【索引36,Methods of improving llm training stability+2024】来分析单个注意力头,如图3所示。结果显示,少数几个头的谱范数异常大。我们通过选择性地为这些异常头(1, 7, 8, 9, 11和12)以高精度计算输出O来确认它们的作用,这足以恢复训练的稳定性。由于第8个头的谱范数最大,我们后续的分析将集中于这个头,以剖析精确的失败机制。
3.3 训练失败的根本原因
本节的研究揭示了训练失败的两个相互关联的根本原因。在3.3.1节中,我们展示了低精度$\delta_{lp}$如何通过有偏的权重更新导致训练失败,并发现偏差源于相似低秩表示R的出现,其系数$(\delta_{lp} - \delta_{hp})[T]$偏向于正值,导致误差累积而非抵消。在3.3.2节中,我们将这些正系数$(\delta_{lp} - \delta_{hp})[T]$的来源追溯到$P\bar{V}$乘积中BF16加法固有的有偏舍入误差。
3.3.1 原因1:相似的低秩矩阵导致权重更新产生偏差
本节将训练失败追溯到有偏的权重更新。我们首先分别分析了用$\delta_{hp}$和$\delta_{lp}$计算的高精度和低精度梯度之间的差异。然后我们发现低秩表示和有偏的$(\delta_{lp} - \delta_{hp})[T]$导致了损失爆炸。
高低精度梯度误差分析。为了理解数值误差如何传播到梯度中,我们分析了查询矩阵$Q$的高精度(hp)和低精度(lp)梯度$dQ$之间的差异。梯度$dQ$由注意力分数的梯度$dS$计算得出,即$dQ = dSK$。分数的梯度由$dS = \alpha P \circ (dP - \delta)$给出,其中根据3.2节的分析,$\delta = \text{rowsum}(dO \circ O)$是高精度和低精度反向传播之间唯一不同的项,而$\alpha$是注意力中的一个缩放因子。
高精度和低精度查询梯度之间的差异可以推导如下:
$$\begin{aligned} \begin{aligned} d\mathbf{Q}_{hp} - d\mathbf{Q}_{lp} &= (d\mathbf{S}_{hp} - d\mathbf{S}_{lp})\mathbf{K} \\ = (\alpha\mathbf{P} \circ (d\mathbf{P} - \boldsymbol{\delta}_{hp}) &- \alpha\mathbf{P} \circ (d\mathbf{P} - \boldsymbol{\delta}_{lp}))\mathbf{K} = (\alpha\mathbf{P} \circ (\boldsymbol{\delta}_{lp} - \boldsymbol{\delta}_{hp}))\mathbf{K} \\ = \alpha \cdot \text{diag}(\boldsymbol{\delta}_{lp} &- \boldsymbol{\delta}_{hp})(\mathbf{PK}). \end{aligned} \end{aligned}$$在最后一步,我们将逐行缩放操作 $P \circ (\delta_{lp} - \delta_{hp})$(遵循广播规则)表示为矩阵乘法,其中 $\text{diag}(\delta_{lp} - \delta_{hp})$ 是一个对角矩阵,其对角线元素是向量差 $\delta_{lp} - \delta_{hp}$ 的元素。这个公式表明,梯度误差与$\delta$中的误差成正比,并受到$PK$项的调节。
查询投影矩阵的梯度误差。查询投影矩阵$W_Q$的梯度$dW_Q$由输入特征$X$和查询梯度$dQ$的外积给出。$W_Q$的高精度(hp)和低精度(lp)梯度之间的差异可以表示为:
$$\begin{aligned} \begin{aligned} d\mathbf{W}_{hp}^Q - d\mathbf{W}_{lp}^Q &= (d\mathbf{Q}_{hp} - d\mathbf{Q}_{lp})^\top\mathbf{X} = \alpha(\mathbf{PK})^\top\mathrm{diag}(\boldsymbol{\delta}_{lp} - \boldsymbol{\delta}_{hp})\mathbf{X}, \\ &= \alpha\sum\nolimits_{T=1}^N(\boldsymbol{\delta}_{lp} - \boldsymbol{\delta}_{hp})[T] \cdot (\mathbf{PK})[T]^\top\mathbf{X}[T], \end{aligned} \end{aligned}$$其中$(PK)[T]$和$X[T]$是它们各自矩阵的第T行向量。这个方程表明,总梯度误差是秩为1的矩阵的加权和,权重由$\delta$中的误差给出。
相似的低秩更新导致训练失败。在图4中,$PK$的行(面板a, d)和$X$的行(面板b, e)在不同的训练步骤和token位置上表现出很强的结构相似性。这意味着最终得到的秩-1矩阵 $(PK)[T]^\top X[T]$ 彼此之间也高度相似。例如,图4(面板c, f)分别展示了在训练步骤6610和6619时,对于token 50和718的这种相似性。由于这些秩-1误差分量在结构上是一致的,我们可以将总梯度差异近似为
其中$R$表示在不同token和训练步骤中出现的共同低秩结构。
误差累积的根源:有偏的系数。公式(3)显示,低秩误差方向R的累积受标量项$\sum_{T=1}^{N}(\delta_{lp} - \delta_{hp})[T]$的控制。如果这个和偏向于非零值,那么不同训练步骤之间的误差将会累积而不是相互抵消。我们追踪了在训练失败前一系列训练步骤(6580到6680)中$\sum_{T=1}^{N}(\delta_{lp} - \delta_{hp})[T]$的累积和,如图5(a)所示。该图显示这个和始终为正,表明存在系统性偏差。这种偏差导致在低秩方向R上的误差在每个训练步骤中不断复合。由于R在不同步骤中也是相似的,这最终会破坏权重更新,增加谱范数(图9)和激活值【索引46,A spectral condition for feature learning+2023】,【索引36,Methods of improving llm training stability+2024】,并导致训练失败(参考声明2)。下一节将通过分析训练步骤6619时的权重和梯度来找出这个正偏差的根本原因,该步骤是图5(a)中确定的一个显著正贡献点。
3.3.2 原因2:有偏的舍入误差导致$(\delta_{lp} - \delta_{hp})[T]$为正值
本节研究了$(\delta_{lp} - \delta_{hp})[T]$中正偏差的来源。我们将此误差追溯到$dO$与$O_{lp} - O_{hp}$中数值差异的相互作用,而后者本身源于$P\bar{V}$计算过程中BF16加法中的有偏舍入误差。
定位$(\delta_{lp} - \delta_{hp})[T]$中的大误差。我们首先研究$\sum_{T=1}^{N}(\delta_{lp} - \delta_{hp})[T]$中正偏差的来源。如3.2节所分析,$\delta$中的误差源于上游梯度$dO$与低精度输出中数值误差$O_{lp} - O_{hp}$的乘积。为了剖析这一点,我们关注一个误差分量$(\delta_{lp} - \delta_{hp})[T]$为正的token位置$T = 718$。
误差来源:dO与输出误差的符号相关性。在图5(b)和(c)中,我们观察到梯度$dO[T, :]$与输出误差$O_{lp}[T, :] - O_{hp}[T, :]$之间存在强烈的符号相关性,特别是在特征维度20和29(在其他token中也观察到此现象)。在这些维度中,$dO$和输出误差$O_{lp} - O_{hp}$都持续为负。这种一致性确保了它们的乘积(对$\delta$中的误差有贡献)是正的。输出误差倾向于为负($O_{lp}[T, i] < O_{hp}[T, i]$)这一事实表明,O的低精度计算系统性地偏向于更负的值。因此,我们接下来的分析重点是找出这种计算偏差的来源。
偏差源于未归一化输出$\bar{O}$的计算。输出$O$是通过一个中间的未归一化输出$\bar{O}$计算得出的。该计算涉及一个安全的softmax,然后是矩阵乘法和归一化:
$$ \bar{\mathbf{P}} = \exp(\mathbf{S} - \text{rowmax}(\mathbf{S})), \qquad \bar{\mathbf{O}} = \bar{\mathbf{P}}\mathbf{V}, \qquad \mathbf{O} = \bar{\mathbf{O}} / \text{rowsum}(\bar{\mathbf{P}}). $$进一步的实验将失败的根源精确定位到未归一化输出$\bar{O} = P\bar{V}$的计算上。我们发现,仅将这个乘积以FP32计算就足以稳定训练。为了理解这个偏差的来源,我们检查了单个元素$\bar{O}[T, i]$(在我们的分析中,特征索引$i = 20$)的低精度和高精度计算之间的差异:
$$\mathbf{\bar{O}}_{lp}[T, i] - \mathbf{\bar{O}}_{hp}[T, i] = (\mathbf{\bar{P}}_{lp}[T, :]\mathbf{V}[:, i])_{lp} - (\mathbf{\bar{P}}_{hp}[T, :]\mathbf{V}[:, i])_{hp}$$这里的输入$\bar{P}$和$V$本身是先前BF16运算的结果。具体来说,下标 $(\cdot)_{lp}$ 表示使用FP32计算点积,最终结果四舍五入到BF16,而 $(\cdot)_{hp}$ 表示完全在FP32中计算。
误差累积于$\bar{P}[T, t]=1$的token位置。为了理解误差$\bar{O}_{lp}[T, i] - \bar{O}_{hp}[T, i]$如何系统性地变为负值,我们在图6(b)和(c)中绘制了随着token位置求和进展的累积误差:
$$\mathbf{\bar{O}}_{\text{error}}(t) = \left( \sum\nolimits_{t'=1}^{t} \mathbf{\bar{P}}[T, t'] \mathbf{V}[t', i] \right)_{lp} - \left( \sum\nolimits_{t'=1}^{t} \mathbf{\bar{P}}[T, t'] \mathbf{V}[t', i] \right)_{hp}$$该图显示误差以显著的负阶跃方式累积。这些阶跃发生在对应注意力概率$\bar{P}[T, t]$恰好为1的token位置t(在其他token位置也观察到此现象)。当预softmax得分$S[T, t]$是其所在行的最大值时,就会发生这种情况,导致$\exp(S[T, t] - \max(S[T, :]))$计算为$\exp(0) = 1$。
有偏舍入误差分析。此外,偏差源于这些单位值与值矩阵V的分布之间的相互作用。如图6(a)所示,对于有问题的特征维度i = 20,V[:, i]的值主要为负。当$\bar{P}[T, t] = 1$时,乘积$\bar{P}[T, t]V[t, i]$就是$V[t, i]$,一个负的BF16数。当两个这样的负BF16数相加时,就会出现系统性误差。在浮点运算中,两个同号数相加可能导致得到的有效数字溢出(例如,-1.xxxx + -1.yyyy = -10.zzzz),需要右移并增加指数以重新归一化。被移出7位BF16小数部分的位决定了舍入方向。当两个负数相加时,舍入操作(例如,四舍五入到最近值)会引入一致的偏差。
舍入偏差的机制。为了说明这种舍入偏差是如何发生的,考虑两个导致溢出的有效数字相加,需要右移进行归一化。被移出的位(舍入位)决定了舍入方向。我们展示了最后两位2位数字所有可能的加法,其中绿色位代表将被移出的位(舍入位):
$$\begin{aligned} \begin{array}{lllll} 00+00=00 & 00+01=01 & 00+10=10 & 00+11=11 & 01+01=10 \\ 01+10=11 & 01+11=100 & 10+10=100 & 10+11=101 & 11+11=110 \end{array} \end{aligned}$$由于求和是在FP32中执行的,小数的低位比特累积可以激活粘滞位(sticky bit)。因此,当后续的BF16数被加上时,这会强制向上舍入。因此,舍入位为1表示需要向上舍入。因为操作数$\bar{P}[T, t]V[t, i]$是负数且指数较大(因为$\bar{P}[T, t] = 1$并不会使$\bar{P}[T, t]V[t, i]$变小),向上舍入的误差被放大,导致一个负误差。当舍入位为0时,结果向下舍入,引入一个正误差。此外,与向上舍入产生的负误差相比,正误差更小,因为相加的值通常非常小。这种不对称性导致舍入误差由向上舍入主导,从而产生我们在分析中观察到的负舍入误差。$\bar{O}$计算中的这种系统性负偏差是训练失败的最终根源。
备注 1。对于$\bar{P}[T, t] < 1$,其与$V[t, i]$的乘积的最低16位不为零。当舍入到BF16时,这不会引入有偏的舍入误差。
具体案例分析:BF16加法中的舍入误差。为了具体说明,我们现在分析导致图6(c)中巨大负误差跳变的特定BF16数值加法。由于求和是在FP32中进行的,第一个BF16值在第二个BF16值加入之前,会与来自其他token的一些小值相加。这可以激活粘滞位,从而在添加第二个BF16值时强制向上舍入。对于这个例子,我们从FP32表示开始。第一个操作数是先前项的累积和,其中包括一个BF16值(11000000000110100000000000000000; -2.40625)加上一个激活了粘滞位的小残差(∼ 0.00087)。第二个操作数是另一个BF16值。它们的FP32表示为:
$$\begin{aligned} \begin{aligned} & 110000000000110100000111000101110 \ (-2.4071154594421387) \\ & 110000000000100110000000000000000 \ (-2.296875) \end{aligned} \end{aligned}$$由于它们的指数相同,加法在它们的有效数字上进行:
$$\begin{aligned} \begin{aligned} & (-1.00110100000111000101110) + (-1.00100110000000000000000) \\ = & -10.01011010000111000101110 \end{aligned} \end{aligned}$$结果溢出了有效数字的格式,需要进行归一化。有效数字右移一位,指数加一:
有效数字: 10.0101101 → 1.00101101 指数: 10000000 → 10000001
在FP32中的精确结果是 1 10000001 00101101000011100010111,对应于-4.703990459442139。为了将此结果存储为BF16,必须将其舍入到7个小数位。有效数字是1.00101101。7位小数部分是0010110。因为舍入位是1,并且舍入位之后还有非零位,所以“四舍五入到最近的偶数”规则会向上舍入(即在小数的最后一位加1):
最终的BF16结果是 1100000010010111,代表-4.71875。这个舍入后的值比真实和-4.703990459442139更负。这次加法引入的误差是-0.014759540557861328。当这种舍入事件在$P\bar{V}$乘积的许多加法中系统性地发生时,误差会累积,在O中产生负偏差,最终导致训练不稳定。
4 实验:减轻舍入误差中的偏差
解决方案:修改Softmax以缓解舍入偏差。我们的分析将失败追溯到$P\bar{V}$计算中的有偏舍入误差。当预softmax分数$S$的某一行中存在多个相同的最大值,导致$\bar{P}$中相应的元素变为1时,就会发生这种情况。为了验证我们的发现,我们修改了softmax,以检测这种特定情况并调整归一化,确保$\bar{P}$的所有元素都严格小于1。这可以防止有偏舍入并恢复训练的稳定性。
为了防止有偏舍入,我们对安全softmax计算引入了一个有针对性的修改。核心思想是,仅当得分矩阵S的一行包含多个相同的最大值时,才动态调整归一化因子m。这种调整确保了指数函数的参数$S - m$在这些最大值位置上变为严格负值,从而保证了$\bar{P} = \exp(S - m)$的所有元素都小于1。一个简单的方法,比如减去一个小的固定常数,是不够的,因为它会引入新的系统性舍入误差(见附录C);因此,需要一种动态最大值策略。我们的修改如下所示。
$$\begin{aligned} \begin{aligned} \mathbf{r}_m &= \text{rowmax}(\mathbf{S}), \quad \mathbf{r}_s = \text{rowsum}(\mathbf{S} \equiv \mathbf{r}_m) \\ \mathbf{m}' &= \text{where}(\mathbf{r}_m > 0 \wedge \mathbf{r}_s > 1, \beta \mathbf{r}_m, \mathbf{r}_m), \beta > 1 \\ \mathbf{m} &= \text{where}(\mathbf{r}_m < 0 \wedge \mathbf{r}_s > 1, 0, \mathbf{m}') \\ \bar{\mathbf{P}} &= \exp(\mathbf{S} - \mathbf{m}) \end{aligned} \end{aligned}$$动态最大值策略的细节。这种修改阻止了$\bar{P}$的元素变为1。如果一行的最大值$r_m$为正且重复出现,归一化因子将被调整为$m = \beta r_m$(其中$\beta > 1$)。这使得指数中的新最大值变为$-(\beta - 1)r_m$,这是一个严格的负数。如果$r_m$为负且重复出现,我们设置$m = 0$,这也确保了指数的最大值保持为负。在这两种情况下,这种调整都保证了$\max(S - m) < 0$,因此$\max(\bar{P}) < 1$,从而防止了导致有偏舍入的条件。
方法的有效性和稳定性。至关重要的是,这种修改在精确算术中与标准注意力是数学等价的,因为它利用了softmax函数的移位不变性($\text{softmax}(z) = \text{softmax}(z - c)$)。我们的方法只是选择一个不同的行常数c来确保数值稳定性。在我们的实验中,我们设置$\beta \in [2, 8]$,因为较小的值有风险使结果舍入回1,而较大的值有风险导致下溢。此修改被集成到标准的Flash Attention分块算法中(算法1中用洋红色标出的行),而无需改变反向传播过程。如图7所示,这个简单的改变($\beta = 7$)成功地稳定了训练,证实了我们的分析。更多设计细节请参见附录C。
A4 实验环境
- 数据集: 使用 OpenWebText 数据集【索引10,Openwebtext corpus+2019】进行预训练。为了确保确定性和可复现性,实验中使用了固定的数据批次序列。
-
模型架构: 采用 GPT-2 架构,具体参数如下:
- 层数:12
- 注意力头数:12
- 嵌入维度:768
- 上下文长度:1024
-
硬件配置:
- GPU: 4 x NVIDIA A100 (80GB)
- 连接: GPU之间通过分布式数据并行(DDP)模块连接。
-
软件配置:
- 代码实现: 基于 PyTorch 实现。
- 依赖库: 使用了PyTorch的自动混合精度(AMP)和分布式数据并行(DDP)模块。
- 训练框架: 实验复现了 nanoGPT 社区报告的失败案例。
- 精度设置: 前向传播使用 BF16,后向传播使用 FP32。
- 优化器: AdamW,参数为 $\beta_1 = 0.9$, $\beta_2 = 0.95$,权重衰减为0。
- 学习率: 采用余弦退火策略,峰值为 $1 \times 10^{-3}$,预热2000步,最终衰减至 $1 \times 10^{-5}$。
- 批处理: 每个GPU的微批次大小为32,梯度累积步数为4,有效全局批次大小为524,288个tokens。
A4 实验结果
-
实验一:复现低精度Flash Attention训练失败
- 实验内容:按照3.1节所述配置,使用BF16精度和Flash Attention训练GPT-2模型。
- 实验结果:训练在约1万步后出现损失突然爆炸,验证损失急剧上升并稳定在一个高水平,无法收敛。这成功复现了社区报告的失败案例(图2,红色曲线)。作为对比,使用高精度Flash Attention的训练过程则保持稳定,损失平稳下降(图2,蓝色曲线)。
-
实验二:隔离失败源头
-
实验内容:通过一系列控制变量实验来定位问题。
- 禁用分块(Tiling):将Flash Attention的块大小设为序列长度。
- 逐层替换:仅在第2层使用Flash Attention,或仅在第2层使用标准Attention。
- 修改$\delta$计算:将$\delta = \text{rowsum}(dO \circ O)$替换为数学等价的$\delta = \text{rowsum}(dP \circ P)$。
- 高精度计算关键部分:仅将$O$或$P\bar{V}$的计算提升到FP32精度。
-
实验结果:
- 禁用分块后训练仍然失败,排除了分块策略是原因。
- 谱范数分析(图9)和逐层替换实验均定位到第2层的注意力模块是失败的根源。
- 修改$\delta$的计算方式后训练恢复稳定,表明问题出在低精度$O$的计算上。
- 仅用FP32计算$O$或$P\bar{V}$就足以稳定训练,最终将问题锁定在$P\bar{V}$乘积的BF16计算上。
-
-
实验三:分析失败的根本原因
- 实验内容:深入分析第2层注意力模块中特定头(头8)在训练失败前(如第6619步)的中间变量。
- 实验结果与分析:
- 低秩结构:发现矩阵$PK$和输入$X$在不同token和训练步间存在相似的结构(图4),导致梯度误差累积在一个固定的低秩方向$R$上。
- 有偏系数:梯度误差的系数$(\delta_{lp}-\delta_{hp})$的和持续为正(图5a),导致误差无法抵消,只能累积。
- 有偏舍入:正偏差源于$O$的计算中存在负向偏差。当注意力概率$\bar{P}$的某些元素为1,且对应$V$的元素为负时,$P\bar{V}$的点积运算中BF16加法会因溢出和粘滞位(sticky bit)导致系统性的向上舍入(即向更负的方向舍入),产生负的累积误差(图6b, c)。
-
实验四:验证解决方案
- 实验内容:在Flash Attention的softmax计算中引入一个微小的动态修改。当检测到一行分数中存在多个相同的最大值时,动态调整归一化因子,以确保$\bar{P}$的所有元素都严格小于1,从而避免触发有偏舍入的条件。
- 实验结果:应用了该修改的Flash Attention(Stabilized Flash Attention)在相同的BF16设置下成功完成了训练,损失平稳下降,没有出现爆炸现象(图7,蓝色曲线)。这证实了我们对失败机理的分析是正确的。
A5 结论
本文首次对低精度Flash Attention训练中一个臭名昭著的损失爆炸问题提供了机理层面的解释。我们精确地指出,其根本原因在于 emergent low-rank representations(涌现的低秩表示)和 biased BF16 rounding errors(有偏的BF16舍入误差)之间的相互作用。通过对Flash Attention进行一个微小且有针对性的修改,我们成功恢复了训练的稳定性,从而验证了我们的分析。我们的分析工作流程为诊断其他架构、规模和低精度格式中类似的数值不稳定性问题提供了一个蓝图,为更鲁棒和高效的大规模模型训练铺平了道路。
讨论
我们的发现在多种硬件(NVIDIA A100、RTX 4090、华为昇腾910B)上都具有一致性,并从机理上解释了训练不稳定性的经验观察。权重谱范数的增长是由于梯度中低秩误差矩阵的累积。我们还阐明了注意力汇点(attention sinks)的作用:通过吸引高的注意力分数,它们更有可能产生为1的注意力概率,这会触发$P\bar{V}$计算中的有偏舍入误差。这为汇点的架构行为与破坏训练的算术不稳定性之间提供了直接的数值联系。
局限性
我们的分析集中在一个GPT-2模型的特定失败案例上。我们的发现对其他架构、更大规模或不同低精度格式(如FP8)的普适性需要进一步研究。此外,我们提出的缓解措施是针对已识别的特定舍入误差量身定制的,可能无法解决其他来源的数值不稳定性问题。
A6 附录
A 相关工作
A.1 混合精度BF16训练
混合精度训练的演进与BF16的普及。当代大型语言模型(LLM)的预训练几乎普遍采用混合精度算术。Micikevicius等人(2017)的早期工作表明,FP16训练——使用FP32主权重副本和固定的损失缩放——可以在许多模型上达到与FP32相当的准确率。然而,FP16狭窄的指数范围常导致许多梯度下溢,需要仔细调整。bfloat16(BF16)格式,拥有8位指数和7位尾数,保留了FP32的宽动态范围,同时将存储成本减半。Kalamkar等人(2019)证明,BF16在大型模型上无需特殊调整即可实现与FP32相当的收敛性。自此,BF16已成为许多大规模训练框架的默认16位格式,并在PyTorch和TensorFlow中获得原生支持【索引42,Bfloat16: The secret to high performance on cloud tpus+2019】。
BF16在里程碑式LLM中的应用。BF16混合精度使得在空前规模上训练里程碑式的LLM成为可能,包括GPT-3 (175B)【索引4,Language models are few-shot learners+2020】、Google PaLM (540B)【索引5,Palm: Scaling language modeling with pathways+2023】、DeepMind Gopher (280B)【索引34,Scaling language models: Methods, analysis & insights from training gopher+2021】、Chinchilla (70B)【索引15,Training compute-optimal large language models+2022】以及Meta的LLaMA家族 (7B-65B)【索引38,Llama: Open and ` efficient foundation language models+2023】。为了处理巨大的内存占用,像Megatron和DeepSpeed这样的并行训练框架将BF16训练与零冗余优化器(ZeRO)【索引35,Zero: Memory optimizations toward training trillion parameter models+2020】等技术相结合。
BF16训练的稳定性挑战。尽管BF16具有优势,但其训练仍可能表现出不稳定性。经验研究表明,即使有损失缩放,FP16也极不稳定,而BF16则消除了大部分与精度相关的调整和失败【索引42,Bfloat16: The secret to high performance on cloud tpus+2019】。然而,Lee等人(2024)报告称,在纯BF16下,约有10%的GPT-2预训练运行发散,而在TF32下为0%。这表明,虽然BF16显著提高了稳定性,但在大规模训练中,补充性的稳定技术仍然是必要的。
A.2 稳定低精度训练
梯度缩放。Micikevicius等人(2017)的早期工作引入了FP16混合精度训练,其中权重、激活和梯度以半精度存储,同时维护一个FP32主副本。他们还提出了损失缩放来防止FP16下溢。即使有损失缩放,在深度网络中仍可能发生一些下溢。为解决此问题,Zhao等人(2021)引入了梯度缩放,该方法动态计算每层的缩放因子以避免下溢和上溢。
超低精度(FP8/INT8)训练。为了进一步降低成本,近期的工作探索了使用FP8或INT8精度进行训练和推理。然而,直接的FP8训练容易发散。Lee等人(2024)指出,若无额外的稳定技术,将FP8直接应用于LLM训练是不稳定的。为解决此问题,Perez等人(2023)提出了对FP8矩阵乘法使用动态调整的逐张量缩放因子。使用此方案,他们成功地以完全FP8的方式训练了高达70B参数的GPT和LLaMA风格模型。类似地,Peng等人(2023)引入了FP8-LM,一个逐步将FP8应用于梯度、优化器状态和分布式通信的框架,与BF16相比,实现了39%的内存减少和75%的加速。Balanc¸a等人(2024)提出了SCALIFY,它在整个计算图中传播缩放因子,以确保稳定的FP8操作,无需手动调整。这些方法共同证明了,通过仔细的缩放管理,可以使FP8或INT8训练达到与BF16相当的性能,同时减少内存和计算需求。
优化器与梯度稳定。优化器算法在训练稳定性中扮演着关键角色。Molybog等人(2023)从理论上分析了Adam,并表明灾难性发散通常发生在更新方向与大规模模型中的真实下降方向不相关时。为了解决梯度不稳定性,Huang等人(2025)提出了SPAM(带冲量重置的尖峰感知Adam),它通过重置冲量和应用尖峰感知裁剪来检测和缓解罕见但严重的“梯度尖峰”。同时,Wortsman等人(2023)研究了视觉-语言模型中的损失尖峰,并表明AdamW在尖峰出现前常低估二阶矩。他们提出了一种混合的AdamW-AdaFactor优化器,能自适应地纠正二阶矩的低估,性能优于单独的梯度裁剪。这些方法凸显了优化器修改如何直接缓解低精度环境下的发散问题。
激活与架构技术。激活函数的选择和初始化策略也会影响稳定性。Fishman等人(2024)观察到SwiGLU激活在长时间的FP8训练中会放大了异常值。他们引入了Smooth-SwiGLU,一种改进的激活函数,能防止异常值放大,从而实现了稳定的万亿token级FP8训练。在视觉-语言领域,Wortsman等人(2023)表明,“层级缩放为零”的初始化和精心设计的低精度线性层(例如SwitchBack)能进一步提高int8训练的稳定性。
B BF16加法
bfloat16(Brain Floating-Point)格式是一种16位浮点表示法,因其在计算效率和数值范围之间的平衡而在深度学习中被广泛使用。它由1个符号位、8个指数位和7个小数(或尾数)位组成。这种结构使bfloat16具有与32位单精度格式(FP32)相同的动态范围,但精度显著降低。
两个bfloat16数(例如a和b)的加法遵循浮点运算的标准程序:
- 指数对齐:比较两个数的指数。指数较小的数,其有效数字(隐含的前导1位和分数位的组合)会向右移动,直到其指数与较大的指数匹配。每次右移会使指数加一。移出可用精度范围的位会丢失,这是误差的初始来源。
- 有效数字相加:对齐后的有效数字相加。结果的符号由操作数的符号和大小决定。
- 归一化:对结果进行归一化,以确保其符合1.xxxx... × 2^e的格式。如果加法导致溢出(例如,10.xxxx...),则有效数字右移,指数递增。如果导致抵消(例如,0.00xx...),则有效数字左移,指数递减,直到前导位为1。
- 舍入:归一化后,结果的有效数字可能超过7个小数位,必须进行舍入。标准模式是“舍入到最近,偶数优先”。这意味着如果被截断的部分大于最后一个可存储位(LSB)值的一半,则向上舍入。如果小于,则向下舍入。如果恰好是一半,则舍入到LSB为偶数的最近值。
数值误差的关键来源来自步骤1和4。在指数对齐期间,数量级较小的数会丢失精度。加法之后,结果必须舍入回7位小数,这又引入了另一个舍入误差。虽然“舍入到最近,偶数优先”被设计为对随机数据是无偏的,但对具有特定分布的数据(例如,大部分是负数相加)进行一系列加法可能会导致有偏的舍入误差,其中累积的误差会持续将结果推向一个方向。这种有偏误差的累积是低精度设置中观察到的训练失败的关键因素。
C 减轻FLASH ATTENTION中有偏舍入误差的设计考量
使用动态最大值而非固定偏移量。固定的偏移量会导致计算出的$\bar{P}$值在BF16转换过程中持续向一个方向舍入,从而引入一个固定的误差。由于V的元素通常符号相同,这个固定的舍入误差在计算$P\bar{V}$时不会平均为零。这会导致输出O中出现有偏误差,进而产生一个有偏的$\delta$项,重新引入了我们试图解决的失败问题。
有条件地应用动态最大值。我们的修改是有条件地应用的——仅当一行包含多个相同的最大值时——以避免引入新的数值不稳定性。无条件的调整并非更好的选择。例如,如果一行有一个单一且非常大的正最大值$r_m$,应用我们的规则将意味着计算$\exp(S - \beta r_m)$。指数中的最大项将变为$-(\beta - 1)r_m$,而$\exp(-(\beta - 1)r_m)$可能会下溢为零。这将导致归一化因子变为零,在计算输出O时导致除零错误。通过仅在导致有偏舍入的特定情况下应用修改,我们在所有其他场景中都保留了标准在线softmax的数值稳定性。
关于处理负数重复行最大值的解释。我们还探索了针对负数、重复行最大值($r_m < 0$)的其他稳定方法。一种方法是为某个$\gamma \in (0, 1)$设置归一化因子$m = \gamma r_m$。这使得指数中的新最大值为$(1 - \gamma)r_m$。然而,我们观察到,如果$\gamma$接近1,这个新最大值会趋近于零。在低精度算术中,$\exp((1 - \gamma)r_m)$可以舍入为1,从而重新引入我们试图防止的失败。因此,我们发现设置$\gamma = 0$(即$m = 0$)是一个稳健的选择,因为它能确保指数中的最大值保持足够负。
算法1:通过减轻有偏舍入误差来稳定Flash Attention:前向传播
要求: 矩阵 Q, K, V ∈ R^(N×d),块大小 Bc, Br, β > 1。
1: 将 Q 分成 Tr = ⌈N/Br⌉ 个块 Q1, ..., QTr,每个大小为 Br × d,并将 K, V 分成
Tc = ⌈N/Bc⌉ 个块 K1, ..., KTc 和 V1, ..., VTc,每个大小为 Bc × d。
2: 将输出 O ∈ R^(N×d) 分成 Tr 个块 O1, ..., OTr,每个大小为 Br × d,并将
logsumexp L 分成 Tr 个块 L1, ..., LTr,每个大小为 Br。
3: for 1 ≤ i ≤ Tr do
4: 初始化 O(0)i = (0)Br×d ∈ R^(Br×d), ℓ(0)i = (0)Br ∈ R^Br, m(0)i = (−∞)Br ∈ R^Br。
5: for 1 ≤ j ≤ Tc do
6: 计算 S(j)i = Qi * K_j^T ∈ R^(Br×Bc)。
7: rm = rowmax(S(j)i), rs = rowsum(S(j)i ≡ rm)
8: m(j)′i = where(rm > 0 ∧ rs > 1, βrm, rm)
9: m(j)i = where(rm < 0 ∧ rs > 1, 0, m(j)′i)
10: 计算 m(j)i = max(m(j−1)i, rm) ∈ R^Br, P¯(j)i = exp(S(j)i − m(j)i) ∈ R^(Br×Bc) (逐点), ℓ(j)i = diag(e^(m(j-1)i - m(j)i))ℓ(j-1)i + rowsum(P¯(j)i) ∈ R^Br
11: 计算 O(j)i = diag(e^(m(j-1)i - m(j)i))O(j-1)i + P¯(j)i * Vj。
12: end for
13: 计算 Oi = diag(ℓ(Tc)i)^(-1) * O(Tc)i。
14: 计算 Li = m(Tc)i + log(ℓ(Tc)i)。
15: 将 Oi 写入 O 的第 i 个块。
16: 将 Li 写入 L 的第 i 个块。
17: end for
18: 返回输出 O 和 logsumexp L。
算法2:Flash Attention:前向传播
要求: 矩阵 Q, K, V ∈ R^(N×d),块大小 Bc, Br。
1: 将 Q 分成 Tr = ⌈N/Br⌉ 个块 Q1, ..., QTr,每个大小为 Br × d,并将 K, V 分成 Tc = ⌈N/Bc⌉ 个块 K1, ..., KTc 和 V1, ..., VTc,每个大小为 Bc × d。
2: 将输出 O ∈ R^(N×d) 分成 Tr 个块 O1, ..., OTr,每个大小为 Br × d,并将 logsumexp L 分成 Tr 个块 L1, ..., LTr,每个大小为 Br。
3: for 1 ≤ i ≤ Tr do
4: 初始化 O(0)i = (0)Br×d ∈ R^(Br×d), ℓ(0)i = (0)Br ∈ R^Br, m(0)i = (−∞)Br ∈ R^Br。
5: for 1 ≤ j ≤ Tc do
6: 计算 S(j)i = Qi * K_j^T ∈ R^(Br×Bc)。
7: 计算 m(j)i = max(m(j-1)i, rowmax(S(j)i)) ∈ R^Br, P¯(j)i = exp(S(j)i − m(j)i) ∈ R^(Br×Bc) (逐点), ℓ(j)i = diag(e^(m(j-1)i - m(j)i))ℓ(j-1)i + rowsum(P¯(j)i) ∈ R^Br。
8: 计算 O(j)i = diag(e^(m(j-1)i - m(j)i))O(j-1)i + P¯(j)i * Vj。
9: end for
10: 计算 Oi = diag(ℓ(Tc)i)^(-1) * O(Tc)i。
11: 计算 Li = m(Tc)i + log(ℓ(Tc)i)。
12: 将 Oi 写入 O 的第 i 个块。
13: 将 Li 写入 L 的第 i 个块。
14: end for
15: 返回输出 O 和 logsumexp L。
算法3:Flash Attention:后向传播
要求: 矩阵 Q, K, V, O, dO ∈ R^(N×d),向量 L ∈ R^N,块大小 Bc, Br。
1: 将 Q 分成 Tr = ⌈N/Br⌉ 个块 Q1, ..., QTr,每个大小为 Br × d,并将 K, V 分成 Tc = ⌈N/Bc⌉ 个块 K1, ..., KTc 和 V1, ..., VTc,每个大小为 Bc × d。
2: 将 O 分成 Tr 个块 O1, ..., OTr,每个大小为 Br × d,将 dO 分成 Tr 个块 dO1, ..., dOTr,每个大小为 Br × d,并将 L 分成 Tr 个块 L1, ..., LTr,每个大小为 Br。
3: 初始化 dQ = (0)^(N×d) 并将其分成 Tr 个块 dQ1, ..., dQTr,每个大小为 Br × d。将 dK, dV ∈ R^(N×d) 分成 Tc 个块 dK1, ..., dKTc 和 dV1, ..., dVTc,每个大小为 Bc × d。
4: 计算 δ = rowsum(dO ◦ O) ∈ R^N (逐点相乘),并将其分成 Tr 个块 δ1, ..., δTr,每个大小为 Br。
5: for 1 ≤ j ≤ Tc do
6: 初始化 dKj = (0)^(Bc×d), dVj = (0)^(Bc×d)。
7: for 1 ≤ i ≤ Tr do
8: 计算 S(j)i = Qi * K_j^T ∈ R^(Br×Bc)。
9: 计算 P(j)i = exp(Sij − Li) ∈ R^(Br×Bc)。
10: 计算 dVj ← dVj + (P(j)i)^T * dOi ∈ R^(Bc×d)。
11: 计算 dP(j) = dOi * V_j^T ∈ R^(Br×Bc)。
12: 计算 dS(j) = P(j) ◦ (dP(j) − δi) ∈ R^(Br×Bc)。
13: 更新 dQi ← dQi + dS(j) * Kj ∈ R^(Br×d)。
14: 计算 dKj ← dKj + (dS(j)i)^T * dQi ∈ R^(Bc×d)。
15: end for
16: end for
17: 返回 dQ, dK, dV。
💬 评论讨论
欢迎在这里分享您的想法和见解!