在學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)中看到了一個(gè)平均滑動(dòng)模型赢笨,該方法可以使模型在測(cè)試數(shù)據(jù)上表現(xiàn)的更加健壯赚哗,而TensorFlow中提供的實(shí)現(xiàn)方法為
tf.train.ExponentialMovingAverage
展父,起初并不理解該方法為啥能使模型在測(cè)試數(shù)據(jù)上更健壯举农,多放查找資料之后幽告,記錄在此。
思想
在初始化ExponentialMovingAverage
時(shí)谬返,需要提供一個(gè)衰減率(decay)來(lái)空值模型跟新的書讀。ExponentialMovingAverage
會(huì)對(duì)TensorFlow中每一個(gè)變量會(huì)維護(hù)一個(gè)影子變量(shadow_variable)日杈,影子變量的初始值為變量的初始值遣铝,每次迭代時(shí),變量進(jìn)行更新之后莉擒,影子變量的值也會(huì)同步更新:
從上式中可以看到酿炸,decay決定模型更新的速度,decay越大涨冀,模型越穩(wěn)定填硕。在實(shí)際應(yīng)用中,decay一般是接近1的數(shù)(0.99鹿鳖,0.999等)扁眯。
當(dāng)decay設(shè)置較大時(shí),模型訓(xùn)練比較慢翅帜,為了使模型在前期能夠更新更快姻檀,ExponentialMovingAverage
還提供了num_updates參數(shù)來(lái)動(dòng)態(tài)設(shè)置decay大小。而此時(shí)的衰減率為:
在使用梯度下降算法進(jìn)行模型訓(xùn)練時(shí)涝滴,每次更新參數(shù)權(quán)重時(shí)绣版,該權(quán)重的影子變量也會(huì)隨著模型的訓(xùn)練而更新周荐,最終穩(wěn)定在一個(gè)接近真實(shí)權(quán)重值的附近。在測(cè)試集上使用影子變量替換原來(lái)的變量進(jìn)行預(yù)測(cè)時(shí)僵娃,可以得到一個(gè)更好的結(jié)果概作。
即,滑動(dòng)平均的使用步驟為:
- 訓(xùn)練階段:為每個(gè)可訓(xùn)練的權(quán)重維護(hù)影子變量默怨,并隨著迭代的進(jìn)行更新讯榕;
- 預(yù)測(cè)階段:使用影子變量替代真實(shí)變量值,進(jìn)行預(yù)測(cè)匙睹。
滑動(dòng)平均為什么在測(cè)試過(guò)程中被使用
訓(xùn)練中一直使用原來(lái)不帶滑動(dòng)的參數(shù)愚屁,可以得到新的參數(shù),如此就可以更新該參數(shù)的影子變量shadow_variable痕檬■保基于上面的式子可以看到,shadow_variable的更新比較平滑梦谜,對(duì)于隨機(jī)梯度下降算法而言丘跌,更平滑的更新效果較好。
代碼示例
import tensorflow as tf
v1 = tf.Variable(0, dtype=tf.float32)
step = tf.Variable(0, trainable=False)
ema = tf.train.ExponentialMovingAverage(0.99, step)
# 定義一個(gè)平滑平均偏亮的操作唁桩,每次執(zhí)行時(shí)闭树,會(huì)更新列表中的變量
maintain_averages_op = ema.apply([v1])
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
# 通過(guò)ema.average(v1)獲取滑動(dòng)平均之后變量的取值
print(sess.run([v1, ema.average(v1)]))
# 更新變量v1的值到5
sess.run(tf.assign(v1, 5))
# 更新v1的滑動(dòng)平均值。decay = min{0.99, (1+step)/(10+step)} = 0.1
# v1 的滑動(dòng)平均更新為 0.1 * 0 + 0.9 * 5 = 4.5
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
# 更新step為10000
sess.run(tf.assign(step, 10000))
# 更新v1的值為10
sess.run(tf.assign(v1, 10))
# 更新v1的滑動(dòng)平均值荒澡,decay = min{0.99, (1+step)/(10+step)} = 0.99
# v1的滑動(dòng)平均更新為 0.99 * 4.5 + 0.01 * 10 = 4.555
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))
# 再次更新滑動(dòng)平均值 0.99 * 4.555 + 0.001 * 10 = 4.60945
sess.run(maintain_averages_op)
print(sess.run([v1, ema.average(v1)]))