U-Net: Convolutional Networks for Biomedical Image Segmentation - 2015
UNet 最早被提出应用到医学图像分析中. 由于其网络的简单易懂,已被广泛应用于很多语义分割场景,如:Github 项目 - Kaggle 车辆边界识别之 UNet.
U-Net 网络结构
Pytorch 网络定义
Github 项目 - Kaggle 车辆边界识别之 UNet 给出了一种 U-Net 的网络结构定义与使用.
这里参考另一种定义 - pytorch-semseg/ptsemseg/models/unet.py,其支持可设定 deconvolution 和 batchnorm.
import torch.nn as nn
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm):
super(unetConv2, self).__init__()
if is_batchnorm:
self.conv1 = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 0),
nn.BatchNorm2d(out_size),
nn.ReLU(), )
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 0),
nn.BatchNorm2d(out_size),
nn.ReLU(), )
else:
self.conv1 = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, 1, 0),
nn.ReLU() )
self.conv2 = nn.Sequential(
nn.Conv2d(out_size, out_size, 3, 1, 0),
nn.ReLU() )
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
return outputs
class unetUp(nn.Module):
def __init__(self, in_size, out_size, is_deconv):
super(unetUp, self).__init__()
self.conv = unetConv2(in_size, out_size, False)
if is_deconv:
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
else:
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, inputs1, inputs2):
outputs2 = self.up(inputs2)
offset = outputs2.size()[2] - inputs1.size()[2]
padding = 2 * [offset // 2, offset // 2]
outputs1 = F.pad(inputs1, padding)
return self.conv(torch.cat([outputs1, outputs2], 1))
class unet(nn.Module):
def __init__(self,
feature_scale=1,
n_classes=2,
is_deconv=True,
in_channels=3,
is_batchnorm=True,
):
super(unet, self).__init__()
self.is_deconv = is_deconv
self.in_channels = in_channels
self.is_batchnorm = is_batchnorm
self.feature_scale = feature_scale
filters = [64, 128, 256, 512, 1024]
# downsampling
self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)
# upsampling
self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)
# final conv (without any concat)
self.final = nn.Conv2d(filters[0], n_classes, 1)
def forward(self, inputs):
conv1 = self.conv1(inputs)
maxpool1 = self.maxpool1(conv1)
conv2 = self.conv2(maxpool1)
maxpool2 = self.maxpool2(conv2)
conv3 = self.conv3(maxpool2)
maxpool3 = self.maxpool3(conv3)
conv4 = self.conv4(maxpool3)
maxpool4 = self.maxpool4(conv4)
center = self.center(maxpool4)
up4 = self.up_concat4(conv4, center)
up3 = self.up_concat3(conv3, up4)
up2 = self.up_concat2(conv2, up3)
up1 = self.up_concat1(conv1, up2)
final = self.final(up1)
return final
采用 torchsummary
库可以打印定义的 U-Net 的各层输出.
import torch
import torch.nn.functional as F
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = unet().to(device)
summary(model, (3, 572, 572))
输出如下:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 1,792
BatchNorm2d-2 [-1, 64, 570, 570] 128
ReLU-3 [-1, 64, 570, 570] 0
Conv2d-4 [-1, 64, 568, 568] 36,928
BatchNorm2d-5 [-1, 64, 568, 568] 128
ReLU-6 [-1, 64, 568, 568] 0
unetConv2-7 [-1, 64, 568, 568] 0
MaxPool2d-8 [-1, 64, 284, 284] 0
Conv2d-9 [-1, 128, 282, 282] 73,856
BatchNorm2d-10 [-1, 128, 282, 282] 256
ReLU-11 [-1, 128, 282, 282] 0
Conv2d-12 [-1, 128, 280, 280] 147,584
BatchNorm2d-13 [-1, 128, 280, 280] 256
ReLU-14 [-1, 128, 280, 280] 0
unetConv2-15 [-1, 128, 280, 280] 0
MaxPool2d-16 [-1, 128, 140, 140] 0
Conv2d-17 [-1, 256, 138, 138] 295,168
BatchNorm2d-18 [-1, 256, 138, 138] 512
ReLU-19 [-1, 256, 138, 138] 0
Conv2d-20 [-1, 256, 136, 136] 590,080
BatchNorm2d-21 [-1, 256, 136, 136] 512
ReLU-22 [-1, 256, 136, 136] 0
unetConv2-23 [-1, 256, 136, 136] 0
MaxPool2d-24 [-1, 256, 68, 68] 0
Conv2d-25 [-1, 512, 66, 66] 1,180,160
BatchNorm2d-26 [-1, 512, 66, 66] 1,024
ReLU-27 [-1, 512, 66, 66] 0
Conv2d-28 [-1, 512, 64, 64] 2,359,808
BatchNorm2d-29 [-1, 512, 64, 64] 1,024
ReLU-30 [-1, 512, 64, 64] 0
unetConv2-31 [-1, 512, 64, 64] 0
MaxPool2d-32 [-1, 512, 32, 32] 0
Conv2d-33 [-1, 1024, 30, 30] 4,719,616
BatchNorm2d-34 [-1, 1024, 30, 30] 2,048
ReLU-35 [-1, 1024, 30, 30] 0
Conv2d-36 [-1, 1024, 28, 28] 9,438,208
BatchNorm2d-37 [-1, 1024, 28, 28] 2,048
ReLU-38 [-1, 1024, 28, 28] 0
unetConv2-39 [-1, 1024, 28, 28] 0
ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664
Conv2d-41 [-1, 512, 54, 54] 4,719,104
ReLU-42 [-1, 512, 54, 54] 0
Conv2d-43 [-1, 512, 52, 52] 2,359,808
ReLU-44 [-1, 512, 52, 52] 0
unetConv2-45 [-1, 512, 52, 52] 0
unetUp-46 [-1, 512, 52, 52] 0
ConvTranspose2d-47 [-1, 256, 104, 104] 524,544
Conv2d-48 [-1, 256, 102, 102] 1,179,904
ReLU-49 [-1, 256, 102, 102] 0
Conv2d-50 [-1, 256, 100, 100] 590,080
ReLU-51 [-1, 256, 100, 100] 0
unetConv2-52 [-1, 256, 100, 100] 0
unetUp-53 [-1, 256, 100, 100] 0
ConvTranspose2d-54 [-1, 128, 200, 200] 131,200
Conv2d-55 [-1, 128, 198, 198] 295,040
ReLU-56 [-1, 128, 198, 198] 0
Conv2d-57 [-1, 128, 196, 196] 147,584
ReLU-58 [-1, 128, 196, 196] 0
unetConv2-59 [-1, 128, 196, 196] 0
unetUp-60 [-1, 128, 196, 196] 0
ConvTranspose2d-61 [-1, 64, 392, 392] 32,832
Conv2d-62 [-1, 64, 390, 390] 73,792
ReLU-63 [-1, 64, 390, 390] 0
Conv2d-64 [-1, 64, 388, 388] 36,928
ReLU-65 [-1, 64, 388, 388] 0
unetConv2-66 [-1, 64, 388, 388] 0
unetUp-67 [-1, 64, 388, 388] 0
Conv2d-68 [-1, 2, 388, 388] 130
================================================================
Total params: 31,039,746
Trainable params: 31,039,746
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.74
Forward/backward pass size (MB): 3136.33
Params size (MB): 118.41
Estimated Total Size (MB): 3258.48
----------------------------------------------------------------