基于低置信度样本对比损失的域适应学习方法及系统转让专利

申请号 : CN202210942337.9

文献号 : CN114998602B

文献日 :

基本信息:

PDF:

法律信息:

相似专利:

发明人 : 王子磊张燚鑫贺伟男

申请人 : 中国科学技术大学

摘要 :

本发明公开了一种基于低置信度样本对比损失的域适应学习方法及系统,使用对比学习的方法,在原有的利用目标域高置信度样本的域适应方法上,充分利用目标域低置信样本,防止图像分类模型因偏向目标域中与源域相近的样本而导致的次优的领域迁移效果;而且,在对比学习中,对原始的图像特征进行重新表示,更好地编码了任务特有的语义信息;此外,对低置信样本使用了跨域混合,并使低置信样本在其中占主导,减小了领域差异,使图像分类模型更好的学习领域不变特征。总的来说,本发明利用了低置信样本,提升了无监督域适应和半监督域适应图像分类的准确率。

权利要求 :

1.一种基于低置信度样本对比损失的域适应学习方法,其特征在于,包括:

根据设定阈值从目标域图像集合中筛选出低置信度样本集合;

对于每一低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像;

将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像分别输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第二增强视图图像与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征,过程表示为:其中,F为所述第一图像分类模型中的特征提取器,为所述第二图像分类模型中的特征提取器, 为L2范数标准化函数, 为对图像特征进行重新表示的函数; 为查询图像, 为使用L2范数标准化函数对查询图像的特征进行标准化处理后获得的图像特征, 为第一重新表示特征; 为第i个低置信度样本图像对应的第二增强视图图像, 为使用L2范数标准化函数对 的特征进行标准化处理后获得的图像特征, 为图像特征 对应的重新表示特征; 为源域样本图像 对应的第四增强视图图像, 为使用L2范数标准化函数对 的特征进行标准化处理后获得的图像特征, 为图像特征 对应的重新表示特征; 为混合重新表示特征, 为所述第一增强视图图像与第三增强视图图像混合时使用的混合系数;

将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练;其中,其余所有重新表示特征包括:第二增强视图图像与第四增强视图图像对应的重新表示特征,以及混合重新表示特征。

2.根据权利要求1所述的一种基于低置信度样本对比损失的域适应学习方法,其特征在于,将所述第一增强视图图像与第三增强视图图像混合的方式表示为:其中,为混合系数,为Beta分布的参数, 为所述第一增强视图图像与第三增强视图图像混合时使用的混合系数,它是通过max函数获得的新的混合系数; 为第i个低置信度样本图像对应的第一增强视图图像, 为源域样本图像 对应的第三增强视图图像,为混合获得的查询图像。

3.根据权利要求1所述的一种基于低置信度样本对比损失的域适应学习方法,其特征在于,对图像特征进行重新表示的函数 表示为:其中, 为第一图像分类模型中分类器C的权值, 表示softmax函数;

, 为第二图像分类模型中分类器 的权值,T为转置符号, 为重新表示

时的温度系数。

4.根据权利要求1所述的一种基于低置信度样本对比损失的域适应学习方法,其特征在于,利用查询特征与各个对比特征的差异构造对比损失表示为:其中, 为查询特征, 为混合重新表示特征, 为所述第二增强视图图像对应的重新表示特征, 为所述第四增强视图图像对应的重新表示特征, 为记忆库M中存储的通过所述第二图像分类模型获得的其他低置信度样本图像的第二增强视图图像对应的重新表示特征; 为余弦相似性函数。

5.根据权利要求1所述的一种基于低置信度样本对比损失的域适应学习方法,其特征在于,所述总损失函数表示为:其中, 为基础损失, 为对比损失, 为对比损失的权重系数, 为

数学期望符号; 为源域图像集合 与低置信度样本集合 的并集, 为

中的单个图像;

所述基础损失包括:有标注图像上的交叉熵损失 ,用于跨域对齐特征的损失,高置信度样本中的KLD正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失,FixMatch表示基于伪标签技术的半监督学习算法;所述基础损失表示为:其中, 为有标注图像集合,为单个有标注图像,有标注图像集合 包含源域图像集合与目标域图像集合中所有有标注图像; 为 与目标域图像集合 的并集,为 中的单个图像, 表示高置信度样本集合,为单个高置信度样本图像,所述高置信度样本集合为所述目标域图像集合中除去低置信度样本集合后剩余图像构成的集合; 为用于跨域对齐特征的损失 的权重系数, 为低置信度样本中的KLD正则项 的权重系数。

6.根据权利要求5所述的一种基于低置信度样本对比损失的域适应学习方法,其特征在于,所述高置信度样本中的KLD正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失 计算方式包括:定义 和 分别表示来自高置信度样本集合 的单个高置信度样本图像

的两个不同的增强视图图像; 输入至第二图像分类模型,通过特征提取与分类,获得第二分类结果,并构造伪标签 ; 输入至第一图像分类模型,通过特征提取与分类,获得第一分类结果,利用所述第一分类结果计算高置信度样本中的KLD正则项 ,以及,利用所述第一分类结果与对应的伪标签计算使用FixMatch后高置信度样本的交叉熵损失;

高置信度样本中的KLD正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失的计算公式表示为:其中,为指示函数, 表示类别数目, 表示即增强视图图像 经

第一图像分类模型中分类器输出的类别为j的概率, 表示即增强视图图像经第一图像分类模型中分类器输出的类别为 的概率,伪标签 为第二分类结果中最大概率对应的类别标签, 表示第二图像分类模型预测的最大概率 大于阈值 。

7.一种基于低置信度样本对比损失的域适应学习系统,其特征在于,基于权利要求1 6~任一项所述的方法实现,该系统包括:

低置信度样本集合生成单元,用于根据设定阈值从目标域图像集合中筛选出低置信度样本集合;

增强视图图像生成单元,用于对于每一低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像;

重新表示特征获取单元,用于将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第一重新表示特征与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征;

总损失函数构造与模型训练单元,用于将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练。

8.一种处理设备,其特征在于,包括:一个或多个处理器;存储器,用于存储一个或多个程序;

其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现如权利要求1 6任一项所述的方法。

~

9.一种可读存储介质,存储有计算机程序,其特征在于,当计算机程序被处理器执行时实现如权利要求1 6任一项所述的方法。

~

说明书 :

基于低置信度样本对比损失的域适应学习方法及系统

技术领域

[0001] 本发明涉及图像分类领域,尤其涉及一种基于低置信度样本对比损失的域适应学习方法及系统。

背景技术

[0002] 近几年,使用深度神经网络处理各类机器学习问题卓有成效,然而其优异的性能很大程度上依赖于大量的高质量有标注的数据集。高额的时间成本和人力成本却让人工标注数据集不切实际。传统深度学习方法也因领域偏移问题而无法很好地泛化到新的数据集上。对此,域适应利用在有大量有标注样本的源域上学习的知识来帮助模型在另一个与源域相关但缺乏标注的目标域上的学习,通过减小领域偏移,能够节约标注成本。域适应按目标域样本是否有标注可以分为无监督域适应和半监督域适应。
[0003] 常用的解决领域偏移的方法是让模型学习领域不变的特征。现有域适应的方法一般是基于域间差异度量,或者是基于对抗。公开号为CN113011456A的中国发明专利申请《用于图像分类的基于类别自适应模型的无监督域适应方法》中,通过自注意模块和交叉注意模块建立领域可转移编码器,实现域内对齐和域间对齐;建立类别自适应解码器,通过类别原型学习和对齐来减少域差异。公开号为CN113011523A的中国发明专利申请《一种基于分布对抗的无监督深度领域适应方法》中,通过在分类器的全连接层进行特征分布匹配,使用MK‑MMD(多核最大均值差异)衡量领域间的特征分布差异,同时在卷积层后搭建两层全连接网络作为领域判别器进行领域对抗来减小领域差异。公开号为CN113673555A的中国发明专利申请《一种基于记忆体的无监督域适应图片分类方法》中,使用神经网络模型提取数据集中图片的特征,使用聚类算法辅助记忆体逐类别地存储源域和目标域的特征,训练神经网络,以源域与目标域记忆体的分布的相似性作为条件约束神经网络。公开号为CN113283489A的中国发明专利申请《一种基于联合分布匹配的半监督域适应学习的分类方法》中,通过基于核方法的预设算法度量源对象样本数据和目标对象样本数据分布之间的差异,拉近目标域和源域的联合分布。公开号为CN113378632A的中国发明专利申请《一种基于伪标签优化的无监督域适应行人重识别算法》中,使用了辅助分类器结构,计算辅助分类器结构输出的类别预测向量与主分类器结构输出的类别预测向量之间的KL散度(相对熵)值,获得更加可靠的伪标签。公开号为CN113610105A的中国发明专利申请《基于动态加权学习和元学习的无监督域适应图像分类方法》中,通过对样本加权、动态调整域对齐损失和分类损失的权重、通过元学习计算域对齐损失和分类损失优化网络模型参数,促进域对齐任务和分类任务之间的优化一致性。
[0004] 但是,现有域适应的方法:一方面,未探索无标签目标域的固有结构;另一方面,使用了一些准则来筛选出高置信度样本,同时完全忽略了低置信样本,由于忽略了低置信样本也就无法反映真实的目标域数据的结构,使图像分类模型偏向于高置信度样本,导致域适应学习后的图像分类模型的分类准确度不佳。

发明内容

[0005] 本发明的目的是提供一种基于低置信度样本对比损失的域适应学习方法及系统,使用低置信样本进行对比学习,有利于提升域适应学习后图像分类模型的分类准确度。
[0006] 本发明的目的是通过以下技术方案实现的:
[0007] 一种基于低置信度样本对比损失的域适应学习方法,包括:
[0008] 根据设定阈值从目标域图像集合中筛选出低置信度样本集合;
[0009] 对于每一低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像;
[0010] 将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像分别输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第二增强视图图像与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征;
[0011] 将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练。
[0012] 一种基于低置信度样本对比损失的域适应学习系统,包括:
[0013] 低置信度样本集合生成单元,用于根据设定阈值从目标域图像集合中筛选出低置信度样本集合;
[0014] 增强视图图像生成单元,用于对于每一低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像;
[0015] 重新表示特征获取单元,用于将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第一重新表示特征与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征;
[0016] 总损失函数构造与模型训练单元,用于将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练。
[0017] 一种处理设备,包括:一个或多个处理器;存储器,用于存储一个或多个程序;
[0018] 其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述的方法。
[0019] 一种可读存储介质,存储有计算机程序,当计算机程序被处理器执行时实现前述的方法。
[0020] 由上述本发明提供的技术方案可以看出:(1)使用对比学习的方法,在原有的利用目标域高置信度样本的域适应方法上,充分利用目标域低置信样本,防止图像分类模型因偏向目标域中与源域相近的样本而导致的次优的领域迁移效果;(2)在对比学习中,对原始的图像特征进行重新表示,更好地编码了任务特有的语义信息;(3)对低置信样本使用了跨域混合,并使低置信样本在其中占主导,减小了领域差异,使图像分类模型更好地学习领域不变特征。总的来说,本发明利用了低置信样本,提升了无监督域适应和半监督域适应图像分类的准确率。

附图说明

[0021] 为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域的普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他附图。
[0022] 图1为本发明实施例提供的一种基于低置信度样本对比损失的域适应学习方法的示意图;
[0023] 图2为本发明实施例提供的同一类样本和不同类样本的平均相似性示意图;
[0024] 图3为本发明实施例提供的对比损失计算流程图;
[0025] 图4为本发明实施例提供的特征重新表示的过程示意图;
[0026] 图5为本发明实施例提供的KLD正则项与高置信度样本的交叉熵损失的计算流程图;
[0027] 图6为本发明实施例提供的一种基于低置信度样本对比损失的域适应学习系统的示意图;
[0028] 图7为本发明实施例提供的一种处理设备的示意图。

具体实施方式

[0029] 下面结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明的保护范围。
[0030] 首先,对本文中可能使用的术语进行如下说明:术语“包括”、“包含”、“含有”、“具有”或其它类似语义的描述,应被解释为非排它性的包括。例如:包括某技术特征要素(如原料、组分、成分、载体、剂型、材料、尺寸、零件、部件、机构、装置、步骤、工序、方法、反应条件、加工条件、参数、算法、信号、数据、产品或制品等),应被解释为不仅包括明确列出的某技术特征要素,还可以包括未明确列出的本领域公知的其它技术特征要素。
[0031] 本发明为了解决现有域适应图像分类方法准确率有限的问题,公开了一种利用低置信样本对比损失的域适应学习方案,可适用于无监督域适应(即目标域中的训练数据都是无标注的)和半监督域适应(即目标域中的训练数据包含少部分标注数据和大部分无标注数据)。下面对本发明所提供的一种基于低置信度样本对比损失的域适应学习方案进行详细描述。本发明实施例中未作详细描述的内容属于本领域专业技术人员公知的现有技术。本发明实施例中未注明具体条件者,按照本领域常规条件或制造商建议的条件进行。
[0032] 实施例一
[0033] 本发明实施例提供了一种基于低置信度样本对比损失的域适应学习方法,如图1所示,其主要包括如下步骤:
[0034] 步骤1、根据设定阈值从目标域图像集合中筛选出低置信度样本集合。
[0035] 本发明实施例中,使用低置信度样本图像进行对比学习,低置信度是根据样本图像输出概率的最大值是否小于给定的阈值 来判定的,如果小于 则属于低置信度样本图像,具体的,使用的是第二图像分类模型的输出概率。
[0036] 本发明通过前期实验发现,属于同一类的低置度信样本图像之间的平均相似性较低,而属于不同类的低置信度样本图像之间的平均相似性 较高,如图2所示。两类平均相似性的定义为:
[0037]
[0038]
[0039] 其中, 与 表示从目标域图像集合中筛选出的两个低置度信样本图像, 与表示两个低置度信样本图像的类别标签, 表示图像 与 属于同一类别,表示图像 与 属于不同类别,与 表示两个低置度信样本图像的图像特征;为数学期望符号,T为转置符号; 可以是 (无监督目标域)、 和 (上标表示低置信度与高置信度)。依据这个结果,仅对低置信样本使用对比损失较为合理,因为这样减弱了在对比损失中同一类样本被视为负样本的不利影响。
[0040] 步骤2、对于每一目标域的低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像。
[0041] 本步骤主要是针对低置信度样本图像(属于目标域图像)与源域样本图像分别进行处理,获得不同的增强视图,以作为低置信度样本对比学习的基础数据;所涉及到的数据增强方法可参照常规技术,本发明不做赘述。
[0042] 步骤3、将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像分别输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第二增强视图图像与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征。
[0043] 本步骤主要是为了获得各个图像的重新表示特征。
[0044] 由于现有对比学习过程仅仅考虑了目标域特征空间结构,忽略了领域差异,因此,本发明实施例提出了跨域混合对比学习,用以学习领域不变特征,即将第一增强视图图像与第三增强视图图像混合,作为查询图片(query image);并且,通过两个图像分类模型分别对查询图像、以及第二增强视图图像与第四增强视图图像进行处理,获得对应的重新表示特征,可以更好地编码任务特定地语义信息;此外,对第二图像分类模型中两个重新表示特征进行混合。
[0045] 步骤4、将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练。
[0046] 本步骤基于前述步骤的处理结构构造跨域混合的对比损失,并结合基础损失函数进行第一图像分类模型的训练。
[0047] 本发明实施例中,所述第一图像分类模型与第二图像分类模型结构相同,均包含特征提取器、重新表示模块与分类器。本发明实施例中的模型训练主要是更新第一图像分类模型的参数,再由第一图像分类模型的参数使用指数移动平均(EMA)产生第二图像分类模型的参数。特征提取器与分类器的实现方式可参照常规技术,本发明不做赘述。
[0048] 为了更加清晰地展现出本发明所提供的技术方案及所产生的技术效果,下面以具体实施例对本发明实施例所提供的一种基于低置信度样本对比损失的域适应学习方法进行详细描述。由于域适应学习总体包含两部分损失,即前文所述的对比损失与基础损失,因此,主要对两部分损失的计算方式进行介绍,再介绍总损失函数。需要说明是以下介绍中所涉及的具体模型结构、框架形式以及具体参数数值等均为举例并非构成限制。
[0049] 一、对比损失。
[0050] 1、模型结构介绍。
[0051] 如图3所示,展示了利用低置信度样本对比学习的主要流程。左侧部分为相关图像,图像内容仅为示例;右侧部分为两个图像分类模型,上方为第一图像分类模型,下方为第二图像分类模型,本发明采用了teacher‑student(教师‑学生)架构,即,第一图像分类模型相当于学生模型,第二图像分类模型相当于教师模型。如之前所述,两个图像分类模型的结构完全相同,但是,教师模型的参数由学生模型的指数移动平均(EMA)产生,示例性的,可以设置衰退系数为0.99。此外,图像分类模型中的分类器输入为原始特征(计算方式将在后文进行介绍),输出主要用于计算基础损失,因此,图3中并未示出相关的分类器。
[0052] 2、对比学习的流程。
[0053] 本发明提出了一种融合跨域混合(Mixup)的对比学习,其出发点是目标域内的低置信样本与源域样本相似性较低,故难以正确分类。现有对比学习过程仅仅考虑了目标域特征空间结构,忽略了领域差异。为此,本发明进一步提出了跨域混合对比学习,用以学习领域不变特征。对比学习过程还可以参见图3,主要包括:
[0054] 以第i个低置信度样本图像为例,将其对应的第一增强视图图像与第二增强视图图像分别记为 与 ;将选出的源域样本图像记为 ,其对应的第三增强视图图像与第四增强视图图像分别记为 与 。
[0055] 先将第一增强视图图像 与第三增强视图图像 混合,作为查询图像。为保证低置信目标域样本在混合中占主导,本发明对混合系数 使用max函数,获得 作为新的混合系数,跨域混合表示为:
[0056]
[0057]
[0058]
[0059] 其中,为混合系数,为Beta分布(贝塔分布)的参数, 为混合获得的查询图像。
[0060] 将查询图像输入至第一图像分类模型,将第二增强视图图像 与第四增强视图图像 分别输入至第二图像分类模型。图3右侧展示了处理流程,对于单个图像分类模型,首先通过特征提取器进行特征提取,再通过L2范数标准化函数(L2 Norm)进行处理获得相应的图像特征,表示为:
[0061]
[0062]
[0063]
[0064] 其中,F为所述第一图像分类模型中的特征提取器,为所述第二图像分类模型中的特征提取器, 与 即为提取出的相应特征; 为L2范数标准化函数, 为使用L2范数标准化函数对查询图像的特征进行标准化处理后获得的图像特征; 为第i个低置信度样本图像对应的第二增强视图图像,为使用L2范数标准化函数对 的特征进行标准化处理后获得的图像特征; 为使用L2范数标准化函数对 的特征进行标准化处理后获得的图像特征。
[0065] 以上获得的图像特征为原始特征,本发明将分类器权值作为类别原型,并用它们作为一组新坐标来重新表示原始特征,其目的是为了更好地编码任务特定地语义信息,图4展示了特征重新表示的过程,表示为:
[0066]
[0067]
[0068] 其中, 为第一图像分类模型中分类器C的权值(不会被对比损失所更新),表示softmax函数;与 均为图像特征,与 均为重新表示特征, 为第二图像分类模型中分类器 的权值,T为转置符号, 为重新表示时的温度系数, 为对图像特征进行重新表示的函数。
[0069] 将 带入第一个式子作为图像特征 ,将 与 分别带入第二个式子作为图像特征 ,对应的获得重新表示特征 、 与 ,即:
[0070]
[0071]
[0072]
[0073]  其中, 为第一重新表示特征, 为图像特征 对应的重新表示特征; 为使用L2范数标准化函数对 的图像特征进行标准化处理后获得的图像特征, 为图像特征 对应的重新表示特征。
[0074] 在得到重新表示特征 与 后,将它们混合,表示为:
[0075]
[0076] 其中, 为混合重新表示特征, 为所述第一增强视图图像与第三增强视图图像混合时使用的混合系数。
[0077] 本发明实施例中,将 作为查询特征,除去 之外的其余所有重新表示特征作为对比特征,并构造跨域混合的对比损失,表示为:
[0078]
[0079] 其中, 为查询特征, 为混合重新表示特征, 为所述第二增强视图图像对应的重新表示特征, 为所述第四增强视图图像对应的重新表示特征, 为记忆库M中存储的通过所述第二图像分类模型获得的其他低置信度样本图像的第二增强视图图像对应的重新表示特征; 为余弦相似性函数,表示为:
[0080]
[0081] 其中, 表示余弦相似性函数 中的两个特征。
[0082] 以上介绍的对比学习方案本发明的核心技术点,可以从三个层面进行概括:(1)使用低置信度样本进行对比学习。(2)对比损失的输入特征需要经过重新表示。(3)在对比学习的基础上融入跨域的Mixup技术。
[0083] 并且,主要获得如下有益效果:(1)在原有的利用目标域高置信度样本的域适应方法上,充分利用目标域低置信样本,防止模型因偏向目标域中与源域相近的样本而导致的次优的领域迁移效果。(2)使用分类器权值对原始特征进行重新表示,而不是直接使用,更好地编码了任务特有的语义信息。(3)对低置信样本使用了跨域混合,并使低置信样本在其中占主导,减小了领域差异,让模型更好的学习领域不变特征。总的来说,本发明利用了低置信样本,提升了无监督域适应和半监督域适应图像分类的准确率。
[0084] 二、基础损失。
[0085] 为了使优化函数完整,下面介绍相关的基础损失。首先,是有标注样本上的交叉熵损失 和用于跨域对齐特征的损失 。在此基础上,还增加了基于伪标签技术的半监督学习算法(FixMatch)以强化高置信度样本的学习过程,从而提升预测一致性并提供可靠的伪标签,同时引入了高置信度样本中的KLD(Kullback‑Leibler divergence,KL散度)正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失 。因此,基础损失表示为:
[0086]
[0087] 其中, 为有标注图像集合,为单个有标注图像,有标注图像集合 对应于无监督域适应的源域、半监督域适应的源域以及目标域有标签部分(即包含源域图像集合与目标域图像集合中所有有标注图像); 为 与目标域图像集合 的并集,为中的单个图像, 表示高置信度样本集合,为单个高置信度样本图像,所述
高置信度样本集合为所述目标域图像集合中除去低置信度样本集合后剩余图像构成的集合,具体的,目标域样本图像经过第二图像分类模型输出的最大概率大于阈值 ; 为用于跨域对齐特征的损失 的权重系数, 为低置信度样本中的KLD正则项 的
权重系数。
[0088] 其中, 可以是常见的其它域适应方法计算的损失(例如,领域差异度量损失MMD、领域对抗损失等),本发明不做具体限定。
[0089] 交叉熵损失 也是常规的损失,形式为:
[0090]
[0091] 其中, 表示给有标注图像 经第一图像分类模型中分类器输出类别为k的概率, 为有标注图像 的类别标签, 表示类别数目,表示为:
, 为分类器的温度参数(例如,设置 )。
[0092] 高置信度样本中的KLD正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失 通过带有正则项的FixMatch模型进行计算,计算过程如图5所示,包括:
[0093] 定义 和 分别表示来自高置信度样本集合 的单个高置信度样本图像 的两个不同的增强视图图像(前者为弱增强视图图像,后者为强增强视图图像);
输入至第二图像分类模型(图5上半部分),通过特征提取与分类,获得第二分类结果,并构造伪标签 ; 输入至第一图像分类模型(图5下半部分),通过特征提取与分类,获得第一分类结果,利用所述第一分类结果计算高置信度样本中的KLD正则项 ,以及,利用所述第一分类结果与对应的伪标签计算使用FixMatch后高置信度样本的交叉熵损失 。
[0094] 高置信度样本中的KLD正则项 ,以及使用FixMatch后高置信度样本的交叉熵损失 的计算公式表示为:
[0095]
[0096]
[0097] 其中,为指示函数, 表示类别数目, 表示即强增强视图图像经第一图像分类模型中分类器输出的类别为j的概率, 表示即强
增强视图图像 经第一图像分类模型中分类器输出的类别为 的概率,伪标签
为第二分类结果中最大概率对应的类别标签, 表示第二图
像分类模型预测的最大概率 大于阈值 。
[0098] 三、总损失函数。
[0099] 本发明实施例中,综合前述对比损失与基础损失构造总的损失函数,表示为:
[0100]
[0101] 其中, 为基础损失, 为对比损失, 为对比损失的权重系数,为数学期望符号; 为源域图像集合 与低置信度样本集合 的并集, 为
中的单个图像。
[0102] 基于上述方案,下面提供一个整体的训练与测试流程介绍,主要步骤包括:
[0103] 步骤1、准备源域标注好的训练数据集和目标域的训练集、测试。对于源域和目标域的训练集图像,在线构造两种增强:强增强和弱增强,强增强采用随机数据增强方法(RandAugment),弱增强采用普通的随机裁剪和水机水平翻转。经过图像处理之后,图像的大小都被缩放到224×224,然后进行数值归一化处理。强增强和弱增强获得的图像也即前文提到的两个不同的增强视图图像,具体的,第一增强视图图像与第三增强视图图像使用强增强方式构造,第二增强视图图像与第四增强视图图像使用弱增强方式构造。
[0104] 步骤2、使用Pytorch深度学习框架,建立基于低置信度样本的对比学习方法。模型由老师模型和学生模型构成,二者具有相同的结构和初始化参数,学生模型通过梯度反传更新,而老师模型是学生模型参数的指数滑动平均。模型结构采用常见的图像分类模型,如ResNet34,ResNet50等,这里将分类模型的分类器改成采用基于余弦相似度的计算方式。对比学习的过程当中,使用额外的记忆库保存处理过的目标域低置信度样本所生成的特征,记忆库的容量为512,采用先进先出的更新方式,在每批次样本迭代结束之后都要进行更新。
[0105] 步骤3、输入源域图像到学生模型,输出预测概率,使用源域的标注信息进行有监督的学习,并且使用源域和目标域的训练数据进行对齐损失 的计算。
[0106] 步骤4、对于目标域图像,输入弱增强的图像到老师模型,输入强增强的图像到学生模型。输出预测概率,根据给定的阈值 ,确定该输入样本是否是高置信度样本(老师模型预测的最大概率大于阈值),如果是高置信度样本,利用FixMatch的学习方式,弱增强图像生成伪标签,监督强增强图像。
[0107] 步骤5、对于低置信度样本的不同增强图像,将它和随机采样的源域图像的增强版本进行混合,混合之后的不同增强样本分别输入学生模型和老师模型,输出的特征中间特征经过重新表示模块,生成新的特征,将第一重新表示特征作为查询特征,结合混合重新表示特征、第二增强视图图像对应的重新表示特征、第四增强视图图像对应的重新表示特征构造对比学习的正样本对,具体的,正样本对包括 与 、 与 、 与 ,然后用记忆库中存储的特征 作为负样本构造对比学习损失,更新学生模型。
[0108] 步骤6、利用目标域的低置信度样本的重新表示之后的特征,更新记忆库。
[0109] 步骤7、对上述步骤3与步骤5的损失函数进行累加,通过反向传播算法以及梯度下降策略,使得损失函数最小化,更新学生模型的权重,并通过学生模型的参数更新老师模型的参数。
[0110] 步骤8、输入测试数据集,计算学生模型分类的准确度。
[0111] 实施例二
[0112] 本发明还提供一种基于低置信度样本对比损失的域适应学习系统,其主要基于前述实施例一提供的方法实现,如图6所示,该系统主要包括:
[0113] 低置信度样本集合生成单元,用于根据设定阈值从目标域图像集合中筛选出低置信度样本集合;
[0114] 增强视图图像生成单元,用于对于每一低置信度样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第一增强视图图像与第二增强视图图像,并在源域图像集合中随机选择源域样本图像,使用数据增强方法获得两个不同的增强视图图像,称为第三增强视图图像与第四增强视图图像;
[0115] 重新表示特征获取单元,用于将所述第一增强视图图像与第三增强视图图像混合后作为查询图像,将所述查询图像输入至第一图像分类模型中,通过所述第一图像分类模型进行图像特征提取并进行重新表示获得第一重新表示特征;将所述第二增强视图图像与第四增强视图图像输入至第二图像分类模型中,通过所述第二图像分类模型分别进行图像特征提取并进行重新表示获得对应的重新表示特征;将所述第一重新表示特征与所述第四增强视图图像对应的重新表示特征混合构成混合重新表示特征;
[0116] 总损失函数构造与模型训练单元,用于将所述第一重新表示特征作为查询特征,其余所有重新表示特征作为对比特征,利用查询特征与各个对比特征的差异构造对比损失,并结合所述第一图像分类模型的基础损失构造总损失函数对所述第一图像分类模型进行训练。
[0117] 所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将系统的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。
[0118] 实施例三
[0119] 本发明还提供一种处理设备,如图7所示,其主要包括:一个或多个处理器;存储器,用于存储一个或多个程序;其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器实现前述实施例提供的方法。
[0120] 进一步的,所述处理设备还包括至少一个输入设备与至少一个输出设备;在所述处理设备中,处理器、存储器、输入设备、输出设备之间通过总线连接。
[0121] 本发明实施例中,所述存储器、输入设备与输出设备的具体类型不做限定;例如:
[0122] 输入设备可以为触摸屏、图像采集设备、物理按键或者鼠标等;
[0123] 输出设备可以为显示终端;
[0124] 存储器可以为随机存取存储器(Random Access Memory,RAM),也可为非不稳定的存储器(non‑volatile memory),例如磁盘存储器。
[0125] 实施例四
[0126] 本发明还提供一种可读存储介质,存储有计算机程序,当计算机程序被处理器执行时实现前述实施例提供的方法。
[0127] 本发明实施例中可读存储介质作为计算机可读存储介质,可以设置于前述处理设备中,例如,作为处理设备中的存储器。此外,所述可读存储介质也可以是U盘、移动硬盘、只读存储器(Read‑Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。
[0128] 以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明披露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求书的保护范围为准。