一种基于多方3D打印数据库联合训练的方法转让专利

申请号 : CN202210284844.8

文献号 : CN114386336B

文献日 :

基本信息:

PDF:

法律信息:

相似专利:

发明人 : 荣鹏高鹏高川云杜娟

申请人 : 成都飞机工业(集团)有限责任公司

摘要 :

本发明公开了一种基于多方3D打印数据库联合训练的方法,在第j次训练迭代过程中,得到训练成员i的梯度矩阵Gi;训练成员i对梯度矩阵Gi中的元素按照绝对值大小进行排序,并选择前m个元素得到对应是稀疏矩阵,填充元素为0;计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;最后使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型Wi。本发明实现了多个数据库之间的联合训练,且基于相关度确定当前轮迭代的训练成员,降低了联合训练过程中数据的传输量,降低了数据传输带宽的需求和投入成本,具有较好的实用性。

权利要求 :

1.一种基于多方3D打印数据库联合训练的方法,其特征在于,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi, ,标签为yi;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签 ,进而根据实际标签yi,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵Gi;

步骤S200:训练成员i对梯度矩阵Gi中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵 ,填充元素为0;

步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;

步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型Wi;

在迭代训练之前,进行模型初始化:服务器对模型W进行初始化,并将初始化结果下发至所有的训练成员,对模型Wi进行初始化,确定梯度上传比例系数α、衰减系数ρ、学习率γ;

所述步骤S200中,统计得到模型Wi的中元素总个数为M,计算得到本次需要上传的梯度元素个数 ;

所述步骤S300中相关度计算如下:

其中:

其中:DKL为KL散度,

P表示各训练成员的自身数据质量,

Q表示服务器所有样本的数据质量;

所述步骤S300中,训练成员i计算得到加权模型参数 ,并利用秘密共享算法对 进行加密得到 ,并上传至服务器;

其中:

为加权模型参数;

为加密的加权模型参数;

所述步骤S400中,服务器更新模型 ,服务器将更新的模型下发至本地,并更新训练成员的模型Wi;

其中:t为模型更新次数,

γ为学习率,

K为上传数据的数训练成员数。

说明书 :

一种基于多方3D打印数据库联合训练的方法

技术领域

[0001] 本发明属于打印数据联合处理的技术领域,具体涉及一种基于多方3D打印数据库联合训练的方法。

背景技术

[0002] 上世纪八十年代,3D打印技术诞生了,3D打印并不仅限于传统的“去除”加工方法,而且3D打印是一种自下而上的制造方式,也称为增材制造技术,其实现了数学模型的建立。3D打印技术自诞生之日起就受到人们的广泛关注,因此获得了快速发展。近几十年来,3D打印技术已成为人们关注的焦点。工业设计,建筑,汽车,航空航天,牙科,教育领域等都被应用,但是其应用和开发仍然受到因素的限制。
[0003] 在3D打印实施过程中,由于3D打印相关参数太多,在实验过程中无法穷尽所有3D打印参数,并判断这些参数是否能够成型合适的零件,因此需要一种3D打印参数学习和预测的方式实现3D打印参数的预测。
[0004] 由于3D打印实验成本高昂,由一家企业或单位完成所有实验无太大可能,可以基于多个数据库共同训练得到更加精准的模型参数,这里,就涉及到多个数据库之间的保密问题。例如,A公司拥有n个数据,B公司拥有m个数据,双方均不想让对方知道自己的工艺参数,但又希望联合进行模型训练。因此,需要一种基于多方3D打印数据库联合训练的方法。

发明内容

[0005] 本发明的目的在于提供一种基于多方3D打印数据库联合训练的方法,旨在解决上述问题。
[0006] 本发明主要通过以下技术方案实现:
[0007] 一种基于多方3D打印数据库联合训练的方法,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi, ,标签为yi;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:
[0008] 步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签 ,进而根据实际标签yi,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵Gi;
[0009] 步骤S200:训练成员i对梯度矩阵Gi中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵 ,填充元素为0;
[0010] 步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;
[0011] 步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型Wi。
[0012] 为了更好地实现本发明,进一步地,在迭代训练之前,进行模型初始化:服务器对模型W进行初始化,并将初始化结果下发至所有的训练成员,对模型Wi进行初始化,确定梯度上传比例系数α、衰减系数ρ、学习率γ。
[0013] 为了更好地实现本发明,进一步地,所述步骤S200中,统计得到模型Wi的中元素总个数为M,计算得到本次需要上传的梯度元素个数 。
[0014] 为了更好地实现本发明,进一步地,所述步骤S300中相关度计算如下:
[0015]
[0016] 其中:
[0017]
[0018] 其中:DKL为KL散度,
[0019] P表示各训练成员的自身数据质量,
[0020] Q表示服务器所有样本的数据质量。
[0021] 为了更好地实现本发明,进一步地,所述步骤S300中,训练成员i计算得到加权模型参数 ,并利用秘密共享算法对 进行加密得到 ,并上传至服务器;
[0022] 其中:
[0023] 为加权模型参数;
[0024] 为加密的加权模型参数。
[0025] 为了更好地实现本发明,进一步地,所述步骤S400中,服务器更新模型,服务器将更新的模型下发至本地,并更新训练成员的模型Wi;
[0026] 其中:t为模型更新次数,
[0027] γ为学习率,
[0028] K为上传数据的数训练成员数。
[0029] 本发明的有益效果:
[0030] 1、本发明可以应用于在保证各方数据安全的情况下,各方协同训练机器学习模型供多方使用的场景。在这个场景中,多个数据方拥有自己的数据,他们想共同使用彼此的数据来统一建模(例如,分类模型、线性回归模型、逻辑回归模型等),并通过梯度稀疏矩阵的方式保证各自的数据不被泄露,具有较好的实用性;
[0031] 2、本发明还可以基于相关度确定当前轮迭代的训练成员,从而实现在训练过程中仅有部分训练成员需要进行数据上传,降低了联合训练过程中数据的传输量,降低了数据传输带宽的需求和投入成本,具有较好的实用性;
[0032] 3、本发明通过加权模型参数矩阵的设计使得不同数据质量的训练样本具有不同的权重,这样的设置使得更高质量的训练样本可以对模型的训练方向起到更大的作用,从而使得整个多轮训练过程更容易收敛,提升了联合训练的效率,减小了总体训练的轮数。

具体实施方式

[0033] 实施例1:
[0034] 一种基于多方3D打印数据库联合训练的方法,包括多个训练成员以及服务器,所述训练成员的模型为Wi,每个训练成员的数据为Xi, ,标签为yi;所述服务器的模型为W,且服务器的模型W与训练成员的模型Wi的网络结构一致;包括以下步骤:
[0035] 步骤S100:在第j次训练迭代过程中,训练成员i读取Xi中一个batch的数据bi,并进行模型Wi的前向传播,得到预测标签 ,进而根据实际标签yi,计算得到模型Wi的损失函数,进而利用反向传播算法得到梯度矩阵Gi;
[0036] 步骤S200:训练成员i对梯度矩阵Gi中的元素按照绝对值大小进行从大至小排序,并选择前m个元素得到对应是稀疏矩阵 ,填充元素为0;
[0037] 步骤S300:计算各训练成员自身数据质量Pi与服务器所有样本对应的数据质量Q之间的相关度,并基于相关度进行排序,获取得到参与当前迭代的训练成员;
[0038] 步骤S400:使用参与当前迭代的训练成员稀疏矩阵更新服务器的模型,并对应更新训练成员的模型Wi。
[0039] 实施例2:
[0040] 本实施例是在实施例1的基础上进行优化,在迭代训练之前,进行模型初始化:服务器对模型W进行初始化,并将初始化结果下发至所有的训练成员,对模型Wi进行初始化,确定梯度上传比例系数α、衰减系数ρ、学习率γ。
[0041] 进一步地,所述步骤S200中,统计得到模型Wi的中元素总个数为M,计算得到本次需要上传的梯度元素个数 。
[0042] 本实施例的其他部分与实施例1相同,故不再赘述。
[0043] 实施例3:
[0044] 本实施例是在实施例1或2的基础上进行优化,所述步骤S300中相关度计算如下:
[0045]
[0046] 其中:
[0047]
[0048] 其中:DKL为KL散度,
[0049] P表示各训练成员的自身数据质量,
[0050] Q表示服务器所有样本的数据质量。
[0051] 本实施例的其他部分与上述实施例1或2相同,故不再赘述。
[0052] 实施例4:
[0053] 一种基于多方3D打印数据库联合训练的方法,以水平切分的分类任务为例,假设共有k个训练成员,每个训练成员的数据集为Xi, ,标签为yi,训练成员的模型为Wi,训练过程中对应的模型梯度为Gi,服务器的模型W与训练成员模型的网络结构保持一致。包括以下步骤:
[0054] 步骤1,模型初始化:
[0055] 服务器对模型W进行初始化,并将初始化结果下发至所有训练成员,对Wi进行统一的初始化。确定梯度上传比例系数α。衰减系数ρ,学习率γ。
[0056] 步骤2,训练成员得到稀疏模型参数矩阵:
[0057] (1)在第j次训练迭代过程中(j=0,1,...,N),训练成员i读取Xi中一个batch的数据bi,batch大小为ni,进行模型Wi的前向传播,得到预测标签 ,进而根据实际标签yi,计算得到损失函数Li,进而利用反向传播算法得到梯度矩阵Gi。
[0058] (2)统计得到模型Wi中的元素总个数为M,计算本次需要上传的梯度元素个数。
[0059] (3)训练成员i对Gi中的元素按照绝对值大小进行从大到小排序,并选择前m个元素,得到Wi对应的稀疏矩阵 ,填充元素为0。
[0060] 步骤3,模型加密:
[0061] (1)训练成员i获取自身数据对应的数据质量Pi;
[0062] (2)服务器所有样本对应的数据质量Q
[0063] (3)那么对于训练成员i
[0064]
[0065] 其中:
[0066]
[0067] 其中:DKL为KL散度,
[0068] P表示各训练成员的自身数据质量,
[0069] Q表示服务器所有样本的数据质量。
[0070] (4)训练成员计算得到加权模型参数 ,利用秘密共享算法对 进行加密得到 ,并上传至服务器;
[0071] 其中:
[0072] 为加权模型参数;
[0073] 为加密的加权模型参数。
[0074] 步骤4,模型更新:
[0075] (1)服务器更新模型 ;
[0076] 其中:t为模型更新次数,
[0077] γ为学习率,
[0078] K为上传数据的数训练成员数。
[0079] (2)服务器将更新模型下发至本地,更新训练成员的模型Wi。
[0080] 步骤5,循环训练。
[0081] 进一步地,在数据质量中,模型性能参数可以包括以下中的一种或多种的组合:错误率、精度、查准率、查全率、AUC、ROC等。可以使用平均性能 ,平均性能 可以是多个性能指标F的综合表征。例如,平均性能 可以是错误率、精度、查准率、查全率、AUC、ROC中任意两种或两种以上的参数的综合计算结果。综合计算结果可以是以任意算式或者函数进行运算,包括但不限于求和、求平均、加权平均、方差等方式。采用多个参数共同表征平均性能可以选出综合性能最高的模型,而非选出某个参数最优的模型。
[0082] 以下简单介绍错误率、精度、查准率、查全率、AUC、ROC的计算方式。
[0083] 设样本集T={(X1,Y1),···,(Xn,Yn)},其中Xi为该样本i的输入特征,Yi为样本的真实标签。
[0084] T的预测结果: ,其中PYi表示模型对T中第i个样本的预测结果。
[0085] 则错误率
[0086]
[0087] 精度
[0088]
[0089] 查准率Precision,反映有多少结果是预测准确的,是基于混淆矩阵得到:
[0090]
[0091] 查全率Recall为:
[0092]
[0093] 其中:TP为真阳性,
[0094] FP为伪阳性,
[0095] FN为伪阴性,
[0096] TN为真阴性。
[0097] ROC(Receiver Operating Characteristic),常用来评价一个二值数据质量的优劣。在逻辑回归中通常会设置一个阈值,超过阈值则预测为正类,小于阈值则为负类。如果调小该值预测为正类的数量就会增加,同时这里面会包含一些本是负类的样本被识别为正类。ROC可以直观的表达该现象。ROC曲线就是以TPR(真阳性率)为y轴,FPR(伪阳性率)为x轴根据分类结果得到的一条曲线。如果曲线比较平滑的话一般不会出现过拟合问题。
[0098] 其中:
[0099]
[0100] AUC(Area Under Curve):是ROC曲线下方的面积,面积越大意味着数据质量越好。
[0101] 以上所述,仅是本发明的较佳实施例,并非对本发明做任何形式上的限制,凡是依据本发明的技术实质对以上实施例所作的任何简单修改、等同变化,均落入本发明的保护范围之内。