NeurIPS 2022 | 用于医学图像分割的生成对抗Transformer模型
-
论文标题:Class-Aware Generative Adversarial Transformers for Medical Image Segmentation
论文地址:https://arxiv.org/pdf/2201.10737.pdf
论文代码:暂无
摘要
- 背景介绍: transformer在医学图像分析领域的建模长距离依赖方面取得了显着的进步。
- 现存问题: 当前基于transformer的模型具有几个缺点:(1)由于naive tokens方案,导致现有方法无法捕获重要特征; (2)模型遭受信息丢失的影响,因为它们仅考虑单尺度特征表示; (3)由于没有考虑丰富的语义环境和解剖学上下文,模型生成的分割图还不够准确。
- 解决方法: 作者提出了Castformer,这是一种新型的生成对抗transformer,用于2D医疗图像分割。首先,利用金字塔结构来构建多尺度表示并处理多尺度变化。然后,设计了一种新颖的多尺度transformer模块,以更好地学习具有语义结构的对象区域。最后,采用了一种对抗性训练策略,该策略提高了分割精度,并相应地允许基于transformer的判别器捕获高级语义特征和低级解剖学特征。
- 实验结果:在三个基准上,Castformer极大地优于先前基于最新的transformer方法,获得了2.54%-5.88%的绝对改善。进一步的定性实验提供了模型内部运作的更详细图片,并证明了迁移学习可以大大提高训练中的性能并减少医疗图像数据集的大小。
算法
我们提出的方法如图1所示。给定输入图像x∈RH×W×3,类似于Transunet架构,生成器G(称为CATformer)由四个关键组件组成:编码器模块,类感知的transformer模块,transformer编码器模块和解码器模块。如图1所示,G具有四个阶段。所有阶段共享一个类似的体系结构,其中包含一个patch嵌入层,类感知层和Li Transformer编码器层。
编码器模块。采用CNN-Transformer混合模型设计。这样的设置提供了两个优点:(1)使用卷积主干有助于transformer在下游视觉任务中表现更好;(2)它提供了具有多分辨率特征图,以帮助提高更好的表示。通过这种方式,可以为transformer构造特征金字塔,并利用多尺度特征图用于下游医疗分割任务。借助不同分辨率的特征图,能够对多分辨率的空间局部上下文进行建模。
分层特征表示。重点是提取CNN的多级特征Fi,其中i∈{1,2,3,4},通过利用高分割精度和低分辨率来实现高分割精度。更确切地说,在第一阶段,利用编码器模块获得密集的特征映射F1∈RH/2×W/2×C1。以类似的方式,可以制定以下特征图如下:()F2∈RH/2×w/2×(C1·4),()F3∈RH/4×W/4×(C1·8)和()F4∈RH/8×W/8×(C1·8)。然后,我们将F1个划分HW/162个大小为16×16×3的path P,并将扁平的patches馈入可学习的线性转换中,以获得大小HW/162×C1的patch嵌入。
类感知的transformer模块。类感知的transformer模块(CAT)旨在适应物体的有用区域(例如,基本的解剖特征和结构信息)。CAT模块具有以下特点:(1)使用4个独立的transformer编码器模块(TEM),将在下面介绍;(3)将M个 CAT模块合并到多尺度表示上,以允许解剖特征的上下文信息传播到表示中。类感知transformer模块是一个迭代优化过程。特别是,应用了类感知transformer模块来获得tokens序列IM,1∈RC×(n×n)的序列,其中(n×n)和M分别表示每个特征图上的采样数量和总迭代次数。如图2所示,给定特征映射F1,通过将其添加到最后一步的估计偏移量来迭代地更新其采样位置,可以表达为如下
其中st∈R2×(n×n)和ot∈R2×(n×n)是采样位置和预测的偏移矢量。具体而言,S1在间距采样网格上初始化。第i个采样位置s1i定义如下:
可以以以下形式定义输入特征映射上的初始tokens:()It′=Fi(st)。将采样函数设置为双线性插值。然后我们可以在每个步骤中获得输出令牌:
其中St∈RC×(n×n)是个位置嵌入。
估计的采样位置偏置为:
transformer编码器模块。transformer编码器模块(TEM)旨在通过从嵌入的输入图像patch的完整序列中汇总全局上下文信息来对远程上下文信息进行建模。在实现中,transformer编码器模块遵循VIT 的体系结构,该体系结构由多头自注意力(MSA)和MLP块组成,公式如下:
LN(·)是层归一化。()h∈R(P2·C)×D和Hpos∈RN×D表示patch嵌入投影和位置嵌入。
解码器模块。该解码器旨在基于不同分辨率的四个输出特征图生成分割掩码。在实施中,结合了轻量级的全MLP解码器,而这样简单的设计能够更有效地产生强大的表示。解码器包括以下设置:1)多尺度特征的通道维度通过MLP层统一;3)利用MLP层融合串联特征,然后从融合特征中预测多类别分割掩码Y’。
判别网络。将在ImageNet上的预训练R50+VIT-B/16混合模型作为判别器设计的初始化,在这种情况下,使用预训练的策略来有效地学习有限的尺寸目标任务数据。然后,只需应用两层多层感知器(MLP)就可以预测类感知图像的真假。
首先利用输入图像X和预测的分割掩码Y’获得了类感知的图像
(即,x和y’的像素乘积)。要注意,该构造重新使用预训练的权重,并且不引入任何其他参数。 D试图在真假输入之间进行判别。 G和D通过试图达到Minimax的平衡点相互竞争。使用此结构使鉴别器能够对远程依赖性进行建模,从而更好地评估医疗图像保真度。从本质上讲,这也赋予该模型对解剖视觉模态(分类特征)的更全面理解。
训练目标。至于损失函数和训练配置,采用了Wasserstein Gan(WGAN)中使用的设置,并使用WGAN-GP损失,并将WGAN-GP损失用于训练G。具体而言,分割损失包括Dice损失和交叉熵损失。因此,可以将CastFormer的训练过程表达为:
实验
Synapse Multi-organ
LiTS
Transfer Learning
Ablation of Model Components