零基礎(chǔ)入門深度學(xué)習(xí)(6) - 長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(LSTM)

往期回顧

在上一篇文章中啸罢,我們介紹了循環(huán)神經(jīng)網(wǎng)絡(luò)以及它的訓(xùn)練算法。我們也介紹了循環(huán)神經(jīng)網(wǎng)絡(luò)很難訓(xùn)練的原因胎食,這導(dǎo)致了它在實(shí)際應(yīng)用中扰才,很難處理長(zhǎng)距離的依賴。在本文中厕怜,我們將介紹一種改進(jìn)之后的循環(huán)神經(jīng)網(wǎng)絡(luò):長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(Long Short Term Memory Network, LSTM)衩匣,它成功的解決了原始循環(huán)神經(jīng)網(wǎng)絡(luò)的缺陷,成為當(dāng)前最流行的RNN粥航,在語(yǔ)音識(shí)別琅捏、圖片描述、自然語(yǔ)言處理等許多領(lǐng)域中成功應(yīng)用躁锡。但不幸的一面是午绳,LSTM的結(jié)構(gòu)很復(fù)雜,因此映之,我們需要花上一些力氣拦焚,才能把LSTM以及它的訓(xùn)練算法弄明白。在搞清楚LSTM之后杠输,我們?cè)俳榻B一種LSTM的變體:GRU (Gated Recurrent Unit)赎败。 它的結(jié)構(gòu)比LSTM簡(jiǎn)單,而效果卻和LSTM一樣好蠢甲,因此僵刮,它正在逐漸流行起來(lái)。最后鹦牛,我們?nèi)匀粫?huì)動(dòng)手實(shí)現(xiàn)一個(gè)LSTM搞糕。

長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)是啥

我們首先了解一下長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)產(chǎn)生的背景÷罚回顧一下零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)中推導(dǎo)的窍仰,誤差項(xiàng)沿時(shí)間反向傳播的公式:

\begin{align} \delta_k^T=&\delta_t^T\prod_{i=k}^{t-1}diag[f'(\mathbf{net}_{i})]W\\ \end{align}

我們可以根據(jù)下面的不等式,來(lái)獲取\delta_k^T的模的上界(睦袷猓可以看做對(duì)\delta_k^T中每一項(xiàng)值的大小的度量):

\begin{align} \|\delta_k^T\|\leqslant&\|\delta_t^T\|\prod_{i=k}^{t-1}\|diag[f'(\mathbf{net}_{i})]\|\|W\|\\ \leqslant&\|\delta_t^T\|(\beta_f\beta_W)^{t-k} \end{align}

我們可以看到驹吮,誤差項(xiàng)\delta從t時(shí)刻傳遞到k時(shí)刻,其值的上界是\beta_f\beta_w的指數(shù)函數(shù)晶伦。\beta_f\beta_w分別是對(duì)角矩陣diag[f'(\mathbf{net}_{i})]和矩陣W模的上界碟狞。顯然,除非\beta_f\beta_w乘積的值位于1附近婚陪,否則族沃,當(dāng)t-k很大時(shí)(也就是誤差傳遞很多個(gè)時(shí)刻時(shí)),整個(gè)式子的值就會(huì)變得極小(當(dāng)\beta_f\beta_w乘積小于1)或者極大(當(dāng)\beta_f\beta_w乘積大于1)竭业,前者就是梯度消失智润,后者就是梯度爆炸。雖然科學(xué)家們搞出了很多技巧(比如怎樣初始化權(quán)重)未辆,讓\beta_f\beta_w的值盡可能貼近于1窟绷,終究還是難以抵擋指數(shù)函數(shù)的威力。

梯度消失到底意味著什么咐柜?在零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)中我們已證明兼蜈,權(quán)重?cái)?shù)組W最終的梯度是各個(gè)時(shí)刻的梯度之和,即:

\begin{align} \nabla_WE&=\sum_{k=1}^t\nabla_{Wk}E\\ &=\nabla_{Wt}E+\nabla_{Wt-1}E+\nabla_{Wt-2}E+...+\nabla_{W1}E \end{align}

假設(shè)某輪訓(xùn)練中拙友,各時(shí)刻的梯度以及最終的梯度之和如下圖:

我們就可以看到为狸,從上圖的t-3時(shí)刻開始,梯度已經(jīng)幾乎減少到0了遗契。那么辐棒,從這個(gè)時(shí)刻開始再往之前走,得到的梯度(幾乎為零)就不會(huì)對(duì)最終的梯度值有任何貢獻(xiàn)牍蜂,這就相當(dāng)于無(wú)論t-3時(shí)刻之前的網(wǎng)絡(luò)狀態(tài)h是什么漾根,在訓(xùn)練中都不會(huì)對(duì)權(quán)重?cái)?shù)組W的更新產(chǎn)生影響,也就是網(wǎng)絡(luò)事實(shí)上已經(jīng)忽略了t-3時(shí)刻之前的狀態(tài)鲫竞。這就是原始RNN無(wú)法處理長(zhǎng)距離依賴的原因辐怕。

既然找到了問(wèn)題的原因,那么我們就能解決它从绘。從問(wèn)題的定位到解決寄疏,科學(xué)家們大概花了7、8年時(shí)間僵井。終于有一天陕截,Hochreiter和Schmidhuber兩位科學(xué)家發(fā)明出長(zhǎng)短時(shí)記憶網(wǎng)絡(luò),一舉解決這個(gè)問(wèn)題批什。

其實(shí)农曲,長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的思路比較簡(jiǎn)單。原始RNN的隱藏層只有一個(gè)狀態(tài)渊季,即h朋蔫,它對(duì)于短期的輸入非常敏感罚渐。那么却汉,假如我們?cè)僭黾右粋€(gè)狀態(tài),即c荷并,讓它來(lái)保存長(zhǎng)期的狀態(tài)合砂,那么問(wèn)題不就解決了么?如下圖所示:

新增加的狀態(tài)c源织,稱為單元狀態(tài)(cell state)翩伪。我們把上圖按照時(shí)間維度展開:

上圖僅僅是一個(gè)示意圖微猖,我們可以看出,在t時(shí)刻缘屹,LSTM的輸入有三個(gè):當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值\mathbf{x}_t凛剥、上一時(shí)刻LSTM的輸出值\mathbf{h}_{t-1}、以及上一時(shí)刻的單元狀態(tài)\mathbf{c}_{t-1}轻姿;LSTM的輸出有兩個(gè):當(dāng)前時(shí)刻LSTM輸出值\mathbf{h}_t犁珠、和當(dāng)前時(shí)刻的單元狀態(tài)\mathbf{c}_t。注意\mathbf{x}互亮、\mathbf{h}犁享、\mathbf{c}都是向量

LSTM的關(guān)鍵豹休,就是怎樣控制長(zhǎng)期狀態(tài)c炊昆。在這里,LSTM的思路是使用三個(gè)控制開關(guān)威根。第一個(gè)開關(guān)凤巨,負(fù)責(zé)控制繼續(xù)保存長(zhǎng)期狀態(tài)c;第二個(gè)開關(guān)医窿,負(fù)責(zé)控制把即時(shí)狀態(tài)輸入到長(zhǎng)期狀態(tài)c磅甩;第三個(gè)開關(guān),負(fù)責(zé)控制是否把長(zhǎng)期狀態(tài)c作為當(dāng)前的LSTM的輸出姥卢。三個(gè)開關(guān)的作用如下圖所示:

接下來(lái)卷要,我們要描述一下,輸出h和單元狀態(tài)c的具體計(jì)算方法独榴。

長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的前向計(jì)算

前面描述的開關(guān)是怎樣在算法中實(shí)現(xiàn)的呢僧叉?這就用到了門(gate)的概念。門實(shí)際上就是一層全連接層棺榔,它的輸入是一個(gè)向量瓶堕,輸出是一個(gè)0到1之間的實(shí)數(shù)向量。假設(shè)W是門的權(quán)重向量症歇,\mathbf郎笆是偏置項(xiàng),那么門可以表示為:

g(\mathbf{x})=\sigma(W\mathbf{x}+\mathbf忘晤)

門的使用宛蚓,就是用門的輸出向量按元素乘以我們需要控制的那個(gè)向量。因?yàn)殚T的輸出是0到1之間的實(shí)數(shù)向量设塔,那么凄吏,當(dāng)門輸出為0時(shí),任何向量與之相乘都會(huì)得到0向量,這就相當(dāng)于啥都不能通過(guò)痕钢;輸出為1時(shí)图柏,任何向量與之相乘都不會(huì)有任何改變,這就相當(dāng)于啥都可以通過(guò)任连。因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Csigma" alt="\sigma" mathimg="1">(也就是sigmoid函數(shù))的值域是(0,1)蚤吹,所以門的狀態(tài)都是半開半閉的。

LSTM用兩個(gè)門來(lái)控制單元狀態(tài)c的內(nèi)容随抠,一個(gè)是遺忘門(forget gate)距辆,它決定了上一時(shí)刻的單元狀態(tài)\mathbf{c}_{t-1}有多少保留到當(dāng)前時(shí)刻\mathbf{c}_t;另一個(gè)是輸入門(input gate)暮刃,它決定了當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入\mathbf{x}_t有多少保存到單元狀態(tài)\mathbf{c}_t跨算。LSTM用輸出門(output gate)來(lái)控制單元狀態(tài)\mathbf{c}_t有多少輸出到LSTM的當(dāng)前輸出值\mathbf{h}_t

我們先來(lái)看一下遺忘門:

\mathbf{f}_t=\sigma(W_f\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf椭懊_f)\qquad\quad(式1)

上式中诸蚕,W_f是遺忘門的權(quán)重矩陣,[\mathbf{h}_{t-1},\mathbf{x}_t]表示把兩個(gè)向量連接成一個(gè)更長(zhǎng)的向量氧猬,\mathbf背犯_f是遺忘門的偏置項(xiàng),\sigma是sigmoid函數(shù)盅抚。如果輸入的維度是d_x漠魏,隱藏層的維度是d_h,單元狀態(tài)的維度是d_c(通常d_c=d_h)妄均,則遺忘門的權(quán)重矩陣W_f維度是d_c\times (d_h+d_x)柱锹。事實(shí)上,權(quán)重矩陣W_f都是兩個(gè)矩陣拼接而成的:一個(gè)是W_{fh}丰包,它對(duì)應(yīng)著輸入項(xiàng)\mathbf{h}_{t-1}禁熏,其維度為d_c\times d_h;一個(gè)是W_{fx}邑彪,它對(duì)應(yīng)著輸入項(xiàng)\mathbf{x}_t瞧毙,其維度為d_c\times d_xW_f可以寫為:

\begin{align} \begin{bmatrix}W_f\end{bmatrix}\begin{bmatrix}\mathbf{h}_{t-1}\\ \mathbf{x}_t\end{bmatrix}&= \begin{bmatrix}W_{fh}&W_{fx}\end{bmatrix}\begin{bmatrix}\mathbf{h}_{t-1}\\ \mathbf{x}_t\end{bmatrix}\\ &=W_{fh}\mathbf{h}_{t-1}+W_{fx}\mathbf{x}_t \end{align}

下圖顯示了遺忘門的計(jì)算:

接下來(lái)看看輸入門:

\mathbf{i}_t=\sigma(W_i\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf寄症_i)\qquad\quad(式2)

上式中宙彪,W_i是輸入門的權(quán)重矩陣,\mathbf有巧_i是輸入門的偏置項(xiàng)释漆。下圖表示了輸入門的計(jì)算:

接下來(lái),我們計(jì)算用于描述當(dāng)前輸入的單元狀態(tài)\mathbf{\tilde{c}}_t剪决,它是根據(jù)上一次的輸出和本次輸入來(lái)計(jì)算的:

\mathbf{\tilde{c}}_t=\tanh(W_c\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf灵汪_c)\qquad\quad(式3)

下圖是\mathbf{\tilde{c}}_t的計(jì)算:

現(xiàn)在,我們計(jì)算當(dāng)前時(shí)刻的單元狀態(tài)\mathbf{c}_t柑潦。它是由上一次的單元狀態(tài)\mathbf{c}_{t-1}按元素乘以遺忘門f_t享言,再用當(dāng)前輸入的單元狀態(tài)\mathbf{\tilde{c}}_t按元素乘以輸入門i_t,再將兩個(gè)積加和產(chǎn)生的:

\mathbf{c}_t=f_t\circ{\mathbf{c}_{t-1}}+i_t\circ{\mathbf{\tilde{c}}_t}\qquad\quad(式4)

符號(hào)\circ表示按元素乘渗鬼。下圖是\mathbf{c}_t的計(jì)算:

這樣览露,我們就把LSTM關(guān)于當(dāng)前的記憶\mathbf{\tilde{c}}_t和長(zhǎng)期的記憶\mathbf{c}_{t-1}組合在一起,形成了新的單元狀態(tài)\mathbf{c}_t譬胎。由于遺忘門的控制差牛,它可以保存很久很久之前的信息,由于輸入門的控制堰乔,它又可以避免當(dāng)前無(wú)關(guān)緊要的內(nèi)容進(jìn)入記憶偏化。下面,我們要看看輸出門镐侯,它控制了長(zhǎng)期記憶對(duì)當(dāng)前輸出的影響:

\mathbf{o}_t=\sigma(W_o\cdot[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf侦讨_o)\qquad\quad(式5)

下圖表示輸出門的計(jì)算:

LSTM最終的輸出,是由輸出門和單元狀態(tài)共同確定的:

\mathbf{h}_t=\mathbf{o}_t\circ \tanh(\mathbf{c}_t)\qquad\quad(式6)

下圖表示LSTM最終輸出的計(jì)算:

式1式6就是LSTM前向計(jì)算的全部公式苟翻。至此韵卤,我們就把LSTM前向計(jì)算講完了。

長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的訓(xùn)練

熟悉我們這個(gè)系列文章的同學(xué)都清楚崇猫,訓(xùn)練部分往往比前向計(jì)算部分復(fù)雜多了沈条。LSTM的前向計(jì)算都這么復(fù)雜,那么诅炉,可想而知蜡歹,它的訓(xùn)練算法一定是非常非常復(fù)雜的。現(xiàn)在只有做幾次深呼吸涕烧,再一頭扎進(jìn)公式海洋吧季稳。

LSTM訓(xùn)練算法框架

LSTM的訓(xùn)練算法仍然是反向傳播算法,對(duì)于這個(gè)算法澈魄,我們已經(jīng)非常熟悉了景鼠。主要有下面三個(gè)步驟:

  1. 前向計(jì)算每個(gè)神經(jīng)元的輸出值,對(duì)于LSTM來(lái)說(shuō)痹扇,即\mathbf{f}_t铛漓、\mathbf{i}_t\mathbf{c}_t鲫构、\mathbf{o}_t浓恶、\mathbf{h}_t五個(gè)向量的值。計(jì)算方法已經(jīng)在上一節(jié)中描述過(guò)了结笨。
  2. 反向計(jì)算每個(gè)神經(jīng)元的誤差項(xiàng)\delta值包晰。與循環(huán)神經(jīng)網(wǎng)絡(luò)一樣湿镀,LSTM誤差項(xiàng)的反向傳播也是包括兩個(gè)方向:一個(gè)是沿時(shí)間的反向傳播,即從當(dāng)前t時(shí)刻開始伐憾,計(jì)算每個(gè)時(shí)刻的誤差項(xiàng)勉痴;一個(gè)是將誤差項(xiàng)向上一層傳播。
  3. 根據(jù)相應(yīng)的誤差項(xiàng)树肃,計(jì)算每個(gè)權(quán)重的梯度蒸矛。

關(guān)于公式和符號(hào)的說(shuō)明

首先,我們對(duì)推導(dǎo)中用到的一些公式胸嘴、符號(hào)做一下必要的說(shuō)明雏掠。

接下來(lái)的推導(dǎo)中,我們?cè)O(shè)定gate的激活函數(shù)為sigmoid函數(shù)劣像,輸出的激活函數(shù)為tanh函數(shù)乡话。他們的導(dǎo)數(shù)分別為:

\begin{align} \sigma(z)&=y=\frac{1}{1+e^{-z}}\\ \sigma'(z)&=y(1-y)\\ \tanh(z)&=y=\frac{e^z-e^{-z}}{e^z+e^{-z}}\\ \tanh'(z)&=1-y^2 \end{align}

從上面可以看出,sigmoid和tanh函數(shù)的導(dǎo)數(shù)都是原函數(shù)的函數(shù)耳奕。這樣蚊伞,我們一旦計(jì)算原函數(shù)的值,就可以用它來(lái)計(jì)算出導(dǎo)數(shù)的值吮铭。

LSTM需要學(xué)習(xí)的參數(shù)共有8組时迫,分別是:遺忘門的權(quán)重矩陣W_f和偏置項(xiàng)\mathbf_f谓晌、輸入門的權(quán)重矩陣W_i和偏置項(xiàng)\mathbf掠拳_i、輸出門的權(quán)重矩陣W_o和偏置項(xiàng)\mathbf纸肉_o溺欧,以及計(jì)算單元狀態(tài)的權(quán)重矩陣W_c和偏置項(xiàng)\mathbf_c柏肪。因?yàn)闄?quán)重矩陣的兩部分在反向傳播中使用不同的公式姐刁,因此在后續(xù)的推導(dǎo)中,權(quán)重矩陣W_f烦味、W_i聂使、W_cW_o都將被寫為分開的兩個(gè)矩陣:W_{fh}谬俄、W_{fx}柏靶、W_{ih}W_{ix}溃论、W_{oh}屎蜓、W_{ox}W_{ch}钥勋、W_{cx}炬转。

我們解釋一下按元素乘\circ符號(hào)辆苔。當(dāng)\circ作用于兩個(gè)向量時(shí),運(yùn)算如下:

\mathbf{a}\circ\mathbf=\begin{bmatrix} a_1\\a_2\\a_3\\...\\a_n \end{bmatrix}\circ\begin{bmatrix} b_1\\b_2\\b_3\\...\\b_n \end{bmatrix}=\begin{bmatrix} a_1b_1\\a_2b_2\\a_3b_3\\...\\a_nb_n \end{bmatrix}

當(dāng)\circ作用于一個(gè)向量和一個(gè)矩陣時(shí),運(yùn)算如下:

\begin{align} \mathbf{a}\circ X&=\begin{bmatrix} a_1\\a_2\\a_3\\...\\a_n \end{bmatrix}\circ\begin{bmatrix} x_{11} & x_{12} & x_{13} & ... & x_{1n}\\ x_{21} & x_{22} & x_{23} & ... & x_{2n}\\ x_{31} & x_{32} & x_{33} & ... & x_{3n}\\ & & ...\\ x_{n1} & x_{n2} & x_{n3} & ... & x_{nn}\\ \end{bmatrix}\\ &=\begin{bmatrix} a_1x_{11} & a_1x_{12} & a_1x_{13} & ... & a_1x_{1n}\\ a_2x_{21} & a_2x_{22} & a_2x_{23} & ... & a_2x_{2n}\\ a_3x_{31} & a_3x_{32} & a_3x_{33} & ... & a_3x_{3n}\\ & & ...\\ a_nx_{n1} & a_nx_{n2} & a_nx_{n3} & ... & a_nx_{nn}\\ \end{bmatrix} \end{align}

當(dāng)\circ作用于兩個(gè)矩陣時(shí),兩個(gè)矩陣對(duì)應(yīng)位置的元素相乘。按元素乘可以在某些情況下簡(jiǎn)化矩陣和向量運(yùn)算棠隐。例如,當(dāng)一個(gè)對(duì)角矩陣右乘一個(gè)矩陣時(shí)盾舌,相當(dāng)于用對(duì)角矩陣的對(duì)角線組成的向量按元素乘那個(gè)矩陣:

diag[\mathbf{a}]X=\mathbf{a}\circ X

當(dāng)一個(gè)行向量右乘一個(gè)對(duì)角矩陣時(shí)走趋,相當(dāng)于這個(gè)行向量按元素乘那個(gè)矩陣對(duì)角線組成的向量:

\mathbf{a}^Tdiag[\mathbf]=\mathbf{a}\circ\mathbf森逮

上面這兩點(diǎn)榨婆,在我們后續(xù)推導(dǎo)中會(huì)多次用到。

在t時(shí)刻褒侧,LSTM的輸出值為\mathbf{h}_t良风。我們定義t時(shí)刻的誤差項(xiàng)\delta_t為:

\delta_t\overset{def}{=}\frac{\partial{E}}{\partial{\mathbf{h}_t}}

注意,和前面幾篇文章不同闷供,我們這里假設(shè)誤差項(xiàng)是損失函數(shù)對(duì)輸出值的導(dǎo)數(shù)烟央,而不是對(duì)加權(quán)輸入net_t^l的導(dǎo)數(shù)。因?yàn)長(zhǎng)STM有四個(gè)加權(quán)輸入歪脏,分別對(duì)應(yīng)\mathbf{f}_t疑俭、\mathbf{i}_t\mathbf{c}_t婿失、\mathbf{o}_t钞艇,我們希望往上一層傳遞一個(gè)誤差項(xiàng)而不是四個(gè)。但我們?nèi)匀恍枰x出這四個(gè)加權(quán)輸入豪硅,以及他們對(duì)應(yīng)的誤差項(xiàng)哩照。

\begin{align} \mathbf{net}_{f,t}&=W_f[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf_f\\ &=W_{fh}\mathbf{h}_{t-1}+W_{fx}\mathbf{x}_t+\mathbf懒浮_f\\ \mathbf{net}_{i,t}&=W_i[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf飘弧_i\\ &=W_{ih}\mathbf{h}_{t-1}+W_{ix}\mathbf{x}_t+\mathbf_i\\ \mathbf{net}_{\tilde{c},t}&=W_c[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf砚著_c\\ &=W_{ch}\mathbf{h}_{t-1}+W_{cx}\mathbf{x}_t+\mathbf眯牧_c\\ \mathbf{net}_{o,t}&=W_o[\mathbf{h}_{t-1},\mathbf{x}_t]+\mathbf_o\\ &=W_{oh}\mathbf{h}_{t-1}+W_{ox}\mathbf{x}_t+\mathbf赖草_o\\ \delta_{f,t}&\overset{def}{=}\frac{\partial{E}}{\partial{\mathbf{net}_{f,t}}}\\ \delta_{i,t}&\overset{def}{=}\frac{\partial{E}}{\partial{\mathbf{net}_{i,t}}}\\ \delta_{\tilde{c},t}&\overset{def}{=}\frac{\partial{E}}{\partial{\mathbf{net}_{\tilde{c},t}}}\\ \delta_{o,t}&\overset{def}{=}\frac{\partial{E}}{\partial{\mathbf{net}_{o,t}}}\\ \end{align}

誤差項(xiàng)沿時(shí)間的反向傳遞

沿時(shí)間反向傳遞誤差項(xiàng)学少,就是要計(jì)算出t-1時(shí)刻的誤差項(xiàng)\delta_{t-1}

\begin{align} \delta_{t-1}^T&=\frac{\partial{E}}{\partial{\mathbf{h_{t-1}}}}\\ &=\frac{\partial{E}}{\partial{\mathbf{h_t}}}\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{h_{t-1}}}}\\ &=\delta_{t}^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{h_{t-1}}}} \end{align}

我們知道秧骑,\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{h_{t-1}}}}是一個(gè)Jacobian矩陣版确。如果隱藏層h的維度是N的話扣囊,那么它就是一個(gè)N\times N矩陣。為了求出它绒疗,我們列出\mathbf{h}_t的計(jì)算公式侵歇,即前面的式6式4

\begin{align} \mathbf{h}_t&=\mathbf{o}_t\circ \tanh(\mathbf{c}_t)\\ \mathbf{c}_t&=\mathbf{f}_t\circ\mathbf{c}_{t-1}+\mathbf{i}_t\circ\mathbf{\tilde{c}}_t \end{align}

顯然,\mathbf{o}_t吓蘑、\mathbf{f}_t惕虑、\mathbf{i}_t\mathbf{\tilde{c}}_t都是\mathbf{h}_{t-1}的函數(shù)磨镶,那么溃蔫,利用全導(dǎo)數(shù)公式可得:

\begin{align} \delta_t^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{h_{t-1}}}}&=\delta_t^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{o}_t}}\frac{\partial{\mathbf{o}_t}}{\partial{\mathbf{net}_{o,t}}}\frac{\partial{\mathbf{net}_{o,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_t^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{c}_t}}\frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{f_{t}}}}\frac{\partial{\mathbf{f}_t}}{\partial{\mathbf{net}_{f,t}}}\frac{\partial{\mathbf{net}_{f,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_t^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{c}_t}}\frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{i_{t}}}}\frac{\partial{\mathbf{i}_t}}{\partial{\mathbf{net}_{i,t}}}\frac{\partial{\mathbf{net}_{i,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_t^T\frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{c}_t}}\frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{\tilde{c}}_{t}}}\frac{\partial{\mathbf{\tilde{c}}_t}}{\partial{\mathbf{net}_{\tilde{c},t}}}\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{\mathbf{h_{t-1}}}}\\ &=\delta_{o,t}^T\frac{\partial{\mathbf{net}_{o,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{f,t}^T\frac{\partial{\mathbf{net}_{f,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{i,t}^T\frac{\partial{\mathbf{net}_{i,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{\tilde{c},t}^T\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{\mathbf{h_{t-1}}}}\qquad\quad(式7) \end{align}

下面,我們要把式7中的每個(gè)偏導(dǎo)數(shù)都求出來(lái)琳猫。根據(jù)式6伟叛,我們可以求出:

\begin{align} \frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{o}_t}}&=diag[\tanh(\mathbf{c}_t)]\\ \frac{\partial{\mathbf{h_t}}}{\partial{\mathbf{c}_t}}&=diag[\mathbf{o}_t\circ(1-\tanh(\mathbf{c}_t)^2)] \end{align}

根據(jù)式4,我們可以求出:

\begin{align} \frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{f_{t}}}}&=diag[\mathbf{c}_{t-1}]\\ \frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{i_{t}}}}&=diag[\mathbf{\tilde{c}}_t]\\ \frac{\partial{\mathbf{c}_t}}{\partial{\mathbf{\tilde{c}_{t}}}}&=diag[\mathbf{i}_t]\\ \end{align}

因?yàn)椋?/p>

\begin{align} \mathbf{o}_t&=\sigma(\mathbf{net}_{o,t})\\ \mathbf{net}_{o,t}&=W_{oh}\mathbf{h}_{t-1}+W_{ox}\mathbf{x}_t+\mathbf脐嫂_o\\\\ \mathbf{f}_t&=\sigma(\mathbf{net}_{f,t})\\ \mathbf{net}_{f,t}&=W_{fh}\mathbf{h}_{t-1}+W_{fx}\mathbf{x}_t+\mathbf统刮_f\\\\ \mathbf{i}_t&=\sigma(\mathbf{net}_{i,t})\\ \mathbf{net}_{i,t}&=W_{ih}\mathbf{h}_{t-1}+W_{ix}\mathbf{x}_t+\mathbf_i\\\\ \mathbf{\tilde{c}}_t&=\tanh(\mathbf{net}_{\tilde{c},t})\\ \mathbf{net}_{\tilde{c},t}&=W_{ch}\mathbf{h}_{t-1}+W_{cx}\mathbf{x}_t+\mathbf账千_c\\ \end{align}

我們很容易得出:

\begin{align} \frac{\partial{\mathbf{o}_t}}{\partial{\mathbf{net}_{o,t}}}&=diag[\mathbf{o}_t\circ(1-\mathbf{o}_t)]\\ \frac{\partial{\mathbf{net}_{o,t}}}{\partial{\mathbf{h_{t-1}}}}&=W_{oh}\\ \frac{\partial{\mathbf{f}_t}}{\partial{\mathbf{net}_{f,t}}}&=diag[\mathbf{f}_t\circ(1-\mathbf{f}_t)]\\ \frac{\partial{\mathbf{net}_{f,t}}}{\partial{\mathbf{h}_{t-1}}}&=W_{fh}\\ \frac{\partial{\mathbf{i}_t}}{\partial{\mathbf{net}_{i,t}}}&=diag[\mathbf{i}_t\circ(1-\mathbf{i}_t)]\\ \frac{\partial{\mathbf{net}_{i,t}}}{\partial{\mathbf{h}_{t-1}}}&=W_{ih}\\ \frac{\partial{\mathbf{\tilde{c}}_t}}{\partial{\mathbf{net}_{\tilde{c},t}}}&=diag[1-\mathbf{\tilde{c}}_t^2]\\ \frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{\mathbf{h}_{t-1}}}&=W_{ch} \end{align}

將上述偏導(dǎo)數(shù)帶入到式7侥蒙,我們得到:

\begin{align} \delta_{t-1}&=\delta_{o,t}^T\frac{\partial{\mathbf{net}_{o,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{f,t}^T\frac{\partial{\mathbf{net}_{f,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{i,t}^T\frac{\partial{\mathbf{net}_{i,t}}}{\partial{\mathbf{h_{t-1}}}} +\delta_{\tilde{c},t}^T\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{\mathbf{h_{t-1}}}}\\ &=\delta_{o,t}^T W_{oh} +\delta_{f,t}^TW_{fh} +\delta_{i,t}^TW_{ih} +\delta_{\tilde{c},t}^TW_{ch}\qquad\quad(式8)\\ \end{align}

根據(jù)\delta_{o,t}\delta_{f,t}匀奏、\delta_{i,t}辉哥、\delta_{\tilde{c},t}的定義,可知:

\begin{align} \delta_{o,t}^T&=\delta_t^T\circ\tanh(\mathbf{c}_t)\circ\mathbf{o}_t\circ(1-\mathbf{o}_t)\qquad\quad(式9)\\ \delta_{f,t}^T&=\delta_t^T\circ\mathbf{o}_t\circ(1-\tanh(\mathbf{c}_t)^2)\circ\mathbf{c}_{t-1}\circ\mathbf{f}_t\circ(1-\mathbf{f}_t)\qquad(式10)\\ \delta_{i,t}^T&=\delta_t^T\circ\mathbf{o}_t\circ(1-\tanh(\mathbf{c}_t)^2)\circ\mathbf{\tilde{c}}_t\circ\mathbf{i}_t\circ(1-\mathbf{i}_t)\qquad\quad(式11)\\ \delta_{\tilde{c},t}^T&=\delta_t^T\circ\mathbf{o}_t\circ(1-\tanh(\mathbf{c}_t)^2)\circ\mathbf{i}_t\circ(1-\mathbf{\tilde{c}}^2)\qquad\quad(式12)\\ \end{align}

式8式12就是將誤差沿時(shí)間反向傳播一個(gè)時(shí)刻的公式攒射。有了它醋旦,我們可以寫出將誤差項(xiàng)向前傳遞到任意k時(shí)刻的公式:

\delta_k^T=\prod_{j=k}^{t-1}\delta_{o,j}^TW_{oh} +\delta_{f,j}^TW_{fh} +\delta_{i,j}^TW_{ih} +\delta_{\tilde{c},j}^TW_{ch}\qquad\quad(式13)

將誤差項(xiàng)傳遞到上一層

我們假設(shè)當(dāng)前為第l層,定義l-1層的誤差項(xiàng)是誤差函數(shù)對(duì)l-1層加權(quán)輸入的導(dǎo)數(shù)会放,即:

\delta_t^{l-1}\overset{def}{=}\frac{\partial{E}}{\mathbf{net}_t^{l-1}}

本次LSTM的輸入x_t由下面的公式計(jì)算:

\mathbf{x}_t^l=f^{l-1}(\mathbf{net}_t^{l-1})

上式中饲齐,f^{l-1}表示第l-1層的激活函數(shù)

因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Cmathbf%7Bnet%7D_%7Bf%2Ct%7D%5El" alt="\mathbf{net}_{f,t}^l" mathimg="1">咧最、\mathbf{net}_{i,t}^l捂人、\mathbf{net}_{\tilde{c},t}^l\mathbf{net}_{o,t}^l都是\mathbf{x}_t的函數(shù)矢沿,\mathbf{x}_t又是\mathbf{net}_t^{l-1}的函數(shù)滥搭,因此,要求出E對(duì)\mathbf{net}_t^{l-1}的導(dǎo)數(shù)捣鲸,就需要使用全導(dǎo)數(shù)公式:

\begin{align} \frac{\partial{E}}{\partial{\mathbf{net}_t^{l-1}}}&=\frac{\partial{E}}{\partial{\mathbf{\mathbf{net}_{f,t}^l}}}\frac{\partial{\mathbf{\mathbf{net}_{f,t}^l}}}{\partial{\mathbf{x}_t^l}}\frac{\partial{\mathbf{x}_t^l}}{\partial{\mathbf{\mathbf{net}_t^{l-1}}}} +\frac{\partial{E}}{\partial{\mathbf{\mathbf{net}_{i,t}^l}}}\frac{\partial{\mathbf{\mathbf{net}_{i,t}^l}}}{\partial{\mathbf{x}_t^l}}\frac{\partial{\mathbf{x}_t^l}}{\partial{\mathbf{\mathbf{net}_t^{l-1}}}} +\frac{\partial{E}}{\partial{\mathbf{\mathbf{net}_{\tilde{c},t}^l}}}\frac{\partial{\mathbf{\mathbf{net}_{\tilde{c},t}^l}}}{\partial{\mathbf{x}_t^l}}\frac{\partial{\mathbf{x}_t^l}}{\partial{\mathbf{\mathbf{net}_t^{l-1}}}} +\frac{\partial{E}}{\partial{\mathbf{\mathbf{net}_{o,t}^l}}}\frac{\partial{\mathbf{\mathbf{net}_{o,t}^l}}}{\partial{\mathbf{x}_t^l}}\frac{\partial{\mathbf{x}_t^l}}{\partial{\mathbf{\mathbf{net}_t^{l-1}}}}\\ &=\delta_{f,t}^TW_{fx}\circ f'(\mathbf{net}_t^{l-1})+\delta_{i,t}^TW_{ix}\circ f'(\mathbf{net}_t^{l-1})+\delta_{\tilde{c},t}^TW_{cx}\circ f'(\mathbf{net}_t^{l-1})+\delta_{o,t}^TW_{ox}\circ f'(\mathbf{net}_t^{l-1})\\ &=(\delta_{f,t}^TW_{fx}+\delta_{i,t}^TW_{ix}+\delta_{\tilde{c},t}^TW_{cx}+\delta_{o,t}^TW_{ox})\circ f'(\mathbf{net}_t^{l-1})\qquad\quad(式14) \end{align}

式14就是將誤差傳遞到上一層的公式瑟匆。

權(quán)重梯度的計(jì)算

對(duì)于W_{fh}W_{ih}栽惶、W_{ch}愁溜、W_{oh}的權(quán)重梯度疾嗅,我們知道它的梯度是各個(gè)時(shí)刻梯度之和(證明過(guò)程請(qǐng)參考文章零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)),我們首先求出它們?cè)趖時(shí)刻的梯度冕象,然后再求出他們最終的梯度代承。

我們已經(jīng)求得了誤差項(xiàng)\delta_{o,t}\delta_{f,t}渐扮、\delta_{i,t}论悴、\delta_{\tilde{c},t},很容易求出t時(shí)刻的W_{oh}墓律、的W_{ih}膀估、的W_{fh}、的W_{ch}

\begin{align} \frac{\partial{E}}{\partial{W_{oh,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{o,t}}}\frac{\partial{\mathbf{net}_{o,t}}}{\partial{W_{oh,t}}}\\ &=\delta_{o,t}\mathbf{h}_{t-1}^T\\\\ \frac{\partial{E}}{\partial{W_{fh,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{f,t}}}\frac{\partial{\mathbf{net}_{f,t}}}{\partial{W_{fh,t}}}\\ &=\delta_{f,t}\mathbf{h}_{t-1}^T\\\\ \frac{\partial{E}}{\partial{W_{ih,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{i,t}}}\frac{\partial{\mathbf{net}_{i,t}}}{\partial{W_{ih,t}}}\\ &=\delta_{i,t}\mathbf{h}_{t-1}^T\\\\ \frac{\partial{E}}{\partial{W_{ch,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{\tilde{c},t}}}\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{W_{ch,t}}}\\ &=\delta_{\tilde{c},t}\mathbf{h}_{t-1}^T\\ \end{align}

將各個(gè)時(shí)刻的梯度加在一起只锻,就能得到最終的梯度:

\begin{align} \frac{\partial{E}}{\partial{W_{oh}}}&=\sum_{j=1}^t\delta_{o,j}\mathbf{h}_{j-1}^T\\ \frac{\partial{E}}{\partial{W_{fh}}}&=\sum_{j=1}^t\delta_{f,j}\mathbf{h}_{j-1}^T\\ \frac{\partial{E}}{\partial{W_{ih}}}&=\sum_{j=1}^t\delta_{i,j}\mathbf{h}_{j-1}^T\\ \frac{\partial{E}}{\partial{W_{ch}}}&=\sum_{j=1}^t\delta_{\tilde{c},j}\mathbf{h}_{j-1}^T\\ \end{align}

對(duì)于偏置項(xiàng)\mathbf玖像_f紫谷、\mathbf齐饮_i\mathbf笤昨_c祖驱、\mathbf_o的梯度瞒窒,也是將各個(gè)時(shí)刻的梯度加在一起捺僻。下面是各個(gè)時(shí)刻的偏置項(xiàng)梯度:

\begin{align} \frac{\partial{E}}{\partial{\mathbf_{o,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{o,t}}}\frac{\partial{\mathbf{net}_{o,t}}}{\partial{\mathbf崇裁_{o,t}}}\\ &=\delta_{o,t}\\\\ \frac{\partial{E}}{\partial{\mathbf匕坯_{f,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{f,t}}}\frac{\partial{\mathbf{net}_{f,t}}}{\partial{\mathbf_{f,t}}}\\ &=\delta_{f,t}\\\\ \frac{\partial{E}}{\partial{\mathbf拔稳_{i,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{i,t}}}\frac{\partial{\mathbf{net}_{i,t}}}{\partial{\mathbf葛峻_{i,t}}}\\ &=\delta_{i,t}\\\\ \frac{\partial{E}}{\partial{\mathbf_{c,t}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{\tilde{c},t}}}\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{\mathbf巴比_{c,t}}}\\ &=\delta_{\tilde{c},t}\\ \end{align}

下面是最終的偏置項(xiàng)梯度术奖,即將各個(gè)時(shí)刻的偏置項(xiàng)梯度加在一起:

\begin{align} \frac{\partial{E}}{\partial{\mathbf_o}}&=\sum_{j=1}^t\delta_{o,j}\\ \frac{\partial{E}}{\partial{\mathbf轻绞_i}}&=\sum_{j=1}^t\delta_{i,j}\\ \frac{\partial{E}}{\partial{\mathbf采记_f}}&=\sum_{j=1}^t\delta_{f,j}\\ \frac{\partial{E}}{\partial{\mathbf_c}}&=\sum_{j=1}^t\delta_{\tilde{c},j}\\ \end{align}

對(duì)于W_{fx}政勃、W_{ix}唧龄、W_{cx}W_{ox}的權(quán)重梯度奸远,只需要根據(jù)相應(yīng)的誤差項(xiàng)直接計(jì)算即可:

\begin{align} \frac{\partial{E}}{\partial{W_{ox}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{o,t}}}\frac{\partial{\mathbf{net}_{o,t}}}{\partial{W_{ox}}}\\ &=\delta_{o,t}\mathbf{x}_{t}^T\\\\ \frac{\partial{E}}{\partial{W_{fx}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{f,t}}}\frac{\partial{\mathbf{net}_{f,t}}}{\partial{W_{fx}}}\\ &=\delta_{f,t}\mathbf{x}_{t}^T\\\\ \frac{\partial{E}}{\partial{W_{ix}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{i,t}}}\frac{\partial{\mathbf{net}_{i,t}}}{\partial{W_{ix}}}\\ &=\delta_{i,t}\mathbf{x}_{t}^T\\\\ \frac{\partial{E}}{\partial{W_{cx}}}&=\frac{\partial{E}}{\partial{\mathbf{net}_{\tilde{c},t}}}\frac{\partial{\mathbf{net}_{\tilde{c},t}}}{\partial{W_{cx}}}\\ &=\delta_{\tilde{c},t}\mathbf{x}_{t}^T\\ \end{align}

以上就是LSTM的訓(xùn)練算法的全部公式选侨。因?yàn)檫@里面存在很多重復(fù)的模式掖鱼,仔細(xì)看看,會(huì)發(fā)覺(jué)并不是太復(fù)雜援制。

當(dāng)然戏挡,LSTM存在著相當(dāng)多的變體,讀者可以在互聯(lián)網(wǎng)上找到很多資料晨仑。因?yàn)榇蠹乙呀?jīng)熟悉了基本LSTM的算法褐墅,因此理解這些變體比較容易,因此本文就不再贅述了洪己。

長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)的實(shí)現(xiàn)

在下面的實(shí)現(xiàn)中妥凳,LSTMLayer的參數(shù)包括輸入維度、輸出維度答捕、隱藏層維度逝钥,單元狀態(tài)維度等于隱藏層維度。gate的激活函數(shù)為sigmoid函數(shù)拱镐,輸出的激活函數(shù)為tanh艘款。

激活函數(shù)的實(shí)現(xiàn)

我們先實(shí)現(xiàn)兩個(gè)激活函數(shù):sigmoid和tanh。

class SigmoidActivator(object):
    def forward(self, weighted_input):
        return 1.0 / (1.0 + np.exp(-weighted_input))

    def backward(self, output):
        return output * (1 - output)


class TanhActivator(object):
    def forward(self, weighted_input):
        return 2.0 / (1.0 + np.exp(-2 * weighted_input)) - 1.0

    def backward(self, output):
        return 1 - output * output

LSTM初始化

和前兩篇文章代碼架構(gòu)一樣沃琅,我們把LSTM的實(shí)現(xiàn)放在LstmLayer類中哗咆。

根據(jù)LSTM前向計(jì)算和方向傳播算法,我們需要初始化一系列矩陣和向量益眉。這些矩陣和向量有兩類用途晌柬,一類是用于保存模型參數(shù),例如W_f郭脂、W_i年碘、W_oW_c展鸡、\mathbf屿衅_f\mathbf娱颊_i傲诵、\mathbf_o箱硕、\mathbf拴竹_c;另一類是保存各種中間計(jì)算結(jié)果剧罩,以便于反向傳播算法使用栓拜,它們包括\mathbf{h}_t\mathbf{f}_t\mathbf{i}_t幕与、\mathbf{o}_t挑势、\mathbf{c}_t\mathbf{\tilde{c}}_t啦鸣、\delta_t潮饱、\delta_{f,t}\delta_{i,t}诫给、\delta_{o,t}香拉、\delta_{\tilde{c},t},以及各個(gè)權(quán)重對(duì)應(yīng)的梯度中狂。

在構(gòu)造函數(shù)的初始化中凫碌,只初始化了與forward計(jì)算相關(guān)的變量,與backward相關(guān)的變量沒(méi)有初始化胃榕。這是因?yàn)闃?gòu)造LSTM對(duì)象的時(shí)候盛险,我們還不知道它未來(lái)是用于訓(xùn)練(既有forward又有backward)還是推理(只有forward)。

class LstmLayer(object):
    def __init__(self, input_width, state_width, 
                 learning_rate):
        self.input_width = input_width
        self.state_width = state_width
        self.learning_rate = learning_rate
        # 門的激活函數(shù)
        self.gate_activator = SigmoidActivator()
        # 輸出的激活函數(shù)
        self.output_activator = TanhActivator()
        # 當(dāng)前時(shí)刻初始化為t0
        self.times = 0       
        # 各個(gè)時(shí)刻的單元狀態(tài)向量c
        self.c_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸出向量h
        self.h_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的遺忘門f
        self.f_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸入門i
        self.i_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸出門o
        self.o_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的即時(shí)狀態(tài)c~
        self.ct_list = self.init_state_vec()
        # 遺忘門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
        self.Wfh, self.Wfx, self.bf = (
            self.init_weight_mat())
        # 輸入門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
        self.Wih, self.Wix, self.bi = (
            self.init_weight_mat())
        # 輸出門權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
        self.Woh, self.Wox, self.bo = (
            self.init_weight_mat())
        # 單元狀態(tài)權(quán)重矩陣Wfh, Wfx, 偏置項(xiàng)bf
        self.Wch, self.Wcx, self.bc = (
            self.init_weight_mat())

    def init_state_vec(self):
        '''
        初始化保存狀態(tài)的向量
        '''
        state_vec_list = []
        state_vec_list.append(np.zeros(
            (self.state_width, 1)))
        return state_vec_list

    def init_weight_mat(self):
        '''
        初始化權(quán)重矩陣
        '''
        Wh = np.random.uniform(-1e-4, 1e-4,
            (self.state_width, self.state_width))
        Wx = np.random.uniform(-1e-4, 1e-4,
            (self.state_width, self.input_width))
        b = np.zeros((self.state_width, 1))
        return Wh, Wx, b

前向計(jì)算的實(shí)現(xiàn)

forward方法實(shí)現(xiàn)了LSTM的前向計(jì)算:

    def forward(self, x):
        '''
        根據(jù)式1-式6進(jìn)行前向計(jì)算
        '''
        self.times += 1
        # 遺忘門
        fg = self.calc_gate(x, self.Wfx, self.Wfh, 
            self.bf, self.gate_activator)
        self.f_list.append(fg)
        # 輸入門
        ig = self.calc_gate(x, self.Wix, self.Wih,
            self.bi, self.gate_activator)
        self.i_list.append(ig)
        # 輸出門
        og = self.calc_gate(x, self.Wox, self.Woh,
            self.bo, self.gate_activator)
        self.o_list.append(og)
        # 即時(shí)狀態(tài)
        ct = self.calc_gate(x, self.Wcx, self.Wch,
            self.bc, self.output_activator)
        self.ct_list.append(ct)
        # 單元狀態(tài)
        c = fg * self.c_list[self.times - 1] + ig * ct
        self.c_list.append(c)
        # 輸出
        h = og * self.output_activator.forward(c)
        self.h_list.append(h)

    def calc_gate(self, x, Wx, Wh, b, activator):
        '''
        計(jì)算門
        '''
        h = self.h_list[self.times - 1] # 上次的LSTM輸出
        net = np.dot(Wh, h) + np.dot(Wx, x) + b
        gate = activator.forward(net)
        return gate

從上面的代碼我們可以看到勋又,門的計(jì)算都是相同的算法苦掘,而門和\mathbf{\tilde{c}_t}的計(jì)算僅僅是激活函數(shù)不同。因此我們提出了calc_gate方法赐写,這樣減少了很多重復(fù)代碼鸟蜡。

反向傳播算法的實(shí)現(xiàn)

backward方法實(shí)現(xiàn)了LSTM的反向傳播算法膜赃。需要注意的是挺邀,與backword相關(guān)的內(nèi)部狀態(tài)變量是在調(diào)用backward方法之后才初始化的。這種延遲初始化的一個(gè)好處是跳座,如果LSTM只是用來(lái)推理端铛,那么就不需要初始化這些變量,節(jié)省了很多內(nèi)存疲眷。

    def backward(self, x, delta_h, activator):
        '''
        實(shí)現(xiàn)LSTM訓(xùn)練算法
        '''
        self.calc_delta(delta_h, activator)
        self.calc_gradient(x)

算法主要分成兩個(gè)部分禾蚕,一部分使計(jì)算誤差項(xiàng):

    def calc_delta(self, delta_h, activator):
        # 初始化各個(gè)時(shí)刻的誤差項(xiàng)
        self.delta_h_list = self.init_delta()  # 輸出誤差項(xiàng)
        self.delta_o_list = self.init_delta()  # 輸出門誤差項(xiàng)
        self.delta_i_list = self.init_delta()  # 輸入門誤差項(xiàng)
        self.delta_f_list = self.init_delta()  # 遺忘門誤差項(xiàng)
        self.delta_ct_list = self.init_delta() # 即時(shí)輸出誤差項(xiàng)

        # 保存從上一層傳遞下來(lái)的當(dāng)前時(shí)刻的誤差項(xiàng)
        self.delta_h_list[-1] = delta_h
        
        # 迭代計(jì)算每個(gè)時(shí)刻的誤差項(xiàng)
        for k in range(self.times, 0, -1):
            self.calc_delta_k(k)

    def init_delta(self):
        '''
        初始化誤差項(xiàng)
        '''
        delta_list = []
        for i in range(self.times + 1):
            delta_list.append(np.zeros(
                (self.state_width, 1)))
        return delta_list

    def calc_delta_k(self, k):
        '''
        根據(jù)k時(shí)刻的delta_h,計(jì)算k時(shí)刻的delta_f狂丝、
        delta_i换淆、delta_o、delta_ct几颜,以及k-1時(shí)刻的delta_h
        '''
        # 獲得k時(shí)刻前向計(jì)算的值
        ig = self.i_list[k]
        og = self.o_list[k]
        fg = self.f_list[k]
        ct = self.ct_list[k]
        c = self.c_list[k]
        c_prev = self.c_list[k-1]
        tanh_c = self.output_activator.forward(c)
        delta_k = self.delta_h_list[k]

        # 根據(jù)式9計(jì)算delta_o
        delta_o = (delta_k * tanh_c * 
            self.gate_activator.backward(og))
        delta_f = (delta_k * og * 
            (1 - tanh_c * tanh_c) * c_prev *
            self.gate_activator.backward(fg))
        delta_i = (delta_k * og * 
            (1 - tanh_c * tanh_c) * ct *
            self.gate_activator.backward(ig))
        delta_ct = (delta_k * og * 
            (1 - tanh_c * tanh_c) * ig *
            self.output_activator.backward(ct))
        delta_h_prev = (
                np.dot(delta_o.transpose(), self.Woh) +
                np.dot(delta_i.transpose(), self.Wih) +
                np.dot(delta_f.transpose(), self.Wfh) +
                np.dot(delta_ct.transpose(), self.Wch)
            ).transpose()

        # 保存全部delta值
        self.delta_h_list[k-1] = delta_h_prev
        self.delta_f_list[k] = delta_f
        self.delta_i_list[k] = delta_i
        self.delta_o_list[k] = delta_o
        self.delta_ct_list[k] = delta_ct

另一部分是計(jì)算梯度:

    def calc_gradient(self, x):
        # 初始化遺忘門權(quán)重梯度矩陣和偏置項(xiàng)
        self.Wfh_grad, self.Wfx_grad, self.bf_grad = (
            self.init_weight_gradient_mat())
        # 初始化輸入門權(quán)重梯度矩陣和偏置項(xiàng)
        self.Wih_grad, self.Wix_grad, self.bi_grad = (
            self.init_weight_gradient_mat())
        # 初始化輸出門權(quán)重梯度矩陣和偏置項(xiàng)
        self.Woh_grad, self.Wox_grad, self.bo_grad = (
            self.init_weight_gradient_mat())
        # 初始化單元狀態(tài)權(quán)重梯度矩陣和偏置項(xiàng)
        self.Wch_grad, self.Wcx_grad, self.bc_grad = (
            self.init_weight_gradient_mat())

       # 計(jì)算對(duì)上一次輸出h的權(quán)重梯度
        for t in range(self.times, 0, -1):
            # 計(jì)算各個(gè)時(shí)刻的梯度
            (Wfh_grad, bf_grad,
            Wih_grad, bi_grad,
            Woh_grad, bo_grad,
            Wch_grad, bc_grad) = (
                self.calc_gradient_t(t))
            # 實(shí)際梯度是各時(shí)刻梯度之和
            self.Wfh_grad += Wfh_grad
            self.bf_grad += bf_grad
            self.Wih_grad += Wih_grad
            self.bi_grad += bi_grad
            self.Woh_grad += Woh_grad
            self.bo_grad += bo_grad
            self.Wch_grad += Wch_grad
            self.bc_grad += bc_grad
            print '-----%d-----' % t
            print Wfh_grad
            print self.Wfh_grad

        # 計(jì)算對(duì)本次輸入x的權(quán)重梯度
        xt = x.transpose()
        self.Wfx_grad = np.dot(self.delta_f_list[-1], xt)
        self.Wix_grad = np.dot(self.delta_i_list[-1], xt)
        self.Wox_grad = np.dot(self.delta_o_list[-1], xt)
        self.Wcx_grad = np.dot(self.delta_ct_list[-1], xt)

    def init_weight_gradient_mat(self):
        '''
        初始化權(quán)重矩陣
        '''
        Wh_grad = np.zeros((self.state_width,
            self.state_width))
        Wx_grad = np.zeros((self.state_width,
            self.input_width))
        b_grad = np.zeros((self.state_width, 1))
        return Wh_grad, Wx_grad, b_grad

    def calc_gradient_t(self, t):
        '''
        計(jì)算每個(gè)時(shí)刻t權(quán)重的梯度
        '''
        h_prev = self.h_list[t-1].transpose()
        Wfh_grad = np.dot(self.delta_f_list[t], h_prev)
        bf_grad = self.delta_f_list[t]
        Wih_grad = np.dot(self.delta_i_list[t], h_prev)
        bi_grad = self.delta_f_list[t]
        Woh_grad = np.dot(self.delta_o_list[t], h_prev)
        bo_grad = self.delta_f_list[t]
        Wch_grad = np.dot(self.delta_ct_list[t], h_prev)
        bc_grad = self.delta_ct_list[t]
        return Wfh_grad, bf_grad, Wih_grad, bi_grad, \
               Woh_grad, bo_grad, Wch_grad, bc_grad

梯度下降算法的實(shí)現(xiàn)

下面是用梯度下降算法來(lái)更新權(quán)重:

    def update(self):
        '''
        按照梯度下降倍试,更新權(quán)重
        '''
        self.Wfh -= self.learning_rate * self.Whf_grad
        self.Wfx -= self.learning_rate * self.Whx_grad
        self.bf -= self.learning_rate * self.bf_grad
        self.Wih -= self.learning_rate * self.Whi_grad
        self.Wix -= self.learning_rate * self.Whi_grad
        self.bi -= self.learning_rate * self.bi_grad
        self.Woh -= self.learning_rate * self.Wof_grad
        self.Wox -= self.learning_rate * self.Wox_grad
        self.bo -= self.learning_rate * self.bo_grad
        self.Wch -= self.learning_rate * self.Wcf_grad
        self.Wcx -= self.learning_rate * self.Wcx_grad
        self.bc -= self.learning_rate * self.bc_grad

梯度檢查的實(shí)現(xiàn)

和RecurrentLayer一樣,為了支持梯度檢查蛋哭,我們需要支持重置內(nèi)部狀態(tài):

    def reset_state(self):
        # 當(dāng)前時(shí)刻初始化為t0
        self.times = 0       
        # 各個(gè)時(shí)刻的單元狀態(tài)向量c
        self.c_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸出向量h
        self.h_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的遺忘門f
        self.f_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸入門i
        self.i_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的輸出門o
        self.o_list = self.init_state_vec()
        # 各個(gè)時(shí)刻的即時(shí)狀態(tài)c~
        self.ct_list = self.init_state_vec()

最后县习,是梯度檢查的代碼:

def data_set():
    x = [np.array([[1], [2], [3]]),
         np.array([[2], [3], [4]])]
    d = np.array([[1], [2]])
    return x, d

def gradient_check():
    '''
    梯度檢查
    '''
    # 設(shè)計(jì)一個(gè)誤差函數(shù),取所有節(jié)點(diǎn)輸出項(xiàng)之和
    error_function = lambda o: o.sum()
    
    lstm = LstmLayer(3, 2, 1e-3)

    # 計(jì)算forward值
    x, d = data_set()
    lstm.forward(x[0])
    lstm.forward(x[1])
    
    # 求取sensitivity map
    sensitivity_array = np.ones(lstm.h_list[-1].shape,
                                dtype=np.float64)
    # 計(jì)算梯度
    lstm.backward(x[1], sensitivity_array, IdentityActivator())
    
    # 檢查梯度
    epsilon = 10e-4
    for i in range(lstm.Wfh.shape[0]):
        for j in range(lstm.Wfh.shape[1]):
            lstm.Wfh[i,j] += epsilon
            lstm.reset_state()
            lstm.forward(x[0])
            lstm.forward(x[1])
            err1 = error_function(lstm.h_list[-1])
            lstm.Wfh[i,j] -= 2*epsilon
            lstm.reset_state()
            lstm.forward(x[0])
            lstm.forward(x[1])
            err2 = error_function(lstm.h_list[-1])
            expect_grad = (err1 - err2) / (2 * epsilon)
            lstm.Wfh[i,j] += epsilon
            print 'weights(%d,%d): expected - actural %.4e - %.4e' % (
                i, j, expect_grad, lstm.Wfh_grad[i,j])
    return lstm

我們只對(duì)W_{fh}做了檢查,讀者可以自行增加對(duì)其他梯度的檢查躁愿。下面是某次梯度檢查的結(jié)果:

GRU

前面我們講了一種普通的LSTM叛本,事實(shí)上LSTM存在很多變體,許多論文中的LSTM都或多或少的不太一樣彤钟。在眾多的LSTM變體中来候,GRU (Gated Recurrent Unit)也許是最成功的一種。它對(duì)LSTM做了很多簡(jiǎn)化逸雹,同時(shí)卻保持著和LSTM相同的效果吠勘。因此,GRU最近變得越來(lái)越流行峡眶。

GRU對(duì)LSTM做了兩個(gè)大改動(dòng):

  1. 將輸入門剧防、遺忘門、輸出門變?yōu)閮蓚€(gè)門:更新門(Update Gate)\mathbf{z}_t和重置門(Reset Gate)\mathbf{r}_t辫樱。
  2. 將單元狀態(tài)與輸出合并為一個(gè)狀態(tài):\mathbf{h}峭拘。

GRU的前向計(jì)算公式為:

\begin{align} \mathbf{z}_t&=\sigma(W_z\cdot[\mathbf{h}_{t-1},\mathbf{x}_t])\\ \mathbf{r}_t&=\sigma(W_r\cdot[\mathbf{h}_{t-1},\mathbf{x}_t])\\ \mathbf{\tilde{h}}_t&=\tanh(W\cdot[\mathbf{r}_t\circ\mathbf{h}_{t-1},\mathbf{x}_t])\\ \mathbf{h}&=(1-\mathbf{z}_t)\circ\mathbf{h}_{t-1}+\mathbf{z}_t\circ\mathbf{\tilde{h}}_t \end{align}

下圖是GRU的示意圖:

GRU的訓(xùn)練算法比LSTM簡(jiǎn)單一些,留給讀者自行推導(dǎo)狮暑,本文就不再贅述了鸡挠。

小結(jié)

至此,LSTM——也許是結(jié)構(gòu)最復(fù)雜的一類神經(jīng)網(wǎng)絡(luò)——就講完了搬男,相信拿下前幾篇文章的讀者們搞定這篇文章也不在話下吧拣展!現(xiàn)在我們已經(jīng)了解循環(huán)神經(jīng)網(wǎng)絡(luò)和它最流行的變體——LSTM,它們都可以用來(lái)處理序列缔逛。但是备埃,有時(shí)候僅僅擁有處理序列的能力還不夠,還需要處理比序列更為復(fù)雜的結(jié)構(gòu)(比如樹結(jié)構(gòu))褐奴,這時(shí)候就需要用到另外一類網(wǎng)絡(luò):遞歸神經(jīng)網(wǎng)絡(luò)(Recursive Neural Network)按脚,巧合的是,它的縮寫也是RNN敦冬。在下一篇文章中辅搬,我們將介紹遞歸神經(jīng)網(wǎng)絡(luò)和它的訓(xùn)練算法。現(xiàn)在脖旱,漫長(zhǎng)的燒腦暫告一段落堪遂,休息一下吧:)

參考資料

  1. CS224d: Deep Learning for Natural Language Processing
  2. Understanding LSTM Networks
  3. LSTM Forward and Backward Pass

相關(guān)文章

零基礎(chǔ)入門深度學(xué)習(xí)(1) - 感知器
零基礎(chǔ)入門深度學(xué)習(xí)(2) - 線性單元和梯度下降
零基礎(chǔ)入門深度學(xué)習(xí)(3) - 神經(jīng)網(wǎng)絡(luò)和反向傳播算法
零基礎(chǔ)入門深度學(xué)習(xí)(4) - 卷積神經(jīng)網(wǎng)絡(luò)
零基礎(chǔ)入門深度學(xué)習(xí)(5) - 循環(huán)神經(jīng)網(wǎng)絡(luò)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市萌庆,隨后出現(xiàn)的幾起案子溶褪,更是在濱河造成了極大的恐慌,老刑警劉巖踊兜,帶你破解...
    沈念sama閱讀 222,464評(píng)論 6 517
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件竿滨,死亡現(xiàn)場(chǎng)離奇詭異佳恬,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)于游,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,033評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門毁葱,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人贰剥,你說(shuō)我怎么就攤上這事倾剿。” “怎么了蚌成?”我有些...
    開封第一講書人閱讀 169,078評(píng)論 0 362
  • 文/不壞的土叔 我叫張陵前痘,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我担忧,道長(zhǎng)芹缔,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 59,979評(píng)論 1 299
  • 正文 為了忘掉前任瓶盛,我火速辦了婚禮最欠,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘惩猫。我一直安慰自己芝硬,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 69,001評(píng)論 6 398
  • 文/花漫 我一把揭開白布轧房。 她就那樣靜靜地躺著拌阴,像睡著了一般。 火紅的嫁衣襯著肌膚如雪奶镶。 梳的紋絲不亂的頭發(fā)上迟赃,一...
    開封第一講書人閱讀 52,584評(píng)論 1 312
  • 那天,我揣著相機(jī)與錄音实辑,去河邊找鬼捺氢。 笑死藻丢,一個(gè)胖子當(dāng)著我的面吹牛剪撬,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播悠反,決...
    沈念sama閱讀 41,085評(píng)論 3 422
  • 文/蒼蘭香墨 我猛地睜開眼残黑,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了斋否?” 一聲冷哼從身側(cè)響起梨水,我...
    開封第一講書人閱讀 40,023評(píng)論 0 277
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎茵臭,沒(méi)想到半個(gè)月后疫诽,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,555評(píng)論 1 319
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,626評(píng)論 3 342
  • 正文 我和宋清朗相戀三年奇徒,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了雏亚。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,769評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡摩钙,死狀恐怖罢低,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情胖笛,我是刑警寧澤网持,帶...
    沈念sama閱讀 36,439評(píng)論 5 351
  • 正文 年R本政府宣布,位于F島的核電站长踊,受9級(jí)特大地震影響功舀,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜身弊,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,115評(píng)論 3 335
  • 文/蒙蒙 一日杈、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧佑刷,春花似錦莉擒、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,601評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至麦萤,卻和暖如春鹿鳖,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背壮莹。 一陣腳步聲響...
    開封第一講書人閱讀 33,702評(píng)論 1 274
  • 我被黑心中介騙來(lái)泰國(guó)打工翅帜, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人命满。 一個(gè)月前我還...
    沈念sama閱讀 49,191評(píng)論 3 378
  • 正文 我出身青樓涝滴,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親胶台。 傳聞我的和親對(duì)象是個(gè)殘疾皇子歼疮,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,781評(píng)論 2 361

推薦閱讀更多精彩內(nèi)容