Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    医学图像分割——Mixed Transformer UNet(MT-UNet)

    CV领域
    4
    4
    158
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 189****6672
      189****6672 last edited by

      医学图像分割——Mixed Transformer UNet(MT-UNet)

      论文地址:https://arxiv.org/pdf/2111.04734.pdf

      代码地址:https://github.com/Dootmaan/MT-UNet

      Mixed Transformer U-Net For Medical Image Segmentation

      表现SOTA!性能优于Swin-Unet、TransUNet等网络


      摘要

      存在问题 虽然U-Net在医学图像分割方面取得了巨大的成功,但它缺乏对长期依赖关系进行显式建模的能力。视觉Transformer由于其固有的通过自注意(SA)捕捉长程相关性的能力,近年来成为一种可替代的分割结构。
      存在问题 然而,Transformer通常依赖于大规模的预训练,具有较高的计算复杂度。此外,SA只能在单个样本中建模self-affinities,忽略了整个数据集的潜在相关性
      论文方法 提出了一种新的混合Transformer模块(MTM),用于同时进行inter-affinities学习和intra-affinities学习。MTM首先通过局部-全局高斯加权自注意(LGG-SA)有效地计算窗口内部affinities。然后,通过外部注意挖掘数据样本之间的联系。利用MTM算法,构造了一种用于医学图像分割的MT-UNet模型

      Method

      如图1所示。该网络基于编码器-解码器结构

      1. 为了降低计算成本,MTMs只对空间大小较小的深层使用,
      2. 浅层仍然使用经典的卷积运算。这是因为浅层主要关注局部信息,包含更多高分辨率的细节。

      MTM

      如图2所示。MTM主要由LGG-SA和EA组成。

      LGG-SA用于对不同粒度的短期和长期依赖进行建模,而EA用于挖掘样本间的相关性。

      该模块是为了替代原来的Transformer编码器,以提高其在视觉任务上的性能和降低时间复杂度

      LGG-SA(Local-Global Gaussian-Weighted Self-Attention)

      传统的SA模块对所有tokens赋予相同的关注度,而LGG -SA则不同,利用local-global自注意力和高斯mask使其可以更专注于邻近区域。实验证明,该方法可以提高模型的性能,节省计算资源。该模块的详细设计如图3所示

      local-global自注意力

      在计算机视觉中,邻近区域之间的相关性往往比遥远区域之间的相关性更重要,在计算注意图时,不需要为更远的区域花费相同的代价。

      因此,提出local-global自注意力。

      1. 上图stage1中的每个局部窗口中含有四个token,local SA计算每个窗口内的内在affinities。
      2. 每个窗口中的token被aggregate聚合为一个全局token ,表示窗口的主要信息。对于聚合函数,轻量级动态卷积(Lightweight Dynamic convolution, LDConv)的性能最好。
      3. 在得到下采样的整个特征图后,可以以更少的开销执行global SA(上图stage2)。

      其中X∈RH×W×CX \in R^{H \times W \times C}X∈RH×W×C

      其中,stage1中的局部窗口自注意力代码如下:

      class WinAttention(nn.Module):
          def __init__(self, configs, dim):
              super(WinAttention, self).__init__()
              self.window_size = configs["win_size"]
              self.attention = Attention(dim, configs)
      
          def forward(self, x):
              b, n, c = x.shape
              h, w = int(np.sqrt(n)), int(np.sqrt(n))
              x = x.permute(0, 2, 1).contiguous().view(b, c, h, w)
              if h % self.window_size != 0:
                  right_size = h + self.window_size - h % self.window_size
                  new_x = torch.zeros((b, c, right_size, right_size))
                  new_x[:, :, 0:x.shape[2], 0:x.shape[3]] = x[:]
                  new_x[:, :, x.shape[2]:,
                        x.shape[3]:] = x[:, :, (x.shape[2] - right_size):,
                                         (x.shape[3] - right_size):]
                  x = new_x
                  b, c, h, w = x.shape
              x = x.view(b, c, h // self.window_size, self.window_size,
                         w // self.window_size, self.window_size)  
              x = x.permute(0, 2, 4, 3, 5,
                            1).contiguous().view(b, h // self.window_size,
                                                 w // self.window_size,
                                                 self.window_size * self.window_size,
                                                 c).cuda()
              x = self.attention(x)  #  (b, p, p, win, c) 对局部窗口内的tokens进行自注意力计算
              return x
      

      聚合函数代码如下

      class DlightConv(nn.Module):
          def __init__(self, dim, configs):
              super(DlightConv, self).__init__()
              self.linear = nn.Linear(dim, configs["win_size"] * configs["win_size"])
              self.softmax = nn.Softmax(dim=-1)
      
          def forward(self, x):  # (b, p, p, win, c)
              h = x
              avg_x = torch.mean(x, dim=-2)  # (b, p, p, c)
              x_prob = self.softmax(self.linear(avg_x))  # (b, p, p, win)
      
              x = torch.mul(h,
                            x_prob.unsqueeze(-1))  # (b, p, p, win, c) 
              x = torch.sum(x, dim=-2)  # (b, p, p, c)
              return x
      

      Gaussian-Weighted Axial Attention

      与使用原始SA的LSA不同,提出了高斯加权轴向注意(GWAA)的方法。GWAA通过一个可学习的高斯矩阵增强了相邻区域的感知全权重,同时由于具有轴向注意力而降低了时间复杂度。

      1. 上图中stage2中特征图的第三行第三列特征进行linear projection得到qi,jq_{i, j}qi,j​
      2. 将该特征点所在行和列的所有特征分别进行linear projection得到Ki,jK_{i, j}Ki,j​和Vi,jV_{i, j}Vi,j​
      3. 将该特征点与所有的K和V的欧式距离定义为Di,jD_{i, j}Di,j​

      最终的高斯加权轴向注意力输出结果为

      并简化为

      轴向注意力代码如下:

      class Attention(nn.Module):
          def __init__(self, dim, configs, axial=False):
              super(Attention, self).__init__()
              self.axial = axial
              self.dim = dim
              self.num_head = configs["head"]
              self.attention_head_size = int(self.dim / configs["head"])
              self.all_head_size = self.num_head * self.attention_head_size
      
              self.query_layer = nn.Linear(self.dim, self.all_head_size)
              self.key_layer = nn.Linear(self.dim, self.all_head_size)
              self.value_layer = nn.Linear(self.dim, self.all_head_size)
      
              self.out = nn.Linear(self.dim, self.dim)
              self.softmax = nn.Softmax(dim=-1)
      
          def transpose_for_scores(self, x):
              new_x_shape = x.size()[:-1] + (self.num_head, self.attention_head_size)
              x = x.view(*new_x_shape)
              return x
      
          def forward(self, x):
              # first row and col attention
              if self.axial:
                   # x: (b, p, p, c)
                  # row attention (single head attention)
                  b, h, w, c = x.shape
                  mixed_query_layer = self.query_layer(x)
                  mixed_key_layer = self.key_layer(x)
                  mixed_value_layer = self.value_layer(x)
      
                  query_layer_x = mixed_query_layer.view(b * h, w, -1)
                  key_layer_x = mixed_key_layer.view(b * h, w, -1).transpose(-1, -2)  # (b*h, -1, w)
                  attention_scores_x = torch.matmul(query_layer_x,
                                                    key_layer_x)  # (b*h, w, w)
                  attention_scores_x = attention_scores_x.view(b, -1, w,
                                                               w)  # (b, h, w, w)
      
                  # col attention  (single head attention)
                  query_layer_y = mixed_query_layer.permute(0, 2, 1,
                                                            3).contiguous().view(
                                                                b * w, h, -1)
                  key_layer_y = mixed_key_layer.permute(
                      0, 2, 1, 3).contiguous().view(b * w, h, -1).transpose(-1, -2)  # (b*w, -1, h)
                  attention_scores_y = torch.matmul(query_layer_y,
                                                    key_layer_y)  # (b*w, h, h)
                  attention_scores_y = attention_scores_y.view(b, -1, h,
                                                               h)  # (b, w, h, h)
      
                  return attention_scores_x, attention_scores_y, mixed_value_layer
      
              else:
                
                  mixed_query_layer = self.query_layer(x)
                  mixed_key_layer = self.key_layer(x)
                  mixed_value_layer = self.value_layer(x)
      
                  query_layer = self.transpose_for_scores(mixed_query_layer).permute(
                      0, 1, 2, 4, 3, 5).contiguous()  # (b, p, p, head, n, c)
                  key_layer = self.transpose_for_scores(mixed_key_layer).permute(
                      0, 1, 2, 4, 3, 5).contiguous()
                  value_layer = self.transpose_for_scores(mixed_value_layer).permute(
                      0, 1, 2, 4, 3, 5).contiguous()
      
                  attention_scores = torch.matmul(query_layer,
                                                  key_layer.transpose(-1, -2))
                  attention_scores = attention_scores / math.sqrt(
                      self.attention_head_size)
                  atten_probs = self.softmax(attention_scores)
      
                  context_layer = torch.matmul(
                      atten_probs, value_layer)  # (b, p, p, head, win, h)
                  context_layer = context_layer.permute(0, 1, 2, 4, 3,
                                                        5).contiguous()
                  new_context_layer_shape = context_layer.size()[:-2] + (
                      self.all_head_size, )
                  context_layer = context_layer.view(*new_context_layer_shape)
                  attention_output = self.out(context_layer)
      
              return attention_output
      

      高斯加权代码如下:

      class GaussianTrans(nn.Module):
          def __init__(self):
              super(GaussianTrans, self).__init__()
              self.bias = nn.Parameter(-torch.abs(torch.randn(1)))
              self.shift = nn.Parameter(torch.abs(torch.randn(1)))
              self.softmax = nn.Softmax(dim=-1)
      
          def forward(self, x): 
              x, atten_x_full, atten_y_full, value_full = x  #x(b, h, w, c) atten_x_full(b, h, w, w)   atten_y_full(b, w, h, h) value_full(b, h, w, c)
              new_value_full = torch.zeros_like(value_full)
      
              for r in range(x.shape[1]):  # row
                  for c in range(x.shape[2]):  # col
                      atten_x = atten_x_full[:, r, c, :]  # (b, w)
                      atten_y = atten_y_full[:, c, r, :]  # (b, h)
      
                      dis_x = torch.tensor([(h - c)**2 for h in range(x.shape[2])
                                            ]).cuda()  # (b, w)
                      dis_y = torch.tensor([(w - r)**2 for w in range(x.shape[1])
                                            ]).cuda()  # (b, h)
      
                      dis_x = -(self.shift * dis_x + self.bias).cuda()
                      dis_y = -(self.shift * dis_y + self.bias).cuda()
      
                      atten_x = self.softmax(dis_x + atten_x)
                      atten_y = self.softmax(dis_y + atten_y)
      
                      new_value_full[:, r, c, :] = torch.sum(
                          atten_x.unsqueeze(dim=-1) * value_full[:, r, :, :] +
                          atten_y.unsqueeze(dim=-1) * value_full[:, :, c, :],
                          dim=-2)
              return new_value_full
      

      local-global自注意力完整代码如下:

      class CSAttention(nn.Module):
          def __init__(self, dim, configs):
              super(CSAttention, self).__init__()
              self.win_atten = WinAttention(configs, dim)
              self.dlightconv = DlightConv(dim, configs)
              self.global_atten = Attention(dim, configs, axial=True)
              self.gaussiantrans = GaussianTrans()
              #self.conv = nn.Conv2d(dim, dim, 3, padding=1)
              #self.maxpool = nn.MaxPool2d(2)
              self.up = nn.UpsamplingBilinear2d(scale_factor=4)
              self.queeze = nn.Conv2d(2 * dim, dim, 1)
      
          def forward(self, x):
              '''
              :param x: size(b, n, c)
              :return:
              '''
              origin_size = x.shape
              _, origin_h, origin_w, _ = origin_size[0], int(np.sqrt(
                  origin_size[1])), int(np.sqrt(origin_size[1])), origin_size[2]
              x = self.win_atten(x)  # (b, p, p, win, c)
              b, p, p, win, c = x.shape
              h = x.view(b, p, p, int(np.sqrt(win)), int(np.sqrt(win)),
                         c).permute(0, 1, 3, 2, 4, 5).contiguous()
              h = h.view(b, p * int(np.sqrt(win)), p * int(np.sqrt(win)),
                         c).permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)
      
              x = self.dlightconv(x)  # (b, p, p, c)
              atten_x, atten_y, mixed_value = self.global_atten(
                  x)  # (b, h, w, w) (b, w, h, h) (b, h, w, c)这里的h w就是p
              gaussian_input = (x, atten_x, atten_y, mixed_value)
              x = self.gaussiantrans(gaussian_input)  # (b, h, w, c)
              x = x.permute(0, 3, 1, 2).contiguous()  # (b, c, h, w)
      
              x = self.up(x)
              x = self.queeze(torch.cat((x, h), dim=1)).permute(0, 2, 3,
                                                                1).contiguous()
              x = x[:, :origin_h, :origin_w, :].contiguous()
              x = x.view(b, -1, c)
      
              return x
      

      EA

      外部注意(External Attention, EA),是用于解决SA无法利用不同输入数据样本之间关系的问题。

      与使用每个样本自己的线性变换来计算注意分数的自我注意不同,在EA中,所有的数据样本共享两个记忆单元MK和MV(如图2所示),描述了整个数据集的最重要信息。

      ​

      EA代码如下:

      class MEAttention(nn.Module):
          def __init__(self, dim, configs):
              super(MEAttention, self).__init__()
              self.num_heads = configs["head"]
              self.coef = 4
              self.query_liner = nn.Linear(dim, dim * self.coef)
              self.num_heads = self.coef * self.num_heads
              self.k = 256 // self.coef
              self.linear_0 = nn.Linear(dim * self.coef // self.num_heads, self.k)
              self.linear_1 = nn.Linear(self.k, dim * self.coef // self.num_heads)
      
              self.proj = nn.Linear(dim * self.coef, dim)
      
          def forward(self, x):
              B, N, C = x.shape
              x = self.query_liner(x)  # (b, n, 4c)
              x = x.view(B, N, self.num_heads, -1).permute(0, 2, 1,
                                                           3)  #  (b, h, n, 4c/h)
      
              attn = self.linear_0(x)  # (b, h, n, 256/4)
      
              attn = attn.softmax(dim=-2)  # (b, h, 256/4)
              attn = attn / (1e-9 + attn.sum(dim=-1, keepdim=True))  # (b, h, 256/4)
      
              x = self.linear_1(attn).permute(0, 2, 1, 3).reshape(B, N, -1)
      
              x = self.proj(x)
      
              return x
      

      EXPERIMENTS

      183****8515 131****7225 1 3 Replies Last reply Reply Quote 3
      • 1
        180****7509 @189****6672 last edited by

        @189-6672 厉害

        1 Reply Last reply Reply Quote 0
        • 131****7225
          131****7225 @189****6672 last edited by

          👏 👏 👏

          1 Reply Last reply Reply Quote 0
          • 183****8515
            183****8515 @189****6672 last edited by

            @189-6672 欢迎欢迎👏👏

            1 Reply Last reply Reply Quote 0
            • First post
              Last post