實(shí)戰(zhàn)tensorflow使用BN------一坑的淚

對于一個(gè)小白奕枝,從了解Batch Normalization(后面簡稱BN)到正確使用BN,可謂路漫漫兮逻锐。在此做一個(gè)記錄。

網(wǎng)上搜索關(guān)于BN最多的就是原理推導(dǎo)读第,相關(guān)論文出處。

例如:

http://blog.csdn.net/Fate_fjh/article/details/53375881

http://www.reibang.com/p/0312e04e4e83

但是這個(gè)并不能幫助我們實(shí)際的使用拥刻,對于需要迅速用起來的伙伴幫助不大怜瞒。我們工程師相信的是先用起來,再去研究原理般哼!呵呵吴汪!

有一些文章介紹的BN層的實(shí)現(xiàn),也有代碼示例蒸眠,但能順利跑起來的寥寥漾橙。因?yàn)槭褂肂N不像卷積層那樣,寫個(gè)層的實(shí)現(xiàn)就可以了楞卡。由于BN層會包含兩個(gè)可訓(xùn)練參數(shù)以及兩個(gè)不可訓(xùn)練參數(shù)霜运,所以涉及到在train代碼中如何保存的關(guān)鍵問題,以及在inference代碼中如何加載的問題蒋腮。有相關(guān)博客介紹到這一步了淘捡,很有幫助。

例如:

https://www.cnblogs.com/hrlnw/p/7227447.html

本以為別人都說這么明白了池摧,抄一抄不是很容易的事情嗎焦除。可以上的代碼是不能讓你正確完成BN功能的作彤。也不知是抄錯(cuò)了膘魄,還是別人漏掉了一些關(guān)鍵環(huán)節(jié)乌逐。總之你的moving_mean/moving_variance好像就是不太對瓣距∏粒基本上中文網(wǎng)頁很難在找到這個(gè)問題的解了。

現(xiàn)在你需要搜索的關(guān)鍵字可能要變成BN/參數(shù)保存/平均滑動等等了蹈丸。還好tensorflow的github中有了線索:

https://github.com/tensorflow/tensorflow/issues/14809

https://github.com/tensorflow/tensorflow/issues/15250

可見有很多人確實(shí)無法正確使用BN功能成黄,然而最有用的一個(gè)issues是:

https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584

在這里,我拼湊成了一個(gè)完整能用的BN功能代碼逻杖,解決了我好久的痛苦奋岁,讓我興奮一下。

知識來源于網(wǎng)絡(luò)荸百,奉獻(xiàn)給網(wǎng)絡(luò)闻伶。不敢獨(dú)享這一成果,再此分享給大家够话。

-----------------------------------------------------------------華麗的分割線----------------------------------------------------------------------------

整個(gè)BN功能的實(shí)現(xiàn)需要分三個(gè)部分:1.BN層實(shí)現(xiàn)蓝翰;2.訓(xùn)練時(shí)更新和完成后保存;3.預(yù)測時(shí)加載女嘲。

1.BN層實(shí)現(xiàn):

如果你接觸了一段時(shí)間后畜份,這里你至少應(yīng)該知道BN的三種實(shí)現(xiàn)方式了,但是我只成功了其中的一種欣尼,希望其他朋友能夠補(bǔ)充完善爆雹。

def bn_layer(x, scope, is_training, epsilon=0.001, decay=0.99, reuse=None):

? ? """

? ? Performs a batch normalization layer

? ? Args:

? ? ? ? x: input tensor

? ? ? ? scope: scope name

? ? ? ? is_training: python boolean value

? ? ? ? epsilon: the variance epsilon - a small float number to avoid dividing by 0

? ? ? ? decay: the moving average decay

? ? Returns:

? ? ? ? The ops of a batch normalization layer

? ? """

? ? with tf.variable_scope(scope, reuse=reuse):

? ? ? ? shape = x.get_shape().as_list()

? ? ? ? # gamma: a trainable scale factor

? ? ? ? gamma = tf.get_variable(scope+"_gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=True)

? ? ? ? # beta: a trainable shift value

? ? ? ? beta = tf.get_variable(scope+"_beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=True)

? ? ? ? moving_avg = tf.get_variable(scope+"_moving_mean", shape[-1], initializer=tf.constant_initializer(0.0), trainable=False)

? ? ? ? moving_var = tf.get_variable(scope+"_moving_variance", shape[-1], initializer=tf.constant_initializer(1.0), trainable=False)

? ? ? ? if is_training:

? ? ? ? ? ? # tf.nn.moments == Calculate the mean and the variance of the tensor x

? ? ? ? ? ? avg, var = tf.nn.moments(x, np.arange(len(shape)-1), keep_dims=True)

? ? ? ? ? ? avg=tf.reshape(avg, [avg.shape.as_list()[-1]])

? ? ? ? ? ? var=tf.reshape(var, [var.shape.as_list()[-1]])

? ? ? ? ? ? #update_moving_avg = moving_averages.assign_moving_average(moving_avg, avg, decay)

? ? ? ? ? ? update_moving_avg=tf.assign(moving_avg, moving_avg*decay+avg*(1-decay))

? ? ? ? ? ? #update_moving_var = moving_averages.assign_moving_average(moving_var, var, decay)

? ? ? ? ? ? update_moving_var=tf.assign(moving_var, moving_var*decay+var*(1-decay))

? ? ? ? ? ? control_inputs = [update_moving_avg, update_moving_var]

? ? ? ? else:

? ? ? ? ? ? avg = moving_avg

? ? ? ? ? ? var = moving_var

? ? ? ? ? ? control_inputs = []

? ? ? ? with tf.control_dependencies(control_inputs):

? ? ? ? ? ? output = tf.nn.batch_normalization(x, avg, var, offset=beta, scale=gamma, variance_epsilon=epsilon)

? ? return output

def bn_layer_top(x, scope, is_training, epsilon=0.001, decay=0.99):

? ? """

? ? Returns a batch normalization layer that automatically switch between train and test phases based on the

? ? tensor is_training

? ? Args:

? ? ? ? x: input tensor

? ? ? ? scope: scope name

? ? ? ? is_training: boolean tensor or variable

? ? ? ? epsilon: epsilon parameter - see batch_norm_layer

? ? ? ? decay: epsilon parameter - see batch_norm_layer

? ? Returns:

? ? ? ? The correct batch normalization layer based on the value of is_training

? ? """

? ? #assert isinstance(is_training, (ops.Tensor, variables.Variable)) and is_training.dtype == tf.bool

? ? return tf.cond(

? ? ? ? is_training,

? ? ? ? lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=True, reuse=None),

? ? ? ? lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=False, reuse=True),

? ? )

這里的參數(shù)epsilon=0.001, decay=0.99可以自行調(diào)整。


2.訓(xùn)練時(shí)更新和完成后保存:

在訓(xùn)練的代碼中增加如下代碼:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies(update_ops):

? ? train = tf.train.AdamOptimizer(learning_rate=lr).minimize(cost)

這個(gè)是用于更新參數(shù)的愕鼓。

var_list = tf.trainable_variables()

g_list = tf.global_variables()

bn_moving_vars = [gfor gin g_listif 'moving_mean' in g.name]

bn_moving_vars += [gfor gin g_listif 'moving_variance' in g.name]

var_list += bn_moving_vars

train_saver = tf.train.Saver(var_list=var_list)

這個(gè)是用于保存bn不可訓(xùn)練的參數(shù)钙态。

3.預(yù)測時(shí)加載:

# get moving avg

var_list = tf.trainable_variables()

g_list = tf.global_variables()

bn_moving_vars = [gfor gin g_listif 'moving_mean' in g.name]

bn_moving_vars += [gfor gin g_listif 'moving_variance' in g.name]

var_list += bn_moving_vars

saver = tf.train.Saver(var_list=var_list)

ckpt_path =""

saver.restore(sess, ckpt_path)

這樣就可以找到checkpoint中的參數(shù)了。



現(xiàn)在你可以開心的使用BN了菇晃!

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末册倒,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子磺送,更是在濱河造成了極大的恐慌剩失,老刑警劉巖,帶你破解...
    沈念sama閱讀 211,948評論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件册着,死亡現(xiàn)場離奇詭異拴孤,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)甲捏,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,371評論 3 385
  • 文/潘曉璐 我一進(jìn)店門演熟,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事芒粹⌒址模” “怎么了?”我有些...
    開封第一講書人閱讀 157,490評論 0 348
  • 文/不壞的土叔 我叫張陵化漆,是天一觀的道長估脆。 經(jīng)常有香客問我,道長座云,這世上最難降的妖魔是什么疙赠? 我笑而不...
    開封第一講書人閱讀 56,521評論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮朦拖,結(jié)果婚禮上圃阳,老公的妹妹穿的比我還像新娘。我一直安慰自己璧帝,他們只是感情好捍岳,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,627評論 6 386
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著睬隶,像睡著了一般锣夹。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上苏潜,一...
    開封第一講書人閱讀 49,842評論 1 290
  • 那天银萍,我揣著相機(jī)與錄音,去河邊找鬼窖贤。 笑死砖顷,一個(gè)胖子當(dāng)著我的面吹牛贰锁,可吹牛的內(nèi)容都是我干的赃梧。 我是一名探鬼主播,決...
    沈念sama閱讀 38,997評論 3 408
  • 文/蒼蘭香墨 我猛地睜開眼豌熄,長吁一口氣:“原來是場噩夢啊……” “哼授嘀!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起锣险,我...
    開封第一講書人閱讀 37,741評論 0 268
  • 序言:老撾萬榮一對情侶失蹤蹄皱,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后芯肤,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體巷折,經(jīng)...
    沈念sama閱讀 44,203評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,534評論 2 327
  • 正文 我和宋清朗相戀三年崖咨,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了锻拘。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,673評論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖署拟,靈堂內(nèi)的尸體忽然破棺而出婉宰,到底是詐尸還是另有隱情,我是刑警寧澤推穷,帶...
    沈念sama閱讀 34,339評論 4 330
  • 正文 年R本政府宣布心包,位于F島的核電站,受9級特大地震影響馒铃,放射性物質(zhì)發(fā)生泄漏蟹腾。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,955評論 3 313
  • 文/蒙蒙 一骗露、第九天 我趴在偏房一處隱蔽的房頂上張望岭佳。 院中可真熱鬧,春花似錦萧锉、人聲如沸珊随。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,770評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽叶洞。三九已至,卻和暖如春禀崖,著一層夾襖步出監(jiān)牢的瞬間衩辟,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,000評論 1 266
  • 我被黑心中介騙來泰國打工波附, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留艺晴,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,394評論 2 360
  • 正文 我出身青樓掸屡,卻偏偏與公主長得像封寞,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個(gè)殘疾皇子仅财,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,562評論 2 349

推薦閱讀更多精彩內(nèi)容