SampleMix: A Sample-wise Pre-training Data Mixing Strategey by Coordinating Data Quality and Diversity
文章标题:SampleMix:一种通过协调数据质量和多样性的样本级预训练数据混合策略
作者/机构:Xiangyu Xi, Deyang Kong, Jian Yang, JiaWei Yang, Zhengyu Chen, Wei Wang, Jingang Wang, Xunliang Cai (美团), Shikun Zhang, Wei Ye (北京大学)
A1 主要贡献
本文研究了大型语言模型(LLM)预训练数据的混合策略问题。现有的数据混合方法通常遵循一种“领域为单位”(domain-wise)的自顶向下方法,即先确定各个数据领域的权重,然后从每个领域中均匀采样数据。作者指出,这种方法存在两个关键问题:
- 忽略了领域间的重叠和共性:当前预训练数据集的“领域”主要基于数据来源划分,而非文本内在属性。作者通过对 SlimPajama 数据集的聚类分析发现,不同领域之间存在显著重叠(如图1所示)。例如,关于“爱因斯坦和相对论”的主题,样本可能同时出现在 ArXiv、C4、CommonCrawl 等多个领域中(如图7所示)。领域为单位的方法忽略了这种共性,无法有效控制整个训练数据集的全局多样性。
- 领域内样本分布次优:在每个领域内进行均匀采样会导致训练样本的分布不理想。直观上,质量更高、多样性更大的样本应该有更高的被选中概率,而低质量样本虽不应完全丢弃,但也需控制其比例。现有方法缺乏对样本选择的细粒度控制。
图1:我们对SlimPajama数据集进行了数据聚类分析。对于每个领域(行),每个单元格显示其聚类中也包含来自其他领域(列)样本的百分比。例如,76.60%的ArXiv聚类包含了维基百科的样本(第一行,第六列)。结果揭示了领域之间存在大量重叠。
为了解决这些局限性,本文提出了一种新颖的、自底向上的“样本为单位”(sample-wise)数据混合方法 SampleMix。该方法不再预先设定领域比例,而是首先基于每个样本的质量和多样性进行全局采样,从而动态地确定领域分布。具体实现上,SampleMix 对每个样本进行独立的质量和多样性评估,并根据评估结果分配相应的采样权重。在给定的目标token预算下,根据每个样本的权重进行采样,构建出最优的训练数据集。此外,该方法还能动态适应不同的token预算,为特定预算确定最佳数据比例,而大多数现有工作依赖于静态的数据比例。
本文的主要贡献如下:
1. 研究了样本级预训练数据混合问题,该问题可以缓解现有领域级混合工作忽略领域间重叠和领域内样本分布次优的限制。
2. 提出了一种样本级预训练数据混合策略(SampleMix),该策略在每个样本的基础上协调数据质量和多样性,有效捕捉了领域间的共性和最优样本分布。
3. 在下游任务和困惑度评估上的大量实验证明了该方法的优势。值得注意的是,SampleMix 达到基线平均准确率所需的训练步数减少了1.9倍,显示了其高效性。
A2 方法细节
2.1 问题提出
问题定义与方法对比。假设源数据集 $D_{src}$ 由 k 个不同领域(如 CommonCrawl、Wikipedia、BookCorpus 等)组成。对于每个领域 i,令 $D_i$ 表示该领域内的文档集合。整个源数据集定义为 $D_{src} \triangleq \{D_1, \dots, D_k\}$,总 token 数为 $T_{src}$。我们的目标是构建一个用于预训练的目标训练集 $D_{tgt}$,并遵循特定的 token 预算 $T_{tgt}$(例如 100B token)。如图2所示,传统方法在不明确考虑总 token 预算的情况下确定领域权重,并通过从每个领域根据这些权重进行均匀采样来构建 $D_{tgt}$。相比之下,我们提出的方法 SampleMix 通过评估每个文档的质量(§ 2.2)和多样性(§ 2.3)来增强这一过程。利用这些双重标准,SampleMix 为每个文档分配唯一的采样权重。为确保符合 token 预算 $T_{tgt}$,我们随后通过根据分配的权重采样文档来构建一个最优的训练数据集(§ 2.4)。
图2:(a) 传统方法确定领域权重,并通过从每个领域均匀采样来构建训练数据集。(b) SampleMix 采用样本级混合策略,通过评估样本质量和多样性,分配适当的权重,并基于这些权重构建最优数据集。相同颜色的点代表来自同一领域的数据。
2.2 数据质量评估
现有数据质量评估方法的局限性。训练数据的质量对大型语言模型至关重要。然而,大多数现有研究通常依赖于简单的启发式规则【索引33,Data selection for language models via importance resampling+2023+NeurIPS; 索引17,Textbooks are all you need ii: phi-1.5 technical report+2023+arXiv; 索引23,How to train data-efficient llms+2024+arXiv】。Wettig 等人(2024)【索引30,Qurating: Selecting high-quality data for training language models+2024+arXiv】引入了四个指标,并使用成对比较来训练一个评估器模型。然而,这些指标在数据选择中是分开应用的,并且成对训练可能会忽略决定样本质量的客观因素。
2.2.1 质量标准
定义多维度质量评估标准。为了全面捕捉文本的基本语言属性以及更深层次的信息和分析质量,我们主张高质量数据应遵循以下原则:语言的精确性和清晰度、结构的连贯性和完整性、内容的可靠性和适当性、信息和教育价值,以及重要性和原创性。为了有效评估这些方面,我们基于上述原则提出了7个质量维度及相应的分数,如表1所示。值得注意的是,对于知识丰富度和逻辑与分析深度,我们使用了更大的评分范围{0, 1, 2},以应对这些特征中更宽的范围和更大的复杂性。通过汇总所有维度的分数,我们获得了每个样本的总体质量评估,范围从0到10。
表1:质量维度和分数。
2.2.2 质量评估器
构建基于序数回归的质量评估器。为了开发一个有效且高效的质量评估器,我们利用 GPT-4o 根据预定义的质量标准(提示见图10)来评估训练数据。具体来说,我们从 SlimPajama 数据集中均匀采样了42万个文档,分别分配了41万和1万个文档用于训练集和测试集。我们使用 gte-en-mlm-base 模型【索引35,Mgte: generalized long-context text representation and reranking models for multilingual text retrieval+2024+arXiv】作为骨干网络来训练质量评估器。我们采用序数回归而非文本分类任务,以利用质量分数的内在顺序。遵循 Niu 等人(2016)【索引21,Ordinal regression with multiple output cnn for age estimation+2016+CVPR】的方法,我们将序数回归转化为一系列二元分类问题,每个问题指示输入数据是否超过特定的质量阈值。然后通过减去二元输出序列来得出总体质量分数(代码见附录F)。
评估器性能评估。我们在测试集上评估了训练好的质量评估器,如表2所示。我们不仅依赖准确率(ACC),还考虑了均方误差(MSE)和平均绝对误差(MAE),这些指标能更准确地反映真实质量分数与预测结果之间的偏差程度。虽然文本分类和序数回归方法都达到了相似的准确率,但序数回归方法在MSE和MAE方面表现更优。我们注意到准确率低于预期;详细分析表明,大多数错误预测都落在真实质量分数的±1范围内。为了解决这个问题,我们引入了邻近准确率(Close Accuracy, CACC),这是一个宽松的度量标准,如果预测值在真实质量分数的±1范围内,则认为预测正确。CACC结果表明,我们的模型对不同质量的样本具有令人满意的区分能力。
表2:文本分类和序数回归模型在测试集上的性能比较。
2.2.3 质量分布分析
SlimPajama数据集的质量分布分析。使用训练好的质量评估器,我们对 SlimPajama 数据集进行了标注,得到的质量分布如图3a所示,从中我们可以发现:(1) Arxiv 和 Book 数据源表现出更高的质量,这与预期相符。(2) Wikipedia 通常被认为是高质量来源;然而,有相当一部分质量较低。我们的人工检查表明,这些低质量样本通常包含简短内容、解析错误、不完整内容等问题。(3) 总体而言,CommonCrawl 数据集的质量优于 C4(平均质量得分:5.65 vs 4.20)。
图3:SlimPajama数据集分析。虚线标记了平均值。
2.3 数据多样性评估
基于聚类的数据多样性评估思路。受 Shao 等人(2024a)【索引26,Balanced data sampling for language model training with clustering+2024+Findings of the Association for Computational Linguistics: ACL 2024】和 Abbas 等人(2024)【索引1,Effective pruning of web-scale datasets based on complexity of concept clusters+2024+arXiv】的启发,我们采用数据聚类来捕捉训练数据集中的文本分布。通过对聚类样本的详细分析,我们观察到与 Abbas 等人(2024)关于图像数据的研究一致的模式,具体来说:(1) 更密集的聚类其构成样本之间具有更高的相似性;(2) 与其他聚类邻近的聚类更可能包含与邻近聚类中样本相似的样本。为了量化数据多样性,我们使用多样性评估器为每个样本估计一个多样性度量。
2.3.1 多样性评估器
数据聚类。我们首先为每个样本生成嵌入向量,随后通过 K-Means 将它们组织成簇,从而有效地根据文本相似性构建数据结构。数据聚类的细节可以在附录G中找到。
簇紧凑度(Cluster Compactness)。我们通过计算簇内成员到质心的平均距离来评估簇的密度,称之为簇紧凑度。较小的平均距离意味着簇更紧凑,表明其构成样本之间的相似性更高。该指标有效地揭示了簇的密集特性。
簇分离度(Cluster Separation)。我们通过测量每个簇的质心与其他簇质心之间的距离来评估其独特性,称之为簇分离度。较大的距离意味着更大的分离度,表明该簇与其他簇的区别更明显,突显了其在全局范围内的独特性。
数据多样性计算。最后,通过整合样本所在簇的分离度和紧凑度来估计每个样本 $x_i$ 的多样性,如下所示:
其中 $x_i$ 属于第 j 个簇, $d_{compactness,j}$ 和 $d_{separation,j}$ 分别代表第 j 个簇的簇紧凑度和簇分离度。这种复合多样性度量有效地封装了簇内的同质性和簇间的独特性,提供了对数据多样性的全面评估。
2.3.2 多样性分布分析
SlimPajama数据集的多样性分布分析。我们研究了 SlimPajama 数据集内的多样性分布,如图3b所示。我们可以发现:(1) 在单个领域内,样本的多样性可能有显著差异。例如,C4 的多样性分布近似于正态分布,表明该领域内存在一致的可变性。(2) 在 SlimPajama 数据集中,不同领域的多样性差异显著。具体来说,C4、CommonCrawl 和 Book 领域表现出最高的多样性水平,这与预期相符。相比之下,StackExchange 领域在所考察的领域中表现出最低的多样性。
2.4 数据采样
2.4.1 采样权重计算
计算样本的采样权重。给定每个文档的质量和多样性评估,我们首先对这两个度量进行最小-最大归一化,以确保它们位于区间 [0, 1] 内,然后按如下方式计算采样权重:
其中 $q(x_i)$ 和 $d(x_i)$ 分别表示文档 $x_i$ 的质量和多样性度量,而 $\alpha \in [0, 1]$ 是一个加权因子,用于平衡多样性相对于质量的贡献。
2.4.2 确定采样频率
确定每个文档的采样频率。给定包含 $|D_{src}|$ 个文档和 $T_{src}$ 个 token 的源数据集 $D_{src}$,我们首先按如下方式估计 $D_{tgt}$ 的目标文档数:
然后,我们使用基于 softmax 的分布来计算每个文档的采样频率 $c(x_i)$,将采样权重转化为概率:
其中 $\tau$ 是调节 softmax 分布的温度参数,控制采样概率的集中程度。
2.4.3 构建训练数据集
将非整数采样频率转换为整数计数。由于 $c(x_i)$ 通常为非整数值,我们通过以下两步过程将这些频率转换为整数计数:
- 整数部分:始终对文档采样 $\lfloor c(x_i) \rfloor$ 次。例如,如果 $c(x_i) = 2.3$,则该文档被采样2次。
- 小数部分:剩余的小数部分 $(c(x_i) - \lfloor c(x_i) \rfloor)$ 用于概率性地确定是否额外采样一次。继续该示例,当 $c(x_i) = 2.3$ 时,有30%的概率 $x_i$ 会被第三次采样,这是通过将小数部分与一个随机生成的数进行比较来确定的。
方法优势总结。通过汇总每个文档 $x_i$ 的采样计数,我们组装出最终的训练数据集 $D_{tgt}$,该数据集紧密匹配目标 token 预算 $T_{tgt}$。我们的方法提供了关键优势:(1) 优先考虑质量和多样性:通过将质量和多样性指标纳入采样权重,SampleMix 确保优先选择高质量和多样化的文档,从而提高训练数据集的整体有效性。(2) 适应训练预算:采样机制动态适应不同的 token 预算 $T_{tgt}$,在无需手动调整的情况下保持质量和多样性之间的最佳平衡。(3) 灵活的领域表示:通过允许在同一领域内采用不同的采样率,该方法支持对不同领域进行更细致的表示。
A4 实验环境
- 数据集:实验使用 SlimPajama 数据集,该数据集是 RedPajama 的一个经过严格增强的版本,包括NFC归一化、长度过滤和全局去重。实验中排除了 GitHub 领域,并对基线方法的权重进行了重新归一化。
- 模型架构:从头开始训练了 10 亿(1B)和 80 亿(8B)参数的 LLaMA 模型,训练数据量为 1000 亿(100B)token。详细的模型架构和训练超参数分别在附录的表6和表7中提供。
- 硬件配置:论文未明确提供具体的硬件配置信息。
- 软件配置:
- 基线方法:与多个基线方法进行比较,包括 Vanilla、DoReMi、CE、BiMIX-OPT、DoGE 和 DML。
- 评估工具:使用 LM-eval Harness 进行下游任务的 5-shot 准确率评估。
- SampleMix 设置:对于 1B 模型实验,权重因子 $\alpha$ 设置为0.8,温度参数 $\tau$ 设置为0.2。
A4 实验结果
4.1 主要结果
下游任务与困惑度评估:如表3所示,SampleMix 在各项评估中表现出色:
1. 在8个下游任务中,SampleMix 取得了最高的平均准确率(47.77%),在其中5个任务上排名第一。
2. 在困惑度评估中,SampleMix 在 Pile 数据集(25.63)和 xP3 数据集(46.38)上均取得了最低的困惑度分数,显示了其在语言建模任务上的优势。
表3:不同数据混合方法在各种下游任务和困惑度评估中的比较。每个指标的最佳性能方法以粗体突出显示,次佳方法以下划线标出。
训练效率:如图4所示,SampleMix 在收敛速度上显著优于基线方法。它达到基线平均准确率所需的训练步数比基线方法少1.4倍至2.1倍,平均快1.9倍(在100k步时达到)。
图4:训练效率比较。SampleMix在10万训练步数时达到基线平均准确率——比基线平均快1.9倍。
在更大模型上的有效性:为了验证方法在更大模型上的效果,研究团队使用排名前三的基线方法和 SampleMix 训练了8B参数模型。如表4所示,SampleMix 显著优于基线方法,其优势与在1B模型上观察到的一致。
表4:与8B模型的性能比较。
这些结果共同证明,SampleMix 不仅提升了模型的整体性能,还提高了训练效率,是一种强大有效的数据混合优化方法。
4.2 质量和多样性的有效性
权重因子 $\alpha$ 的影响:为了探究质量和多样性评估的有效性,研究团队通过网格搜索系统地改变了权重因子 $\alpha$(从0.0到1.0)。如图5所示,实验结果表明:
1. 多样性的重要性:当 $\alpha=0.0$(仅考虑质量)时,模型准确率最低(45.53%)。随着 $\alpha$ 从0.0增加到0.8,准确率稳步提升至峰值47.77%,这突显了多样性在实现均衡数据混合和全面数据覆盖中的关键作用。
2. 质量的必要性:当 $\alpha=1.0$(仅考虑多样性)时,准确率略微下降至47.58%,表明虽然多样性至关重要,但结合质量度量可以进一步提升性能。
3. 最佳权重:在 $\alpha=0.8$ 时性能最佳,说明优先考虑多样性同时兼顾质量能带来最有效的模型性能。研究团队认为这可能因为多样性度量的绝对值本身较小,且SlimPajama数据集已经过严格的质量过滤。
4. 使用建议:用户应根据其数据集的特点调整 $\alpha$。对于质量较低的数据集,应优先考虑质量;对于高质量数据集,则建议增大 $\alpha$ 以确保多样性。
图5:不同权重因子α下下游任务的平均性能。
4.3 对不同 Token 预算的适应性
动态适应预算的优势:大多数现有方法提供固定的数据比例,难以适应不同训练阶段(如预训练、持续预训练)所需的不同token预算。为了评估SampleMix的动态适应能力,研究团队将源数据集缩减至原始大小的1/5(约100B token),使得训练预算与源数据大小相等($T_{tgt} = T_{src}$),而之前的实验是 $T_{tgt} = \frac{1}{5} T_{src}$。
如表5所示,实验结果表明:
1. SampleMix 仍取得了最高的平均准确率(47.46%),证明其能有效适应不同的token预算。
2. 基线方法的表现在token预算变化时表现不一致。例如,之前表现最好的DoReMi在此设置下不如Vanilla和CE,表明基线方法难以有效适应不同的token预算。
表5:以100B数据作为候选池时,不同数据混合方法的性能比较。
适应机制分析:图6分析了不同预算下SampleMix的采样计数。
- 当源数据充足时($T_{tgt} = \frac{1}{5} T_{src}$),SampleMix能精确选择高权重样本,很少需要上采样(采样次数>1)。
- 当源数据有限时($T_{tgt} = T_{src}$),高权重样本不足以满足预算,SampleMix会纳入部分低权重样本,但仍能有效识别并丢弃价值最低的数据(占数据集的18.245%),同时对高权重数据进行更频繁的上采样。这也验证了SampleMix能有效利用采样空间并适应不同的token预算。
(b) 不同采样计数的采样权重(即p(x))。
图6:不同采样计数的分析。
A5 结论
本文提出了 SampleMix,一种通过协调数据质量和多样性的样本级预训练数据混合策略。大量的实验证明,SampleMix 的性能优于现有的领域级方法,并且在达到相当准确率的情况下,训练步数减少了1.9倍。未来,研究团队有兴趣将模型视角的自动评估指标纳入考量,以补充当前手动设计的度量标准,并探索代码数据的混合方法。
A6 附录
A 领域重叠
手动检查聚类样本。我们手动检查了同一聚类中但来自不同领域的样本。这些样本通常在主题上相关,并且在结构、语义和上下文方面相似。如图7所示,这些样本都讨论了关于爱因斯坦和相对论的话题。
Arxiv
广义相对论在能量和动量的局域守恒定律方面存在一些问题。这是1915年至1918年间希尔伯特、克莱因、诺特和爱因斯坦之间一次关键讨论的主题……
C4
术语mc2早在他9月26日介绍狭义相对论的论文中就已出现。11月21日的论文表明E=mc2适用于静止物体。[今日物理]……
CommonCrawl
广义相对论(GRT)的诞生,除其他原因外,源于能够使用任意坐标系来描述自然法则的需求。根据协变性原理,自然法则的形式不应决定性地依赖于特定坐标系的选择……
StackExchange
除了爱因斯坦,没人能确定他究竟是如何得出广义相对论的。从阅读各种当时的历史资料来看,在我看来,一旦爱因斯坦提出了等效原理,他便开始寻找体现这一原理的理论……
Wikipedia
《相对论的意义:1921年5月在普林斯顿大学的四次讲座》是普林斯顿大学出版社于1922年出版的一本书,汇编了阿尔伯特·爱因斯坦于1921年在普林斯顿大学发表的斯塔福德·利特尔系列讲座……
图7:来自不同领域的样本,均描述了与爱因斯坦和相对论相关的信息。
B Slimpajama CommonCrawl 中的样本
手动检查样本质量。我们手动检查了 Slimpajama CommonCrawl 中的低质量和高质量样本。如图8所示,CommonCrawl 的数据质量差异很大。低质量样本的特点是信息零散、组织混乱,主要由与体育新闻相关的零星标题和链接组成。另一方面,高质量样本提供了一段关于天体物理研究的连贯且信息丰富的摘录,展示了清晰且结构化的叙述。
低质量样本
新帖子 精选 搜索论坛
体育简报 (纽约时报)
主题发起人 articlebot
自行车新闻头条
articlebot
赛车。
http://us.rd.yahoo.com/dailynews/rss/search/cycling+racing/SIG=120pnaegk/*http
自行车新闻头条 2007年7月31日
自行车新闻头条 2007年6月2日
...
体育简报:棒球、高尔夫、赛马和自行车 (纽约时报)
自行车新闻头条 2006年6月26日
高质量样本
数十年的研究表明,大多数大质量星系的中心都有一个超大质量黑洞,黑洞的质量是周围恒星球体总质量的十分之一。来自天体物理中心 | 哈佛和史密森尼的两位天体物理学家提出了一种方法来观测可能是地球第二近的超大质量黑洞。
图8:来自Slimpajama CommonCrawl的低质量和高质量样本示例。
C 不同方法的领域权重
领域权重分布。图9展示了不同方法的领域权重。
图9:不同方法的领域权重。
D 训练模型的超参数
模型配置。1B 和 8B 参数模型的实验均遵循标准的 Transformer 架构,并使用了精心优化的超参数。表6和表7分别介绍了两种模型规模的架构配置和训练规格。
表6:实验中使用的1B模型的超参数。
表7:实验中使用的8B模型的超参数。
E 质量评估提示
GPT-4o 评估提示。用于 GPT-4o 评估训练数据质量的提示如图10所示。
图10:GPT-4o评估文本数据质量的提示。
F 质量评估器代码
序数回归模型实现。表8展示了用于实现质量评分任务的序数回归模型的Python代码,包括模型定义、损失函数计算和推理过程。完整代码可在补充材料中找到。
OrdinalRegressionModel类。OrdinalRegressionModel
类初始化了预训练的基础模型和一系列序数层。每个序数层输出质量分数大于特定阈值的概率。例如,第一个序数层(索引0)计算质量分数大于0的概率,即分数至少为1的概率。类似地,第二个序数层(索引1)计算质量分数大于1的概率,即分数至少为2的概率,依此类推。最后一个序数层(索引9)计算分数大于9的概率,这相当于分数恰好为10的概率。因此,该模型总共有10个序数层,每个对应于一个阈值。
损失函数。损失函数通过对预测概率和目标值之间的二元交叉熵损失求和来计算序数损失。对于每个序数层,会创建一个二元目标,指示真实分数是否大于该层对应的阈值。具体来说,预测分数与真实分数之间的偏差越大,损失就越高,这有助于模型在训练过程中专注于减少这些偏差。
predict函数。predict
函数使用训练好的序数回归模型进行推理。它首先计算每个类别的预测概率,然后通过选择概率最大的类别来计算最终的预测分数。该函数还计算了所有可能分数上的概率分布,为预测分数提供了置信度度量。
1 class OrdinalRegressionModel(nn.Module):
2 def __init__(self, base_model, num_classes=10):
3 super(OrdinalRegressionModel, self).__init__()
4 self.base_model = base_model
5 self.ordinal_layers = nn.ModuleList([
6 nn.Linear(base_model.config.hidden_size, 1) for _ in range(num_classes)
7 ])
8
9 def forward(self, input_ids, attention_mask):
10 outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
11 pooled_output = outputs.pooler_output
12 ordinal_outputs = [layer(pooled_output) for layer in self.ordinal_layers]
13 return torch.cat(ordinal_outputs, dim=1)
14
15 def ordinal_loss(logits, targets, num_classes=10):
16 targets = targets.long()
17 ordinal_targets = torch.zeros_like(logits, device=logits.device)
18 for i in range(num_classes):
19 ordinal_targets[:, i] = (targets > i).float()
20 return F.binary_cross_entropy_with_logits(logits, ordinal_targets, reduction='mean')
21
22 def predict(model, dataloader, device):
23 model.eval()
24 all_preds = []
25 all_probs = []
26 with torch.no_grad():
27 for batch in dataloader:
28 input_ids = batch['input_ids'].to(device)
29 attention_mask = batch['attention_mask'].to(device)
30 logits = model(input_ids, attention_mask)
31 probs = torch.sigmoid(logits)
32 cum_probs = torch.cumprod(probs, dim=1)
33 preds = torch.sum(cum_probs > 0.5, dim=1)
34 all_preds.extend(preds.cpu().numpy())
35 all_probs.extend(probs.cpu().numpy())
36 return all_preds, all_probs
表8:用于实现质量评分任务的序数回归模型的Python代码。
G K-means 聚类细节
聚类实现细节。对于§ 2.3中的数据聚类,我们为每个样本生成768维的嵌入。此外,我们将嵌入归一化,使其L2范数为1.0,并使用 faiss【索引15,Billion-scale similarity search with gpus+2019+IEEE Transactions on Big Data】执行K-means聚类。遵循 Tirumala 等人(2023)【索引29,D4: Improving llm pretraining via document de-duplication and diversification+2023+NeurIPS】和 Abbas 等人(2024)【索引1,Effective pruning of web-scale datasets based on complexity of concept clusters+2024+arXiv】的方法,我们将聚类数量设置为待聚类总点数的平方根。数据聚类的核心代码如表9所示。完整代码可在补充材料中找到。
1 # Calculate the number of clusters
2 n_centroids = int(math.sqrt(all_embeddings.shape[0]))
3 # define the parameters
4 kmeans = faiss.Kmeans(
5 d=768,
6 k=n_centroids,
7 niter=50, # 50 iterations
8 gpu=True,
9 seed=1024,
10 spherical=True,
11 min_points_per_centroid=1,
12 max_points_per_centroid=all_embeddings.shape[0]
13)
14 # perform data clustering
15 kmeans.train(all_embeddings)
表9:实现K-Means聚类的Python代码。
H 所有方法的收敛速度
收敛速度对比。图11显示了 SampleMix 和所有基线的完整比较。SampleMix 使用比基线少1.4倍至2.1倍的训练步骤即可达到其准确率。
图11:所有基线和SampleMix的收敛速度。SampleMix实现了最佳的训练效率。
I 采样计数分布分析
采样计数分布。图12a展示了每个领域的采样计数分布。尽管我们的目标训练预算 $T_{tgt}$ 约等于候选池的大小 $T_{src}$,我们的方法通过将质量和多样性最低的文档分配为零采样计数,从而策略性地丢弃了它们。这种方法与传统方法在所有文档上进行均匀采样形成对比。在图12b中,我们展示了与采样计数相对应的采样权重。结果表明,我们的方法为具有较大采样权重的样本分配了更高的采样计数,这与我们的预期一致。此外,采样计数的分布在不同领域之间表现出显著差异。这种可变性突显了我们的方法在捕捉不同领域间的细微变化和共性方面的有效性,确保了更细致和高效的采样过程。
(a) Ttgt = Tsrc 时不同采样计数的比例
(b) Ttgt = Tsrc 时不同采样计数的采样权重(即p(x))
图12:采样计数分析。
💬 评论讨论
欢迎在这里分享您的想法和见解!