UNETR++: 高精度3D医学图像分割
-
论文标题:UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation
论文地址:https://arxiv.org/pdf/2212.04497.pdf
代码地址:https://tinyurl.com/2p87x5xn
- 背景介绍: 由于transformer模型的成功,最近的工作研究了其在3D医学分割任务中的适用性。与基于局部卷积的设计相比,在transformer模型中,自注意力机制是捕获长距离依赖性的主要构建块之一。
- 现存问题: 自注意力操作具有二次复杂性,这被证明是一个计算瓶颈,特别是在体素医学成像中,其中输入是具有多个切片的3D结构。
- 解决方法: 作者提出了一种名为UNETR++的3D医学图像分割方法,该方法提供了高质量的分割掩模以及参数和计算成本方面的效率。设计的核心是引入一种新的高效成对注意力(EPA)块,该块使用一对基于空间和通道注意力的相互依赖分支来有效地学习空间和通道区分特征。对模拟的空间注意力是有效的,相对于输入序列长度具有线性复杂性。为了实现空间和通道分支之间的通信,共享查询和键映射的权重,提供了互补的好处(成对关注),同时还减少了总体网络参数。
- 实验结果: 对Synapse、BTCV和ACDC这三个基准进行了广泛的评估,从效率和准确性两方面证明了方法的有效性。在Synapse数据集上,UNETR++以87.2%的Dice相似性得分创下了新的最先进水平,同时与文献中最好的现有方法相比,在参数和FLOP方面都显著降低了71%。
算法
整体架构
图2展示了UNETR++架构,包括分层编码器-解码器结构。UNETR++框架是基于UNETR模型。UNETR++没有在编码器中使用固定的特征分辨率,而是采用分层设计,其中特征分辨率在每个阶段逐渐降低两倍。在UNETR++框架内,编码器有四个阶段,其中第一阶段包括将体素输入分割成3D块的块嵌入,然后是新型有效成对注意力(EPA)块。在patch嵌入中,将每个3D输入(体素)x∈ RH×W×D转换为非重叠patches (,,)xu∈RN×(P1,P2,P3),其中(P1,P1,P3)是每个patch的分辨率。然后,将patches投影到C通道维度,生成尺寸为H/P1×W/P2×D/P3×C的特征图。对于剩余的每个编码器层级,使用非重叠卷积的下采样层来将分辨率降低两倍,然后是EPA块。
每个EPA块包括两个注意模块,以通过使用共享k,q编码空间和通道啊维度的信息来有效地学习丰富的空间通道特征表示。编码器层级通过跳过连接与解码器层级连接,以合并不同分辨率的输出。这使得能够恢复在下采样操作期间丢失的空间信息,从而预测更精确的输出。与编码器类似,解码器也包括四个阶段,其中每个解码器阶段包括使用反卷积将特征图的分辨率提高两倍的上采样层,然后是EPA块(除了最后一个解码器)。接下来,我们将详细介绍EPA模块。
EPA
所提出的EPA块执行有效的全局关注,并有效地捕获丰富的空间通道特征表示。EPA模块包括空间注意力模块和通道注意力模块。
空间注意模块将自注意力的复杂性从二次降低到线性。另一方面,通道注意力模块有效地学习通道特征图之间的相互依赖关系。EPA块基于两个注意力模块之间的共享k, q方案,以相互交流,以便生成更好和更有效的特征表示。
如图2(右)所示,通过大小为H/4×W/4×D/2×C的patch嵌入x生成的特征图被直接馈送到连续的EPA块中,随后是三个编码器层。Q和K线性层的权重在两个注意力模块之间共享,每个注意力模块使用不同的V层。两个注意力模块计算如下:
其中,X^s和X^c分别表示空间和通道注意图。SA是空间注意力模块,CA是通道注意力模块。Qshared、Kshared、Vspace和Vchannel分别是共享q、共享k、空间值层和通道值层的矩阵。
**空间注意力:**在本模块中,通过将复杂度从O(n2)降低到O(np)来有效地学习空间信息,其中n是tokens的数量,p是投影向量的维数,其中p<<n。给定HWD×C形状的归一化张量X,使用三个线性层计算Qshared、Kshared和Vspatial投影,得出Qshared=WQX,Kshared=WKX,Vspatial=WVX,维度为HWD×C,其中WQ、WK和WV分别是Qshared、Kshared和Vspatial的投影权重。
然后,执行三个步骤。
首先,将Kshared和Vspace层从HWD×C投影到形状为p×C的低维矩阵中。
其次,通过将Qshared层乘以投影Kshared的转置来计算空间注意力图,然后使用softmax来测量每个特征与其余空间特征之间的相似度。
第三,将这些相似性乘以投影的Vspace层,以生成HWD×C形状的最终空间注意力图。空间注意力的定义如下:
**通道注意力:**该模块通过在通道值层和通道注意力图之间的通道维度中应用点积运算来捕获特征频道之间的相互依赖性。使用空间注意力模块的相同Qshared和Kshared,计算通道的值层,并使用线性层学习互补特征,得到Vchannel=WVX,维度为HW D×C,其中WV是Vchannel的投影权重。通道注意力定义如下:
最后,执行融合,并通过卷积块变换两个注意力模块的输出,以获得丰富的特征表示。EPA模块的最终输出Xû如下所示:
损失函数
损失函数基于常用soft dice损失和交叉熵损失的总和,以同时利用两个互补损失函数的优势。其定义为:
实验
Baseline Comparison
State-of-the-Art Comparison