ECCV2022 UniMiSS | 混合2D-3D医学自监督框架
-
论文标题:UniMiSS: Universal Medical Self-Supervised Learning via Breaking Dimensionality Barrier
论文地址:https://arxiv.org/abs/2112.09356
代码地址:https://github.com/YtongXie/UniMiSS-code
摘要
- 背景介绍: 自监督学习(SSL)为医学图像分析提供了巨大的机会,而医学图像分析因缺乏注释而闻名。
- 现存问题: 然而,由于其高成像成本和隐私限制,聚集大量(未标记)三维医学图像(如计算机断层扫描(CT))仍然具有挑战性。
- 解决方法: 作者主张利用大量的二维图像(如胸部X射线)来弥补三维数据的不足,旨在建立一个通用的医学自我监督表达学习框架,称为UniMiSS。下面的问题是如何打破维度障碍,即使2D和3D图像都可以执行SSL?为了实现这一点,设计了一个金字塔U型医学Transformer(MiT)。它由可切换patch嵌入(SPE)模块和transformer组成。SPE模块根据输入维度自适应地切换到2D或3Dpatch嵌入。嵌入的patch被转换为序列,而不管其原始尺寸如何。transformer对序列中的长期依赖性建模,从而使UniMiSS能够从2D和3D图像中学习表示。以MiT为主干,以自蒸馏方式执行UniMiSS。
- 实验结果: 在六个3D/2D医学图像分析任务上进行了丰富的实验,包括分割和分类。结果表明,所提出的UniMiSS在各种下游任务上都取得了令人满意的性能,大大优于ImageNet预训练和其他高级SSL对手。
算法
UniMiSS是一个通用的医学SSL框架,它精于学习具有大规模混合2D和3D未标记医学图像的通用图像表示。图2说明了UniMiSS框架。让我们用{,D2D,D3D}表示混合2D和3D数据池。为了使UniMiSS能够处理2D和3D医学图像,作者构建了MiT作为其主干,它主要由维度自适应SPE模块和transformer层组成。
以自蒸馏方式执行SSL过程,并利用标准交叉熵损失最大化学生和教师输出之间的一致性。此外,为了最大限度地利用三维体素信息,引入了体素切片一致性约束,这鼓励UniMiSS对跨维度的一致性进行建模。它直观地有助于从体素图像中学习强特征表示。现在深入研究这个框架的细节
MiT: A Dimension-free Architecture
尽管视觉transformer在计算机视觉方面取得了巨大成功,但由于高计算成本和内存需求,其在处理高分辨率3D图像方面仍然具有挑战性。作者设计了具有金字塔结构的MiT,以有效处理2D和3D图像。为了打破维度障碍,提出了一种简单而有效的SPE模块,根据输入类型自适应地选择2D或3Dpatch嵌入。MiT有一个编码器-解码器体系结构。现在MiT的每个部分。
SPE。如图2所示,SPE模块在获得维度特定嵌入方面发挥了重要作用,即对2D输入使用2Dpatch嵌入操作,对3D输入使用3Dpatch嵌入操作。请注意,编码器和解码器中SPE的实现是不同的。编码器中的SPE是指具有步长为2的2D/3D卷积块,降低特征分辨率。相反,解码器中的SPE是2D/3D转置卷积块,这增加了特征分辨率。
编码器和解码器。MiT编码器遵循渐进多层次金字塔transformer,它由四个阶段组成,每个阶段由一个SPE模块和几个堆叠transformer组成。在每个阶段,SPE模块对输入特征进行下采样,并生成特定于维度的嵌入序列。值得注意的是,在patch嵌入序列中附加了一个额外的可学习SSL token。SSL token类似于ViT中的[CLS]token,它能够通过自注意力聚合来自整个patch嵌入token的信息。输出序列与可学习位置嵌入相结合,输入到后续transformer中,用于长期依赖性建模。每个transformer层包括自注意力模块和具有两个隐藏层的前馈网络(FFN)。为了使MiT能够处理高分辨率图像,遵循空间缩减注意(SRA)层[45]。给定查询q、键k和值v作为输入,SRA首先降低k和v的空间分辨率,然后将q、约化k和约化v馈送到多头自注意力(MSA)层以产生精细特征。这个过程可以正式表达如下:
MiT有一个由三级组成的对称解码器结构。在每个阶段,输入特征图首先由SPE模型进行上采样,然后由堆叠transformer层进行细化。此外,还添加了编码器和解码器之间的跳跃连接,以保留更多低级别但高分辨率的信息。
Objective of UniMiSS
提出的UniMiSS框架基于学生-教师范式。每个路径包括一个MiT网络Fθ(·)和一个投影器Pθ(·)。Pθ(·)是一个n层多层感知器(MLP)头,θ表示该路径的参数集。SPE层在前馈计算期间切换到执行2Dpatch嵌入或3Dpatch嵌入,分别表示为Fθ(·;2D)和Fθ(·;3D)。在SSL过程中,仅从Fθ(·;2D/3D)的输出中提取SSL token作为投影器的输入。由于transformer将维度设置是自由的,所以UniMiSS能够从2D和3D未标记医学图像中学习图像表示。
二维数据的目标。以一小批2D数据x为例,首先使用数据采集模块T创建两个增强视图x1和x2,然后将它们输入学生和教师网络。将获得的SSL token输入投影器以产生输出向量,表示为f1=Pθ(Fθ(x1;2D)),f2=Pμ(Fμ(x2;2D))。UniMiSS的目标是最大化学生和教师网络的输出向量之间的一致性,公式如下:
3D数据的目标。在医学领域,3D体素可以被视为2D图像与层间维度的叠加。体素数据与其切片具有固有的一致性,这激励SSL建模体素-切片一致性。给定从3D医学数据集采样的3D数据x,将其两个增强视图表示为x1和x2,每个视图包含m个2D切片。以3D模式计算学生和教师网络的全局体素表示,即f1=Pθ(Fθ(x1;3D)),f2=Pμ(Fμ(x2;3D))。同时,在一批中堆叠每个增强视图的m个切片,并将它们用作2D输入,以计算2D模式下的切片表示,然后将所有切片的平均输出视为整体切片表示。之后,构建以下目标函数
实验
Results on 3D downstream tasks
Results on 2D downstream tasks
Effectiveness of volume-slice consistency
MiT with different Transformer scales
Transferability on unseen modality data
Visualization of Segmentation Results