面向非独立同分布数据的联邦知识蒸馏方法及装置转让专利

申请号 : CN202311714820.2

文献号 : CN117408330B

文献日 :

基本信息:

PDF:

法律信息:

相似专利:

发明人 : 田辉王欢郭玉刚张志翔

申请人 : 合肥高维数据技术有限公司

摘要 :

本申请涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置,其包括根据公共数据集进行随机采样,获取辅助数据集;基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型,本申请通过生成网络模型和局部模型蒸馏算法对客户端的深度学习模型进行优化,减少深度学习模型的优化目标与全局优化目标的偏差。

权利要求 :

1.一种面向非独立同分布数据的联邦知识蒸馏方法,其特征在于,所述方法包括:根据预设的公共数据集进行随机采样,获取辅助数据集;

基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;

将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;

控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;

控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型;

其中,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数;

所述对抗目标损失函数的计算公式为:

其中, 为所述辅助数据集中的数据样本,为所述噪声向量,为所述生成网络,和 分别代表所述生成网络 和所述鉴别网络 的模型参数;

所述互信息平滑损失函数的计算公式为:;

其中, 代表一次批处理过程中所述噪声向量 的数量;

所述相似度惩罚损失函数的计算公式为:;

其中,和 代表重复采样过程中不同的噪声向量;

其中,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:基于所述生成网络模型 生成的所述生成网络数据 和客户端的所述本地数据通过所述数据融合算法进行融合,得到所述融合数据 ;

其中,所述数据融合算法的计算公式为:;

其中, 为基于随迭代次数从最小值0增加到最大值0.5的动量参数, 为样本的伪标签, 和 为合成后的数据样本和标签;

其中,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:计算所述生成网络数据与所述本地数据之间的数量比例;

控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对所述深度学习模型进行优化训练,得到所述全局模型;

其中,所述局部模型蒸馏算法的计算公式为:;

其中,其中 为所述本地数据的样本数量, 为所述生成网络数据的样本数量,是代表客户端本地的深度学习模型 在所述生成网络数据 和所述融合数据 之间Kullback‑Leibler距离,为用于调整知识蒸馏强度的参数, 为所述生成网络数据中标签为 的样本数量, 则代表归一化指数函数;

其中,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:接收全体客户端深度学习模型的模型参数;

基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;

基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;

基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;

将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;

其中,所述集成模型的计算公式为:

其中, 是一个可学习参数并处于0到1之间, 则是用于控制权重参数正则化的程度, 代表客户端上的所述模型参数;

所述全局聚合蒸馏算法 的定义如下:

其中 代表所述全局模型, 代表所述集成模型, 为所述虚拟数据集中的数据样本。

2.根据权利要求1所述的方法,其特征在于,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述深度学习模型进行迭代优化,获取全部客户端的优化模型;

接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。

3.一种面向非独立同分布数据的联邦知识蒸馏装置,其特征在于,所述装置包括:数据采样模块,用于根据预设的公共数据集进行随机采样,获取辅助数据集;

生成网络模块,用于基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;

数据生成模块,用于将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;

数据融合模块,用于控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;

模型优化模块,用于控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型;

其中,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数;

所述对抗目标损失函数的计算公式为:

其中, 为所述辅助数据集中的数据样本,为所述噪声向量,为所述生成网络,和 分别代表所述生成网络 和所述鉴别网络 的模型参数;

所述互信息平滑损失函数的计算公式为:;

其中, 代表一次批处理过程中所述噪声向量 的数量;

所述相似度惩罚损失函数的计算公式为:;

其中,和 代表重复采样过程中不同的噪声向量;

其中,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:基于所述生成网络模型 生成的所述生成网络数据 和客户端的所述本地数据通过所述数据融合算法进行融合,得到所述融合数据 ;

其中,所述数据融合算法的计算公式为:;

其中, 为基于随迭代次数从最小值0增加到最大值0.5的动量参数, 为样本的伪标签, 和 为合成后的数据样本和标签;

其中,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型,包括:计算所述生成网络数据与所述本地数据之间的数量比例;

控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对所述深度学习模型进行优化训练,得到所述全局模型;

其中,所述局部模型蒸馏算法的计算公式为:;

其中,其中 为所述本地数据的样本数量, 为所述生成网络数据的样本数量,是代表客户端本地的深度学习模型 在所述生成网络数据 和所述融合数据 之间Kullback‑Leibler距离,为用于调整知识蒸馏强度的参数, 为所述生成网络数据中标签为 的样本数量, 则代表归一化指数函数;

其中,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型之后,还包括:接收全体客户端深度学习模型的模型参数;

基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;

基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;

基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;

将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;

其中,所述集成模型的计算公式为:

其中, 是一个可学习参数并处于0到1之间, 则是用于控制权重参数正则化的程度, 代表客户端上的所述模型参数;

所述全局聚合蒸馏算法 的定义如下:

其中 代表所述全局模型, 代表所述集成模型, 为所述虚拟数据集中的数据样本。

说明书 :

面向非独立同分布数据的联邦知识蒸馏方法及装置

技术领域

[0001] 本申请涉及数据安全技术领域,尤其是涉及一种面向非独立同分布数据的联邦知识蒸馏方法及装置。

背景技术

[0002] 随着互联网、物联网、云计算和大数据等各种技术的快速发展,企业面临海量的数据处理与分析,数据的搜集、共享、发布和分析过程中可能导致用户隐私信息的泄露,给用户带来巨大损失。同时,全球数据保护法规越来越严格,企业在使用数据过程中面临隐私泄露和数据违规风险。因此,隐私计算技术变得越发重要。
[0003] 联邦学习是一种新兴的人工智能技术,最初由谷歌在2016年提出,旨在解决个人数据在安卓手机端的隐私问题。该技术的设计动机是保护手机或平板计算机中用户的隐私数据,因此提出了一种数据不动模型动的新型分布式机器学习范式。联邦学习可以看成是一种分布式机器学习框架,与传统的分布式机器学习框架不同,其使用了加密技术,并且各方数据保存在本地。在联邦学习中,各个参与方(例如手机、平板计算机等设备)将本地数据进行计算和更新,然后将结果发送回中央服务器进行聚合。联邦学习体现了集中数据收集和最小化的原则,可以减轻传统集中式机器学习和数据挖掘方法带来的系统和统计层面上的隐私风险和通信效率开销。
[0004] 针对上述中的相关技术,由于联邦学习系统中各个客户端通过不同的硬件或软件设备收集并处理数据,因此客户端之间的数据分布往往是差异极其大的,并进一步导致各客户端深度学习模型的参数不一致。各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果。

发明内容

[0005] 为了改善各客户端深度学习模型的优化目标与全局优化目标存在偏差,在模型训练时会远离最优点,从而导致模型在效率、效果、隐私保护层面上都不能达到一个很好的效果的问题,本申请提供一种面向非独立同分布数据的联邦知识蒸馏方法及装置。
[0006] 第一方面,本申请提供的一种面向非独立同分布数据的联邦知识蒸馏方法,采用如下的技术方案:包括:
[0007] 根据预设的公共数据集进行随机采样,获取辅助数据集;
[0008] 基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
[0009] 将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
[0010] 控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
[0011] 控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户的深度学习模型进行优化训练,得到全局模型。
[0012] 可选的,所述优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。
[0013] 可选的,所述对抗目标损失函数的计算公式为:
[0014] ;
[0015] 其中, 为所述辅助数据集中的数据样本,为所述噪声向量,为所述生成网络,和 则分别代表所述生成网络 和所述鉴别网络 的模型参数。
[0016] 可选的,所述互信息平滑损失函数的计算公式为:
[0017] ;
[0018] 其中, 代表一次批处理过程中所述噪声向量 的数量。
[0019] 可选的,所述相似度惩罚损失函数的计算公式为:
[0020] ;
[0021] 其中,和 代表重复采样过程中不同的噪声向量。
[0022] 可选的,所述控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据,包括:
[0023] 基于所述生成网络模型 生成的所述生成网络数据 和客户端的所述本地数据 通过所述数据融合算法进行融合,得到所述融合数据 ;
[0024] 其中,所述数据融合算法的计算公式为:
[0025] ;
[0026] ;
[0027] ;
[0028] 其中, 为基于随迭代次数从最小值0增加到最大值0.5的动量参数, 为样本 的伪标签, 和 为合成后的数据样本和标签。
[0029] 可选的,所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型,包括:
[0030] 计算所述生成网络数据与所述本地数据之间的数量比例;
[0031] 控制客户端基于所述局部模型蒸馏算法、所述数量比例以及所述融合数据对生成网络进行优化训练,得到所述全局模型;
[0032] 其中,所述局部模型蒸馏算法的计算公式为:
[0033] ;
[0034] 其中,其中 为所述本地数据的样本数量, 为所述生成网络数据的样本数量,是代表客户端本地的深度学习模型 在所述生成网络数据 和所述融合数据 之间Kullback‑Leibler距离,为用于调整知识蒸馏强度的参数, 为所述生成网络数据中标签为 的样本数量, 则代表归一化指数函数。
[0035] 可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:
[0036] 若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述全局模型进行迭代优化,获取全部客户端的优化模型;
[0037] 接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。
[0038] 可选的,在所述控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对所述生成网络模型进行优化训练,得到全局模型之后,还包括:
[0039] 接收全体客户端深度学习模型的模型参数;
[0040] 基于每个客户端的所述模型参数通过可学习参数进行加权处理,得到集成模型;
[0041] 基于所述生成网络模型批量生成的所述生成网络数据,得到虚拟数据集;
[0042] 基于全局聚合蒸馏算法和集成模型,通过解耦所述生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
[0043] 将所述全局微调模型重新分发给各个客户端,控制每个客户端根据所述局部模型蒸馏算法以及所述融合数据、所述全局聚合蒸馏算法和所述集成模型对所述全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度;
[0044] 其中,所述集成模型的计算公式为:
[0045] ;
[0046] 其中, 是一个可学习参数并处于0到1之间, 则是用于控制权重参数正则化的程度, 代表客户端上的所述模型参数;
[0047] 所述全局聚合蒸馏算法 的定义如下:
[0048] ;
[0049] 其中 代表所述全局模型, 代表所述集成模型, 为所述虚拟数据集中的数据样本。
[0050] 第二方面,本申请还提供一种面向非独立同分布数据的联邦知识蒸馏装置,采用如下技术方案,包括:
[0051] 数据采样模块,用于根据预设的公共数据集进行随机采样,获取辅助数据集;
[0052] 生成网络模块,用于基于预设的优化函数以及所述辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
[0053] 数据生成模块,用于将所述生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入所述生成网络模型,得到生成网络数据;
[0054] 数据融合模块,用于控制客户端基于预设的数据融合算法、所述生成网络数据以及预设的本地数据进行数据融合,获取所述融合数据;
[0055] 模型优化模块,用于控制客户端根据预设的局部模型蒸馏算法以及所述融合数据对客户端的深度学习模型进行优化训练,得到全局模型。
[0056] 综上所述,本申请通过采用上述技术方案,服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对生深度学习模型进行优化训练,直至所有客户端依次对全局模型进行优化迭代后,客户端将全局模型发送至服务器,服务器再将全局模型进行平均加权处理后下发至所有客户端,从而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。

附图说明

[0057] 图1是本申请实施例中一种面向非独立同分布数据的联邦知识蒸馏方法的流程示意图。
[0058] 图2是本申请实施例中一种面向非独立同分布数据的联邦知识蒸馏装置的结构框图。
[0059] 附图标记说明:310、数据采样模块;320、生成网络模块;330、数据生成模块;340、数据融合模块;350、模型优化模块。

具体实施方式

[0060] 以下结合附图1‑2对本申请作进一步详细说明。
[0061] 本申请实施例公开一种面向非独立同分布数据的联邦知识蒸馏方法,知识蒸馏是获取高效小规模网络的一种新兴方法,其主要思想是将学习能力强的模型中的信息迁移到简单的模型中去,可以有效提取出数据中的潜在信息。
[0062] 本申请主要通过优化函数对生成网络进行预训练,得到生成网络模型,所有客户端基于生成网络模型以及本地数据对深度学习模型进行优化,得到全局模型,最后服务器将全局模型再下发至所有客户端,减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大大提高了深度学习模型图像分类任务的准确率。
[0063] 其中,客户端的深度学习模型可以是ResNet深度神经网络模型,ResNet深度神经网络模型是指:论文“Deep Residual Learning for Image Recognition”中提出的基于ResNet深度神经网络模型进行图像识别的方法,简称ResNet深度神经网络模型。
[0064] 参照图1,本申请实施例至少包括步骤S10至步骤S50。
[0065] S10,根据预设的公共数据集进行随机采样,获取辅助数据集。
[0066] 其中,本申请实施例中所采用的公共数据集为CIFAR‑10和CIFAR‑100数据集,也可以使用其他数据集。
[0067] 应当理解的是,由于参与模型训练的公共数据集都是符合独立同分布的数据集,但是这并不满足联邦学习系统中跨客户端本地数据之间非独立同分布的假设。因此,本申请基于狄利克雷分布来划分公共数据集,以满足跨客户端本地数据之间非独立同分布的要求。并且,由于是从公共数据集中随机采样,因此不会泄露各参与客户端的私有数据信息。
[0068] 本申请实施例在CIFAR‑10数据集上测试基于狄利克雷分布的非独立同分布数据划分算法,并进行可视化呈现,其中规定客户端数量 ,狄利克雷分布的参数向量满足 ,其中 。
[0069] S20,基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型。
[0070] 其中,生成网络和鉴别网络是生成对抗网络的组成部分,生成对抗网络由Ian Goodfellow 等人在2014年提出,它是一种深度神经网络架构,由一个生成网络和一个鉴别网络组成。生成网络产生『假』数据,并试图欺骗鉴别网络;鉴别网络对生成数据进行真伪鉴别,试图正确识别所有假数据。
[0071] S30,将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据。
[0072] S40,控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据。
[0073] 其中,本申请实施例中的本地数据是基于狄利克雷分布对公共数据集进行划分后平均分配至每个客户端的,每个客户端的本地数据数量一致,但内容和类别不一致。
[0074] S50,控制客户端根据预设的局部模型蒸馏算法以及融合数据对深度学习模型进行优化训练,得到全局模型。
[0075] 具体来说,服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对生成网络模型进行优化训练,直至所有客户端依次对全局模型进行优化迭代后,客户端将全局模型发送至服务器,服务器再将平均加权处理后全局模型下发至所有客户端,从而所基于生成网络模型对本地的深度学习模型进行迭代优化,进而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
[0076] 实际来说,对于联邦学习中客户端 的深度学习模型而言,其定义为 (其模型参数为 ),对于辅助数据集 而言,其中每个样本 是从初始的公共数据集 中随机采样得到。需要注意的是,客户端 上的本地数据集 是符合非独立同分布的,同时客户端总数量为 且全局模型被定义为 (其模型参数为 )。
[0077] 在一些实施例中,针对中央服务器而言,基于辅助数据集 使用数据样本 和基于高斯噪声初始化的噪声向量 ,通过对抗目标损失函数 来训练出一个轻量级生成网络模型。对抗目标损失函数 的计算公式为:
[0078] ;
[0079] 其中 和 分别代表生成网络 和鉴别网络 的模型参数,注意在训练生成器模型的过程中输入样本 可以是真实数据 或者是之前生成器所生成的旧数据 。
[0080] 在一些实施例中,考虑到随机抽样的辅助数据集 的样本数量过少,为了减少在生成网络模型的训练过程中出现模式崩溃等问题,本申请实施例从互信息的角度出发,将鉴别网络 视为一个分类模型,然后通过互信息平滑损失函数 来最大化生成网络数据的平均信息熵,以此达到平衡生成网络模型类分布的目的。互信息平滑损失函数 的计算公式为:
[0081] ;
[0082] 其中 代表一次批处理过程中噪声向量 的数量,通过互信息平滑损失函数可以使得基于生成器 生成的数据的类别信息更加平衡。
[0083] 在一些实施例中,为了进一步增强生成网络模型生成的生成网络数据的多样性,本申请实施例从重采样的角度提出了相似度惩罚损失函数 ,即考虑到不同的噪声向量和 ,基于相似度惩罚损失函数 在生成相似类别的同时扩大 和 之间的距离。的相似度惩罚损失函数 的计算公式为:
[0084] ;
[0085] 通过相似度惩罚损失函数可以使得生成器 有效地生成同一类别的不同样本。
[0086] 进一步的,基于对抗目标损失函数 、互信息平滑损失函数 、以及相似度惩罚损失函数 可以得到生成网络的优化函数 ,基于此优化目标使得生成网络可以生成更多样化且更清晰的数据样本。优化函数 的计算公式为:
[0087] ;
[0088] 通过优化函数 基于辅助数据集 训练生成网络,从而得到生成网络模型。
[0089] 在一些实施例中,服务器将预训练的生成网络模型发送给参与训练的各客户端,对于客户端 而言,基于生成网络模型生成的生成网络数据 和客户端的本地数据 通过动量数据融合算法进行融合,可以得到融合数据 。动量数据融合算法的计算公式为:
[0090] ;
[0091] ;
[0092] 其中, 为基于随迭代次数从最小值0增加到最大值0.5的动量参数,为样本 的伪标签,和 为合成后的数据样本和标签,其有效保留了生成网络数据 和本地数据 的类别信息。
[0093] 接着,客户端计算生成网络数据与本地数据在融合数据中所占的比例,并在客户端模型局部训练中应用该比例对损失进行加权计算。然后,客户端将合成数据 和 视为一种先验信息,基于局部模型蒸馏算法并设计优化目标,从而对客户端 的本地模型 进行优化。局部模型蒸馏算法的计算公式 为:
[0094] ;
[0095] 其中,其中 为本地数据的样本数量, 为生成网络数据的样本数量, 是代表客户端本地的深度学习模型 在生成网络数据 和融合数据 之间Kullback‑Leibler距离,为调整知识蒸馏强度的参数, 为生成网络数据中标签为 的样本数量, 代表归一化指数函数。例如:生成网络数据有20个样本,本地数据有80个样本,那么计算损失的时候, 目标函数需要乘上80/(20+80)=0.8。
[0096] 通过数据融合算法和局部模型蒸馏算法对深度学习模型进行优化,大大增加了深度学习模型对本地数据的拟合度。
[0097] 进一步的,若存在多个客户端,则控制每个客户端通过局部模型蒸馏算法、数据融合算法对深度学习模型进行迭代优化,获取全部客户端的优化模型;接收所有客户端的优化模型,并根据优化模型进行平均加权处理,得到全局模型,从而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
[0098] 在一些实施例中,服务器接收全体客户端深度学习模型的模型参数,基于每个客户端的模型参数通过可学习参数进行加权处理,得到集成模型 ,集成模型 的定义如下:
[0099] ;
[0100] 其中, 是一个可学习的参数并处于0到1之间,则是用于控制权重参数正则化的程度, 代表客户端 上的模型参数。
[0101] 接着,服务器基于生成网络模型批量生成的生成网络数据,获取一个虚拟数据集,并基于全局聚合蒸馏算法和集成模型,通过解耦数据中的类别信息从而对全局模型进行微调。全局聚合蒸馏算法 的定义如下:
[0102] ;
[0103] 其中 代表全局模型, 代表客户端的集成模型。
[0104] 最后,基于虚拟数据集 ,通过全局聚合蒸馏算法微调全局模型 ,重复上述的步骤,控制每个客户端根据局部模型蒸馏算法以及融合数据、全局聚合蒸馏算法和集成模型对全局微调模型进行优化训练,直至所述全局微调模型收敛或者达到指定精度,可以有效消除由于全局更新引入的模型聚合漂移问题。
[0105] 本申请实施例一种面向非独立同分布数据的联邦知识蒸馏方法的实施原理为:服务器根据公共数据集进行随机采样,并根据辅助数据集和优化函数对生成网络进行预训练,获取生成网络模型,服务器再将生成网络模型发送至客户端,客户端根据噪声向量输出对应的生成网络数据,客户端根据将本地数据和生成网络数据通过数据融合算法进行动量融合,并根据局部蒸馏算法以及融合数据对深度学习模型进行优化训练,同时客户端通过全局聚合蒸馏算法对全局模型进行微调,直至所有客户端依次对深度学习模型进行优化迭代后,获取优化模型,客户端将优化模型发送至服务器,服务器再将全部的优化模型进行平均加权处理,得到全局模型,最后将全局模型下发至所有客户端,从而便于所有的客户端基于全局模型对本地的深度学习模型进行迭代优化,进而减少深度学习模型训练出现偏差的问题,以减少各客户端的深度学习模型的优化目标与全局优化目标的偏差,大幅提升了深度学习模型图像分类任务的准确率。
[0106] 下面结合仿真实验对本申请的效果做进一步的说明:
[0107] 仿真实验条件:
[0108] 本申请仿真实验的硬件平台为:一个中心服务器计算机,处理器为Intel至强E3‑1231V3,主频为3.6GHz,内存64GB,英伟达GeForce RTX 3090显卡。三台客户端计算机,处理器为Intel(R) Core(TM) i7‑9700F,主频为3.0GHz,内存16GB,英伟达GeForce RTX 2060显卡。
[0109] 本申请仿真实验的软件平台为:Ubuntu 16.04 LTS,64位操作系统、Python 3.8、PyTorch深度学习框架(版本1.11.0)以及PyCharm代码编写软件。
[0110] 仿真实验内容及其结果分析:
[0111] 本申请仿真实验是采用本申请和一个现有技术(ResNet神经网络)分别对两种常见的图像分类数据集(CIFAR‑10数据集和CIFAR‑100数据集)进行图像预测任务,并获得分类预测结果。其中,在本申请实验中,划分的训练集和测试集的比例为7:3。
[0112] 为了验证本申请实验的效果,采用全局模型在测试数据集上的预测分类准确率作为定量评价指标,对经过本方法和其他方法训练的模型进行评价。
[0113] 在本方法的仿真实验中,其他方法分别为联邦平均聚合算法(FedAvg)、联邦优化算法(FedProx)、联邦归一化平均算法(FedNova)、联邦终身学习算法(FedCurv)、联邦融合集成算法(FedDF)和联邦无数据知识蒸馏算法(FedGEN)。
[0114] 在本方法的仿真实验中,代表基于狄利克雷分布划分后的数据集的非独立同分布程度的大小,其中 如果越小,则数据非独立同分布程度越大。
[0115] 从表1中可以看出,本申请方法与其他方法相比,通过本方法训练后的模型在不同的数据集和数据不平衡程度上实现了更高的分类预测准确率,特别是在CIFAR‑100数据集上,虽然其训练数据复杂且严重不平衡,但通过本申请方法训练后的全局模型仍然取得了优异的预测精度。
[0116]
[0117] 以上仿真实验表明:本申请提出了一种面向非独立同分布数据的联邦知识蒸馏方法,分别在本地客户端和中央服务器上通过局部模型蒸馏和全局聚合蒸馏算法,解决了现有技术中处理非独立同分布数据时可能存在的模型训练偏差问题,以及中央服务器上存在的模型聚合漂移问题。
[0118] 图1为一个实施例中面向非独立同分布数据的联邦知识蒸馏方法的流程示意图。应该理解的是,虽然图1的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行;除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,这些步骤可以以其它的顺序执行;并且图1中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,这些子步骤或者阶段的执行顺序也不必然是依次进行,而是可以与其它步骤或者其它步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
[0119] 基于相同的技术构思,参照图2,本申请实例还提供了一种面向非独立同分布数据的联邦知识蒸馏装置,采用如下技术方案,该装置包括:
[0120] 数据采样模块310,用于根据预设的公共数据集进行随机采样,获取辅助数据集;
[0121] 生成网络模块320,用于基于预设的优化函数以及辅助数据集对预设的生成网络和鉴别网络进行预训练,获取生成网络模型;
[0122] 数据生成模块330,用于将生成网络模型发送至客户端,并控制客户端将预设的噪声向量输入生成网络模型,得到生成网络数据;
[0123] 数据融合模块340,用于控制客户端基于预设的数据融合算法、生成网络数据以及预设的本地数据进行数据融合,获取融合数据;
[0124] 模型优化模块350,用于控制客户端根据预设的局部模型蒸馏算法以及融合数据对客户端的深度学习模型进行优化训练,得到全局模型。
[0125] 在一些实施例中,优化函数至少包括以下一个算法或多个算法相加的组合:对抗目标损失函数、互信息平滑损失函数和相似度惩罚损失函数。
[0126] 在一些实施例中,对抗目标损失函数的计算公式为:
[0127] ;
[0128] 其中, 为辅助数据集中的数据样本,为噪声向量,为生成网络,和 分别代表生成网络 和鉴别网络 的模型参数。
[0129] 在一些实施例中,互信息平滑损失函数的计算公式为:
[0130] ;
[0131] 其中, 代表一次批处理过程中噪声向量 的数量。
[0132] 在一些实施例中,相似度惩罚损失函数的计算公式为:
[0133] ;
[0134] 其中,和 代表重复采样过程中不同的噪声向量。
[0135] 在一些实施例中,数据融合模块340具体用于基于生成网络模型 生成的生成网络数据 和客户端的本地数据 通过数据融合算法进行融合,得到融合数据 ;
[0136] 其中,数据融合算法的计算公式为:
[0137] ;
[0138] ;
[0139] ;
[0140] 其中,为基于随迭代次数从最小值0增加到最大值0.5的动量参数, 为样本的伪标签,和 为合成后的数据样本和标签。
[0141] 在一些实施例中,数据融合模块340还用于计算生成网络数据与本地数据之间的数量比例;
[0142] 控制客户端基于局部模型蒸馏算法、数量比例以及融合数据对生成网络进行优化训练,得到全局模型;
[0143] 其中,局部模型蒸馏算法的计算公式为:
[0144] ;
[0145] 其中,其中 为本地数据的样本数量, 为生成网络数据的样本数量, 是代表客户端本地的深度学习模型 在生成网络数据 和融合数据 之间Kullback‑Leibler距离,为用于调整知识蒸馏强度的参数, 为生成网络数据中标签为 的样本数量,代表归一化指数函数。
[0146] 在一些实施例中,模型优化模块350还用于若存在多个客户端,则控制每个客户端通过所述局部模型蒸馏算法、所述数据融合算法对所述深度学习模型进行迭代优化,获取全部客户端的优化模型;
[0147] 接收所有客户端的所述优化模型,并根据所述优化模型进行平均加权处理,得到所述全局模型。
[0148] 在一些实施例中,模型优化模块350还用于接收全体客户端深度学习模型的模型参数;
[0149] 基于每个客户端的模型参数通过可学习参数进行加权处理,得到集成模型;
[0150] 基于生成网络模型批量生成的生成网络数据,得到虚拟数据集;
[0151] 基于全局聚合蒸馏算法,通过解耦生成网络数据中的类别信息对全局模型进行微调,得到全局微调模型;
[0152] 将全局微调模型重新分发给各个客户端依次进行迭代优化,直至全局微调模型收敛或者达到指定精度;
[0153] 其中,集成模型的计算公式为:
[0154] ;
[0155] 其中, 是一个可学习参数并处于0到1之间,则是用于控制权重参数正则化的程度, 代表客户端上的模型参数;
[0156] 全局聚合蒸馏算法 的定义如下:
[0157] ;
[0158] 其中 代表全局模型, 代表集成模型, 为虚拟数据集中的数据样本。
[0159] 本申请实例还公开一种控制设备。
[0160] 具体来说,该控制设备包括存储器和处理器,存储器上存储有能够被处理器加载并执行上述面向非独立同分布数据的联邦知识蒸馏方法的计算机程序。
[0161] 本申请实例还公开一种计算机可读存储介质。
[0162] 具体来说,该计算机可读存储介质,其存储有能够被处理器加载并执行如上述面向非独立同分布数据的联邦知识蒸馏方法的计算机程序,该计算机可读存储介质例如包括:U盘、移动硬盘、只读存储器(Read‑OnlyMemory,ROM)、随机存取存储器(RandomAccessMemory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
[0163] 以上均为本申请的较佳实施例,并非依此限制本申请的保护范围,故:凡依本申请的结构、形状、原理所做的等效变化,均应涵盖于本申请的保护范围之内。