Skrull: Towards Efficient Long Context Fine-tuning through Dynamic Data Scheduling
Skrull: Towards Efficient Long Context Fine-tuning through Dynamic Data Scheduling
作者/机构: Hongtao Xu1,2,3, Wenting Shen3, Yuanxin Wei4, Ang Wang3, Guo Runfan2, Tianxing Wang3, Yong Li3, Mingzhen Li2†, Weile Jia2†
1School of Advanced Interdisciplinary Sciences, University of Chinese Academy of Sciences 2State Key Lab of Processors, Institute of Computing Technology, CAS 3Alibaba Group 4Sun Yat-sen University † Corresponding authors
A1 主要贡献
本文提出了一种从数据调度的新视角来解决长上下文微调(Long-SFT)中异构数据分布带来的挑战。长上下文微调通常需要在包含长短序列的混合数据集上进行训练,以平滑地使大语言模型(LLM)适应长上下文场景。然而,现有的训练系统无法同时为长序列和短序列实现高训练效率,导致端到端系统性能不佳。
核心问题与研究目标:
长上下文微调中的异构序列长度分布给现有分布式LLM训练框架带来了巨大挑战。长序列因其巨大的内存需求,需要上下文并行和其他内存削减策略,但这些策略会因通信开销和GPU利用率不足而损害短序列的训练效率。此外,序列长度的广泛分布加剧了Attention模块计算特性的不匹配(计算复杂度为二次方,内存消耗为线性),导致负载均衡问题。本文的目标是通过动态数据调度来平衡长短序列的计算需求,从而提高整体训练效率。
创新点与主要贡献:
本文提出了Skrull,一个专为高效长上下文微调设计的动态数据调度器。其主要贡献如下:
* 新的数据调度视角:提出了一种新的数据调度视角来解决异构序列长度分布的问题。
* 分布式感知上下文并行(DACP):提出了一种新的上下文并行技术DACP,它基于细粒度的数据调度,既能保持处理长序列的能力,又能保证短序列的处理效率,从而实现了在异构数据分布下的高效训练。
* 全局数据调度(GDS):实现了粗粒度的全局数据调度(GDS),并将GDS和DACP的调度过程共同建模为一个联合优化问题。
* 轻量级启发式算法:设计了一种轻量级的启发式算法,在真实世界数据集上实现了平均3.76倍(峰值7.54倍)的性能提升。
A3 背景知识与关键观察
背景知识
数据并行(Data Parallelism, DP):数据并行【15, Pytorch distributed: experiences on accelerating data parallel training, 2020, VLDB Endow.】【24, Pytorch fsdp: Experiences on scaling fully sharded data parallel, 2023, arXiv】【20, Zero: memory optimizations toward training trillion parameter models, 2020, SC】将训练样本划分给多个工作节点(worker),每个工作节点维护一个完整的模型权重副本。在每次迭代中,所有工作节点独立处理全局批次的一个子集,然后同步所有DP rank之间的梯度。然而,由于DP中固有的同步语义,负载均衡成为一个显著问题,尤其是在长上下文场景中。
上下文并行(Context Parallelism, CP):上下文并行【12, System optimizations for enabling training of extreme long sequence transformer models, 2024, PODC】【16, Ringattention with blockwise transformers for near-infinite context, 2024, ICLR】【10, The llama 3 herd of models, 2024, arXiv】沿着序列长度维度对输入张量进行分区,并将其分配给多个工作节点。在处理长上下文时,CP正成为一种不可避免的并行策略。在Transformer架构中,CP的主要挑战来自于Attention模块的并行化,因为每个token都需要关注序列中的其他token。因此,CP中的通信是不可避免的。值得注意的是,本文提出的DACP利用数据调度来最小化由CP引起的开销,并且与具体的CP实现是正交的。
关键观察
异构序列长度分布
真实世界Long-SFT数据集的序列长度分布存在显著差异。本文观察到在真实世界的Long-SFT数据集中,序列长度分布存在显著的方差,这些数据集包括Wikipedia 【2, pleisto/wikipedia-cn-20230720-filtered · datasets at hugging face】、LMsysChat1M 【25, LMSYS-chat-1m: A large-scale real-world LLM conversation dataset, 2024, ICLR】和ChatQA2-Long-SFT 【1, nvidia/ChatQA2-long-SFT-data · datasets at hugging face】。其中,ChatQA2-long-SFT的序列长度分布呈现双峰模式,长短序列的比例几乎相等:约40%的序列短于8K tokens,其余60%超过此阈值。相比之下,长尾分布是Long-SFT数据集中的另一种典型模式。在Llama3内部收集的Long-SFT数据集中【10, The llama 3 herd of models, 2024, arXiv】,99.89%的序列平均长度低于1K tokens,而其余0.11%的序列平均长度约为37K tokens,显示出极度倾斜的长尾分布。由于数据可访问性限制,本文绘制了Wikipedia和LMsysChat1M的序列长度分布图,它们与Llama3的Long-SFT数据集具有相同的特征。表1列出了这三个数据集中不同长度阈值下的序列比例,凸显了Long-SFT中异构的序列长度分布。
图1:不同数据集上的序列长度分布及其对应的性能影响。
表1:真实世界数据集中序列长度的百分比。
短序列的性能下降
为长序列配置的训练设置会导致短序列性能下降和GPU利用率不足。在训练过程中,上下文并行度和其他内存削减策略(如梯度累积)通常会根据数据集中最长的序列来设置,以避免内存溢出(OOMs)。然而,这些训练设置为占Long-SFT数据集中大多数的较短序列带来了性能下降。如图1b所示,本文测试了Attention模块【6, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2024, ICLR】在不同CP度下的性能。结果表明,特别是对于短序列,更高的CP度会加剧核函数执行效率的恶化。此外,上下文并行还给短序列带来了不必要的通信开销。同时,为长序列量身定制的内存削减策略导致在大部分时间内GPU内存利用率低下。
A2 方法细节
本节介绍Skrull的设计及其在在线Long-SFT训练中的高效实现。图2展示了Skrull的工作流程。从数据调度的角度来看,Skrull包含两个部分:(i)全局数据调度(GDS):在每次迭代中,Skrull将全局批次作为输入,并采用粗粒度调度为每个DP rank生成最优的微批次(micro-batches)。(ii)分布式感知上下文并行(DACP):接收GDS产生的微批次,Skrull进一步采用更细粒度的调度,选择性地分发序列并将其分配给不同的CP工作节点。为了便于公式化表述,本文将依次在4.1节介绍DACP,4.2节介绍GDS,4.3节介绍高效实现。
图2:Skrull的工作流程。(a) 离线分析:给定模型和训练设置,为数据调度提供性能估计。(b) GDS:为DACP生成最优的批处理策略。(c) DACP:动态地将数据调度到特定硬件上,以实现负载均衡和最小化开销。(d) DACP的性能增益:展示了减少的通信量和重叠如何提高性能。
分布式感知上下文并行 (Distributed-aware Context Parallelism)
DACP通过动态分类序列为分布式或本地来平衡长短序列的处理效率。为了同时为所有序列实现高效率,本文提出了分布式感知上下文并行(DACP)。如图2(c)所示,DACP动态决定是否分发序列以避免不必要的开销。一方面,DACP保留原始的上下文并行设置以维持处理长序列的能力。另一方面,DACP选择性地将短序列完全调度在单个设备内以最小化性能下降。因此,基于不同的计算特性,DACP将序列分为两类:(i)需要上下文并行的分布式序列,和(ii)需要高效处理并旨在完全驻留在单个设备内的本地序列。值得注意的是,这些序列仍然在一个共享的CP组内处理,而无需增加用于训练的GPU数量。此外,如图2(d)所示,由于分布式序列和本地序列之间固有的独立性,DACP在Attention模块中为重叠分布式序列的通信和本地序列的计算带来了额外的机会。
调度过程面临内存和负载均衡挑战,需通过性能建模来指导。然而,调度过程带来了重大挑战。首先,不恰当的序列分类可能导致内存溢出(OOMs)。其次,本地序列长度各异,给CP rank之间带来了负载不均衡问题。为了充分探究调度计划与性能增益之间的关系,本文首先在附录C中分析了计算和内存特性。通过离线分析,本文通过FLOPs函数和延迟估计函数$T_{comp}$对计算进行建模(详见附录C.2)。此外,本文将序列长度映射到内存消耗,并推导出BucketSize $C$,该值表示每个rank的总序列长度容量。BucketSize $C$在衡量Skrull调度过程中的内存约束方面起着至关重要的作用。更多细节列于附录C.1。类似地,本文对通信量函数$Volume$和延迟函数$T_{comm}$进行了建模,详见附录C.3。最后,本文将调度过程形式化为一个优化问题。常用符号列于表2。
DACP 公式化。我们首先定义序列分类数组$D \in {0, 1}^K$(0代表本地序列,1代表分布式序列)和本地序列分配矩阵$P \in {0, 1}^{K \times N}$(1代表分配,0代表不分配)。例如,$D_k = 1$表示长度为$S_k$的第$k$个序列被调度以分布式方式计算。类似地,$P_{kj} = 1$表示第$k$个序列被分配给设备$j$,这意味着$D_k = 0$。给定一个包含$K$个序列的微批次,其长度为$S_k$ ($k = (0, \dots, K-1)$),BucketSize为$C$,CP度为$N$,DACP的调度过程可以公式化如下:
表2:本文使用的符号。
约束条件为:
$$\text{Time}_j = \max \left( T_{comm} \left( \text{V} \right), T_{comp} \left( \text{Local}_j \right) \right) + T_{comp} \left( \text{Dist} \right), \quad \forall j$$ $$\text{Local}_{j}=\sum_{k}\text{FLOPs}\left(P_{kj}\cdot S_{k}\right),\quad\forall j$$ $$\text{Dist}=\frac{1}{\mathrm{N}}\sum_{k}\text{FLOPs}\left(D_{k}\cdot S_{k}\right)$$ $$V = \text{Volume}(\sum_k D_k \cdot S_k)$$ $$\sum_{j} P_{k j}+D_{k}=1, \quad \forall k$$ $$\sum_{k} S_{k} \cdot P_{k j}+\frac{D_{k} \cdot S_{k}}{\mathrm{N}} \leq C, \quad \forall j$$优化目标是找到最优的D和P以最小化$T_{DACP}$。在这里,我们的优化目标是找到最优的$D$和$P$来最小化$T_{DACP}$,它代表一个微批次的总时长。如公式1所示,$T_{DACP}$由所有CP rank $j$中的最大执行时间$Time_j$决定。具体来说,如公式2所述,$Time_j$由两部分组成:(1)重叠项,定义为通信时间$T_{comm}(V)$和本地序列计算时间$T_{comp}(Local_j)$的最大值,以及(2)分布式序列的计算时间$T_{comp}(Dist)$。这里,$T_{comm}$取决于通信量$V$,如公式5建模。类似地,$T_{comp}$利用了公式3和4的结果,这两个公式分别计算了CP rank $j$上本地序列的FLOPs和分布式序列的FLOPs。最后,公式6确保了数据调度的完整性,而公式7则强制执行了内存约束。
全局数据调度 (Global Data Scheduling)
仅靠DACP内的调度不足以应对Long-SFT中的负载不均衡问题。4.1节讨论了微批次范围内的的数据调度。然而,仅仅依赖DACP中的调度是不够的。原因如下。
异构序列长度分布导致微批次间负载不均衡,且DACP需要精心设计的微批次策略。首先,异构的序列长度分布也导致了不同微批次之间严重的负载不均衡,从而在Long-SFT场景中导致次优的训练效率。其次,为了在DACP中实现最大的性能增益,精细的微批次策略是必不可少的。例如,将长序列和短序列与适当的内存压力配对,可以扩大DACP的有效调度空间。具体来说,总序列长度较大的微批次会增加OOMs的风险,并限制DACP中的优化,如选择性分片。相反,总序列长度较小的微批次会引入GPU利用率不足,降低端到端性能。因此,如图2(b)所示,Skrull采用全局数据调度(GDS),从全局批次中推导出最优的微批次策略。我们将调度范围限制在全局批次,因为这代表了能为主流优化器如Adam 【13, Adam: A method for stochastic optimization, 2017, arXiv】和AdamW 【17, Decoupled weight decay regularization, 2019, ICLR】保持数学等价性的最大范围。
联合公式化。我们将调度过程重新表述为一个集成了DACP和GDS的联合优化问题。我们首先定义批处理矩阵$B_{kij} \in {0, 1}^{K \times N}$,它指示第k个序列是否被调度到DP rank i的第j个微批次中。给定一个由K个长度为$S_k$的序列组成的全局批次B,我们重新表述调度过程如下:
$$\text{\textbf{Objective}} \quad \min_{arg} \max_{i} (\sum_{j} Time_{ij})$$约束条件为:
$$\sum_{ij} B_{kij} = 1, \quad \forall k$$ $$\sum_{ki} B_{kij} * S_k \le C * N, \quad \forall j$$ $$\text{Time}_{ij} = \mathbf{TDACP} \left( B_{k i j} * S_{k} \right), \quad \forall k$$联合优化目标是通过最优调度计划最小化每次迭代的总执行时间。在这里,$T_{DACP}$代表公式1中的函数。公式9确保所有序列都被精确分配一次。公式10中的内存约束防止了OOMs,而公式11中显示的$Time_{ij}$则使用公式1中的DACP公式为每个微批次提供成本估算。如公式8所示,每次迭代的总执行时间由其微批次累积执行时间最长的DP rank决定。
总的来说,优化目标是通过推导最优调度计划来最小化每次迭代的总执行时间,该计划由$B_{kij}$、$D_k$和$P_{kj}$的组合表示。
高效的在线调度
为实现在线调度,本文设计了轻量级启发式算法。尽管一些求解器如【4, The scip optimization suite 9.0, 2024, arXiv】可以推导出最优调度计划,但其较长的求解时间使其不适用于运行时调度。为了在Long-SFT期间实现在线调度,我们转而设计轻量级的启发式调度算法。值得注意的是,我们的调度算法集成在DataLoader中,并为训练过程引入了近乎零的开销。
内存与计算的权衡分析
内存和计算是影响性能的关键因素,二者之间存在权衡。内存和计算是与性能相关的关键因素,如调度公式所示。我们应在不违反内存约束的情况下实现最佳性能,这带来了一个权衡。因此,我们首先分析在推导调度策略时计算与内存之间的权衡,突出了设计调度算法时的考量。
序列分类(推导数组D)涉及内存与性能的权衡。我们分析序列分类(4.1节中的数组D)。从计算的角度来看,D影响通信量和分片序列的计算(公式5和4)。更多的分片序列将导致更多的性能下降,这既来自通信开销也来自核函数执行(参考3.2节)。然而,从内存的角度来看,更多的分布式序列将带来更均衡的内存消耗(公式7),这可以降低OOM的风险,因为剩余的不同长度的本地序列很难被均匀分配。此外,尽管DACP中的重叠可以在一定程度上缓解性能下降问题(公式2),但决定最优的分类数组D仍然非同小可。
本地序列分配(推导矩阵P)同样面临内存与计算的权衡。然后,我们分析由P表示的本地序列分配。从计算的角度来看,P影响公式3,这意味着每个CP ranks中的计算工作量,从而影响负载均衡。理想情况是平衡本地序列以实现计算平衡。然而,从内存的角度来看,平衡计算的调度会导致内存消耗的不平衡,从而增加OOM的风险。
计算与内存在Attention模块中因复杂度不同而难以同时平衡。不幸的是,我们无法同时平衡计算和内存。原因在于,应用FlashAttention【7, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022, NeurIPS】【6, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2024, ICLR】后,计算复杂度与序列长度(n)的相关性是$O(n^2)$,而内存与序列长度的相关性是$O(n)$。此外,随着序列长度的增加,Attention模块的部分逐渐主导计算负载,使得平衡计算和内存变得更加困难。更糟糕的是,模型配置(例如,KV头、隐藏层大小)也会产生影响。由于篇幅有限,我们在附录C中列出了详细信息。
因此,需要精心设计的启发式算法来处理内存和计算的平衡。因此,我们需要仔细处理内存占用平衡和计算复杂度平衡,并设计了以下启发式方法。
启发式算法
DACP的调度算法遵循避免分片、优先计算和回滚机制三大原则。我们首先总结DACP算法设计的三个原则。(i)避免分片:我们力求避免序列分片,并假设所有序列首先都会被本地处理。(ii)优先计算:我们优先平衡计算而非内存,以实现更好的性能。(iii)回滚机制:我们持续监控估计的内存消耗,并在必要时撤销决策。回滚机制保证了公式7和公式10中概述的内存约束,同时允许基于(i)和(ii)进行更激进的调度尝试。我们的DACP启发式算法列在算法1中。给定一个包含K个序列(长度为S[K])的微批次和一个预定义的BucketSize C,该算法以数组ret的形式输出序列分类和分配结果。在该数组中,第i个位置的值为-1表示第i个序列将被分片,而值v = (0, ..., ws-1)表示第i个序列被完全分配给CP rank v。为了在确保内存约束的同时更好地平衡计算,我们在DACP调度期间维护两个数组:RemainBucket RB和Loads L,分别代表当前的内存预算和计算负载。我们首先按升序对序列进行排序。对于每个序列,我们依次将其分配给L最小的桶(以及CP rank),以避免分片并优先平衡计算(第6-8行)。如果该桶无法容纳该序列,我们尝试将其分配给RB最大的桶,以避免分片(第10-12行)。如果两次尝试都失败,我们将该序列分类为分布式序列并尝试对其进行分片(第14-16行)。然而,如果RB最小的桶在分片后无法处理子序列,这表明之前的过程错误地将不合适的序列分类为该桶内的本地序列。为了解决这个问题,我们采用回滚机制(第18行和附录B.1)。该机制识别桶中的一个本地序列,将其分片以减少内存压力,并恢复分配过程。如果由于桶中没有本地序列而导致回滚失败,我们返回一个DACP调度错误。在这种情况下,GDS也将恢复批处理计划(见第4.3.2节)。值得注意的是,每次分配都会通过预定义的函数UpdateLocal和UpdateAll更新RB和L。这些函数(包括RollBack)的详细信息在附录B.1中进一步阐述。
GDS的调度算法遵循优先计算、长短序列配对和提高内存利用率三大原则。算法2展示了GDS的启发式调度算法。给定一个包含K个序列(长度为S[K])的全局批次、DP世界大小ws和DP rank dp_rank,该算法返回调度结果mbs,它由多个微批次组成,作为算法1的输入。我们总结了算法设计的三个原则。(i)优先计算:我们优先平衡DP工作节点间的计算。为实现此目的,我们估算FLOPs(附录C.2)并采用箱柜打包算法在粗粒度上平衡计算工作负载(第1行)。(ii)配对长短序列:我们在每个DP rank内对序列进行排序,并以交错方式将它们批处理(第7行)。这种方法确保长序列更均匀地分配到各个微批次中。此外,每个微批次包含多个短序列,从而增强了任务重叠和负载均衡。(iii)提高内存利用率:我们估算总内存需求,并尝试以较少的微批次数量提高并发性。得益于回滚机制(第8行),此方法在不增加OOM风险的情况下最大化了内存利用率。如第5行所示,如果调度失败并需要回滚,我们会逐渐增加微批次的数量。
A4 实验环境与结果
实验环境
- 硬件配置:实验在一个由4个节点组成的测试平台上进行,节点间通过高性能InfiniBand网络互连。每个节点配备8个Nvidia H100 GPU,通过900GB/s的NVLink连接。
- 软件配置与基线:Skrull在DeepSpeed(一个先进的分布式LLM训练系统)之上实现,并启用了Zero-2优化作为基线。此外,还实现了LongAlign 【3, LongAlign: A recipe for long context alignment of large language models, 2024, EMNLP 2024】中的排序批处理方法进行更多比较,该方法按序列长度对数据集排序,并为每个批次选择随机的连续组以提高长SFT训练效率。
- 模型架构:在Qwen2.5-0.5B和Qwen2.5-7B上评估了优化效果。
- 数据集:使用了三个真实世界数据集:Wikipedia【2, pleisto/wikipedia-cn-20230720-filtered · datasets at hugging face】、LMsysChat1M【25, LMSYS-chat-1m: A large-scale real-world LLM conversation dataset, 2024, ICLR】和ChatQA2-long-SFT【1, nvidia/ChatQA2-long-SFT-data · datasets at hugging face】。前两者具有长尾分布特性,后者具有双峰分布特性。
- 训练配置:通过离线分析,为Qwen2.5-0.5B和Qwen2.5-7B分别配置了26K和13K的BucketSize。所有实验共享相同的训练设置<DP=4, CP=8, BatchSize=64>,启用zero-2和选择性重计算策略,但在使用ChatQA2-long-SFT数据集训练Qwen-2.5-7B时,由于内存需求增加,并行设置调整为<DP=2, CP=16, BatchSize=40>。全局批大小等于DP大小乘以BatchSize。
实验结果
DACP启发式调度算法
# 算法 1 DACP的启发式调度算法
Require: 序列数 K, 序列长度 S[K], 桶大小 C, CP度 N
Ensure: 调度结果 ret[K]
1: Sort(SeqLens, ascending=True)
2: for i = 0 to N − 1 do
3: RB[i] ← C, L[i] ← 0 ▷ 初始化
4: end for
5: for i = 0 to K − 1 do
6: t ← argmin(L) ▷ 找到负载最小的rank t
7: if RB[t] ≥ S[i] then
8: ret[i] ← t, UpdateLocal(i, t)
9: else
10: t ← argmax(RB)
11: if RB[t] ≥ S[i] then
12: ret[i] ← t, UpdateLocal(i, t)
13: else
14: t ← argmin(RB)
15: if RB[t] ≥ S[i]/N then
16: ret[i] ← −1, UpdateAll(i) ▷ 分发序列
17: else
18: Assert RollBack(t, RB, L)
19: i ← i − 1 ▷ 回滚以避免OOM
20: continue
21: end if
22: end if
23: end if
24: end for
25: return ret
GDS启发式调度算法
# 算法 2 GDS的启发式调度算法
Require: 序列数 K, 序列长度 S[K], 桶大小 C, CP度 N, DP世界大小 ws, DP_Rank dp_rank
Ensure: 微批次 mbs
1: Bin[ws] ← Binpack(ws, FLOPs(S[K])) ▷ 粗粒度平衡
2: Subset ← Bin[dp_rank], init ← ⌈Sum(Subset)/C × N⌉ − 1
3: Sort(Subset, ascending=True)
4: while init ≤ K + 1 do
5: init ← init + 1, mbs ← []
6: for j ← 0 to init do
7: mbs.append(Subset[j :: init]) ▷ 配对长短序列
8: if Sum(mbs[−1]) ≥ C × N or not scheduling_in_DACP(mbs[−1]) then
9: Continue ▷ 如果过载或DACP调度失败则回滚
10: end if
11: end for
12: end while
13: return mbs
总体性能
Skrull在不同模型和数据集上均取得显著性能提升。如图3所示,Skrull在平均迭代时间方面显著优于排序批处理策略。实验结果表明,与DeepSpeed和排序批处理方法相比,Skrull的性能平均提升了3.76倍,峰值提升达到7.54倍。Qwen-0.5B和Qwen-7B的平均加速比分别为5.50倍和2.03倍。这种差异归因于BucketSize的变化,它直接影响了有效的数据调度空间。此外,Skrull比排序批处理方法平均快3.45倍,峰值提升6.85倍。从数据集角度看,由于数据分布相似(都呈长尾特征),在Wikipedia和LMsysChat1M上的性能相似。在这种分布中,短序列占主导,因此展现出更大的优化潜力。相比之下,ChatQA2-Long-SFT数据集中长序列也占多数(呈双峰分布),导致优化空间相对较小。特别是在用此数据集训练Qwen-7B时,主要序列长度超过了BucketSize,导致加速有限。通过结合更多优化技术,如参数高效微调(PEFT)【11, LoRA: Low-Rank Adaptation of Large Language Models, 2022, ICLR】【5, LongLoRA: Efficient fine-tuning of long-context large language models, 2023, arXiv】,可以进一步扩展BucketSize。
图3:总体性能和分步评估。设置分别代表DP度、CP度和批大小。
分步评估
DACP和GDS两个组件均有效且能协同工作。此外,本文使用上述相同的训练设置进行了分步评估。如图3所示,我们相继启用了DACP和GDS来测试Skrull中每个部分的有效性。实验结果表明,这两个组件都是有效的,并且可以很好地协作以进一步提高Long-SFT中的端到端系统性能。
BatchSize和BucketSize的性能影响
更大的BatchSize和BucketSize可以提升性能,但后者有OOM风险。为了研究BatchSize和BucketSize的性能影响,我们在ChatQA2-long-SFT上使用Qwen2.5-0.5B进行了实验,默认设置为<DP=1, CP=8, BatchSize=64, BucketSize=26K>。如图4(a)所示,我们将BatchSize从8调整到54,端到端加速比也随之提高。我们将此性能增益归因于更大批次大小所提供的扩展调度范围。然而,随着BatchSize进一步增加,抽样批次逐渐收敛到数据集的序列长度分布,导致性能增益稳定在一个合理范围内。此外,我们还评估了BucketSize的效果。图4(a)显示,将BucketSize从8K增加到32K(括号中的值)可以提高加速比,直到发生内存溢出(OOM)错误。这表明虽然更大的BucketSize能提升性能,但也会增加OOM错误的风险。因此,设置合适的BucketSize非常重要,这凸显了Skrull中性能建模模块的重要性。
图4:左(a)显示了不同BatchSize和BucketSize下的性能影响。右(b)显示了Skrull在更大模型中与Lora的有效性和兼容性。
案例研究
Skrull的启发式算法在平衡计算和内存方面优于简单策略,且回滚机制至关重要。本节对Skrull的训练过程进行量化分析。我们使用Qwen2.5-0.5B模型和ChatQA2-Long-SFT数据集,在<DP=1, CP=8, BatchSize=64>配置下进行实验。为评估启发式实现的有效性,我们将其与轮询(RR)调度策略进行比较,后者以简单的轮询方式分配序列(详见附录B.2)。此外,我们测试了两种调度算法在有无回滚机制下的表现,以进一步评估Skrall中的权衡设计。如表3所示,Skrull显著优于RR调度。我们通过检查前四次迭代的量化数据来分析此结果。表4(以元组格式呈现)报告了每次迭代期间所有GPU的最小和最大峰值内存使用量(以GB为单位)。与基线相比,Skrull和RR调度都在本地分配了更多序列,代价是增加了内存不平衡。虽然只要不超过内存容量,这种不平衡是可以接受的,但它增加了OOM的风险。因此,如表3所示,没有回滚机制,两种调度策略都会导致OOM,这凸显了此保障措施的重要性。与RR调度相比,Skrull在遵守内存约束的同时实现了更好的计算平衡(由加速比表示),证明了其在Skrull中权衡设计的有效性。
表3:调度策略比较
表4:每次迭代的延迟比较
A7 补充细节
相关工作
从数据工程的角度看,这些工作【23, LongSkywork: A training recipe for efficiently extending context length in large language models, 2024, arXiv】【3, LongAlign: A recipe for long context alignment of large language models, 2024, EMNLP 2024】【19, Qwen2.5 technical report, 2024, arXiv】【10, The llama 3 herd of models, 2024, arXiv】涉及为长上下文微调精心收集训练数据集。从训练系统的角度看,LongAlign【3, LongAlign: A recipe for long context alignment of large language models, 2024, EMNLP 2024】采用排序批处理策略来优化长上下文微调阶段的系统效率。Chunkflow【22, Efficient long context fine-tuning with chunk flow, 2025, ICLR】将训练数据组织成固定大小的块,实现了可控的峰值内存消耗和减少的流水线气泡。此外,一些工作采用动态并行设置【8, Enabling parallelism hot switching for efficient training of large language models, 2024, SOSP】来处理不同长度的序列,这与long-SFT类似。相比之下,Skrull采用固定的并行设置,并与这些方法正交。另一类工作是参数高效微调(PEFT)【11, LoRA: Low-Rank Adaptation of Large Language Models, 2022, ICLR】【5, LongLoRA: Efficient fine-tuning of long-context large language models, 2023, arXiv】,Skrull对这类方法也同样有效。
A5 结论
本文提供了一种新的数据调度视角,以提高Long-SFT场景下的训练效率。Long-SFT中异构的数据分布给现有训练系统在配置并行策略和确保负载均衡方面带来了困境。为了应对这些挑战,我们提出了Skrull,一个专为Long-SFT设计的动态数据调度器。通过动态数据调度,Skrull在长序列和短序列上都实现了高效训练。此外,我们将调度过程公式化为一个联合优化问题,并采用了一种轻量级调度算法。实验结果表明,在真实世界的long-SFT场景中,Skrull的性能平均比DeepSpeed高出3.76倍(最高可达7.54倍)。此外,我们相信Skrull可以作为其他场景下的有效解决方案,特别是在处理长短训练数据混合的情况下,例如基于人类反馈的强化学习(RLHF)。
A6 附录
A 精确度验证
Skrull的数据调度不影响模型收敛。为了评估损失等价性,我们在LMsysChat1M上训练Qwen2.5-0.5B时,比较了Skrull和标准训练方法的损失曲线。Skrull中的数据调度改变了累积顺序,由于浮点运算的非结合性,我们可以观察到轻微的数值差异。然而,Skrull不改变每个全局批次中的任何内容和顺序,优化轨迹保持等价。因此,如图5所示,Skrull不影响收敛。
图5:Skrull与标准训练方法的损失比较。
B 启发式调度算法
B.1 启发式算法的函数定义
B.2 轮询调度算法
# 算法 3 DACP调度算法中的函数定义
Require: 序列数 K, 序列长度 S[K], 桶大小 C, CP度 N, 负载 L[N], 剩余桶容量 RB[N], DACP调度结果 ret
1: function UPDATELOCAL(idx, rank)
2: RB[rank] ← RB[rank] − S[idx] ▷ 更新剩余桶容量
3: L[rank] ← L[rank] + FLOPs(S[idx]) ▷ 更新当前负载
4: end function
5: function UPDATEALL(idx)
6: for i = 0 to N − 1 do
7: RB[i] ← RB[i] − S[idx]/N ▷ 跨所有桶分发
8: L[i] ← L[i] + FLOPs(S[idx], N) ▷ 更新所有负载
9: end for
10: end function
11: function ROLLBACK(rank,RB,L)
12: for i = 0 to K − 1 do
13: if ret[i] == rank then
14: ret[i] ← −1 ▷ 分发序列
15: RB[rank] ← RB[rank] − S[i] + S[i]/N
16: if L is not None then
17: L[rank] ← L[rank] − FLOPs(S[i]) + FLOPs(S[i], N)
18: end if
19: return True ▷ 回滚成功
20: end if
21: end for
22: return False ▷ 回滚失败
23: end function
1: Input: 序列数 N, 序列长度 S, 桶 C, 世界大小 ws
2: Output: 分区ID P
3: for i = 1 to m − 1 do
4: t ← FindMaxBucketsIds()
5: if C[t] ≥ S[i] then
6: P[i] ← t ▷ 放入最大的桶
7: else
8: j ← FindMinBucketsIds()
9: if C[j] ≥ S[i]/ws then
10: P[i] ← −1 ▷ 分区序列
11: else
12: Assert RollBack(j, C) ▷ 带回滚
13: i ← i − 1
14: continue
15: end if
16: end if
17: end for
C 性能建模
C.1 内存估计
通过离线分析将内存容量映射为序列长度容量BucketSize C。由于篇幅有限,本节讨论Skrull的内存估计方法。本节的关键是确定BucketSize $C$,它将内存容量映射到序列token长度。
内存消耗分为静态和动态部分,动态部分与序列长度成线性关系。我们首先分析LLM训练期间的内存消耗。内存消耗大致可分为两部分:静态内存和动态内存。静态内存通常包括模型参数和优化器状态,在给定的模型配置和并行策略下,在整个训练过程中大致保持不变。相比之下,动态内存或激活内存随输入工作负载而变化。在Transformer架构中,激活内存与序列长度成正比。例如,Linear模块、LayerNorm和Attention模块(使用FlashAttention【7, FlashAttention: Fast and memory-efficient exact attention with IO-awareness, 2022, NeurIPS】【6, FlashAttention-2: Faster attention with better parallelism and work partitioning, 2024, ICLR】)表现出与序列长度的线性关系。因此,我们可以使用以下方程估算给定序列长度S的激活内存:
$$Memory(S) = \alpha S + \beta$$系数α和β通过离线分析确定,并可推导出BucketSize C。在这里,系数$\alpha$和常数$\beta$是在离线分析时确定的。值得注意的是,一些内存减少策略,如梯度检查点,只影响$\alpha$和$\beta$。我们仍然可以应用离线分析方法来估计激活内存。在我们的实现中,我们发现$\beta$通常可以忽略不计。此外,我们采用序列打包来消除填充并提高性能,使我们能够直接使用总序列长度进行内存估计。因此,通过离线分析,我们可以在各种设置下推导出BucketSize $C$。
C.2 计算估计
计算成本$T_{comp}$通过估算FLOPs来建模。本节描述用于估计计算成本$T_{comp}$的方法。
Transformer层的FLOPs是序列长度S的线性和二次方依赖的混合体。准确地将计算成本建模为序列长度S的函数并非易事。简单地假设与序列长度呈线性或二次关系是不够的,因为TransformerLayer的计算FLOPs主要由Linear和Attention模块主导,表现出对S的线性和二次方依赖的混合。这些组件的相对贡献因具体模型配置而异。因此,我们构建了一个FLOPs函数,以便在给定特定模型配置和序列长度S的情况下,提供大致的计算成本估算。
FLOPs估算公式考虑了模型配置。给定隐藏维度h、键/值隐藏维度$h_{kv}$和训练批大小b(当采用序列打包时通常为1)的模型配置,FLOPs按公式13估算。
$$FLOPs(S_{k}) = 20 * b * h^{2} * S_{k} + 4 * b * h * h_{kv} * S_{k} + 4 * b * h * S_{k}^{2}$$计算时间$T_{comp}$可由FLOPs线性估算。对于每个序列,$T_{comp}$可以估算为:
$$T_{comp} = \alpha FLOPs + \beta$$其中所有的$\alpha$和$\beta$都在离线分析时确定。
长序列的计算量增长远快于内存消耗,导致计算与内存平衡的难题。此外,如图6所示,我们绘制了Qwen-2.5-0.5B和Qwen-2.5-7B的FLOPs与序列长度之间的关系图。结果突出了长短序列的鲜明特征。对于短序列,计算工作量和激活内存消耗大致与序列长度成线性关系。然而,对于长序列,由于二次项的主导,计算工作量迅速增长,而内存消耗仍保持线性,这导致了平衡计算和内存的权衡问题,这在4.3节中详细讨论,我们在此节中提出了启发式算法设计的见解。
二次项主导的转折点随模型配置而变。此外,二次项开始占主导地位的转折点因模型配置而异。如图6所示,具有更大隐藏维度h的Qwen-2.5-7B的FLOPs增长比Qwen-2.5-0.5B更迅速。虽然Qwen-2.5-0.5B的FLOPs增长较慢,我们以它为例进一步讨论长短序列之间的鲜明特征。在Qwen-2.5-0.5B中,二次项仅在序列长度S超过约4K时才开始占主导地位,在短序列中表现出大致的线性关系。然而,当S = 32K时,总计算工作量是S = 4K时的30倍,而内存消耗仅增加4倍。这些估算进一步阐明了长短序列的鲜明特征。
图6:Qwen-2.5 0.5B和7B的FLOPs与序列长度对比
C.3 通信估计
通信时间$T_{comm}$通过离线分析建模。对于$T_{comm}$,我们可以简单地通过离线方式进行分析。具体来说,当通信量小于某个阈值时,通信的固定开销主导了延迟。然而,随着大小的增加,固定开销变得可以忽略不计,延迟近似与通信量成正比。我们可以通过简单的分析推导出阈值、固定开销和估计函数。如表5所示,我们绘制了通信性能分析结果。因此,我们可以根据不同硬件环境中的通信量V来拟合公式16。然后,我们可以根据不同模型配置下的序列长度S,如公式15所示,推导出通信量,其中$hidden_{kv}$和b分别表示键/值的隐藏维度和批大小。
$$Volume(S) = b * S_k * hidden_{kv}$$ $$T_{comm} = (\alpha V + T_{fixed})$$表5:集体通信延迟分析。
💬 评论讨论
欢迎在这里分享您的想法和见解!