基于知识蒸馏的信息检索方法转让专利

申请号 : CN202110534072.4

文献号 : CN113312548B

文献日 :

基本信息:

PDF:

法律信息:

相似专利:

发明人 : 鲁伟明朱堂灿庄越挺

申请人 : 浙江大学

摘要 :

本发明公开了一种基于知识蒸馏的信息检索方法。首先,利用训练集T基于交叉熵损失函数来训练教师模型。再利用教师模型,对训练集Told段落重排序得到Tnew。之后,利用训练集T基于交叉熵损失函数和利用训练集Tnew基于列表置换损失函数,将两者损失函数的加权和作为最终损失函数训练学生模型。最后,利用学生模型进行信息检索。本发明方法无需人工干预,检索准确率较高,具有良好的可扩展性。

权利要求 :

1.一种基于知识蒸馏的信息检索方法,其特征在于,包括以下步骤:

1)训练教师模型:基于交叉熵损失函数,利用训练集T来训练教师模型;具体步骤为,训练集T为 其中Qi表示查询,pi和ni为正负例,N为总的查询数量;

首先,选择教师模型为BERT‑CAT模型,则教师模型计算查询Q与段落d之间相关性的评分公式为:

Teacher(Q,d)=BERT‑CAT(Q,d)=BERT([CLS;Q;SEP;d])1*W其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS和SEP表示BERT中的特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵;

之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用该教师模型计算正例得分Pi以及负例得分Ni:

Pi=Teacher(Qi,pi)Ni=Teacher(Qi,ni)再通过正负例得分计算相应的交叉熵损失:最后通过最小化交叉熵损失来优化教师模型,训练得到最终的教师模型;

2)训练集段落重排序:使用步骤1)训练后的教师模型,对训练集Told中每个查询所对应的段落集进行相关性重排序,得到排序πT,并用重排序后的段落集构建新训练集Tnew;具体步骤为,

利用教师模型对训练集Told进行重排序;

基于步骤1)所训练的教师模型Teacher,对于训练集Told中每个查询Q所对应的一个段落集D={d1,d2,...,dl},使用模型Teacher对所有段落进行相对于查询Q的打分:S=Teacher(Q,D)={s1,s2,...,sl}其中,si=Teacher(Q,di),之后根据每个段落得分的高低对所有段落进行重排序,得到一个新的有序的段落集Dr={dr1,dr2,...,drl},其中sr1>sr2>…>srl,所有查询对应的有序段落集构成新训练集Tnew;

3)训练学生模型:利用训练集T,计算学生模型的交叉熵损失L1;然后,利用学生模型,对训练集Tnew中每个查询所对应的段落集进行相关性重排序,得到排序πS,再利用列表置换损失函数计算πT与πS之间的差异损失L2;最后用L1和L2的加权和作为学生模型的最终损失L,并通过最小化L来训练学生模型;具体步骤为;

首先,选择BERT‑DOT模型和ColBERT模型作为学生模型Student;

BERT‑DOT模型计算查询Q与段落d之间相关性的评分公式为:rq=BERT([CLS;Q])1*Wrd=BERT([CLS;d])1*WBERT‑DOT(Q,d)=rq·rd其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算;

ColBERT模型计算查询Q与段落d之间相关性的评分公式为:rq=BERT([CLS;Q;rep(MASK)])1*Wrd=BERT([CLS;d])1*W其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,rep(MASK)表示多个MASK词条拼接而成的词条集,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算;

之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用学生模型计算正例得分Pi以及负例得分Ni:

Pi=Student(Q,pi)Ni=Student(Q,ni)其中,Student代表BERT‑DOT模型和ColBERT模型;

之后通过正负例得分计算相应的交叉熵损失:接着计算重排序序列的列表置换损失函数;

根据步骤2)所得的重排序段落训练集Tnew,对于每个查询Q,有对应的重排序后的段落集Dr={dr1,dr2,...,drl},段落di相对于查询Q使用教师模型所得到的分数si,满足sr1>sr2>…>srl;

使用学生模型重新计算所有段落相对于查询Q的得分,得到一个新的分数列表:S′=Student(Q,Dr)={s′r1,s′r2,...,s′rl}根据该列表,得到查询置换的概率:之后最大化每个查询置换概率的对数似然,即最小化列表置换损失函数:最后,将两部分损失加权求和作为模型的损失:Loss=Loss1+αLoss2其中,α为权重参数;

4)利用学生模型进行信息检索:利用学生模型计算用户查询所对应的段落的评分,将评分最高的段落作为查询答案。

2.根据权利要求1所述的一种基于知识蒸馏的信息检索方法,其特征在于,所述步骤4)具体为:

利用学生模型进行信息检索;

在步骤3)训练得到学生模型后,使用该学生模型对测试集中相应查询所对应的段落集进行重排序,获取排行最高的段落作为查询答案,以此来测试模型的效果;

对于用户给定的问题,在语料库中初步筛选出相应段落,再用学生模型计算段落相对于问题的得分,根据得分的高低将相应用户所需要的答案量的答案提供给用户。

说明书 :

基于知识蒸馏的信息检索方法

技术领域

[0001] 本发明属于信息检索领域,尤其涉及一种基于知识蒸馏的信息检索方法。

背景技术

[0002] 随着互联网的普及和发展,人们可以接触到非常丰富的资源。对于某些人们想要了解的领域以及相关知识,人们可以选择信息检索来获取相关的知识。为了提升检索效率,
优化检索效果,可以利用人工智能技术辅助信息检索,以帮助人们更快更好地获取想要知
道的相关知识。
[0003] 然而,现存的许多模型和方法存在着精度与速度两者不可兼得的问题。精度较高的模型往往有着大量的参数需要进行计算,导致检索延时大幅度提高,而速度较快的检索
模型由于其更加看重速度,其精度也会有着一定的损失。
[0004] 鉴于此,我们基于知识蒸馏(KD,knowledge distillation)的方法,利用精度较高的教师模型进行辅助训练精度稍低但是速度却快许多的学生模型,以期望得到速度变化不
大但精度却提升较多的一个新的学生模型,以此来达到检索效果提升的目的。

发明内容

[0005] 本发明的目的在于提供一种基于知识蒸馏的信息检索方法,从而方便人们更加高效地进行信息检索。
[0006] 本发明解决其技术问题采用的技术方案如下:一种基于知识蒸馏的信息检索方法,包括以下步骤:
[0007] 1)训练教师模型:基于交叉熵损失函数,利用训练集T来训练教师模型。
[0008] 2)训练集段落重排序:使用步骤1)训练后的教师模型,对训练集Told中每个查询所对应的段落集进行相关性重排序,得到排序πT,并用重排序后的段落集构建新训练集Tnew。
[0009] 3)训练学生模型:利用训练集T,计算学生模型的交叉熵损失L1;然后,利用学生模型,对训练集Tnew中每个查询所对应的段落集进行相关性重排序,得到排序πS,再利用列表
置换损失函数计算πT与πS之间的差异损失L2;最后用L1和L2的加权和作为学生模型的最终损
失L,并通过最小化L来训练学生模型。
[0010] 4)利用学生模型进行信息检索:利用学生模型计算用户查询所对应的段落的评分,将评分最高的段落作为查询答案。
[0011] 进一步地,步骤1)所述的教师模型的训练,具体为:
[0012] 训练集T为 其中Qi表示查询,pi和ni为正负例,N为总的查询数量。首先,选择教师模型为BERT‑CAT模型,则教师模型计算查询Q与段落d之间相关性的评
分公式为:
[0013] Teacher(Q,d)=BERT‑CAT(Q,d)=BERT([CLS;Q;SEP;d])1*W
[0014] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS和SEP表示BERT中的特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵。
[0015] 之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用该教师模型计算正例得分Pi以及负例得分Ni:
[0016] Pi=Teacher(Qi,pi)
[0017] Ni=Teacher(Qi,ni)
[0018] 再通过正负例得分计算相应的交叉熵损失:
[0019]
[0020] 最后通过最小化交叉熵损失来优化教师模型,训练得到最终的教师模型。
[0021] 进一步地,步骤2)所述的训练集段落重排序,具体为:
[0022] 利用教师模型对训练集Told进行重排序。
[0023] 基于步骤1)所训练的教师模型Teacher,对于训练集Told中每个查询Q所对应的一个段落集D={d1,d2,...,dl},使用模型Teacher对所有段落进行相对于查询Q的打分:
[0024] S=Teacher(Q,D)={s1,s2,...,sl}
[0025] 其中,si=Teacher(Q,di),之后根据每个段落得分的高低对所有段落进行重排序,得到一个新的有序的段落集Dr={dr1,dr2,...,drl},其中sr1>sr2>…>srl,所有查询对应
的有序段落集构成新训练集Tnew。
[0026] 进一步地,步骤3)所述的学生模型的训练,具体为:
[0027] 首先,选择BERT‑DOT模型和ColBERT模型作为学生模型Student。
[0028] BERT‑DOT模型是BERT‑CAT模型的简化,将拼接操作改成了内积计算,其计算查询Q与段落d之间相关性的评分公式为:
[0029] rq=BERT([CLS;Q])1*W
[0030] rd=BERT([CLS;d])1*W
[0031] BERT‑DOT(Q,d)=ra·rd
[0032] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算。
[0033] 该BERT‑DOT模型检索效果相比于教师模型会稍差,但计算速度会大幅度提升。
[0034] ColBERT是BERT‑DOT的一种变体,其在顶层多加了一层最大池化的计算,显著提高了检索效果,其计算查询Q与段落d之间相关性的评分公式为:
[0035] rq=BERT([CLS;Q;rep(MASK)])1*W
[0036] rd=BERT([CLS;d])1*W
[0037]
[0038] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,rep(MASK)表示多个MASK词条拼接而成的词条集,下标1表示取CLS词条,
W表示一个权重矩阵,·表示内积运算。
[0039] 该ColBERT模型相对于BERT‑DOT模型能提升检索效果,计算速度与BERT‑DOT模型类似。
[0040] 之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用学生模型计算正例得分Pi以及负例得分Ni:
[0041] Pi=Student(Q,pi)
[0042] Ni=Student(Q,ni)
[0043] 其中,Student代表BERT‑DOT模型和ColBERT模型。
[0044] 之后通过正负例得分计算相应的交叉熵损失:
[0045]
[0046] 接着是计算重排序序列的列表置换损失函数。
[0047] 根据步骤2)所得的重排序段落训练集Tnew,对于每个查询Q,有对应的重排序后的段落集Dr={dr1,dr2,...,drl},段落di相对于查询Q使用教师模型所得到的分数si,满足sr1>
sr2>…>srl。
[0048] 使用学生模型重新计算所有段落相对于查询Q的得分,得到一个新的分数列表:
[0049] S′=Student(Q,Dr)={s′r1,s′r2,...,s′rl}
[0050] 根据该列表,得到查询置换的概率:
[0051]
[0052] 之后最大化每个查询置换概率的对数似然,即最小化列表置换损失函数:
[0053]
[0054] 最小化损失Loss2能够使学生模型计算出的同一个段落集对查询的相关性的排序结果更加接近教师模型,从而提升学生模型检索效果。
[0055] 最后,将两部分损失加权求和作为模型的损失:
[0056] Loss=Loss1+αLoss2
[0057] 其中,α为权重参数。
[0058] 进一步地,所述步骤4)具体为:
[0059] 利用学生模型进行信息检索。
[0060] 在步骤3)训练得到学生模型后,使用该学生模型对测试集中相应查询所对应的段落集进行重排序,获取排行最高的段落作为查询答案,以此来测试模型的效果。
[0061] 对于用户给定的问题,在语料库中初步筛选出相应段落,再用学生模型计算段落相对于问题的得分,根据得分的高低将相应用户所需要的答案量的答案提供给用户。
[0062] 本发明方法与现有技术相比具有的有益效果:
[0063] 1.本方法依靠人工智能方法进行信息检索,减少人工工作,更加系统、科学。
[0064] 2.本方法的流程可以依靠机器学习自动完成,无需人工干预,减轻用户负担。
[0065] 3.本方法在神经网络中引入知识蒸馏方法,可以充分利用教师模型的检索效果,以此优化了学生模型的检索效果。
[0066] 4.本方法预测准确率较高,能够较准确检索出用户想要检索出来的结果。
[0067] 5.本方法具有良好的可扩展性,针对不同领域,可以选用不同领域的检索数据进行训练,在不同领域的检索效果都可以得到响应的提升并且不会造成太多检索延时增加。

附图说明

[0068] 图1是本发明方法总体流程图;
[0069] 图2是本发明实施例提供的学生模型训练过程中知识蒸馏模型结构。

具体实施方式

[0070] 下面结合附图和具体实施例对本发明作进一步详细说明。
[0071] 如图1所示,本发明提供一种基于知识蒸馏的信息检索方法,包括以下步骤:
[0072] 1)训练教师模型:基于交叉熵损失函数,利用训练集T来训练教师模型。
[0073] 2)训练集段落重排序:使用步骤1)训练后的教师模型,对训练集Told中每个查询所对应的段落集进行相关性重排序,得到排序πT,并用重排序后的段落集构建新训练集Tnew。
[0074] 3)训练学生模型:利用训练集T,计算学生模型的交叉熵损失L1;然后,利用学生模型,对训练集Tnew中每个查询所对应的段落集进行相关性重排序,得到排序πS,再利用列表
置换损失函数计算πT与πS之间的差异损失L2;最后用L1和L2的加权和作为学生模型的最终损
失L,并通过最小化L来训练学生模型。
[0075] 4)利用学生模型进行信息检索:利用学生模型计算用户查询所对应的段落的评分,将评分最高的段落作为查询答案。
[0076] 进一步地,步骤1)所述的教师模型的训练,具体为:
[0077] 训练集T为 其中Qi表示查询,pi和ni为正负例,N为总的查询数量。首先,选择教师模型为BERT‑CAT模型,则教师模型计算查询Q与段落d之间相关性的评分
公式为:
[0078] Teacher(Q,d)=BERT‑CAT(Q,d)=BERT([CLS;Q;SEP;d])1*W
[0079] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS和SEP表示BERT中的特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵。
[0080] 之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用该教师模型计算正例得分Pi以及负例得分Ni:
[0081] Pi=Teacher(Qi,pi)
[0082] Ni=Teacher(Qi,ni)
[0083] 再通过正负例得分计算相应的交叉熵损失:
[0084]
[0085] 最后通过最小化交叉熵损失来优化教师模型,训练得到最终的教师模型。
[0086] 进一步地,步骤2)所述的训练集段落重排序,具体为:
[0087] 利用教师模型对训练集Told进行重排序。
[0088] 基于步骤1)所训练的教师模型Teacher,对于训练集Told中每个查询Q所对应的一个段落集D={d1,d2,...,dl},使用模型Teacher对所有段落进行相对于查询Q的打分:
[0089] S=Teacher(Q,D)={s1,s2,...,sl}
[0090] 其中,si=Teacher(Q,di),之后根据每个段落得分的高低对所有段落进行重排序,得到一个新的有序的段落集Dr={dr1,dr2,...,drl},其中sr1>sr2>…>srl,所有查询对应
的有序段落集构成新训练集Tnew。
[0091] 进一步地,步骤3)所述的学生模型的训练,如图2所示,具体为:
[0092] 首先,选择BERT‑DOT模型和ColBERT模型作为学生模型Student。
[0093] BERT‑DOT模型是BERT‑CAT模型的简化,将拼接操作改成了内积计算,其计算查询Q与段落d之间相关性的评分公式为:
[0094] rq=BERT([CLS;Q])1*W
[0095] rd=BERT([CLS;d])1*W
[0096] BERT‑DOT(Q,d)=ra·rd
[0097] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算。
[0098] 该BERT‑DOT模型检索效果相比于教师模型会稍差,但计算速度会大幅度提升。
[0099] ColBERT是BERT‑DOT的一种变体,其在顶层多加了一层最大池化的计算,显著提高了检索效果,其计算查询Q与段落d之间相关性的评分公式为:
[0100] rq=BERT([CLS;Q;rep(MASK)])1*W
[0101] rd=BERT([CLS;d])1*W
[0102]
[0103] 其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,rep(MASK)表示多个MASK词条拼接而成的词条集,下标1表示取CLS词条,
W表示一个权重矩阵,·表示内积运算。
[0104] 该ColBERT模型相对于BERT‑DOT模型能提升检索效果,计算速度与BERT‑DOT模型类似。
[0105] 之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用学生模型计算正例得分Pi以及负例得分Ni:
[0106] Pi=Student(Q,pi)
[0107] Ni=Student(Q,ni)
[0108] 其中,Student代表BERT‑DOT模型和ColBERT模型。
[0109] 之后通过正负例得分计算相应的交叉熵损失:
[0110]
[0111] 接着是计算重排序序列的列表置换损失函数。
[0112] 根据步骤2)所得的重排序段落训练集Tnew,对于每个查询Q,有对应的重排序后的段落集Dr={dr1,dr2,...,drl},段落di相对于查询Q使用教师模型所得到的分数si,满足sr1>
sr2>…>srl。
[0113] 使用学生模型重新计算所有段落相对于查询Q的得分,得到一个新的分数列表:
[0114] S′=Student(Q,Dr)={s′r1,s′r2,...,s′rl}
[0115] 根据该列表,得到查询置换的概率:
[0116]
[0117] 之后最大化每个查询置换概率的对数似然,即最小化列表置换损失函数:
[0118]
[0119] 最小化损失Loss2能够使学生模型计算出的同一个段落集对查询的相关性的排序结果更加接近教师模型,从而提升学生模型检索效果。
[0120] 最后,将两部分损失加权求和作为模型的损失:
[0121] Loss=Loss1+αLoss2
[0122] 其中,α为权重参数。
[0123] 进一步地,所述步骤4)具体为:
[0124] 利用学生模型进行信息检索。
[0125] 在步骤3)训练得到学生模型后,使用该学生模型对测试集中相应查询所对应的段落集进行重排序,获取排行最高的段落作为查询答案,以此来测试模型的效果。
[0126] 对于用户给定的问题,在语料库中初步筛选出相应段落,再用学生模型计算段落相对于问题的得分,根据得分的高低将相应用户所需要的答案量的答案提供给用户。
[0127] 实施例
[0128] 下面结合本发明的方法详细说明本实施例实施的具体步骤,如下:
[0129] 在本实施例中,将本发明的方法应用于MS MARCO数据集,对其中的查询的相关段落进行检索。
[0130] 1)训练集包含了640000项数据,其中包含320000的查询以及相对应的320000的正例和320000的负例。
[0131] 2)段落训练集包含800000项数据,其中包含80000的查询以及相对应的80000的正例和720000的负例。
[0132] 3)测试集总共有6669195项数据,其中包含6980个查询,每个查询对应平均1000个段落,大部分查询对应的段落都包含至少一个正例段落。
[0133] 将这些1)和2)的数据集按照本方法进行训练,其中α取值为0.1,在3)的测试集上进行测试,计算每个方法的mrr@10,Recall@50,Recall@200,Recall@1000这四个值,其结果
如表1所示。
[0134] 表1预测结果评估
[0135] 教师模型 mrr@10 Recall@50 Recall@200 Recall@1000BERT‑CAT 0.340 0.747 0.802 0.814
学生模型 mrr@10 Recall@50 Recall@200 Recall@1000
BERT‑DOT 0.239 0.657 0.772 0.814
BERT‑DOT+KD 0.253 0.668 0.776 0.814
ColBERT 0.313 0.726 0.798 0.814
ColBERT+KD 0.323 0.735 0.799 0.814
[0136] 对于知识蒸馏模型,不同α值的mrr@10结果如表2所示。
[0137] 表2模型对比结果(不同α)
[0138]
[0139]
[0140] 以上所述仅是本发明的优选实施方式,虽然本发明已以较佳实施例披露如上,然而并非用以限定本发明。任何熟悉本领域的技术人员,在不脱离本发明技术方案范围情况
下,都可利用上述揭示的方法和技术内容对本发明技术方案做出许多可能的变动和修饰,
或修改为等同变化的等效实施例。因此,凡是未脱离本发明技术方案的内容,依据本发明的
技术实质对以上实施例所做的任何的简单修改、等同变化及修饰,均仍属于本发明技术方
案保护的范围内。