Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    Virtual Data Augmentation: 虚拟数据扩增技术【EMNLP 2021】

    语音识别与语义处理领域
    2
    2
    75
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by Alice_恒源云

      听说过数据扩增(Data Augmentation),也听说过虚拟对抗训练(Virtual Adversarial Traning),但是我没想到会有人将其结合,谓之虚拟数据扩增(Virtual Data Augmentation)。这篇文章主要讲解EMNLP2021上的一篇论文Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models,该论文提出了一种鲁棒且通用的数据扩增方法,论文源码在https://github.com/RUCAIBox/VDA

      论文开篇提到目前数据扩增存在的主要问题:产生数据多样性的同时如何保证其仍然在同一个语义空间中?简单地说,增强数据扩增的多样性很容易,核心就一个字:“乱”,例如许多数据扩增方法会随机打乱一个句子中token的位置,或者是随机删除某些token,随机插入某些token。这样虽然增强了样本的多样性,但是语义可能也会产生非常大的变化,甚至不再与原样本的语义相同。保持语义不变,或者说保证扩增后的样本和原样本在同一个语义空间中很容易,核心就是:“不要太乱”,例如通过同义词替换等,这种方法可以做到几乎不改变语义,但是数据多样性却不够,因为本质上还是同一句话

      这两个需求实际上是矛盾的,我们所能做的只是尽力达到某种平衡。具体来说,作者所提出的方法包含两个重要部分:Embedding Augmentation以及Regularized Training

      Embedding Augmentation

      假设现在我们有句子「Time is enough for test」,对于每个位置的token,我们都可以将其替换为[MASK],然后通过MLM预测Vocabulary中所有token在该位置的概率,例如

      [MASK] is enough for test

      [MASK]位置输出的token及其概率为

      Time  p=0.5
      Day   p=0.3
      Hours p=0.15
      ...
      

      再比如

      Times is enough for [MASK]

      [MASK]位置输出的token及其概率为

      test       p=0.5
      evaluation p=0.3
      experiment p=0.1
      ...
      

      看到这里大家脑海中可能已经有了一个数据扩增的想法,就是利用MLM任务对句子中每个位置的token进行预测,然后根据预测概率随机挑选出一个token进行替换,例如上面的句子可能就会被替换为「Hours is enough for evaluation」。这确实是一种还不错的数据扩增方法,但是论文作者却并不是这么做的

      为了描述简单,我们仅讨论对于给定句子SSS中的一个token w~\tilde{w}w~进行扩增的情况(实际上句子SSS中的所有token都会进行该操作),通过MLM任务我们可以预测出Vocabulary中所有单词在w~\tilde{w}w~位置的概率

      771537e0-9790-4d03-8a38-d88dbbdf36a1-image.png

      其中,VVV是Vocabulary中的token数量

      为了增强数据扩增的多样性,或者说引入某些噪声以增强抗干扰性,我们从高斯分布中随机采样出一个向量

      2b258954-4bfb-4069-800c-280de08e08b2-image.png

      将该向量与公式(1)的概率分布进行混合,我们可以得到一个新的概率分布

      6c75c4fc-593f-4375-a33f-2c08dc0aec14-image.png

      然后对于每个即将被替换的token w~\tilde{w}w~,我们根据概率p’(w^i∣S)p’(\hat{w}_i\mid S)p’(w^i​∣S)加权融合所有token w^i\hat{w}_iw^i​的Embedding向量

      817ad86e-4c1c-4702-9008-a9b5f31e65ee-image.png

      其中,pw~=p’(w^i∣S)i=1V\mathbf{p}_{\tilde{w}}={p’(\hat{w}_i\mid S)}_{i=1}^Vpw~​=p’(w^i​∣S)i=1V​,ME∈RV×d\mathbf{M}_E\in \mathbb{R}^{V\times d}ME​∈RV×d是MLM模型的词向量矩阵

      举个简单的例子解释一下,为了方便,同样还是以替换一个token为例,并且整个Vocabulary只有4个token,词向量的维度为2。首先我们有一句话「She is a good student」,将「good」进行MASK,然后通过MLM模型,预测出概率分布为

      1ddad683-66b0-4550-b2d9-b5ccc72eba30-image.png

      从左到右分别是good, perfect, excellent, smart的概率,根据高斯分布N(0,σ2)\mathcal{N}(0, \sigma^2)N(0,σ2)随机产生的向量为

      baf88dc6-e1a3-4e2a-82ea-e943a47bda89-image.png

      这里我并没有具体指明方差σ2\sigma^2σ2到底是多少,因为我懒得算

      将p(w^i∣S)p(\hat{w}_i\mid S)p(w^i​∣S)与ϵ\epsilonϵ混合后进行Softmax得到新的概率分布为

      e0001421-13b1-42d6-9e3a-d6677c1c9555-image.png

      假设Embedding矩阵为

      1c6451e0-6914-4341-86b7-6c7e9642661c-image.png

      那么最终「good」这个位置对应的embedding为

      c026d116-6285-4777-879c-a5420f256e62-image.png

      到此为止,不知道大家有没有体会到什么叫「Virtual Data Augmentation」,Virtual本质上就是不用一个真实的token去替换,而是使用一个embedding去替换,而如果你用这个embedding去反查ME\mathbf{M}_EME​矩阵一般是找不到对应的索引的,也就是说我们生成的这个embedding并不对应一个实际存在的token

      Regularized Traning

      标题起的很有故事,但本质上就是多引入了一个损失函数,具体来说,现在我们的优化目标为

      caff2c2d-fe20-4123-89fa-0f9bf6ce532c-image.png

      其中fff表示含有参数θ\thetaθ的预训练模型,nnn为样本个数,kkk表示由一条句子扩增出了kkk条句子。具体来说,如果是分类任务,则

      b2e0510e-8294-4480-a8a9-f884fef41a28-image.png

      其中,CE(⋅,⋅)\text{CE}(\cdot ,\cdot)CE(⋅,⋅)是Cross-Entropy Loss,可以根据具体任务替换的,Ei\mathbf{E}_iEi​表示第iii条句子通过Word2Vec之后生成的向量,其维度为[seq_len, emd_dim]

      为了防止扩增后的样本与原始样本间的语义产生巨大差距,换句话说,我们希望扩增后的样本与原样本间的分布是接近的,因此论文引入了KL散度作为第二项损失

      01ef3012-d822-4ef1-829d-1ca065362e06-image.png

      其中,kkk指的是原样本扩增出了kkk个样本,DsKLD_{sKL}DsKL​是对称的KL散度,具体来说

      0146f2d1-85fb-4c2f-843b-6b8e265aa7b7-image.png

      实际上这种方法可以看作是多任务,我们希望模型参数训练到一种境界,这种境界是,不论模型对原样本进行下游任务,还是让模型判断原样本与扩增样本的差距,模型都能做的很好。最后给出论文中的一张图结束这部分(图中一个样本扩增了3条样本)

      Results

      如果单看原始的准确率对比,似乎提升并不是很大,感觉我随便引入一些trick都能达到甚至超过Virtual Data Augmentation的效果。关键在于第二列「Att Acc」,这代表模型受到攻击时的结果,这部分的提升特别大,表明VDA这种方法确实有很强的抗干扰性,或者说鲁棒性很强

      个人总结

      实际上前面已经把这篇论文讲的很清楚了,这里没有什么好总结的,但我倒是有一点个人拙见想和大家讨论一下,因为他做MLM任务时,将整个Vocabulary都作为候选集,这样无论是对计算速度还是显存占用都不是很友好,我觉得可以将其改为取出概率最大的前Top k个token,这个k可以取的稍微大一点,例如200, 300等,这样可以保证取到后面一些语义上不那么相近的token的同时,避免对整个Vocabulary进行运算,至少不会生成几万几十万那么夸张的概率分布

      1 Reply Last reply Reply Quote 3
      • 183****8515
        183****8515 last edited by

        真心的给大佬献上膝盖~

        1 Reply Last reply Reply Quote 1
        • Referenced by  Alice_恒源云 Alice_恒源云 
        • Referenced by  Alice_恒源云 Alice_恒源云 
        • Referenced by  Alice_恒源云 Alice_恒源云 
        • First post
          Last post