U-Net最早用作生物圖像的分割,后來(lái)在目標(biāo)檢測(cè)、圖像轉(zhuǎn)換昔搂,以及Tone Mapping ,Reverse Tone Mapping很多地方都有應(yīng)用输拇。它的一個(gè)特點(diǎn)是早期的卷積層結(jié)果和最后幾層的結(jié)果采用級(jí)聯(lián)的形式作為新的神經(jīng)網(wǎng)絡(luò)層摘符。我覺(jué)得它的過(guò)程很類(lèi)似圖像金字塔和圖像重建的過(guò)程,前面的下采樣策吠,提取出信息逛裤,后面進(jìn)行重建,區(qū)別在于這里不是像拉普拉斯金字塔重建那樣將圖像復(fù)原猴抹,這里則是生成了具有“新特性”的圖像带族。之所以叫它U-Net,是因?yàn)樗雌饋?lái)像個(gè)U形蟀给,如果考慮中間層次的級(jí)聯(lián)蝙砌,更像一把琴。
U-Net
從圖中可以看出跋理,U-Net主要由Conv+ReLu择克,maxpool,up-conv前普,conv 1x1 幾個(gè)部分構(gòu)成肚邢,那么我們首先在tensorflow里面將這幾個(gè)部分函數(shù)化。
- conv+ReLU
def conv_relu_layer(net,numfilters,name):
network = tf.layers.conv2d(net,
activation=tf.nn.relu,
filters= numfilters,
kernel_size=(3,3),
padding='Valid',
name= "{}_conv_relu".format(name))
return network
- maxpool
def maxpool(net,name):
network = tf.layers.max_pooling2d(net,
pool_size= (2,2),
strides = (2,2),
padding = 'valid',
name = "{}_maxpool".format(name))
return network
- up_conv
def up_conv(net,numfilters,name):
network = tf.layers.conv2d_transpose(net,
filters = numfilters,
kernel_size= (2,2),
strides= (2,2),
padding= 'valid',
activation= tf.nn.relu,
name = "{}_up_conv".format(name))
return network
- copy_crop
def copy_crop(skip_connect,net):
skip_connect_shape = skip_connect.get_shape()
net_shape = net.get_shape()
print(net_shape[1])
size = [-1,net_shape[1].value,net_shape[2].value,-1]
skip_connect_crop = tf.slice(skip_connect,[0,0,0,0],size)
concat = tf.concat([skip_connect_crop,net],axis=3)
return concat
- conv1x1
def conv1x1(net,numfilters,name):
return tf.layers.conv2d(net,filters=numfilters,strides=(1,1),kernel_size=(1,1),name = "{}_conv1x1".format(name),padding='SAME')
#define input data
input = tf.placeholder(dtype=tf.float32,shape = (64,572,572,3))
#define downsample path
network = conv_relu_layer(input,numfilters=64,name='lev1_layer1')
skip_con1 = conv_relu_layer(network,numfilters=64,name='lev1_layer2')
network = maxpool(skip_con1,'lev2_layer1')
network = conv_relu_layer(network,128,'lev2_layer2')
skip_con2 = conv_relu_layer(network,128,'lev2_layer3')
network = maxpool(skip_con2,'lev3_layer1')
network = conv_relu_layer(network,256,'lev3_layer1')
skip_con3 = conv_relu_layer(network,256,'lev3_layer2')
network = maxpool(skip_con3,'lev4_layer1')
network = conv_relu_layer(network,512,'lev4_layer2')
skip_con4 = conv_relu_layer(network,512,'lev4_layer3')
network = maxpool(skip_con4,'lev5_layer1')
network = conv_relu_layer(network,1024,'lev5_layer2')
network = conv_relu_layer(network,1024,'lev5_layer3')
#define upsample path
network = up_conv(network,512,'lev6_layer1')
network = copy_crop(skip_con4,network)
network = conv_relu_layer(network,numfilters=512,name='lev6_layer2')
network = conv_relu_layer(network,numfilters=512,name='lev6_layer3')
network = up_conv(network,256,name='lev7_layer1')
network = copy_crop(skip_con3,network)
network = conv_relu_layer(network,256,name='lev7_layer2')
network = conv_relu_layer(network,256,'lev7_layer3')
network = up_conv(network,128,name='lev8_layer1')
network = copy_crop(skip_con2,network)
network = conv_relu_layer(network,128,name='lev8_layer2')
network = conv_relu_layer(network,128,'lev8_layer3')
network = up_conv(network,64,name='lev9_layer1')
network = copy_crop(skip_con1,network)
network = conv_relu_layer(network,64,name='lev9_layer2')
network = conv_relu_layer(network,64,name='lev9_layer3')
network = conv1x1(network,2,name='lev9_layer4')
利用tensorboard可以得到如下的網(wǎng)絡(luò)架構(gòu)圖拭卿。
U-net