對于一個(gè)小白奕枝,從了解Batch Normalization(后面簡稱BN)到正確使用BN,可謂路漫漫兮逻锐。在此做一個(gè)記錄。
網(wǎng)上搜索關(guān)于BN最多的就是原理推導(dǎo)读第,相關(guān)論文出處。
例如:
http://blog.csdn.net/Fate_fjh/article/details/53375881
http://www.reibang.com/p/0312e04e4e83
但是這個(gè)并不能幫助我們實(shí)際的使用拥刻,對于需要迅速用起來的伙伴幫助不大怜瞒。我們工程師相信的是先用起來,再去研究原理般哼!呵呵吴汪!
有一些文章介紹的BN層的實(shí)現(xiàn),也有代碼示例蒸眠,但能順利跑起來的寥寥漾橙。因?yàn)槭褂肂N不像卷積層那樣,寫個(gè)層的實(shí)現(xiàn)就可以了楞卡。由于BN層會包含兩個(gè)可訓(xùn)練參數(shù)以及兩個(gè)不可訓(xùn)練參數(shù)霜运,所以涉及到在train代碼中如何保存的關(guān)鍵問題,以及在inference代碼中如何加載的問題蒋腮。有相關(guān)博客介紹到這一步了淘捡,很有幫助。
例如:
https://www.cnblogs.com/hrlnw/p/7227447.html
本以為別人都說這么明白了池摧,抄一抄不是很容易的事情嗎焦除。可以上的代碼是不能讓你正確完成BN功能的作彤。也不知是抄錯(cuò)了膘魄,還是別人漏掉了一些關(guān)鍵環(huán)節(jié)乌逐。總之你的moving_mean/moving_variance好像就是不太對瓣距∏粒基本上中文網(wǎng)頁很難在找到這個(gè)問題的解了。
現(xiàn)在你需要搜索的關(guān)鍵字可能要變成BN/參數(shù)保存/平均滑動等等了蹈丸。還好tensorflow的github中有了線索:
https://github.com/tensorflow/tensorflow/issues/14809
https://github.com/tensorflow/tensorflow/issues/15250
可見有很多人確實(shí)無法正確使用BN功能成黄,然而最有用的一個(gè)issues是:
https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584
在這里,我拼湊成了一個(gè)完整能用的BN功能代碼逻杖,解決了我好久的痛苦奋岁,讓我興奮一下。
知識來源于網(wǎng)絡(luò)荸百,奉獻(xiàn)給網(wǎng)絡(luò)闻伶。不敢獨(dú)享這一成果,再此分享給大家够话。
-----------------------------------------------------------------華麗的分割線----------------------------------------------------------------------------
整個(gè)BN功能的實(shí)現(xiàn)需要分三個(gè)部分:1.BN層實(shí)現(xiàn)蓝翰;2.訓(xùn)練時(shí)更新和完成后保存;3.預(yù)測時(shí)加載女嘲。
1.BN層實(shí)現(xiàn):
如果你接觸了一段時(shí)間后畜份,這里你至少應(yīng)該知道BN的三種實(shí)現(xiàn)方式了,但是我只成功了其中的一種欣尼,希望其他朋友能夠補(bǔ)充完善爆雹。
def bn_layer(x, scope, is_training, epsilon=0.001, decay=0.99, reuse=None):
? ? """
? ? Performs a batch normalization layer
? ? Args:
? ? ? ? x: input tensor
? ? ? ? scope: scope name
? ? ? ? is_training: python boolean value
? ? ? ? epsilon: the variance epsilon - a small float number to avoid dividing by 0
? ? ? ? decay: the moving average decay
? ? Returns:
? ? ? ? The ops of a batch normalization layer
? ? """
? ? with tf.variable_scope(scope, reuse=reuse):
? ? ? ? shape = x.get_shape().as_list()
? ? ? ? # gamma: a trainable scale factor
? ? ? ? gamma = tf.get_variable(scope+"_gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=True)
? ? ? ? # beta: a trainable shift value
? ? ? ? beta = tf.get_variable(scope+"_beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=True)
? ? ? ? moving_avg = tf.get_variable(scope+"_moving_mean", shape[-1], initializer=tf.constant_initializer(0.0), trainable=False)
? ? ? ? moving_var = tf.get_variable(scope+"_moving_variance", shape[-1], initializer=tf.constant_initializer(1.0), trainable=False)
? ? ? ? if is_training:
? ? ? ? ? ? # tf.nn.moments == Calculate the mean and the variance of the tensor x
? ? ? ? ? ? avg, var = tf.nn.moments(x, np.arange(len(shape)-1), keep_dims=True)
? ? ? ? ? ? avg=tf.reshape(avg, [avg.shape.as_list()[-1]])
? ? ? ? ? ? var=tf.reshape(var, [var.shape.as_list()[-1]])
? ? ? ? ? ? #update_moving_avg = moving_averages.assign_moving_average(moving_avg, avg, decay)
? ? ? ? ? ? update_moving_avg=tf.assign(moving_avg, moving_avg*decay+avg*(1-decay))
? ? ? ? ? ? #update_moving_var = moving_averages.assign_moving_average(moving_var, var, decay)
? ? ? ? ? ? update_moving_var=tf.assign(moving_var, moving_var*decay+var*(1-decay))
? ? ? ? ? ? control_inputs = [update_moving_avg, update_moving_var]
? ? ? ? else:
? ? ? ? ? ? avg = moving_avg
? ? ? ? ? ? var = moving_var
? ? ? ? ? ? control_inputs = []
? ? ? ? with tf.control_dependencies(control_inputs):
? ? ? ? ? ? output = tf.nn.batch_normalization(x, avg, var, offset=beta, scale=gamma, variance_epsilon=epsilon)
? ? return output
def bn_layer_top(x, scope, is_training, epsilon=0.001, decay=0.99):
? ? """
? ? Returns a batch normalization layer that automatically switch between train and test phases based on the
? ? tensor is_training
? ? Args:
? ? ? ? x: input tensor
? ? ? ? scope: scope name
? ? ? ? is_training: boolean tensor or variable
? ? ? ? epsilon: epsilon parameter - see batch_norm_layer
? ? ? ? decay: epsilon parameter - see batch_norm_layer
? ? Returns:
? ? ? ? The correct batch normalization layer based on the value of is_training
? ? """
? ? #assert isinstance(is_training, (ops.Tensor, variables.Variable)) and is_training.dtype == tf.bool
? ? return tf.cond(
? ? ? ? is_training,
? ? ? ? lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=True, reuse=None),
? ? ? ? lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=False, reuse=True),
? ? )
這里的參數(shù)epsilon=0.001, decay=0.99可以自行調(diào)整。
2.訓(xùn)練時(shí)更新和完成后保存:
在訓(xùn)練的代碼中增加如下代碼:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
? ? train = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)
這個(gè)是用于更新參數(shù)的愕鼓。
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [gfor gin g_listif 'moving_mean' in g.name]
bn_moving_vars += [gfor gin g_listif 'moving_variance' in g.name]
var_list += bn_moving_vars
train_saver = tf.train.Saver(var_list=var_list)
這個(gè)是用于保存bn不可訓(xùn)練的參數(shù)钙态。
3.預(yù)測時(shí)加載:
# get moving avg
var_list = tf.trainable_variables()
g_list = tf.global_variables()
bn_moving_vars = [gfor gin g_listif 'moving_mean' in g.name]
bn_moving_vars += [gfor gin g_listif 'moving_variance' in g.name]
var_list += bn_moving_vars
saver = tf.train.Saver(var_list=var_list)
ckpt_path =""
saver.restore(sess, ckpt_path)
這樣就可以找到checkpoint中的參數(shù)了。
現(xiàn)在你可以開心的使用BN了菇晃!