最近在寫多卡的TensorFlow版I3D的代碼派草,其中遇到batch norm的坑搀缠,記錄一波铛楣。
I3D使用的是snt.BatchNorm近迁,當(dāng)is_training = True時(shí),意味著創(chuàng)建Update ops簸州,利用當(dāng)前batch的均值和方差去更新moving averages(即某層累計(jì)的平均均值和方差)鉴竭。這里提供兩種方式創(chuàng)建update_ops,
一是自己顯式的創(chuàng)建update_ops岸浑,手動(dòng)更新搏存。update_ops默認(rèn)放置在tf.GraphKeys.UPDATE_OPS中,因此這里在執(zhí)行train_ops的同時(shí)更新均值方差即可矢洲,對(duì)于單卡來說很容易理解璧眠,對(duì)于多卡來說,相當(dāng)于collection所有卡的batch的均值方差后統(tǒng)一更新读虏,也可以只collection第一塊卡的均值方差(理論上需要積累其他卡责静,但是由于這操作積累得很快,所以只取第一塊卡也不影響性能盖桥,在TensorFlow高階API的樣例代碼cifar10_main.py中如是說)灾螃。代碼如下:
????update_ops?=?tf.get_collection(tf.GraphKeys.UPDATE_OPS)
????with?tf.control_dependencies(update_ops):
??????train_op?=?optimizer.minimize(loss)
或者
update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))
? ? ? train_op = tf.group(train_op, update_ops)
二是自動(dòng)的更新,?只需在初始化前 bn = BatchNorm(update_ops_collection=None)即可揩徊。不過這種方式下腰鬼,會(huì)在完成更新前阻塞網(wǎng)絡(luò)的forward嵌赠,因此會(huì)帶來時(shí)間上的成本。具體而言熄赡,這時(shí)bn的參數(shù)mean,var是立即更新的姜挺,也是計(jì)算完當(dāng)前l(fā)ayer的mean,var就更新,然后進(jìn)行下一個(gè)layer的操作本谜。這在單卡下沒有問題的初家, 但是多卡情況下就會(huì)寫等讀的沖突,因?yàn)榭赡艽嬖贕PU0更新(寫)mean但此時(shí)GPU1還沒有計(jì)算到該層乌助,所以GPU0就要等GPU1讀完mean才能寫溜在。