DeFormer:分解预先训练好的Transformers以提高问题回答速度
-
DeFormer: Decomposing Pre-trained Transformers for Faster Question Answering
DeFormer:分解预先训练好的Transformers以提高问题回答速度
Abstract
BERTQA中应用了大量的自注意力,导致了模型训练的速度很慢并且占用大量内存。本文提出一个Deformer模型用于分解transformer,具体来说用较低层的question-wide和passage-wide的self-attention(分别计算self-attention)替代question-passage的全局注意力,由于Deformer与原始模型很相似,因此用原始的transformer的预训练权重初始化Deformer。速度提升了4倍,通过简单的知识蒸馏,准确度仅下降1%。
1 Introduction
基于Transformer的模型中,大部分的计算开销都是来自每一层的自注意力计算。在MRC式的QA中,算力主要是消耗在问题和passage的自注意力计算,虽然自注意力有助于模型创建高效的问题上下文表示,但是构建context表示需要更多的时间,因为context的长度总是比question长的多,如果context可以独立于问题进行处理,那么最难计算的context表示就减少了一部分和question的attention计算,可以加速QA过程。有研究表明:transformer较低层编码倾向于关注一些局部特征,如词形、语法等;较高层才逐渐编码与下游任务相关的全局语义信息(远距离信息)。也就是说:在较低层passage编码对question的依赖不高,因此本文采用的在较低层对question和context分别编码,在较高层联合处理(形成question-context联合表征进行交互编码),如图1所示:
假设n层模型中的前K个较低层独立的处理question和context,Deformer将两个第K层的表示作为输入馈送到第K+1层,这种方法很显然能减少运算量和内存。Deformer的上层应该生成与transformer相应层相同类型信息的表示,因此本文增加了两个蒸馏式损失,目的是用于最小化分解模型和原始模型之间的高层表征和分类层logits。
在三个QA数据集上进行评估模型,分别基于BERT和XLNet。速度提升了2.7 to 3.4倍,内存减少了65.8% to 72.9%,性能减少了0.6 to 1.8。BERT-large比BERT-base速度更快,精确度更高。
2 The Approach
基于transformer模型的MRC框架是计算question-context上的self-attention。这种方式产生了输入对的高效表示,因为从文本中提取什么信息通常取决于问题。想要降低复杂性,可以牺牲一些代表性能力来换取脱机处理文本的能力(脱机文件:一般指保存的网页,即不联网也能浏览网页的内容)。本文也测量了文本表示在与不同问题配对时的变化(计算了上下文与不同问题配对时的段落表征方差),得出结论:==在较低层中文本表示的变化不像在较高层中的那么大,这表明在较低层中忽略question-context的注意力计算影响不会太大==。先前的研究也表明:较低层倾向于对局部现象(词性、句法类别等)建模,较高层倾向于对依赖任务(实体共指)的更多语义现象进行建模。
2.1 DeFormer
定义两段文本表示Ta、TbT_a、T_bTa、Tb的配对任务的transformer计算。
TaTaTa嵌入的表示是:A=[a1;a2;…;aq]\mathrm{A}=[a_1;a_2;…;a_q]A=[a1;a2;…;aq]
TbT_bTb嵌入的表示是:B=[b1;b2;…;bp]\mathrm{B}=[b_1;b_2;…;b_p]B=[b1;b2;…;bp]
完整的输入序列X表示为:X=[A;B]\mathrm{X}=[\mathrm{A};\mathrm{B}]X=[A;B]
Transformer如果有n层,第i层表示为LiL_iLi,这些层按顺序转换表示为:
将第i层到第j层的层叠表示为:Li:jL_{i:j}Li:j完整transformer、An\mathrm{A}^nAn、Bn\mathrm{B}^nBn的输出为:
模型图3所示:简单的去除Ta\mathrm{T}_aTa和Tb\mathrm{T}_bTb表示之间的交叉交互,分解较低层计算(到第K层)。分解后的输出为:
基于transformer的问答系统通过一组自我关注层将输入问题和上下文一起处理。因此,将此分解应用于Transformer for QA允许我们独立处理问题和上下文文本,这反过来又允许我们离线计算较低层的上下文文本表示。在较低层的时间复杂度从O((p+q)2)O((p+q)^2)O((p+q)2)到O(q2+c)O(q^2+c)O(q2+c),其中ccc表示加载缓存表示的成本。
2.2 Auxiliary Supervision for DeFormer辅助监督
使用transformer预训练的权重参数训练DeFormer。由于section2.1在较低层将question和passage进行独立编码,因此会丢失一些信息,本文通过对上层微调弥补这一缺点,并且还添加了辅助损失,使DeFormer的预测及其上层表示==更接近于==transformer的预测和相应的层表示。
2.2.1 Knowledge Distillation Loss知识蒸馏损失
将分解transformer的预测分布PAP_APA和transformer预测分布PBP_BPB之间的KL散度最小化:
2.2.2 Layerwise Representation Similarity Loss
通过最小化分解transformer上层的token表示与transformer之间的欧几里得距离使得DeFormer的上层表示更接近于transformer的token表示
vij\mathrm{v}_i^jvij是transformer中第i层的第j个token的表示 uij\mathrm{u}_i^juij是Deformer中对应的表示2.2.3 最终损失
本文将知识蒸馏损失Lkd\mathcal{L_{kd}}Lkd和分层表示相似性损失Llrs\mathcal{L_{lrs}}Llrs与特定任务的监督损失Lts\mathcal{L_{ts}}Lts相加,通过超参数来调整它们的相对重要性,超参数使用贝叶斯优化来调整的,这种方法减少了寻找最优化参数组合所需要的步骤数。
3 Experiment
3.1 Datasets
SQuAD:众包工作者在维基百科上生成的超过10万个问答对
RACE:从英语考试中收集到的阅读理解数据集,旨在评估初中生的阅读和理解能力。超过28K个段落和100K个问题
BooIQ:15942个由是/否的问题组成的,这些问题出现在无提示、不受约束的环境中。
MNLI:是一个由43.3万个句子对组成的众包语料库,用文本蕴涵信息进行标注
QQP:由超过40W个来自Quora的潜在重复问题对组成
3.2 Implementation Details
作者基于原始的Bert和XLNet代码库实现了所有模型(TF1.15)。启用了bfloat16格式的TPU v3-8 node(8核,128 GB内存)上执行所有实验。
对于DeFormer-Bert和DeFormer-XLNet,通过离线计算其中一个输入段的表示并缓存它。对于问答,缓存段落;对于自然语言推理,缓存前提;对于问题相似度,缓存第一个问题
3.3 Result
表1显示了使用9个下层和3个上层时BERT-BASE和DeFormer-BERT-BASE的性能、推理速度和内存需求的主要比较结果
在所有数据集中观察到显著的加速比和显著的内存减少,同时保持了原始模型的大部分有效性,XLNet在同一表格中的结果证明了不同预先训练的transformer架构的分解有效性。
表2显示,在采用成对输入序列的QQP和MNLI数据集上,分解带来了2倍的推理加速和超过一半的内存减少。
3.3.1 Small Distilled or Large Decomposed?小蒸馏还是大蒸馏?
表3比较了BERT-BASE、BERT-LARGE和DeFormer-BERTLarge的性能、速度和内存。
DeFormer-Bert-Large比较小的Bert-Base模型快1.6倍。事实证明,分解较大的模型也比使用较小的基础模型(+2.3点)更有效。这表明,通过分解,大型transformer可以比尺寸为其一半的小型transformer运行得更快,同时也更精确。
3.4 Divergence of DeFormer and original BERT representations
最初的Bert和DeFormer-Bert之间的主要区别是在较低的层没有交叉注意。本文分析了这两个模型在所有层的表示之间的差异。
为此本文从SQuAD开发人员数据集中==随机选择了100个段落==,并<u>随机选择了与每个段落相关的数据集中已经存在的5个不同的问题</u>。对于每一篇文章,我们使用精调的原始BERT-BASE模型和DeFormerBERT-BERT模型对所有5个question-passage对序列进行编码,并==计算它们在每一层的向量表示的距离==。
图5显示了问题和文章在不同层次上的平均距离。
两个模型的pasage和question的低层表征保持相似,但上层表征有显著差异,缺乏交叉注意对低层的影响小于高层得到了证明。此外,使用上层的辅助监督可以强制DeFormer生成更接近原始模型的表示,从而达到预期效果。对于问题表征,这种影响不那么明显。
4 启示
- 损失函数超参数的调整能够减少算力开销
- 可以通过RNN进行替代,先分开编码再联合编码。先用RNN分开编码,再用BERT学习联合特征。
-
呱唧呱唧呱唧呱唧
-
@155-7220 给大佬鼓掌