一种基于元学习的对抗网络的零样本图像分类方法转让专利

申请号 : CN202011147848.9

文献号 : CN112364894B

文献日 :

基本信息:

PDF:

法律信息:

相似专利:

发明人 : 冀中崔碧莹

申请人 : 天津大学

摘要 :

本发明属于图像分类的技术领域,具体涉及一种基于元学习的对抗网络的零样本图像分类方法,将元学习的训练方式用于零样本分类任务中,通过将视觉特征和语义特征先后输入网络,在训练阶段模拟了对零样本图像分类的学习任务,不仅完成了视觉特征的生成过程,而且保证了不同分类器的对齐关系,同时每个episode的任务获得的知识得到充分利用,使语义分类器在视觉分类器的监督下更好地训练出来,从而合成更趋近于真实分布的视觉特征和语义特征,设计出适合于现实情况的零样本图像分类技术。本发明能够使广义零样本图像分类能力更加突出,提高模型的泛化能力,缓解零样本学习普遍存在的领域偏移问题。

权利要求 :

1.一种基于元学习的对抗网络的零样本图像分类方法,其特征在于,包括如下步骤:

1)从可见类中随机选择M个类别作为一个episode的训练集,可见类中剩余类别作为这个episode的测试类,从而训练集 可知 其中ntr为每个episode中训练集样本的数目,xi为第i个训练样本的视觉特征,yi为第i个训练样本的相应类别标签,ai∈Atr为第i个训练样本的类别语义原型,同时ate∈Ate为一个episode中测试类的语义原型,定义两个记忆模块m1、m2;

2)将训练样本的视觉特征xi随机选取设定批量的数据x,输入到一个由编码器E1和解码器D1组成的变分自编码器中,使生成和真实视觉样本相似的伪视觉特征 重构约束如下:其中, 为2范数表示;

3)经过变分自编码器后,计算变分自编码器损失函数LVAE;

4)将生成的伪视觉特征经过一个降维矩阵W后输入到softmax分类器中,获得一个one‑hot分类结果表示每个类别的概率,根据其真实的标签计算分类损失如下:其中f表示softmax分类器,W为分类器参数,作用是将生成特征的维度降到M维再和真实标签y做对比,把W定义为视觉模态的分类器;

5)将训练样本的视觉特征x、生成的伪视觉特征 输入到判别器D中,对抗损失为LD:

6)计算这个episode视觉模态训练过程的蒸馏损失Lkd‑w和Lkd‑v;

7)设定目标函数为上述损失函数相加,多次迭代训练视觉模态的变分自编码器:其中,λ1、λ2为特征重构损失和变分自编码器损失的权重系数,将训练好的变分自编码器的编码器E1和解码器D1的参数分别存储到两个记忆模块中;

8)将训练类中的类别语义原型atr作为一个自编码器的输入,生成相应的视觉原型同时把 定义为语义模态的分类器,利用 对重构特征进行分类,计算分类损失Lcls2:;

9)用视觉模态的分类器W约束语义模态的分类器 从而得到视觉对语义的蒸馏约束,计算蒸馏损失Lkd2;

10)训练语义模态的自编码器的目标函数如下:

La=Lcls2+λ3Lsup+λ4Lkd2

其中,Lsup为视觉模态的解码器对语义模态的解码器的监督,λ3和λ4分别为监督损失和蒸馏损失的权重系数;

11)该episode的测试过程:将测试集的语义原型ate输入到训练好的编码器E2和解码器D2中,获得相应的视觉原型

12)将 和 拼接在一起,得到所有可见类的分类器 此时再用分类器CS对所有可见类样本进行分类,计算分类损失,对之前学习到的参数进行微调:;

13)将可见类和未见类的测试样本的语义特征at输入到语义编码器和解码器中,将生成的视觉特征原型和xt做对比,其中,xt为测试样本的视觉特征,利用最近邻的方法得到分类结果;

14)重复步骤1)~步骤13),完成多个episode的元训练过程,直到得到最优的分类性能。

2.如权利要求1所述的一种基于元学习的对抗网络的零样本图像分类方法,其特征在于,所述步骤2)的生成伪视觉特征 以及步骤3)计算LVAE的工作过程包括:(2.1)将训练样本的视觉特征xi随机选取设定批量的数据x,输入到编码器E1中,得到潜在变量z,z的概率分布如下表示:p(z|x)=N(μ,Σ)

其中,p(z|x)表示潜在变量z的分布,μ、Σ分别表示潜在变量z的均值和方差,N表示正态分布;

(2.2)将z输入到解码器D1中,生成伪视觉特征其中,w1、v1分别为编码器E1和解码器D1的参数;

(2.3)计算变分自编码器损失函数LVAE:

其中,LVAE表示变分自编码器损失函数, 表示在潜在变量z的分布上计算期望,p(x|z)表示通过潜在变量z生成视觉特征的分布,q(z|x)表示潜在变量z的条件分布,p(z)表示潜在变量z的先验分布,设定为正态分布,log为取对数运算,DKL为KL散度计算。

3.如权利要求1所述的一种基于元学习的对抗网络的零样本图像分类方法,其特征在于,所述步骤6)的计算蒸馏损失Lkd‑w和Lkd‑v的工作过程包括:利用记忆模块中存储的编码器E1和解码器D1参数计算蒸馏损失:其中,w1‑before和v1‑before分别表示在两个记忆模块中前一个episode存储的编码器E1的参数和解码器D1的参数,当episode=1时,w1‑before=v1‑before=0。

4.如权利要求1所述的一种基于元学习的对抗网络的零样本图像分类方法,其特征在于,所述步骤8)的生成视觉原型 的工作过程包括:(4.1)训练类中的类别语义原型atr作为编码器E2的输入,将atr映射到和z同维度的隐藏空间中,得到za:za=E2(atr,w2)

其中,w2为编码器E2的参数;

(4.2)将za输入到解码器D2中,生成相应的视觉原型 且 与真实视觉特征xi维度相同:其中,v2为解码器D2的参数。

5.如权利要求1所述的一种基于元学习的对抗网络的零样本图像分类方法,其特征在于,所述步骤10)的计算Lsup的工作过程包括:其中,v1、v2分别为解码器D1和D2的参数,用2范数的算法使语义模态的解码器和视觉模态的解码器相近,从而使得生成的视觉原型更接近真实的视觉原型。

说明书 :

一种基于元学习的对抗网络的零样本图像分类方法

技术领域

[0001] 本发明属于图像分类的技术领域,具体涉及一种基于元学习的对抗网络的零样本图像分类方法。

背景技术

[0002] 近年来,机器学习在自然语言处理、计算机视觉、语音识别等领域都得到了广泛应用,而在计算机视觉领域,图像分类任务是最受关注且应用最广的任务之一,各种分类技术层出不穷,性能不断提升。在机器学习任务中,通过大量人工标注的图像而实现分类的监督学习方法是图像分类的传统方法,在现实生活中得到了很好的应用。然而,实际中为每个类别的图像收集足够的样本并且进行标注并不容易,会消耗大量的劳动力。不难理解,自然界的物种分布呈现长尾效应,只有少数类别的物种具有足够的图像样本可供监督学习训练分类模型,而很多类别的物种样本少而标签标注困难,这就使监督学习带来巨大挑战。因此,为解决样本标签缺失的问题,零样本学习应运而生。
[0003] 零样本图像分类是零样本学习的一个重要方向,用来解决图像标注困难的分类问题,在传统的零样本图像分类设定中,利用可见类图像样本及其标签训练模型,利用未见类图像样本测试模型,可见这种设定下测试图像的类别于训练图像的类别不相交;而在广义的零样本图像分类设定中,测试图像样本既包括可见类的图像又包括未见类的图像。本专利指的零样本学习包括如上两种设定情况。目前零样本图像分类的主要研究方法可大致分为两种:一是基于映射的方法,通过视觉特征空间和语义特征空间之间的映射或者二者到公共空间的映射来对其视觉特征和语义特征,从而获得较好的分类结果;二是基于生成的方法,利用生成对抗网络、变分自编码器等生成模型来生成测试样本的伪特征,通过比较生成的伪特征与真实特征之间的相似度来确定所属类别。
[0004] 为了完成对测试样本类别的预测,零样本图像分类技术通过利用可见类和未见类的语义信息以达到知识迁移的作用。实验设置如下所示:在训练阶段,给定可见类的带标签样本 其中n为可见类的样本数目, 为第i个样本的视觉特征,表示其相应的类别标签,此外, 表示其对应的类级语义原型。传统的零样本图像分类是给定未见类的语义特征AU,将测试样本xt分到未见类YU中,且 广义的零样
本图像分类是根据可见类和未见类的语义特征,将测试样本xt分到可见类和未见类中。总之,零样本图像分类就是利用可见类样本的相关特征训练模型,利用这个模型预测测试样本的类别标签yt。
[0005] 通过学习视觉空间和语义空间之间的简单映射关系会导致特征表征的不完整,同时会产生低维枢纽点问题。通过学习从高维视觉空间到低维语义空间的简单映射会引发高维中不同类的样本压缩到低维中同一类语义的枢纽点现象,而从低维空间到高维空间的简单映射同样会产生类似的问题。近年来,生成对抗网络获得了科研人员的关注,将其与零样本学习结合起来,通过生成大量伪特征提高分类的准确度。但是生成对抗网络的本质缺点就是训练过程不稳定,容易引发模式崩溃的问题。还有一种基于生成的方法引入了变分自编码器(VAE),以语义信息为条件输入VAE生成伪视觉特征。但VAE由于变分下界的引入使生成的视觉特征容易失真。

发明内容

[0006] 本发明的目的在于:针对现有技术的不足,提供一种基于元学习的对抗网络的零样本图像分类方法,能够提高零样本图像分类准确率。
[0007] 为了实现上述目的,本发明采用如下技术方案:
[0008] 一种基于元学习的对抗网络的零样本图像分类方法,包括如下步骤:
[0009] 1)从可见类中随机选择M个类别作为一个episode的训练集,可见类中剩余类别作为这个episode的测试类,从而训练集 可知 其中ntr为每个episode中训练集样本的数目,xi为第i个训练样本的视觉特征,yi为第i个训练样本的相应类别标签,ai∈Atr为第i个训练样本的类别语义原型,同时ate∈Ate为一个episode中测试类的语义原型,定义两个记忆模块m1、m2;
[0010] 2)将训练样本的视觉特征xi随机选取设定批量的数据x,输入到一个由编码器E1和解码器D1组成的变分自编码器中,使生成和真实视觉样本相似的伪视觉特征 重构约束如下:
[0011]
[0012] 其中, 为2范数表示;
[0013] 3)经过变分自编码器后,计算变分自编码器损失函数LVAE;
[0014] 4)将生成的伪视觉特征经过一个降维矩阵W后输入到softmax分类器中,获得一个one‑hot分类结果表示每个类别的概率,根据其真实的标签计算分类损失如下:
[0015]
[0016] 其中f表示softmax分类器,W为分类器参数,作用是将生成特征的维度降到M维再和真实标签y做对比,把W定义为视觉模态的分类器;
[0017] 5)将训练样本的视觉特征x、生成的伪视觉特征 输入到判别器D中,对抗损失为LD:
[0018] ;
[0019] 6)计算这个episode视觉模态训练过程的蒸馏损失Lkd‑w和Lkd‑v;
[0020] 7)设定目标函数为上述损失函数相加,多次迭代训练视觉模态的变分自编码器:
[0021]
[0022] 其中,λ1、λ2为特征重构损失和变分自编码器损失的权重系数,将训练好的变分自编码器的编码器E1和解码器D1的参数分别存储到两个记忆模块中;
[0023] 8)将训练类中的类别语义原型atr作为一个自编码器的输入,生成相应的视觉原型同时把 定义为语义模态的分类器,利用 对重构特征进行分类,计算分类损失Lcls2:
[0024] ;
[0025] 9)用视觉模态的分类器W约束语义模态的分类器 从而得到视觉对语义的蒸馏约束,计算蒸馏损失Lkd2
[0026] ;
[0027] 10)训练语义模态的自编码器的目标函数如下:
[0028] La=Lcls2+λ3Lsup+λ4Lkd2
[0029] 其中,Lsup为视觉模态的解码器对语义模态的解码器的监督,λ3和λ4分别为监督损失和蒸馏损失的权重系数;
[0030] 11)该episode的测试过程:将测试集的语义原型ate输入到训练好的编码器E2和解码器D2中,获得相应的视觉原型
[0031] 12)将 和 拼接在一起,得到所有可见类的分类器 此时再用分类器CS对所有可见类样本进行分类,计算分类损失,对之前学习到的参数进行微调:
[0032] ;
[0033] 13)将可见类和未见类的测试样本的语义特征at输入到语义编码器和解码器中,将生成的视觉特征原型和xt做对比,利用最近邻的方法得到分类结果;
[0034] 14)重复步骤1)~步骤13),完成多个episode的元训练过程,直到得到最优的分类性能。
[0035] 作为本发明所述的一种基于元学习的对抗网络的零样本图像分类方法的一种改进,所述步骤2)的生成伪视觉特征 、以及步骤3)计算LVAE的工作过程包括:
[0036] (2.1)将训练样本的视觉特征xi随机选取设定批量的数据x,输入到编码器E1中,得到潜在变量z,z的概率分布如下表示:
[0037] p(z|x)=N(μ,Σ)
[0038] 其中,p(z|x)表示潜在变量z的分布,μ、Σ分别表示潜在变量z的均值和方差,N表示正态分布;
[0039] (2.2)将z输入到解码器D1中,生成伪视觉特征
[0040]
[0041] 其中,w1、v1分别为编码器E1和解码器D1的参数;
[0042] (2.3)计算变分自编码器损失函数LVAE:
[0043]
[0044] 其中,LVAE表示变分自编码器损失函数, 表示在潜在变量z的分布上计算期望,p(x|z)表示通过潜在变量z生成视觉特征的分布,q(z|x)表示潜在变量z的条件分布,p(z)表示潜在变量z的先验分布,设定为正态分布,log为取对数运算,DKL为KL散度计算。
[0045] 作为本发明所述的一种基于元学习的对抗网络的零样本图像分类方法的一种改进,所述步骤6)的计算蒸馏损失Lkd‑w和Lkd‑v的工作过程包括:
[0046] 利用记忆模块中存储的编码器E1和解码器D1参数计算蒸馏损失:
[0047]
[0048]
[0049] 其中,w1‑before和v1‑before分别表示在两个记忆模块中前一个episode存储的编码器E1的参数和解码器D1的参数,当episode=1时,w1‑before=v1‑before=0。
[0050] 作为本发明所述的一种基于元学习的对抗网络的零样本图像分类方法的一种改进,所述步骤8)的生成视觉原型 的工作过程包括:
[0051] (4.1)训练类中的类别语义原型atr作为编码器E2的输入,将atr映射到和z同维度的隐藏空间中,得到za:
[0052] za=E2(atr,w2)
[0053] 其中,w2为编码器E2的参数;
[0054] (4.2)将za输入到解码器D2中,生成相应的视觉原型 且 与真实视觉特征xi维度相同:
[0055]
[0056] 其中,v2为解码器D2的参数。
[0057] 作为本发明所述的一种基于元学习的对抗网络的零样本图像分类方法的一种改进,所述步骤10)的计算Lsup的工作过程包括:
[0058]
[0059] 其中,v1、v2分别为解码器D1和D2的参数,用2范数的算法使语义模态的解码器和视觉模态的解码器相近,从而使得生成的视觉原型更接近真实的视觉原型。
[0060] 本发明的有益效果在于,本发明利用双路生成网络的方法完成一个episode的元训练过程,使语义分类器学习视觉分类器,利用生成器和判别器的对抗以及前后episode间特征的知识蒸馏,更直观高效地提升零样本学习的性能。将元学习的训练方式用于零样本分类任务中,通过将视觉特征和语义特征先后输入网络,在训练阶段模拟了对零样本图像分类的学习任务,不仅完成了视觉特征的生成过程,而且保证了不同分类器的对齐关系,同时每个episode的任务获得的知识得到充分利用,使语义分类器在视觉分类器的监督下更好地训练出来,从而合成更趋近于真实分布的视觉特征和语义特征,设计出适合于现实情况的零样本图像分类技术。由此,本发明能够使广义零样本图像分类能力更加突出,提高模型的泛化能力,缓解零样本学习普遍存在的领域偏移问题,从而可以实现在更加真实的场景中的分类任务,有利于推动零样本学习应用于生产生活实际,加速深度学习算法向实用发展。

附图说明

[0061] 下面将参考附图来描述本发明示例性实施方式的特征、优点和技术效果。
[0062] 图1为本发明中元学习的结构示意图。

具体实施方式

[0063] 如在说明书及权利要求当中使用了某些词汇来指称特定组件。本领域技术人员应可理解,硬件制造商可能会用不同名词来称呼同一个组件。本说明书及权利要求并不以名称的差异来作为区分组件的方式,而是以组件在功能上的差异来作为区分的准则。如在通篇说明书及权利要求当中所提及的“包含”为一开放式用语,故应解释成“包含但不限定于”。“大致”是指在可接受的误差范围内,本领域技术人员能够在一定误差范围内解决技术问题,基本达到技术效果。
[0064] 此外,术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性。
[0065] 在发明中,除非另有明确的规定和限定,术语“安装”、“相连”、“连接”、“固定”等术语应做广义理解,例如,可以是固定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,可以是两个元件内部的连通。对于本领域的普通技术人员而言,可以根据具体情况理解上述术语在本发明中的具体含义。
[0066] 以下结合附图1对本发明作进一步详细说明,但不作为对本发明的限定。
[0067] 本发明的一种基于元学习的对抗网络的零样本图像分类方法,其基本思想是用每个子任务来模仿整个广义零样本图像分类的过程,并且在任务间采用知识蒸馏的方法来增强模型的记忆和泛化能力。假设在每个episode任务中,从所有可见类中随机选取若干个类别作为每个任务中的可见类,用来模拟广义零样本学习,利用变分自编码器学习到一个视觉分类器之后,利用该视觉分类器对语义进行引导,学习到一个语义分类器。在每个episode的学习过程中,将相关的参数存储到记忆模块中,用来监督下一个episode相关参数的学习,从而达到知识蒸馏的作用。同时,视觉分类器对语义分类器的监督也可以看作是知识蒸馏的操作。在每个episode训练后的测试过程中,都使用最近邻来对测试样本进行分类,实现零样本图像分类技术。
[0068] 在零样本图像分类中,目前普遍的训练方式是用可见类多次迭代的单轮训练模型,然后再预测测试样本的类别,这里测试样本包括可见类样本也包括未见类样本。近年来,元学习在少样本学习上应用广泛,并且得到了很好的性能。元学习的训练方式中,基于集(episode)的元训练方式应用非常广泛。这种训练方式是在训练过程中,每一个episode都是利用不同的训练数据对模型进行更新,这样就会充分利用以往的知识经验来指导新任务的学习。
[0069] 本发明的一种基于元学习的对抗网络的零样本图像分类方法,首先将图像数据集分为可见类和未见类,再从可见类中随机选择M个类别作为一个episode的训练集,可见类中剩余类别作为这个episode的测试类。给定训练集 可知 其中ntr为每个episode中训练集样本的数目,xi为第i个训练样本的视觉特征,yi为第i个训练样本的相应类别标签,ai∈Atr为第i个训练样本的类别语义原型,同时ate∈Ate为每个episode中测试类的语义原型。给定xt为测试样本的视觉特征,at为测试样本的类别语义特征。如图1所示,进行如下步骤:
[0070] 1)从可见类中随机选择M个类别作为一个episode的训练集,可见类中剩余类别作为这个episode的测试类。分别初始化视觉模态变分自编码器中的编码器E1和解码器D1、语义模态自编码器中的编码器E2和解码器D2以及判别器D的参数w1、v1、w2、v2和r,定义用来存储参数w1、v1的两个记忆模块为m1、m2;
[0071] 2)在这个episode中,将训练样本的视觉特征xi随机选取设定批量的数据x,作为编码器E1的输入;
[0072] 3)根据如下生成伪视觉特征公式,得到生成的伪视觉特征
[0073]
[0074] 其中,编码器E1的输出为潜在变量,用z表示,z的概率分布如下表示:
[0075] p(z|x)=N(μ,Σ)                             (2)
[0076] 其中,p(z|x)表示潜在变量z的分布,μ、Σ分别表示潜在变量z的均值和方差,N表示正态分布;
[0077] 4)经过变分自编码器后,希望生成的伪视觉特征接近真实的特征,分别计算特征重构损失函数和变分自编码器损失函数:
[0078]
[0079]
[0080] 其中,Lrec1表示重构损失函数, 为2范数表示,LVAE表示变分自编码器损失函数,EPE(zx)表示在潜在变量z的分布上计算期望,p(x|z)表示通过潜在变量z生成视觉特征的分布,q(z|x)表示潜在变量z的条件分布,p(z)表示潜在变量z的先验分布,设定为正态分布,log为取对数运算,DKL为KL散度计算;
[0081] 5)将生成的伪视觉特征经过一个降维矩阵W后输入到softmax分类器中,获得一个one‑hot分类结果表示每个类别的概率,根据其真实的标签计算分类损失如下:
[0082]
[0083] 其中f表示softmax分类器,W为分类器参数,作用是将生成特征的维度降到M维再和真实标签y做对比。这里把W定义为视觉模态的分类器。
[0084] 6)将训练样本的视觉特征x、生成的伪视觉特征 输入到判别器D中,用对抗损失函数公式训练判别器D,保留使其性能最好的参数r,对抗损失函数公式如下:
[0085]
[0086] 其中,LD为判别器D的对抗损失函数,Ex为在训练样本的视觉特征x的分布上计算期望, 为在生成的伪视觉特征 的分布上计算期望;
[0087] 7)计算这个episode的蒸馏损失如下:
[0088]
[0089]
[0090] 其中,w1‑before和v1‑before分别表示在两个记忆模块中前一个episode存储的编码器E1的参数和解码器D1的参数,当episode=1时,w1‑before=v1‑before=0;
[0091] 8)将公式(3)~(8)的损失函数相加,训练视觉变分自编码器中的E1和D1,更新记忆模块;
[0092]
[0093] 其中,λ1、λ2为特征重构损失和变分自编码器损失的权重系数。
[0094] 9)在这个episode中,再将训练类中的类别语义原型atr作为一个自编码器的输入,其中编码器E2将类别语义原型映射到和z同维度的隐藏空间中,再通过解码器D2将隐藏空间的特征重构到视觉空间当中,生成相应的视觉原型,这里解码器用D1监督约束:
[0095]
[0096]
[0097] 其中, 是由类别语义原型生成的视觉原型,将其定义为语义模态的分类器,Lsup表示解码器D1对D2的2范数约束;
[0098] 10)同时,视觉原型特征同样需要用降维矩阵W约束,即视觉模态的分类器约束语义模态的分类器,从而得到视觉对语义的蒸馏约束,蒸馏损失Lkd2如下:
[0099]
[0100] 11)利用 对特征进行分类,计算分类损失:
[0101]
[0102] 12)将公式(11)~(13)中的损失函数相加,训练编码器E2和解码器D2:
[0103] La=Lcls2+λ3Lsup+λ4Lkd2                              (14)
[0104] 其中,λ3和λ4分别为监督损失和蒸馏损失的权重系数;
[0105] 13)将此episode的测试集的语义原型ate输入到训练好的编码器E2和解码器D2中,获得相应的视觉原型:
[0106]
[0107] 14)利用得到的 和 拼接在一起,得到所有可见类的分类器 此时再用分类器CS对所有可见类样本进行分类,计算分类损失,对参数w1、v1、w2、v2和r进行微调:
[0108]
[0109] 15)将可见类和未见类的测试样本的语义特征at输入到语义编码器和解码器中,将生成的视觉特征原型和xt做对比,利用最近邻的方法得到分类结果。
[0110] 16)重复步骤1)~15),完成多个episode的元训练过程,直到得到最优的分类性能。
[0111] 根据上述说明书的揭示和教导,本发明所属领域的技术人员还能够对上述实施方式进行变更和修改。因此,本发明并不局限于上述的具体实施方式,凡是本领域技术人员在本发明的基础上所作出的任何显而易见的改进、替换或变型均属于本发明的保护范围。此外,尽管本说明书中使用了一些特定的术语,但这些术语只是为了方便说明,并不对本发明构成任何限制。