Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    Siamese Network & Triplet NetWork

    语音识别与语义处理领域
    1
    1
    25
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by

      Siamese Network(孪生网络)

      简单来说,孪生网络就是共享参数的两个神经网络

      在孪生网络中,我们把一张图片X1X_1X1​作为输入,得到该图片的编码GW(X1)G_W(X_1)GW​(X1​)。然后,我们在不对网络参数进行任何更新的情况下,输入另一张图片X2X_2X2​,并得到改图片的编码GW(X2)G_W(X_2)GW​(X2​)。由于相似的图片应该具有相似的特征(编码),利用这一点,我们就可以比较并判断两张图片的相似性

      孪生网络的损失函数

      传统的Siamese Network使用Contrastive Loss(对比损失函数)
      $$
      \mathcal{L} = (1-Y)\frac{1}{2}(D_W)^2+(Y)\frac{1}{2}{max(0, m-D_W)}^2
      $$
      其中DWD_WDW​被定义为孪生网络两个输入之间的欧氏距离,即
      $$
      D_W = \sqrt{{G_W(X_1)-G_W(X_2)}^2}
      $$

      • YYY值为0或1,如果X1,X2X_1,X_2X1​,X2​这对样本属于同一类,则Y=0Y=0Y=0,反之Y=1Y=1Y=1
      • mmm是边际价值(margin value),即当Y=1Y=1Y=1,如果X1X_1X1​与X2X_2X2​之间距离大于mmm,则不做优化(省时省力);如果X1X_1X1​与X2X_2X2​之间的距离小于mmm,则调整参数使其距离增大到mmm
      Contrastive Loss代码
      import torch
      import numpy as np
      import torch.nn.functional as F
      
      class ContrastiveLoss(torch.nn.Module):
          "Contrastive loss function"
          def __init__(self, m=2.0):
              super(ContrastiveLoss, self).__init__()
              self.m = m
                  
          def forward(self, output1, output2, label):
              d_w = F.pairwise_distance(output1, output2)
              contrastive_loss = torch.mean((1-label) * 0.5 * torch.pow(d_w, 2) +
                                            (label) * 0.5 * torch.pow(torch.clamp(self.m - d_w, min=0.0), 2))
      
              return contrastive_loss
      

      其中,F.pairwise_distance(x1, x2, p=2)函数公式如下
      $$
      (\sum_{i=1}^n(|x_1-x_2|^p))^{\frac{1}{p}}\
      x_1,x_2 \in \mathbb{R}^{b\times n}
      $$

      pairwise_distance(x1, x2, p) Computes the batchwise pairwise distance between vectors x1x_1x1​, x2x_2x2​ using the p-norm

      孪生网络的用途

      简单来说,孪生网络的直接用途就是衡量两个输入的差异程度(或者说相似程度)。将两个输入分别送入两个神经网络,得到其在新空间的representation,然后通过Loss Function来计算它们的差异程度(或相似程度)

      • 词汇语义相似度分析,QA中question和answer的匹配
      • 手写体识别也可以用Siamese Network
      • Kaggle上Quora的Question Pair比赛,即判断两个提问是否为同一个问题
      Pseudo-Siamese Network(伪孪生网络)

      对于伪孪生网络来说,两边可以是不同的神经网络(如一个是lstm,一个是cnn),并且如果是相同的神经网络,是不共享参数的

      孪生网络和伪孪生网络分别适用的场景
      • 孪生网络适用于处理两个输入比较类似的情况
      • 伪孪生网络适用于处理两个输入有一定差别的情况

      例如,计算两个句子或者词汇的语义相似度,使用Siamese Network比较合适;验证标题与正文的描述是否一致(标题和正文长度差别很大),或者文字是否描述了一幅图片(一个是图片,一个是文字)就应该使用Pseudo-Siamese Network

      Triplet Network(三胞胎网络)

      如果说Siamese Network是双胞胎,那Triplet Network就是三胞胎。它的输入是三个:一个正例+两个负例,或一个负例+两个正例。训练的目标仍然是让相同类别间的距离尽可能小,不同类别间的距离尽可能大。Triplet Network在CIFAR,MNIST数据集上效果均超过了Siamese Network

      损失函数定义如下:
      $$
      \mathcal{L}=max(d(a,p)-d(a,n)+margin, 0)
      $$

      • aaa表示anchor图像
      • ppp表示positive图像
      • nnn表示negative图像

      我们希望aaa与ppp的距离应该小于aaa与nnn的距离。marginmarginmargin是个超参数,它表示d(a,p)d(a,p)d(a,p)与d(a,n)d(a,n)d(a,n)之间应该相差多少,例如,假设margin=0.2margin=0.2margin=0.2,并且d(a,p)=0.5d(a,p)=0.5d(a,p)=0.5,那么d(a,n)d(a,n)d(a,n)应该大于等于0.70.70.7

      Reference

      • 多种类型的神经网络(孪生网络)
      • Siamese network 孪生神经网络–一个简单神奇的结构
      • Siamese Network & Triplet Loss
      • A friendly introduction to Siamese Networks
      • Contrastive Loss
      1 Reply Last reply Reply Quote 1
      • First post
        Last post