Batch Normalization的好處我就不多說了昂秃,詳細可看論文爆惧,其實老早之前就看過論文了屑那,但無奈拖延癥(加上使用Keras)哈踱,所以對BN的代碼具體實現(xiàn)(train和test階段)不是很懂,所以在此記個筆記~~~~~~~~~~
簡要說下:訓練完成后的均值方差還只是最后一個batch的均值方差慢显,所以測試的時候我們用訓練時所有批次均值方差的滑動平均來作為測試的均值方差爪模,區(qū)別就這些,實際操作還是看下面舉例吧
參考代碼:
https://github.com/soloice/mnist-bn(作者用的是TF-Slim)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)
###########################################################
# The list of values in the collection with the given name, or an empty list if
# no value has been added to that collection. The list contains the values in
# the order under which they were collected.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# 關(guān)鍵操作
# 我們要一次進行多個操作(訓練節(jié)點和參數(shù)滑動平均更新)荚藻,只不過滑動平均已經(jīng)封
# 裝好了屋灌,一次進行多個操作就會用到tf.control_dependencies和tf.group兩種機制
# 來產(chǎn)生操作依賴關(guān)系(詳情見我的另一篇筆記)
# 我試了以下三種形式,都可以使用
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], cross_entropy)
if update_ops:
updates = tf.group(*update_ops)
cross_entropy = control_flow_ops.with_dependencies([updates], train_step)
with tf.control_dependencies([tf.group(*update_ops)]):
train_step = slim.learning.create_train_op(cross_entropy, optimizer, global_step=step)
tf.nn.batch_normalization()的用法应狱,這個api是封裝級別比較低的一個
def bacthnorm(inputs, scope, epsilon=1e-05, momentum=0.99, is_training=True):
inputs_shape = inputs.get_shape().as_list()
params_shape = inputs_shape[-1:]
axis = list(range(len(inputs_shape) - 1))
with tf.variable_scope(scope):
beta = create_bn_var("beta", params_shape,
initializer=tf.zeros_initializer())
gamma = create_bn_var("gamma", params_shape,
initializer=tf.ones_initializer())
# for inference
moving_mean = create_bn_var("moving_mean", params_shape,
initializer=tf.zeros_initializer(), trainable=False)
moving_variance = create_bn_var("moving_variance", params_shape,
initializer=tf.ones_initializer(), trainable=False)
if is_training:
mean, variance = tf.nn.moments(inputs, axes=axis)
update_move_mean = moving_averages.assign_moving_average(moving_mean,
mean, decay=momentum)
update_move_variance = moving_averages.assign_moving_average(moving_variance,
variance, decay=momentum)
tf.add_to_collection("_update_ops_", update_move_mean)
tf.add_to_collection("_update_ops_", update_move_variance)
else:
mean, variance = moving_mean, moving_variance
return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
我已實驗過共郭,也是基于mnist的,傳送門
tf.layers.batch_normalization()也是一個封裝級別比較高的API
# 舉例,來自官網(wǎng)
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
一般來說除嘹,這三個就夠用了~~~~~~~~~~~