往期回顧
在上一篇文章中啸罢,我們介紹了循環(huán)神經(jīng)網(wǎng)絡(luò)以及它的訓(xùn)練算法。我們也介紹了循環(huán)神經(jīng)網(wǎng)絡(luò)很難訓(xùn)練的原因胎食,這導(dǎo)致了它在實(shí)際應(yīng)用中扰才,很難處理長(zhǎng)距離的依賴。在本文中厕怜,我們將介紹一種改進(jìn)之后的循環(huán)神經(jīng)網(wǎng)絡(luò):長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(Long Short Term Memory Network, LSTM)衩匣,它成功的解決了原始循環(huán)神經(jīng)網(wǎng)絡(luò)的缺陷,成為當(dāng)前最流行的RNN粥航,在語(yǔ)音識(shí)別琅捏、圖片描述、自然語(yǔ)言處理等許多領(lǐng)域中成功應(yīng)用躁锡。但不幸的一面是午绳,LSTM的結(jié)構(gòu)很復(fù)雜,因此映之,我們需要花上一些力氣拦焚,才能把LSTM以及它的訓(xùn)練算法弄明白。在搞清楚LSTM之后杠输,我們?cè)俳榻B一種LSTM的變體:GRU (Gated Recurrent Unit)赎败。 它的結(jié)構(gòu)比LSTM簡(jiǎn)單,而效果卻和LSTM一樣好蠢甲,因此僵刮,它正在逐漸流行起來(lái)。最后鹦牛,我們?nèi)匀粫?huì)動(dòng)手實(shí)現(xiàn)一個(gè)LSTM搞糕。
長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)是啥
我們首先了解一下長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)產(chǎn)生的背景÷罚回顧一下零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)中推導(dǎo)的窍仰,誤差項(xiàng)沿時(shí)間反向傳播的公式:
我們可以根據(jù)下面的不等式,來(lái)獲取的模的上界(睦袷猓可以看做對(duì)
中每一項(xiàng)值的大小的度量):
我們可以看到驹吮,誤差項(xiàng)從t時(shí)刻傳遞到k時(shí)刻,其值的上界是
的指數(shù)函數(shù)晶伦。
分別是對(duì)角矩陣
和矩陣W模的上界碟狞。顯然,除非
乘積的值位于1附近婚陪,否則族沃,當(dāng)t-k很大時(shí)(也就是誤差傳遞很多個(gè)時(shí)刻時(shí)),整個(gè)式子的值就會(huì)變得極小(當(dāng)
乘積小于1)或者極大(當(dāng)
乘積大于1)竭业,前者就是梯度消失智润,后者就是梯度爆炸。雖然科學(xué)家們搞出了很多技巧(比如怎樣初始化權(quán)重)未辆,讓
的值盡可能貼近于1窟绷,終究還是難以抵擋指數(shù)函數(shù)的威力。
梯度消失到底意味著什么咐柜?在零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)中我們已證明兼蜈,權(quán)重?cái)?shù)組W最終的梯度是各個(gè)時(shí)刻的梯度之和,即:
假設(shè)某輪訓(xùn)練中拙友,各時(shí)刻的梯度以及最終的梯度之和如下圖:
我們就可以看到为狸,從上圖的t-3時(shí)刻開始,梯度已經(jīng)幾乎減少到0了遗契。那么辐棒,從這個(gè)時(shí)刻開始再往之前走,得到的梯度(幾乎為零)就不會(huì)對(duì)最終的梯度值有任何貢獻(xiàn)牍蜂,這就相當(dāng)于無(wú)論t-3時(shí)刻之前的網(wǎng)絡(luò)狀態(tài)h是什么漾根,在訓(xùn)練中都不會(huì)對(duì)權(quán)重?cái)?shù)組W的更新產(chǎn)生影響,也就是網(wǎng)絡(luò)事實(shí)上已經(jīng)忽略了t-3時(shí)刻之前的狀態(tài)鲫竞。這就是原始RNN無(wú)法處理長(zhǎng)距離依賴的原因辐怕。
既然找到了問(wèn)題的原因,那么我們就能解決它从绘。從問(wèn)題的定位到解決寄疏,科學(xué)家們大概花了7、8年時(shí)間僵井。終于有一天陕截,Hochreiter和Schmidhuber兩位科學(xué)家發(fā)明出長(zhǎng)短時(shí)記憶網(wǎng)絡(luò),一舉解決這個(gè)問(wèn)題批什。
其實(shí)农曲,長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的思路比較簡(jiǎn)單。原始RNN的隱藏層只有一個(gè)狀態(tài)渊季,即h朋蔫,它對(duì)于短期的輸入非常敏感罚渐。那么却汉,假如我們?cè)僭黾右粋€(gè)狀態(tài),即c荷并,讓它來(lái)保存長(zhǎng)期的狀態(tài)合砂,那么問(wèn)題不就解決了么?如下圖所示:
新增加的狀態(tài)c源织,稱為單元狀態(tài)(cell state)翩伪。我們把上圖按照時(shí)間維度展開:
上圖僅僅是一個(gè)示意圖微猖,我們可以看出,在t時(shí)刻缘屹,LSTM的輸入有三個(gè):當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值凛剥、上一時(shí)刻LSTM的輸出值
、以及上一時(shí)刻的單元狀態(tài)
轻姿;LSTM的輸出有兩個(gè):當(dāng)前時(shí)刻LSTM輸出值
犁珠、和當(dāng)前時(shí)刻的單元狀態(tài)
。注意
互亮、
犁享、
都是向量。
LSTM的關(guān)鍵豹休,就是怎樣控制長(zhǎng)期狀態(tài)c炊昆。在這里,LSTM的思路是使用三個(gè)控制開關(guān)威根。第一個(gè)開關(guān)凤巨,負(fù)責(zé)控制繼續(xù)保存長(zhǎng)期狀態(tài)c;第二個(gè)開關(guān)医窿,負(fù)責(zé)控制把即時(shí)狀態(tài)輸入到長(zhǎng)期狀態(tài)c磅甩;第三個(gè)開關(guān),負(fù)責(zé)控制是否把長(zhǎng)期狀態(tài)c作為當(dāng)前的LSTM的輸出姥卢。三個(gè)開關(guān)的作用如下圖所示:
接下來(lái)卷要,我們要描述一下,輸出h和單元狀態(tài)c的具體計(jì)算方法独榴。
長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的前向計(jì)算
前面描述的開關(guān)是怎樣在算法中實(shí)現(xiàn)的呢僧叉?這就用到了門(gate)的概念。門實(shí)際上就是一層全連接層棺榔,它的輸入是一個(gè)向量瓶堕,輸出是一個(gè)0到1之間的實(shí)數(shù)向量。假設(shè)W是門的權(quán)重向量症歇,是偏置項(xiàng),那么門可以表示為:
門的使用宛蚓,就是用門的輸出向量按元素乘以我們需要控制的那個(gè)向量。因?yàn)殚T的輸出是0到1之間的實(shí)數(shù)向量设塔,那么凄吏,當(dāng)門輸出為0時(shí),任何向量與之相乘都會(huì)得到0向量,這就相當(dāng)于啥都不能通過(guò)痕钢;輸出為1時(shí)图柏,任何向量與之相乘都不會(huì)有任何改變,這就相當(dāng)于啥都可以通過(guò)任连。因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Csigma" alt="\sigma" mathimg="1">(也就是sigmoid函數(shù))的值域是(0,1)蚤吹,所以門的狀態(tài)都是半開半閉的。
LSTM用兩個(gè)門來(lái)控制單元狀態(tài)c的內(nèi)容随抠,一個(gè)是遺忘門(forget gate)距辆,它決定了上一時(shí)刻的單元狀態(tài)有多少保留到當(dāng)前時(shí)刻
;另一個(gè)是輸入門(input gate)暮刃,它決定了當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入
有多少保存到單元狀態(tài)
跨算。LSTM用輸出門(output gate)來(lái)控制單元狀態(tài)
有多少輸出到LSTM的當(dāng)前輸出值
。
我們先來(lái)看一下遺忘門:
上式中诸蚕,是遺忘門的權(quán)重矩陣,
表示把兩個(gè)向量連接成一個(gè)更長(zhǎng)的向量氧猬,
是遺忘門的偏置項(xiàng),
是sigmoid函數(shù)盅抚。如果輸入的維度是
漠魏,隱藏層的維度是
,單元狀態(tài)的維度是
(通常
)妄均,則遺忘門的權(quán)重矩陣
維度是
柱锹。事實(shí)上,權(quán)重矩陣
都是兩個(gè)矩陣拼接而成的:一個(gè)是
丰包,它對(duì)應(yīng)著輸入項(xiàng)
禁熏,其維度為
;一個(gè)是
邑彪,它對(duì)應(yīng)著輸入項(xiàng)
瞧毙,其維度為
。
可以寫為:
下圖顯示了遺忘門的計(jì)算:
接下來(lái)看看輸入門:
上式中宙彪,是輸入門的權(quán)重矩陣,
是輸入門的偏置項(xiàng)释漆。下圖表示了輸入門的計(jì)算:
接下來(lái),我們計(jì)算用于描述當(dāng)前輸入的單元狀態(tài)剪决,它是根據(jù)上一次的輸出和本次輸入來(lái)計(jì)算的:
下圖是的計(jì)算:
現(xiàn)在,我們計(jì)算當(dāng)前時(shí)刻的單元狀態(tài)柑潦。它是由上一次的單元狀態(tài)
按元素乘以遺忘門
享言,再用當(dāng)前輸入的單元狀態(tài)
按元素乘以輸入門
,再將兩個(gè)積加和產(chǎn)生的:
符號(hào)表示按元素乘渗鬼。下圖是
的計(jì)算:
這樣览露,我們就把LSTM關(guān)于當(dāng)前的記憶和長(zhǎng)期的記憶
組合在一起,形成了新的單元狀態(tài)
譬胎。由于遺忘門的控制差牛,它可以保存很久很久之前的信息,由于輸入門的控制堰乔,它又可以避免當(dāng)前無(wú)關(guān)緊要的內(nèi)容進(jìn)入記憶偏化。下面,我們要看看輸出門镐侯,它控制了長(zhǎng)期記憶對(duì)當(dāng)前輸出的影響:
下圖表示輸出門的計(jì)算:
LSTM最終的輸出,是由輸出門和單元狀態(tài)共同確定的:
下圖表示LSTM最終輸出的計(jì)算:
式1到式6就是LSTM前向計(jì)算的全部公式苟翻。至此韵卤,我們就把LSTM前向計(jì)算講完了。
長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的訓(xùn)練
熟悉我們這個(gè)系列文章的同學(xué)都清楚崇猫,訓(xùn)練部分往往比前向計(jì)算部分復(fù)雜多了沈条。LSTM的前向計(jì)算都這么復(fù)雜,那么诅炉,可想而知蜡歹,它的訓(xùn)練算法一定是非常非常復(fù)雜的。現(xiàn)在只有做幾次深呼吸涕烧,再一頭扎進(jìn)公式海洋吧季稳。
LSTM訓(xùn)練算法框架
LSTM的訓(xùn)練算法仍然是反向傳播算法,對(duì)于這個(gè)算法澈魄,我們已經(jīng)非常熟悉了景鼠。主要有下面三個(gè)步驟:
- 前向計(jì)算每個(gè)神經(jīng)元的輸出值,對(duì)于LSTM來(lái)說(shuō)痹扇,即
铛漓、
、
鲫构、
浓恶、
五個(gè)向量的值。計(jì)算方法已經(jīng)在上一節(jié)中描述過(guò)了结笨。
- 反向計(jì)算每個(gè)神經(jīng)元的誤差項(xiàng)
值包晰。與循環(huán)神經(jīng)網(wǎng)絡(luò)一樣湿镀,LSTM誤差項(xiàng)的反向傳播也是包括兩個(gè)方向:一個(gè)是沿時(shí)間的反向傳播,即從當(dāng)前t時(shí)刻開始伐憾,計(jì)算每個(gè)時(shí)刻的誤差項(xiàng)勉痴;一個(gè)是將誤差項(xiàng)向上一層傳播。
- 根據(jù)相應(yīng)的誤差項(xiàng)树肃,計(jì)算每個(gè)權(quán)重的梯度蒸矛。
關(guān)于公式和符號(hào)的說(shuō)明
首先,我們對(duì)推導(dǎo)中用到的一些公式胸嘴、符號(hào)做一下必要的說(shuō)明雏掠。
接下來(lái)的推導(dǎo)中,我們?cè)O(shè)定gate的激活函數(shù)為sigmoid函數(shù)劣像,輸出的激活函數(shù)為tanh函數(shù)乡话。他們的導(dǎo)數(shù)分別為:
從上面可以看出,sigmoid和tanh函數(shù)的導(dǎo)數(shù)都是原函數(shù)的函數(shù)耳奕。這樣蚊伞,我們一旦計(jì)算原函數(shù)的值,就可以用它來(lái)計(jì)算出導(dǎo)數(shù)的值吮铭。
LSTM需要學(xué)習(xí)的參數(shù)共有8組时迫,分別是:遺忘門的權(quán)重矩陣和偏置項(xiàng)
谓晌、輸入門的權(quán)重矩陣
和偏置項(xiàng)
、輸出門的權(quán)重矩陣
和偏置項(xiàng)
溺欧,以及計(jì)算單元狀態(tài)的權(quán)重矩陣
和偏置項(xiàng)
柏肪。因?yàn)闄?quán)重矩陣的兩部分在反向傳播中使用不同的公式姐刁,因此在后續(xù)的推導(dǎo)中,權(quán)重矩陣
烦味、
聂使、
、
都將被寫為分開的兩個(gè)矩陣:
谬俄、
柏靶、
、
溃论、
屎蜓、
、
钥勋、
炬转。
我們解釋一下按元素乘符號(hào)辆苔。當(dāng)
作用于兩個(gè)向量時(shí),運(yùn)算如下:
當(dāng)作用于一個(gè)向量和一個(gè)矩陣時(shí),運(yùn)算如下:
當(dāng)作用于兩個(gè)矩陣時(shí),兩個(gè)矩陣對(duì)應(yīng)位置的元素相乘。按元素乘可以在某些情況下簡(jiǎn)化矩陣和向量運(yùn)算棠隐。例如,當(dāng)一個(gè)對(duì)角矩陣右乘一個(gè)矩陣時(shí)盾舌,相當(dāng)于用對(duì)角矩陣的對(duì)角線組成的向量按元素乘那個(gè)矩陣:
當(dāng)一個(gè)行向量右乘一個(gè)對(duì)角矩陣時(shí)走趋,相當(dāng)于這個(gè)行向量按元素乘那個(gè)矩陣對(duì)角線組成的向量:
上面這兩點(diǎn)榨婆,在我們后續(xù)推導(dǎo)中會(huì)多次用到。
在t時(shí)刻褒侧,LSTM的輸出值為良风。我們定義t時(shí)刻的誤差項(xiàng)
為:
注意,和前面幾篇文章不同闷供,我們這里假設(shè)誤差項(xiàng)是損失函數(shù)對(duì)輸出值的導(dǎo)數(shù)烟央,而不是對(duì)加權(quán)輸入的導(dǎo)數(shù)。因?yàn)長(zhǎng)STM有四個(gè)加權(quán)輸入歪脏,分別對(duì)應(yīng)
疑俭、
、
婿失、
钞艇,我們希望往上一層傳遞一個(gè)誤差項(xiàng)而不是四個(gè)。但我們?nèi)匀恍枰x出這四個(gè)加權(quán)輸入豪硅,以及他們對(duì)應(yīng)的誤差項(xiàng)哩照。
誤差項(xiàng)沿時(shí)間的反向傳遞
沿時(shí)間反向傳遞誤差項(xiàng)学少,就是要計(jì)算出t-1時(shí)刻的誤差項(xiàng)。
我們知道秧骑,是一個(gè)Jacobian矩陣版确。如果隱藏層h的維度是N的話扣囊,那么它就是一個(gè)
矩陣。為了求出它绒疗,我們列出
的計(jì)算公式侵歇,即前面的式6和式4:
顯然,吓蘑、
惕虑、
、
都是
的函數(shù)磨镶,那么溃蔫,利用全導(dǎo)數(shù)公式可得:
下面,我們要把式7中的每個(gè)偏導(dǎo)數(shù)都求出來(lái)琳猫。根據(jù)式6伟叛,我們可以求出:
根據(jù)式4,我們可以求出:
因?yàn)椋?/p>
我們很容易得出:
將上述偏導(dǎo)數(shù)帶入到式7侥蒙,我們得到:
根據(jù)、
匀奏、
辉哥、
的定義,可知:
式8到式12就是將誤差沿時(shí)間反向傳播一個(gè)時(shí)刻的公式攒射。有了它醋旦,我們可以寫出將誤差項(xiàng)向前傳遞到任意k時(shí)刻的公式:
將誤差項(xiàng)傳遞到上一層
我們假設(shè)當(dāng)前為第l層,定義l-1層的誤差項(xiàng)是誤差函數(shù)對(duì)l-1層加權(quán)輸入的導(dǎo)數(shù)会放,即:
本次LSTM的輸入由下面的公式計(jì)算:
上式中饲齐,表示第l-1層的激活函數(shù)。
因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Cmathbf%7Bnet%7D_%7Bf%2Ct%7D%5El" alt="\mathbf{net}_{f,t}^l" mathimg="1">咧最、捂人、
、
都是
的函數(shù)矢沿,
又是
的函數(shù)滥搭,因此,要求出E對(duì)
的導(dǎo)數(shù)捣鲸,就需要使用全導(dǎo)數(shù)公式:
式14就是將誤差傳遞到上一層的公式瑟匆。
權(quán)重梯度的計(jì)算
對(duì)于、
栽惶、
愁溜、
的權(quán)重梯度疾嗅,我們知道它的梯度是各個(gè)時(shí)刻梯度之和(證明過(guò)程請(qǐng)參考文章零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)),我們首先求出它們?cè)趖時(shí)刻的梯度冕象,然后再求出他們最終的梯度代承。
我們已經(jīng)求得了誤差項(xiàng)、
渐扮、
论悴、
,很容易求出t時(shí)刻的
墓律、的
膀估、的
、的
:
將各個(gè)時(shí)刻的梯度加在一起只锻,就能得到最終的梯度:
對(duì)于偏置項(xiàng)紫谷、
、
祖驱、
的梯度瞒窒,也是將各個(gè)時(shí)刻的梯度加在一起捺僻。下面是各個(gè)時(shí)刻的偏置項(xiàng)梯度:
下面是最終的偏置項(xiàng)梯度术奖,即將各個(gè)時(shí)刻的偏置項(xiàng)梯度加在一起:
對(duì)于政勃、
唧龄、
、
的權(quán)重梯度奸远,只需要根據(jù)相應(yīng)的誤差項(xiàng)直接計(jì)算即可:
以上就是LSTM的訓(xùn)練算法的全部公式选侨。因?yàn)檫@里面存在很多重復(fù)的模式掖鱼,仔細(xì)看看,會(huì)發(fā)覺(jué)并不是太復(fù)雜援制。
當(dāng)然戏挡,LSTM存在著相當(dāng)多的變體,讀者可以在互聯(lián)網(wǎng)上找到很多資料晨仑。因?yàn)榇蠹乙呀?jīng)熟悉了基本LSTM的算法褐墅,因此理解這些變體比較容易,因此本文就不再贅述了洪己。
長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的實(shí)現(xiàn)
在下面的實(shí)現(xiàn)中妥凳,LSTMLayer的參數(shù)包括輸入維度、輸出維度答捕、隱藏層維度逝钥,單元狀態(tài)維度等于隱藏層維度。gate的激活函數(shù)為sigmoid函數(shù)拱镐,輸出的激活函數(shù)為tanh艘款。
激活函數(shù)的實(shí)現(xiàn)
我們先實(shí)現(xiàn)兩個(gè)激活函數(shù):sigmoid和tanh。
class SigmoidActivator(object):
def forward(self, weighted_input):
return 1.0 / (1.0 + np.exp(-weighted_input))
def backward(self, output):
return output * (1 - output)
class TanhActivator(object):
def forward(self, weighted_input):
return 2.0 / (1.0 + np.exp(-2 * weighted_input)) - 1.0
def backward(self, output):
return 1 - output * output
LSTM初始化
和前兩篇文章代碼架構(gòu)一樣沃琅,我們把LSTM的實(shí)現(xiàn)放在LstmLayer類中哗咆。
根據(jù)LSTM前向計(jì)算和方向傳播算法,我們需要初始化一系列矩陣和向量益眉。這些矩陣和向量有兩類用途晌柬,一類是用于保存模型參數(shù),例如郭脂、
年碘、
、
展鸡、
、
傲诵、
箱硕、
;另一類是保存各種中間計(jì)算結(jié)果剧罩,以便于反向傳播算法使用栓拜,它們包括
、
、
幕与、
挑势、
、
啦鸣、
潮饱、
、
诫给、
香拉、
,以及各個(gè)權(quán)重對(duì)應(yīng)的梯度中狂。
在構(gòu)造函數(shù)的初始化中凫碌,只初始化了與forward計(jì)算相關(guān)的變量,與backward相關(guān)的變量沒(méi)有初始化胃榕。這是因?yàn)闃?gòu)造LSTM對(duì)象的時(shí)候盛险,我們還不知道它未來(lái)是用于訓(xùn)練(既有forward又有backward)還是推理(只有forward)。
class LstmLayer(object):
def __init__(self, input_width, state_width,
learning_rate):
self.input_width = input_width
self.state_width = state_width
self.learning_rate = learning_rate
# 門的激活函數(shù)
self.gate_activator = SigmoidActivator()
# 輸出的激活函數(shù)
self.output_activator = TanhActivator()
# 當(dāng)前時(shí)刻初始化為t0
self.times = 0
# 各個(gè)時(shí)刻的單元狀態(tài)向量c
self.c_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸出向量h
self.h_list = self.init_state_vec()
# 各個(gè)時(shí)刻的遺忘門f
self.f_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸入門i
self.i_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸出門o
self.o_list = self.init_state_vec()
# 各個(gè)時(shí)刻的即時(shí)狀態(tài)c~
self.ct_list = self.init_state_vec()
# 遺忘門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
self.Wfh, self.Wfx, self.bf = (
self.init_weight_mat())
# 輸入門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
self.Wih, self.Wix, self.bi = (
self.init_weight_mat())
# 輸出門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
self.Woh, self.Wox, self.bo = (
self.init_weight_mat())
# 單元狀態(tài)權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
self.Wch, self.Wcx, self.bc = (
self.init_weight_mat())
def init_state_vec(self):
'''
初始化保存狀態(tài)的向量
'''
state_vec_list = []
state_vec_list.append(np.zeros(
(self.state_width, 1)))
return state_vec_list
def init_weight_mat(self):
'''
初始化權(quán)重矩陣
'''
Wh = np.random.uniform(-1e-4, 1e-4,
(self.state_width, self.state_width))
Wx = np.random.uniform(-1e-4, 1e-4,
(self.state_width, self.input_width))
b = np.zeros((self.state_width, 1))
return Wh, Wx, b
前向計(jì)算的實(shí)現(xiàn)
forward方法實(shí)現(xiàn)了LSTM的前向計(jì)算:
def forward(self, x):
'''
根據(jù)式1-式6進(jìn)行前向計(jì)算
'''
self.times += 1
# 遺忘門
fg = self.calc_gate(x, self.Wfx, self.Wfh,
self.bf, self.gate_activator)
self.f_list.append(fg)
# 輸入門
ig = self.calc_gate(x, self.Wix, self.Wih,
self.bi, self.gate_activator)
self.i_list.append(ig)
# 輸出門
og = self.calc_gate(x, self.Wox, self.Woh,
self.bo, self.gate_activator)
self.o_list.append(og)
# 即時(shí)狀態(tài)
ct = self.calc_gate(x, self.Wcx, self.Wch,
self.bc, self.output_activator)
self.ct_list.append(ct)
# 單元狀態(tài)
c = fg * self.c_list[self.times - 1] + ig * ct
self.c_list.append(c)
# 輸出
h = og * self.output_activator.forward(c)
self.h_list.append(h)
def calc_gate(self, x, Wx, Wh, b, activator):
'''
計(jì)算門
'''
h = self.h_list[self.times - 1] # 上次的LSTM輸出
net = np.dot(Wh, h) + np.dot(Wx, x) + b
gate = activator.forward(net)
return gate
從上面的代碼我們可以看到勋又,門的計(jì)算都是相同的算法苦掘,而門和的計(jì)算僅僅是激活函數(shù)不同。因此我們提出了calc_gate方法赐写,這樣減少了很多重復(fù)代碼鸟蜡。
反向傳播算法的實(shí)現(xiàn)
backward方法實(shí)現(xiàn)了LSTM的反向傳播算法膜赃。需要注意的是挺邀,與backword相關(guān)的內(nèi)部狀態(tài)變量是在調(diào)用backward方法之后才初始化的。這種延遲初始化的一個(gè)好處是跳座,如果LSTM只是用來(lái)推理端铛,那么就不需要初始化這些變量,節(jié)省了很多內(nèi)存疲眷。
def backward(self, x, delta_h, activator):
'''
實(shí)現(xiàn)LSTM訓(xùn)練算法
'''
self.calc_delta(delta_h, activator)
self.calc_gradient(x)
算法主要分成兩個(gè)部分禾蚕,一部分使計(jì)算誤差項(xiàng):
def calc_delta(self, delta_h, activator):
# 初始化各個(gè)時(shí)刻的誤差項(xiàng)
self.delta_h_list = self.init_delta() # 輸出誤差項(xiàng)
self.delta_o_list = self.init_delta() # 輸出門誤差項(xiàng)
self.delta_i_list = self.init_delta() # 輸入門誤差項(xiàng)
self.delta_f_list = self.init_delta() # 遺忘門誤差項(xiàng)
self.delta_ct_list = self.init_delta() # 即時(shí)輸出誤差項(xiàng)
# 保存從上一層傳遞下來(lái)的當(dāng)前時(shí)刻的誤差項(xiàng)
self.delta_h_list[-1] = delta_h
# 迭代計(jì)算每個(gè)時(shí)刻的誤差項(xiàng)
for k in range(self.times, 0, -1):
self.calc_delta_k(k)
def init_delta(self):
'''
初始化誤差項(xiàng)
'''
delta_list = []
for i in range(self.times + 1):
delta_list.append(np.zeros(
(self.state_width, 1)))
return delta_list
def calc_delta_k(self, k):
'''
根據(jù)k時(shí)刻的delta_h,計(jì)算k時(shí)刻的delta_f狂丝、
delta_i换淆、delta_o、delta_ct几颜,以及k-1時(shí)刻的delta_h
'''
# 獲得k時(shí)刻前向計(jì)算的值
ig = self.i_list[k]
og = self.o_list[k]
fg = self.f_list[k]
ct = self.ct_list[k]
c = self.c_list[k]
c_prev = self.c_list[k-1]
tanh_c = self.output_activator.forward(c)
delta_k = self.delta_h_list[k]
# 根據(jù)式9計(jì)算delta_o
delta_o = (delta_k * tanh_c *
self.gate_activator.backward(og))
delta_f = (delta_k * og *
(1 - tanh_c * tanh_c) * c_prev *
self.gate_activator.backward(fg))
delta_i = (delta_k * og *
(1 - tanh_c * tanh_c) * ct *
self.gate_activator.backward(ig))
delta_ct = (delta_k * og *
(1 - tanh_c * tanh_c) * ig *
self.output_activator.backward(ct))
delta_h_prev = (
np.dot(delta_o.transpose(), self.Woh) +
np.dot(delta_i.transpose(), self.Wih) +
np.dot(delta_f.transpose(), self.Wfh) +
np.dot(delta_ct.transpose(), self.Wch)
).transpose()
# 保存全部delta值
self.delta_h_list[k-1] = delta_h_prev
self.delta_f_list[k] = delta_f
self.delta_i_list[k] = delta_i
self.delta_o_list[k] = delta_o
self.delta_ct_list[k] = delta_ct
另一部分是計(jì)算梯度:
def calc_gradient(self, x):
# 初始化遺忘門權(quán)重梯度矩陣和偏置項(xiàng)
self.Wfh_grad, self.Wfx_grad, self.bf_grad = (
self.init_weight_gradient_mat())
# 初始化輸入門權(quán)重梯度矩陣和偏置項(xiàng)
self.Wih_grad, self.Wix_grad, self.bi_grad = (
self.init_weight_gradient_mat())
# 初始化輸出門權(quán)重梯度矩陣和偏置項(xiàng)
self.Woh_grad, self.Wox_grad, self.bo_grad = (
self.init_weight_gradient_mat())
# 初始化單元狀態(tài)權(quán)重梯度矩陣和偏置項(xiàng)
self.Wch_grad, self.Wcx_grad, self.bc_grad = (
self.init_weight_gradient_mat())
# 計(jì)算對(duì)上一次輸出h的權(quán)重梯度
for t in range(self.times, 0, -1):
# 計(jì)算各個(gè)時(shí)刻的梯度
(Wfh_grad, bf_grad,
Wih_grad, bi_grad,
Woh_grad, bo_grad,
Wch_grad, bc_grad) = (
self.calc_gradient_t(t))
# 實(shí)際梯度是各時(shí)刻梯度之和
self.Wfh_grad += Wfh_grad
self.bf_grad += bf_grad
self.Wih_grad += Wih_grad
self.bi_grad += bi_grad
self.Woh_grad += Woh_grad
self.bo_grad += bo_grad
self.Wch_grad += Wch_grad
self.bc_grad += bc_grad
print '-----%d-----' % t
print Wfh_grad
print self.Wfh_grad
# 計(jì)算對(duì)本次輸入x的權(quán)重梯度
xt = x.transpose()
self.Wfx_grad = np.dot(self.delta_f_list[-1], xt)
self.Wix_grad = np.dot(self.delta_i_list[-1], xt)
self.Wox_grad = np.dot(self.delta_o_list[-1], xt)
self.Wcx_grad = np.dot(self.delta_ct_list[-1], xt)
def init_weight_gradient_mat(self):
'''
初始化權(quán)重矩陣
'''
Wh_grad = np.zeros((self.state_width,
self.state_width))
Wx_grad = np.zeros((self.state_width,
self.input_width))
b_grad = np.zeros((self.state_width, 1))
return Wh_grad, Wx_grad, b_grad
def calc_gradient_t(self, t):
'''
計(jì)算每個(gè)時(shí)刻t權(quán)重的梯度
'''
h_prev = self.h_list[t-1].transpose()
Wfh_grad = np.dot(self.delta_f_list[t], h_prev)
bf_grad = self.delta_f_list[t]
Wih_grad = np.dot(self.delta_i_list[t], h_prev)
bi_grad = self.delta_f_list[t]
Woh_grad = np.dot(self.delta_o_list[t], h_prev)
bo_grad = self.delta_f_list[t]
Wch_grad = np.dot(self.delta_ct_list[t], h_prev)
bc_grad = self.delta_ct_list[t]
return Wfh_grad, bf_grad, Wih_grad, bi_grad, \
Woh_grad, bo_grad, Wch_grad, bc_grad
梯度下降算法的實(shí)現(xiàn)
下面是用梯度下降算法來(lái)更新權(quán)重:
def update(self):
'''
按照梯度下降倍试,更新權(quán)重
'''
self.Wfh -= self.learning_rate * self.Whf_grad
self.Wfx -= self.learning_rate * self.Whx_grad
self.bf -= self.learning_rate * self.bf_grad
self.Wih -= self.learning_rate * self.Whi_grad
self.Wix -= self.learning_rate * self.Whi_grad
self.bi -= self.learning_rate * self.bi_grad
self.Woh -= self.learning_rate * self.Wof_grad
self.Wox -= self.learning_rate * self.Wox_grad
self.bo -= self.learning_rate * self.bo_grad
self.Wch -= self.learning_rate * self.Wcf_grad
self.Wcx -= self.learning_rate * self.Wcx_grad
self.bc -= self.learning_rate * self.bc_grad
梯度檢查的實(shí)現(xiàn)
和RecurrentLayer一樣,為了支持梯度檢查蛋哭,我們需要支持重置內(nèi)部狀態(tài):
def reset_state(self):
# 當(dāng)前時(shí)刻初始化為t0
self.times = 0
# 各個(gè)時(shí)刻的單元狀態(tài)向量c
self.c_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸出向量h
self.h_list = self.init_state_vec()
# 各個(gè)時(shí)刻的遺忘門f
self.f_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸入門i
self.i_list = self.init_state_vec()
# 各個(gè)時(shí)刻的輸出門o
self.o_list = self.init_state_vec()
# 各個(gè)時(shí)刻的即時(shí)狀態(tài)c~
self.ct_list = self.init_state_vec()
最后县习,是梯度檢查的代碼:
def data_set():
x = [np.array([[1], [2], [3]]),
np.array([[2], [3], [4]])]
d = np.array([[1], [2]])
return x, d
def gradient_check():
'''
梯度檢查
'''
# 設(shè)計(jì)一個(gè)誤差函數(shù),取所有節(jié)點(diǎn)輸出項(xiàng)之和
error_function = lambda o: o.sum()
lstm = LstmLayer(3, 2, 1e-3)
# 計(jì)算forward值
x, d = data_set()
lstm.forward(x[0])
lstm.forward(x[1])
# 求取sensitivity map
sensitivity_array = np.ones(lstm.h_list[-1].shape,
dtype=np.float64)
# 計(jì)算梯度
lstm.backward(x[1], sensitivity_array, IdentityActivator())
# 檢查梯度
epsilon = 10e-4
for i in range(lstm.Wfh.shape[0]):
for j in range(lstm.Wfh.shape[1]):
lstm.Wfh[i,j] += epsilon
lstm.reset_state()
lstm.forward(x[0])
lstm.forward(x[1])
err1 = error_function(lstm.h_list[-1])
lstm.Wfh[i,j] -= 2*epsilon
lstm.reset_state()
lstm.forward(x[0])
lstm.forward(x[1])
err2 = error_function(lstm.h_list[-1])
expect_grad = (err1 - err2) / (2 * epsilon)
lstm.Wfh[i,j] += epsilon
print 'weights(%d,%d): expected - actural %.4e - %.4e' % (
i, j, expect_grad, lstm.Wfh_grad[i,j])
return lstm
我們只對(duì)做了檢查,讀者可以自行增加對(duì)其他梯度的檢查躁愿。下面是某次梯度檢查的結(jié)果:
GRU
前面我們講了一種普通的LSTM叛本,事實(shí)上LSTM存在很多變體,許多論文中的LSTM都或多或少的不太一樣彤钟。在眾多的LSTM變體中来候,GRU (Gated Recurrent Unit)也許是最成功的一種。它對(duì)LSTM做了很多簡(jiǎn)化逸雹,同時(shí)卻保持著和LSTM相同的效果吠勘。因此,GRU最近變得越來(lái)越流行峡眶。
GRU對(duì)LSTM做了兩個(gè)大改動(dòng):
- 將輸入門剧防、遺忘門、輸出門變?yōu)閮蓚€(gè)門:更新門(Update Gate)
和重置門(Reset Gate)
辫樱。
- 將單元狀態(tài)與輸出合并為一個(gè)狀態(tài):
峭拘。
GRU的前向計(jì)算公式為:
下圖是GRU的示意圖:
GRU的訓(xùn)練算法比LSTM簡(jiǎn)單一些,留給讀者自行推導(dǎo)狮暑,本文就不再贅述了鸡挠。
小結(jié)
至此,LSTM——也許是結(jié)構(gòu)最復(fù)雜的一類神經(jīng)網(wǎng)絡(luò)——就講完了搬男,相信拿下前幾篇文章的讀者們搞定這篇文章也不在話下吧拣展!現(xiàn)在我們已經(jīng)了解循環(huán)神經(jīng)網(wǎng)絡(luò)和它最流行的變體——LSTM,它們都可以用來(lái)處理序列缔逛。但是备埃,有時(shí)候僅僅擁有處理序列的能力還不夠,還需要處理比序列更為復(fù)雜的結(jié)構(gòu)(比如樹結(jié)構(gòu))褐奴,這時(shí)候就需要用到另外一類網(wǎng)絡(luò):遞歸神經(jīng)網(wǎng)絡(luò)(Recursive Neural Network)按脚,巧合的是,它的縮寫也是RNN敦冬。在下一篇文章中辅搬,我們將介紹遞歸神經(jīng)網(wǎng)絡(luò)和它的訓(xùn)練算法。現(xiàn)在脖旱,漫長(zhǎng)的燒腦暫告一段落堪遂,休息一下吧:)
參考資料
- CS224d: Deep Learning for Natural Language Processing
- Understanding LSTM Networks
- LSTM Forward and Backward Pass
相關(guān)文章
零基礎(chǔ)入門深度學(xué)習(xí)(1) - 感知器
零基礎(chǔ)入門深度學(xué)習(xí)(2) - 線性單元和梯度下降
零基礎(chǔ)入門深度學(xué)習(xí)(3) - 神經(jīng)網(wǎng)絡(luò)和反向傳播算法
零基礎(chǔ)入門深度學(xué)習(xí)(4) - 卷積神經(jīng)網(wǎng)絡(luò)
零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)