Life sucks, but you're gonna love it.

0%

论文阅读 |U-Net: Convolutional Networks for Biomedical Image Segmentation

摘要

本文主要介绍了一种能够最大程度利用数据增强进行训练的模型,并且在数据分割任务中表现出色。这个网络结构包含了一个降维采样(contracting path)的过程来提取图片中的内容信息,以及一个对称的上采样(expanding path)可以准确提供分割的位置。文章展示了该网络可以通过比较少的图片进行训练,然后得到比之前的方法更好的分割数据。

简介

这部分就是在说近两年,神经网络在计算机视觉的任务解决上发挥了很大的作用,包括图片的分类定位和分割。然后提出了在2012年由 Ciresan提出的对于图像中target进行识别和定位的算法。该算法是通过训练一个神经网络,通过在图片上滑动的窗口进行像素级别的分类。这个算法有两个弱点,首先,这个过程会非常的慢,因为神经网络需要应用在每个滑动窗口的patch上,其次,在定位准确性和图片内容使用上有一个trade-off。更大的patch需要更多的池化层,而池化层或降低localization的准确度。但是如果输入的patch较小,又会使得摄取的图片内容过少。

这篇文章中,作者设计了一个更为精妙的模型,也就是全卷积神经网络(fully convolutional network)。作者稍微调整了这个结构,使得该网络结构用少量的训练图片可以得到更加准确的分割。主要思路是利用上采样层取代池化层来弥补之前的压缩过程(usual contract)。因此,这些层可以增加输出的的分辨率。为了定位准确,压缩路径中的高分辨率的特征和上采样的结果相结合。接下来的卷积层(a successive convolution layer)可以基于此得到一个更准确的结果。

img

网络结构

网络结构如上图所示,包含左侧的下采样路径(contracting path)和右侧的上采样(expanding path)路径。

  • Contracting path:

    传统的卷积网络结构,包含两个 $3 \times 3$ 的卷积(unpadded),连接一个ReLU,之后是一个池化层。

    在每个下采样过程中,都会将feature channel的数量加倍。

    展示一个downsampling 的单元

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    class down(nn.Module):
    def __init__(self, in_ch, out_ch):
    super(down, self).__init__()
    self.mpconv = nn.Sequential(
    nn.Conv2d(in_ch, out_ch, 3),
    nn.BatchNorm2d(out_ch),
    nn.ReLU(inplace=True),
    nn.Conv2d(out_ch, out_ch, 3),
    nn.BatchNorm2d(out_ch),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(2),
    )

    def forward(self, x):
    x = self.mpconv(x)
    return x
  • Expanding path: 这部分有些些复杂。主要分为两个部分,一部分是从下采样过程中直接引入feature map;另一部分是上采样过程中的feature map。

    在上采样的过程中,feature map会经过一个 $2 \times 2$ 的卷积,用来拓展feature map的大小,收缩channel的数量使其减半。然后直接和左侧下采样过程中的相同feature channel数的层相拼接(也就是U型中间横线的指向)。 然后拼接的网络通过两次 $3 \times 3$ 卷积(unpadded)+ReLU的结构。

    展示一个上采样单元:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
    super(up, self).__init__()

    if bilinear:
    self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    else:
    self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

    self.conv = nn.Sequence(
    nn.Conv2d(in_ch, out_ch, 3),
    nn.BatchNorm2d(out_ch),
    nn.ReLU(inplace=True),
    nn.Conv2d(out_ch, out_ch, 3),
    nn.BatchNorm2d(out_ch),
    nn.ReLU(inplace=True))

    def forward(self, x1, x2):
    x1 = self.up(x1)

    # input is CHW
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]

    x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
    diffY // 2, diffY - diffY//2))

    x = torch.cat([x2, x1], dim=1)
    x = self.conv(x)
    return x
  • 最后一层的 $1 \times 1$ 卷积层是用来将64层channel变为所需的类的个数。

网络总共包含23个卷积层。

为了和segmentation map由无缝连接,我们需要选取合适输入的图片的尺寸,让每个 $2 \times 2$的卷积层在上面作用时都得到偶数的尺寸。【这点还挺重要的,在实际应用中,我有接触用不合适的尺寸来做下采样,导致在上采样过程中很难和之前的结构concat。当两部分不能直接concat的时候,可以尝试去掉一列。】

上采样 upsampling

上采样的过程其实就是拓展你的feature map的过程。我们一般在做分类任务的卷积过程时,得到的feature map可能它的channel数量越来越多,但是 它的 h,w值是越来越小的,直到最后一层再进行分类的时候会通过全连接神经网络变成一个1024或者更小的向量。这个数值是远小于一开始输入的图片大小的。

但是在端到端分割网络结构中,我们的输出是一个和输入图片大小一样的图片,并且在像素级地分辨前景和背景。所以就存在一个上采样的过程,使原本减小了的feature map再逐渐变大。所以就存在了一个上采样的过程。一般扩大feature map的操作分为三种: 反卷积(deconvolution/ convolution transpose),上池化(unpooling),上采样(unsampling)。【题目中提到的unsampling是一个泛泛的扩大feature map的概念】

  • 反卷积 deconvolution

    如果在你知道了它是如何操作之后,这里其实叫转置卷积更为合适。

  • 上池化 unpooling

    这里就是将池化的过程反过来做。

  • 上采样 unsmapling

    一般指的是用传统的方法进行插值采样。

上池化和上采样不需要训练参数,但是反卷积因为有卷积的存在,所以还是存在训练参数的。

训练过程

  • 最终的feature map上的每一个pixel都会经过soft-max层来判断它究竟属于哪一类。

  • 提到了一个pre-computed weight map,有助于分辨相邻的有边界的范围。也就是在训练的时候将边界处的loss weight加大,强迫机器学习边界处的内容。具体公式如下:

    这里的 $w_c({\bf x})$ 是balance class frequencies的map,$d_1$ 和 $d_2$ 分别表示距离第一近和第二近细胞的边界的距离。$w_0$ 和 $\sigma$ 都是人为设置的。可以看出来像素距离边界越近得到的权重越高。

    最终的cross entropy loss在计算的过程中会考虑权重

  • weight matrix一开始的初始值是以 $\sqrt{2/N}$ 为deviation的高斯分布取值的,N的值为卷积的总个数。

  • 因为数据本身不太够,所以需要进行数据增广。

试验结果

总之就是在三种任务上测试并且表现都很好。