RNN是兩種神經(jīng)網(wǎng)絡(luò)模型的縮寫距糖,一種是遞歸神經(jīng)網(wǎng)絡(luò)(Recursive Neural Network)豹障,一種是循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)。雖然這兩種神經(jīng)網(wǎng)絡(luò)有著千絲萬縷的聯(lián)系漾峡,但是本文主要討論的是第二種神經(jīng)網(wǎng)絡(luò)模型——循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)陨享。
循環(huán)神經(jīng)網(wǎng)絡(luò)是指一個隨著時間的推移,重復(fù)發(fā)生的結(jié)構(gòu)沮协。在自然語言處理(NLP),語音圖像等多個領(lǐng)域均有非常廣泛的應(yīng)用卓嫂。RNN網(wǎng)絡(luò)和其他網(wǎng)絡(luò)最大的不同就在于RNN能夠?qū)崿F(xiàn)某種“記憶功能”慷暂,是進(jìn)行時間序列分析時最好的選擇。如同人類能夠憑借自己過往的記憶更好地認(rèn)識這個世界一樣晨雳。RNN也實現(xiàn)了類似于人腦的這一機(jī)制行瑞,對所處理過的信息留存有一定的記憶奸腺,而不像其他類型的神經(jīng)網(wǎng)絡(luò)并不能對處理過的信息留存記憶。
RNN原理
循環(huán)神經(jīng)網(wǎng)絡(luò)的原理并不十分復(fù)雜血久,本節(jié)主要從原理上分析RNN的結(jié)構(gòu)和功能突照,不涉及RNN的數(shù)學(xué)推導(dǎo)和證明,整個網(wǎng)絡(luò)只有簡單的輸入輸出和網(wǎng)絡(luò)狀態(tài)參數(shù)氧吐。一個典型的RNN神經(jīng)網(wǎng)絡(luò)如圖所示:
由上圖可以看出:一個典型的RNN網(wǎng)絡(luò)包含一個輸入x讹蘑,一個輸出h和一個神經(jīng)網(wǎng)絡(luò)單元A。和普通的神經(jīng)網(wǎng)絡(luò)不同的是副砍,RNN網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)單元A不僅僅與輸入和輸出存在聯(lián)系衔肢,其與自身也存在一個回路。這種網(wǎng)絡(luò)結(jié)構(gòu)就揭示了RNN的實質(zhì):上一個時刻的網(wǎng)絡(luò)狀態(tài)信息將會作用于下一個時刻的網(wǎng)絡(luò)狀態(tài)。如果上圖的網(wǎng)絡(luò)結(jié)構(gòu)仍不夠清晰刁愿,RNN網(wǎng)絡(luò)還能夠以時間序列展開成如下形式:
等號右邊是RNN的展開形式姥卢。由于RNN一般用來處理序列信息,因此下文說明時都以時間序列來舉例院峡,解釋。等號右邊的等價RNN網(wǎng)絡(luò)中最初始的輸入是x0,輸出是h0优烧,這代表著0時刻RNN網(wǎng)絡(luò)的輸入為x0,輸出為h0链峭,網(wǎng)絡(luò)神經(jīng)元在0時刻的狀態(tài)保存在A中畦娄。當(dāng)下一個時刻1到來時,此時網(wǎng)絡(luò)神經(jīng)元的狀態(tài)不僅僅由1時刻的輸入x1決定弊仪,也由0時刻的神經(jīng)元狀態(tài)決定熙卡。以后的情況都以此類推,直到時間序列的末尾t時刻励饵。
上面的過程可以用一個簡單的例子來論證:假設(shè)現(xiàn)在有一句話“I want to play basketball”驳癌,由于自然語言本身就是一個時間序列,較早的語言會與較后的語言存在某種聯(lián)系役听,例如剛才的句子中“play”這個動詞意味著后面一定會有一個名詞颓鲜,而這個名詞具體是什么可能需要更遙遠(yuǎn)的語境來決定,因此一句話也可以作為RNN的輸入典予√鸨酰回到剛才的那句話,這句話中的5個單詞是以時序出現(xiàn)的瘤袖,我們現(xiàn)在將這五個單詞編碼后依次輸入到RNN中衣摩。首先是單詞“I”,它作為時序上第一個出現(xiàn)的單詞被用作x0輸入孽椰,擁有一個h0輸出昭娩,并且改變了初始神經(jīng)元A的狀態(tài)凛篙。單詞“want”作為時序上第二個出現(xiàn)的單詞作為x1輸入,這時RNN的輸出和神經(jīng)元狀態(tài)將不僅僅由x1決定栏渺,也將由上一時刻的神經(jīng)元狀態(tài)或者說上一時刻的輸入x0決定呛梆。之后的情況以此類推,直到上述句子輸入到最后一個單詞“basketball”磕诊。
接下來我們需要關(guān)注RNN的神經(jīng)元結(jié)構(gòu):
上圖依然是一個RNN神經(jīng)網(wǎng)絡(luò)的時序展開模型填物,中間t時刻的網(wǎng)絡(luò)模型揭示了RNN的結(jié)構(gòu)■眨可以看到滞磺,原始的RNN網(wǎng)絡(luò)的內(nèi)部結(jié)構(gòu)非常簡單。神經(jīng)元A在t時刻的狀態(tài)僅僅是t-1時刻神經(jīng)元狀態(tài)與t時刻網(wǎng)絡(luò)輸入的雙曲正切函數(shù)的值莱褒,這個值不僅僅作為該時刻網(wǎng)絡(luò)的輸出击困,也作為該時刻網(wǎng)絡(luò)的狀態(tài)被傳入到下一個時刻的網(wǎng)絡(luò)狀態(tài)中,這個過程叫做RNN的正向傳播(forward propagation)广凸。注:雙曲正切函數(shù)的解析式如下:
雙曲正切函數(shù)的求導(dǎo)如下:
雙曲正切函數(shù)的圖像如下所示:
這里就帶來一個問題:為什么RNN網(wǎng)絡(luò)的激活函數(shù)要選用雙曲正切而不是sigmod呢阅茶?(RNN的激活函數(shù)除了雙曲正切,RELU函數(shù)也用的非常多)原因在于RNN網(wǎng)絡(luò)在求解時涉及時間序列上的大量求導(dǎo)運算谅海,使用sigmod函數(shù)容易出現(xiàn)梯度消失脸哀,且sigmod的導(dǎo)數(shù)形式較為復(fù)雜。事實上扭吁,即使使用雙曲正切函數(shù)撞蜂,傳統(tǒng)的RNN網(wǎng)絡(luò)依然存在梯度消失問題,無法“記憶”長時間序列上的信息侥袜,這個bug直到LSTM上引入了單元狀態(tài)后才算較好地解決蝌诡。
數(shù)學(xué)基礎(chǔ)
這一節(jié)主要介紹與RNN相關(guān)的數(shù)學(xué)推導(dǎo),由于RNN是一個時序模型系馆,因此其求解過程可能和一般的神經(jīng)網(wǎng)絡(luò)不太相同送漠。首先需要介紹一下RNN完整的結(jié)構(gòu)圖,上一節(jié)給出的RNN結(jié)構(gòu)圖省去了很多內(nèi)部參數(shù)由蘑,僅僅作為一個概念模型給出闽寡。
上圖表明了RNN網(wǎng)絡(luò)的完整拓?fù)浣Y(jié)構(gòu),從圖中我們可以看到RNN網(wǎng)絡(luò)中的參數(shù)情況尼酿。在這里我們只分析t時刻網(wǎng)絡(luò)的行為與數(shù)學(xué)推導(dǎo)爷狈。t時刻網(wǎng)絡(luò)迎來一個輸入xt,網(wǎng)絡(luò)此時刻的神經(jīng)元狀態(tài)st用如下式子表達(dá):
t時刻的網(wǎng)絡(luò)狀態(tài)st不僅僅要輸入到下一個時刻t+1的網(wǎng)絡(luò)狀態(tài)中去裳擎,還要作為該時刻的網(wǎng)絡(luò)輸出涎永。當(dāng)然,st不能直接輸出,在輸出之前還要再乘上一個系數(shù)V羡微,而且為了誤差逆?zhèn)鞑r的方便通常還要對輸出進(jìn)行歸一化處理谷饿,也就是對輸出進(jìn)行softmax化。因此妈倔,t時刻網(wǎng)絡(luò)的輸出ot表達(dá)為如下形式:
為了表達(dá)方便博投,筆者將上述兩個公式做如下變換:
以上,就是RNN網(wǎng)絡(luò)的數(shù)學(xué)表達(dá)了盯蝴,接下來我們需要求解這個模型毅哗。在論述具體解法之前首先需要明確兩個問題:優(yōu)化目標(biāo)函數(shù)是什么?待優(yōu)化的量是什么捧挺?
只有在明確了這兩個問題之后才能對模型進(jìn)行具體的推導(dǎo)和求解虑绵。關(guān)于第一個問題,筆者選取模型的損失函數(shù)作為優(yōu)化目標(biāo)闽烙;關(guān)于第二個問題翅睛,我們從RNN的結(jié)構(gòu)圖中不難發(fā)現(xiàn):只要我們得到了模型的U,V鸣峭,W這三個參數(shù)就能完全確定模型的狀態(tài)宏所。因此該優(yōu)化問題的優(yōu)化變量就是RNN的這三個參數(shù)酥艳。順便說一句摊溶,RNN模型的U,V充石,W三個參數(shù)是全局共享的莫换,也就是說不同時刻的模型參數(shù)是完全一致的,這個特性使RNN得參數(shù)變得稍微少了一些骤铃。
損失函數(shù)
不做過多的討論拉岁,RNN的損失函數(shù)選用交叉熵(Cross Entropy),這是機(jī)器學(xué)習(xí)中使用最廣泛的損失函數(shù)之一了惰爬,其通常的表達(dá)式如下所示:
上面式子是交叉熵的標(biāo)量形式喊暖,y_i是真實的標(biāo)簽值,y_i*是模型給出的預(yù)測值撕瞧,最外面之所以有一個累加符號是因為模型輸出的一般都是一個多維的向量陵叽,只有把n維損失都加和才能得到真實的損失值。交叉熵在應(yīng)用于RNN時需要做一些改變:首先丛版,RNN的輸出是向量形式巩掺,沒有必要將所有維度都加在一起,直接把損失值用向量表達(dá)就可以了页畦;其次胖替,由于RNN模型處理的是序列問題,因此其模型損失不能只是一個時刻的損失,應(yīng)該包含全部N個時刻的損失独令。
故RNN模型在t時刻的損失函數(shù)寫成如下形式:
全部N個時刻的損失函數(shù)(全局損失)表達(dá)為如下形式:
需要說明的是:yt是t時刻輸入的真實標(biāo)簽值端朵,ot為模型的預(yù)測值,N代表全部N個時刻燃箭。下文中為了書寫方便逸月,將Loss簡記為L。在結(jié)束本小節(jié)之前遍膜,最后補(bǔ)充一個softmax函數(shù)的求導(dǎo)公式:
BPTT算法
由于RNN模型與時間序列有關(guān)碗硬,因此不能直接使用BP(back propagation)算法。針對RNN問題的特殊情況瓢颅,提出了BPTT算法恩尾。BPTT的全稱是“隨時間變化的反向傳播算法”(back propagation through time)。這個方法的基礎(chǔ)仍然是常規(guī)的鏈?zhǔn)角髮?dǎo)法則挽懦,接下來開始具體推導(dǎo)翰意。雖然RNN的全局損失是與全部N個時刻有關(guān)的,但為了簡單筆者在推導(dǎo)時只關(guān)注t時刻的損失函數(shù)信柿。
首先求出t時刻下?lián)p失函數(shù)關(guān)于o_t*的微分:
求出損失函數(shù)關(guān)于參數(shù)V的微分:
因此冀偶,全局損失關(guān)于參數(shù)V的微分為:
求出t時刻的損失函數(shù)關(guān)于關(guān)于st*的微分:
求出t時刻的損失函數(shù)關(guān)于s_t-1*的微分:
求出t時刻損失函數(shù)關(guān)于參數(shù)U的偏微分。注意:由于是時間序列模型渔嚷,因此t時刻關(guān)于U的微分與前t-1個時刻都有關(guān)进鸠,在具體計算時可以限定最遠(yuǎn)回溯到前n個時刻,但在推導(dǎo)時需要將前t-1個時刻全部帶入:
因此形病,全局損失關(guān)于U的偏微分為:
求t時刻損失函數(shù)關(guān)于參數(shù)W的偏微分客年,和上面相同的道理,在這里仍然要計算全部前t-1時刻的情況:
因此漠吻,全局損失關(guān)于參數(shù)W的微分結(jié)果為:
至此量瓜,全局損失函數(shù)關(guān)于三個主要參數(shù)的微分都已經(jīng)得到了。整理如下:
接下來進(jìn)一步化簡上述微分表達(dá)式途乃,化簡的主要方向為t時刻的損失函數(shù)關(guān)于ot的微分以及關(guān)于st*的微分绍傲。已知t時刻損失函數(shù)的表達(dá)式,求關(guān)于ot的微分:
softmax函數(shù)求導(dǎo):
因此:
又因為:
且:
有了上面的數(shù)學(xué)推導(dǎo)耍共,我們可以得到全局損失關(guān)于U烫饼,V,W三個參數(shù)的梯度公式:
由于參數(shù)U和W的微分公式不僅僅與t時刻有關(guān)划提,還與前面的t-1個時刻都有關(guān)枫弟,因此無法寫出直接的計算公式。不過上面已經(jīng)給出了t時刻的損失函數(shù)關(guān)于s_t-1的微分遞推公式鹏往,想來求解這個式子也是十分簡單的淡诗,在這里就不贅述了骇塘。
以上就是關(guān)于BPTT算法的全部數(shù)學(xué)推導(dǎo)。從最終結(jié)果可以看出三個公式的偏微分結(jié)果非常簡單韩容,在具體的優(yōu)化過程中可以直接帶入進(jìn)行計算款违。對于這種優(yōu)化問題來說,最常用的方法就是梯度下降法群凶。針對本文涉及的RNN問題插爹,可以構(gòu)造出三個參數(shù)的梯度更新公式:
依靠上述梯度更新公式就能夠迭代求解三個參數(shù),直到三個參數(shù)的值發(fā)生收斂请梢。
后記
這是筆者第一次嘗試推導(dǎo)RNN的數(shù)學(xué)模型赠尾,在推導(dǎo)過程中遇到了非常多的bug。非常感謝互聯(lián)網(wǎng)上的一些公開資料和博客毅弧,給了我非常大的幫助和指引气嫁。接下來筆者將嘗試實現(xiàn)一個單隱層的RNN模型用于實現(xiàn)一個語義預(yù)測模型。