项目来自:TransUNet
论文:TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation
结合UNet和Transformer,具体模块包含ViT+ResNet50+skip connection
在编码器中使用hybrid CNN-Transformer
解码过程使用Cascaded Upsampler
既能提取CNN的局部细节信息,又能提取到Transformer的全局上下文信息。
原始输入维度是H×W×C,期望输出能够划分出每一个像素值的类别,实现分割。在UNet结构上增加了self-attention机制,通过Transformer来实现。
1.1 encoder中的CNN
ResnetV2网络提取特征,对图片做下采样处理。其中root模块卷积操作的步长为2;
block1堆叠多个bottleneck结构的卷积操作,特征层大小不变;
block2堆叠多个bottleneck结构的卷积操作,步长为2,特征层宽高减半;
block3堆叠多个bottleneck结构的卷积操作,步长为2,特征层宽高减半;
经过CNN下采样之后的特征层经过1×1卷积调整通道数
1.2 embedding
对于输入到Transformer中的序列必须是一维的,需要对输入的图片做变换处理。
-
首先将H×W原始图片划分成P×P大小的patch,patch数量为N。
$N=\frac{H×W}{p^2}$ -
patch embedding:将patch向量映射到D维空间,并增加位置编码
$Z=[X_p^1E,X_p^2E,...X_p^NE]+E_{pos}$ $E$ 表示将patch映射到D维线性空间的变换矩阵,是可训练的参数,$E_{pos}$ 表示位置编码。 线性空间变换之后矩阵维度是N×D,与位置编码相加得到向量维度是N×D。 -
Transformer结构可以由以下公式表示:
$z'l=MSA(LN(z{l-1}))+z_{l-1}$ ,
$z_l=MLP(LN(z'_l))+z'_l$ ,首先经过Layer Normalization,经过Multi-head Self Attention,再加上残差结构。
结果再经过Layer Normalization,经过MLP全连接,再加上残差结构
实现Transformer。 -
经过Transformer结构的输出向量维度是N×D,需要经过一步变换,使得上采样过程能够还原到原始的图像分辨率。对应网络结构图的这个位置:
patch数量
Trasnformer输出结果(N,D)reshape成(D,H/p,W/p),再经过1×1卷积调整通道数得到(512,H/p,W/p)。
上采样过程使用双线性插值方法,使特征层宽高变为二倍,
3×3卷积,四次上采样输入的通道数为in_channels + skip_channels,将当前特征层大小与输入相同尺度的特征层在通道维度融合,经过两次3×3卷积操作,输出通道数为(256, 128, 64, 16)。
代码实现对应:
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x最后将特征层通道数调整为分类数量:
self.segmentation_head = SegmentationHead(
in_channels=config['decoder_channels'][-1],
out_channels=config['n_classes'],
kernel_size=3,
)