1壳贪、引言
今天要介紹的是如何通過(guò)反向傳播算法(backpropagation )和梯度下降算法(gradient decent )調(diào)整神經(jīng)網(wǎng)絡(luò)中參數(shù)的取值。梯度下降算法主要用于優(yōu)化單個(gè)參數(shù)的取值蜓斧,而反向傳播算法給出了一個(gè)高效的方式在所有參數(shù)上使用梯度下降算法,從而使神經(jīng)網(wǎng)絡(luò)模型在訓(xùn)練數(shù)據(jù)上的損失函數(shù)盡可能小他膳。
反向傳播算法是訓(xùn)練神經(jīng)網(wǎng)絡(luò)的核心算法惦银,它可以根據(jù)定義好的損失函數(shù)優(yōu)化神經(jīng)網(wǎng)絡(luò)中參數(shù)的取值,從而使神經(jīng)網(wǎng)絡(luò)模型在訓(xùn)練數(shù)據(jù)集上的損失函數(shù)達(dá)到一個(gè)較小值棍厂。神經(jīng)網(wǎng)絡(luò)模型中參數(shù)的優(yōu)化過(guò)程直接決定了模型的質(zhì)量颗味,是使用神經(jīng)網(wǎng)絡(luò)時(shí)非常重要的一步。今天我們將主要介紹神經(jīng)網(wǎng)絡(luò)優(yōu)化過(guò)程的基本概念和主要思想牺弹,而略去算法的數(shù)學(xué)推導(dǎo)和證明脱衙,數(shù)學(xué)的推導(dǎo)網(wǎng)上有很好的文章講解,這里推薦一篇文章:https://zhuanlan.zhihu.com/p/25081671
2例驹、梯度下降法
假設(shè)用θ表示神經(jīng)網(wǎng)絡(luò)中的參數(shù)捐韩,J(θ)表示在給定的參數(shù)取值下,訓(xùn)練數(shù)據(jù)集上損失函數(shù)的大小鹃锈。那么整個(gè)優(yōu)化過(guò)程可以抽象為尋找一個(gè)參數(shù)θ荤胁,使得J(θ)最小。因?yàn)槟壳皼](méi)有一個(gè)通用的方法可以對(duì)任意損失函數(shù)直接求解最佳的參數(shù)取值屎债,所以在實(shí)踐中仅政,梯度下降算法是最常用的神經(jīng)網(wǎng)絡(luò)優(yōu)化方法。梯度下降算法會(huì)法代式更新參數(shù)θ盆驹,不斷沿著梯度的反方向讓參數(shù)朝著總損失更小的方向更新圆丹。下圖展示了梯度下降算法的基本原理。
圖1中x軸表示參數(shù)θ的取值躯喇,y軸表示損失函數(shù)J(θ)的值辫封。曲線表示了在參數(shù)θ取不同值時(shí),對(duì)應(yīng)損失函數(shù)J(θ)的大小廉丽。假設(shè)當(dāng)前的參數(shù)和損失值對(duì)應(yīng)圖中小圓點(diǎn)的位置倦微,那么梯度下降算法會(huì)將參數(shù)向x 軸左側(cè)移動(dòng),從而使得小圓點(diǎn)朝著箭頭的方向移動(dòng)正压。參數(shù)的梯度可以通過(guò)求偏導(dǎo)的方式計(jì)算欣福,對(duì)于參數(shù)θ,其梯度為焦履。有了梯度拓劝,還需要定義一個(gè)學(xué)習(xí)率η(learning rate)來(lái)定義每次參數(shù)更新的幅度雏逾。從直觀上理解,可以認(rèn)為學(xué)習(xí)率定義的就是每次參數(shù)移動(dòng)的幅度郑临。通過(guò)參數(shù)的梯度和學(xué)習(xí)率校套,參數(shù)更新的公式為:
可以看出,神經(jīng)網(wǎng)絡(luò)的優(yōu)化過(guò)程可以分為兩個(gè)階段牧抵,第一個(gè)階段先通過(guò)前向傳播算法計(jì)算得到預(yù)測(cè)值笛匙,井將預(yù)測(cè)值和真實(shí)值做對(duì)比得出兩者之間的差距。然后在第二個(gè)階段通過(guò)反向傳播算法計(jì)算損失函數(shù)對(duì)每一個(gè)參數(shù)的梯度犀变,再根據(jù)梯度和學(xué)習(xí)率使用梯度下降算法更新每一個(gè)參數(shù)妹孙。
3、存在的問(wèn)題
需要注意的是获枝,梯度下降算法并不能保證被優(yōu)化的函數(shù)達(dá)到全局最優(yōu)解蠢正。如下圖2所示,圖中給出的函數(shù)就有可能只能得到局部最優(yōu)解而不是全局最優(yōu)解省店。在小黑點(diǎn)處嚣崭,損失函數(shù)的偏導(dǎo)為0 ,于是參數(shù)就不會(huì)再進(jìn)一步更新懦傍。在這個(gè)樣例中雹舀,如果參數(shù)x的初始值落在右側(cè)深色的區(qū)間中,那么通過(guò)梯度下降得到的結(jié)果都會(huì)落到小黑點(diǎn)代表的局部最優(yōu)解粗俱。只有當(dāng)x的初始值落在左側(cè)淺色的區(qū)間時(shí)梯度下降才能給出全局最優(yōu)答案说榆。由此可見(jiàn)在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),參數(shù)的初始值會(huì)很大程度影響最后得到的結(jié)果寸认。只有當(dāng)損失函數(shù)為凸函數(shù)時(shí)签财,梯度下降算法才能保證達(dá)到全局最優(yōu)解。
除了不一定能達(dá)到全局最優(yōu)偏塞,梯度下降算法的另外一個(gè)問(wèn)題就是計(jì)算時(shí)間太長(zhǎng)唱蒸。因?yàn)橐谌坑?xùn)練數(shù)據(jù)上最小化損失,所以損失函數(shù)是在所有訓(xùn)練數(shù)據(jù)上的損失和灸叼。這樣在每一輪迭代中都需要計(jì)算在全部訓(xùn)練數(shù)據(jù)上的損失函數(shù)神汹。在海量訓(xùn)練數(shù)據(jù)下,要計(jì)算所有訓(xùn)練數(shù)據(jù)的損失函數(shù)是非常消耗時(shí)間的怜姿。為了加速訓(xùn)練過(guò)程慎冤,可以使用隨機(jī)梯度下降的算法(stochastic gradient descent)。這個(gè)算法優(yōu)化的不是在全部訓(xùn)練數(shù)據(jù)上的損失函數(shù)沧卢,而是在每一輪法代中,隨機(jī)優(yōu)化某一條訓(xùn)練數(shù)據(jù)上的損失函數(shù)醉者。這樣每一輪參數(shù)更新的速度就大大加快了但狭。因?yàn)殡S機(jī)梯度下降算法每次優(yōu)化的只是某一條數(shù)據(jù)上的損失函數(shù)披诗,所以它的問(wèn)題也非常明顯:在某一條數(shù)據(jù)上損失函數(shù)更小并不代表在全部數(shù)據(jù)上損失函數(shù)更小,于是使用隨機(jī)梯度下降優(yōu)化得到的神經(jīng)網(wǎng)絡(luò)甚至可能無(wú)法達(dá)到局部最優(yōu)立磁。
為了綜合梯度下降算法和隨機(jī)梯度下降算法的優(yōu)缺點(diǎn)呈队,在實(shí)際應(yīng)用中一般采用這兩個(gè)算法的折中一一每次計(jì)算一小部分訓(xùn)練數(shù)據(jù)的損失函數(shù)。這一小部分?jǐn)?shù)據(jù)被稱(chēng)之為一個(gè)batch 唱歧。通過(guò)矩陣運(yùn)算宪摧,每次在一個(gè)batch 上優(yōu)化神經(jīng)網(wǎng)絡(luò)的參數(shù)并不會(huì)比單個(gè)數(shù)據(jù)慢太多。另一方面颅崩,每次使用一個(gè)batch 可以大大減小收斂所需要的法代次數(shù)几于,同時(shí)可以使收斂到的結(jié)果更加接近梯度下降的效果。
4沿后、一個(gè)例子
# 在此僅展示神經(jīng)網(wǎng)絡(luò)訓(xùn)練遵循的大致過(guò)程:
import tensorflow as tf
batch_size = n
# 每次讀取一小部分?jǐn)?shù)據(jù)作為當(dāng)前的訓(xùn)練數(shù)據(jù)來(lái)執(zhí)行反向傳播算法
x = tf.placeholder(tf.float32, shape=(batch_size, 2), name='x-input')
y_ = tf.placeholder(tf.float32, shape=(batch_size, 1), name='y-input')
# 定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)和優(yōu)化算法
loss = ...
train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
# 訓(xùn)練神經(jīng)網(wǎng)絡(luò)
with tf.Session() as sess:
# 參數(shù)初始化
...
# 迭代地更新參數(shù)
for i in range(steps):
# 準(zhǔn)備batch_size個(gè)訓(xùn)練數(shù)據(jù)沿彭。一般講所有數(shù)據(jù)隨機(jī)打亂后再選取可以得到更好的優(yōu)化效果
current_X =, current_Y = ...
sess.run(train_step, feed_dict={x: current_X, y: current_Y})
參考文獻(xiàn)
書(shū)籍:Tensorflow:實(shí)戰(zhàn)Google深度學(xué)習(xí)框架(第二版)