論文地址:https://arxiv.org/pdf/1512.03385.pdf
1李剖、引言-深度網(wǎng)絡的退化問題
在深度神經(jīng)網(wǎng)絡訓練中尖飞,從經(jīng)驗來看琉挖,隨著網(wǎng)絡深度的增加,模型理論上可以取得更好的結(jié)果决摧。但是實驗卻發(fā)現(xiàn)亿蒸,深度神經(jīng)網(wǎng)絡中存在著退化問題(Degradation problem)≌谱可以看到边锁,在下圖中56層的網(wǎng)絡比20層網(wǎng)絡效果還要差。
上面的現(xiàn)象與過擬合不同波岛,過擬合的表現(xiàn)是訓練誤差小而測試誤差大茅坛,而上面的圖片顯示訓練誤差和測試誤差都是56層的網(wǎng)絡較大。
深度網(wǎng)絡的退化問題至少說明深度網(wǎng)絡不容易訓練则拷。我們假設(shè)這樣一種情況贡蓖,56層的網(wǎng)絡的前20層和20層網(wǎng)絡參數(shù)一模一樣,而后36層是一個恒等映射( identity mapping)煌茬,即輸入x輸出也是x斥铺,那么56層的網(wǎng)絡的效果也至少會和20層的網(wǎng)絡效果一樣,可是為什么出現(xiàn)了退化問題呢坛善?因此我們在訓練深層網(wǎng)絡時晾蜘,訓練方法肯定存在的一定的缺陷。
正是上面的這個有趣的假設(shè)浑吟,何凱明博士發(fā)明了殘差網(wǎng)絡ResNet來解決退化問題!讓我們來一探究竟耗溜!
2组力、ResNet網(wǎng)絡結(jié)構(gòu)
ResNet中最重要的是殘差學習單元:
對于一個堆積層結(jié)構(gòu)(幾層堆積而成)當輸入為x時其學習到的特征記為H(x),現(xiàn)在我們希望其可以學習到殘差F(x)=H(x)-x抖拴,這樣其實原始的學習特征是F(x)+x 燎字。當殘差為0時,此時堆積層僅僅做了恒等映射阿宅,至少網(wǎng)絡性能不會下降候衍,實際上殘差不會為0,這也會使得堆積層在輸入特征基礎(chǔ)上學習到新的特征洒放,從而擁有更好的性能蛉鹿。一個殘差單元的公式如下:
后面的x前面也需要經(jīng)過參數(shù)Ws變換,從而使得和前面部分的輸出形狀相同往湿,可以進行加法運算妖异。
在堆疊了多個殘差單元后惋戏,我們的ResNet網(wǎng)絡結(jié)構(gòu)如下圖所示:
3、ResNet代碼實戰(zhàn)
我們來實現(xiàn)一個mnist手寫數(shù)字識別的程序他膳。代碼中主要使用的是tensorflow.contrib.slim中定義的函數(shù)响逢,slim作為一種輕量級的tensorflow庫,使得模型的構(gòu)建棕孙,訓練舔亭,測試都變得更加簡單。卷積層蟀俊、池化層以及全聯(lián)接層都可以進行快速的定義钦铺,非常方便。這里為了方便使用欧漱,我們直接導入slim职抡。
import tensorflow.contrib.slim as slim
我們主要來看一下我們的網(wǎng)絡結(jié)構(gòu)。首先定義兩個殘差結(jié)構(gòu),第一個是輸入和輸出形狀一樣的殘差結(jié)構(gòu),一個是輸入和輸出形狀不一樣的殘差結(jié)構(gòu)转捕。
下面是輸入和輸出形狀相同的殘差塊祭衩,這里slim.conv2d函數(shù)的輸入有三個,分別是輸入數(shù)據(jù)缕坎、卷積核數(shù)量,卷積核的大小,默認的話padding為SAME郊丛,即卷積后形狀不變,由于輸入和輸出形狀相同瞧筛,因此我們可以在計算outputs時直接將兩部分相加厉熟。
def res_identity(input_tensor,conv_depth,kernel_shape,layer_name):
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape))
outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor)
return outputs
下面是輸入和輸出形狀不同的殘差塊,由于輸入和輸出形狀不同较幌,因此我們需要對輸入也進行一個卷積變化揍瑟,使二者形狀相同。ResNet作者建議可以用1*1的卷積層乍炉,stride=2绢片,來進行變換:
def res_change(input_tensor,conv_depth,kernel_shape,layer_name):
with tf.variable_scope(layer_name):
relu = tf.nn.relu(slim.conv2d(input_tensor,conv_depth,kernel_shape,stride=2))
input_tensor_reshape = slim.conv2d(input_tensor,conv_depth,[1,1],stride=2)
outputs = tf.nn.relu(slim.conv2d(relu,conv_depth,kernel_shape) + input_tensor_reshape)
return outputs
最后是整個網(wǎng)絡結(jié)構(gòu),對于x的輸入岛琼,我們先進行一次卷積和池化操作底循,然后接入四個殘差塊,最后接兩層全聯(lián)接層得到網(wǎng)絡的輸出槐瑞。
def inference(inputs):
x = tf.reshape(inputs,[-1,28,28,1])
conv_1 = tf.nn.relu(slim.conv2d(x,32,[3,3])) #28 * 28 * 32
pool_1 = slim.max_pool2d(conv_1,[2,2]) # 14 * 14 * 32
block_1 = res_identity(pool_1,32,[3,3],'layer_2')
block_2 = res_change(block_1,64,[3,3],'layer_3')
block_3 = res_identity(block_2,64,[3,3],'layer_4')
block_4 = res_change(block_3,32,[3,3],'layer_5')
net_flatten = slim.flatten(block_4,scope='flatten')
fc_1 = slim.fully_connected(slim.dropout(net_flatten,0.8),200,activation_fn=tf.nn.tanh,scope='fc_1')
output = slim.fully_connected(slim.dropout(fc_1,0.8),10,activation_fn=None,scope='output_layer')
return output
完整的代碼地址在:https://github.com/princewen/tensorflow_practice/tree/master/CV/ResNet
參考文獻:
1熙涤、論文:https://arxiv.org/pdf/1512.03385.pdf
2、https://blog.csdn.net/kaisa158/article/details/81096588?utm_source=blogxgwz4