關(guān)于RNN和LSTM的理解

REFER
人人都能看懂的LSTM
RNN模型與NLP應(yīng)用
torch.nn.LSTM

RNN

現(xiàn)在我們的數(shù)據(jù)樣本是這樣的序列的形式:
[x_1,x_2,...,x_t,..., x_n]
序列中的每個(gè)元素都不是獨(dú)立的,而是和其他元素存在著一定聯(lián)系虹蓄,例如一個(gè)句字就是這種形式,句子中的每個(gè)詞和其他詞是存在上下文關(guān)系的。現(xiàn)在我們要對(duì)這個(gè)句子建模,捕捉它整體的含義艾栋,那我們的模型就必須要考慮這種數(shù)據(jù)結(jié)構(gòu)中元素和元素間的關(guān)聯(lián)切诀。為了處理這樣的序列數(shù)據(jù),RNN應(yīng)運(yùn)而生缨睡,它的一般形式如下:
y_t, h_{t} = f(x_t, h_{t-1})
輸入有兩個(gè),x和隱藏狀態(tài)h陈辱;輸出有兩個(gè)奖年,y和下一個(gè)時(shí)間步的隱藏狀態(tài)h。RNN并不是一次性把整個(gè)序列輸入沛贪,而是每次按時(shí)間步t把序列中元素依次輸入陋守,這樣的做法可以讓模型處理每個(gè)樣本序列長度不一的情況。RNN的關(guān)鍵就是這個(gè)隱藏狀態(tài)h利赋,每個(gè)時(shí)間步結(jié)合上一步的h和當(dāng)前的x更新h水评,當(dāng)前x的信息和前文的信息被整合進(jìn)h,然后傳入下一時(shí)間步媚送,實(shí)現(xiàn)了對(duì)前文的記憶中燥。

通常來說,RNN可以用下面這樣的形式實(shí)現(xiàn)塘偎,將來自上個(gè)時(shí)間步的h和x做拼接然后用全連接層映射疗涉。注意h到y(tǒng)的映射不是必須的,有的實(shí)現(xiàn)中直接用h作為y吟秩,這一點(diǎn)很好理解咱扣,你確實(shí)需要的話,對(duì)h用一層全連接映射到y(tǒng)就是了涵防,這種額外的操作是容易實(shí)現(xiàn)的偏窝。

h_t = \sigma(W_a [h_{t-1}, x_t] + b) \\ y_t = \sigma(W_b h_t + b)

盡管RNN被設(shè)計(jì)用來記憶信息,但它的記憶力實(shí)在是不怎么樣武学,觀察RNN的輸出祭往,它經(jīng)過激活函數(shù)映射到(0,1),在多步運(yùn)算后數(shù)值衰減很快火窒,也就是學(xué)習(xí)的信息在幾步運(yùn)算后就趨于零了硼补。那我們不用激活函數(shù)呢?且不提激活函數(shù)引入非線性性的事熏矿,激活函數(shù)可以把輸出映射到一定的區(qū)間范圍內(nèi)已骇,假設(shè)撇掉x,計(jì)算h時(shí)沒有激活函數(shù)票编,h在多步運(yùn)算后像下面這樣褪储,若w的特征值小于1,h100幾乎全為0慧域,w特征值大于1鲤竹,h100數(shù)值爆炸。
h_{100} = Wh_{99} = W^{2}h_{98} ... = W^{100}h_0
從反向傳播的角度來看RNN也是不行昔榴,同一層神經(jīng)元在求導(dǎo)時(shí)共享權(quán)重矩陣辛藻,多個(gè)時(shí)間步的計(jì)算后碘橘,梯度從后向前一通鏈?zhǔn)椒▌t連乘過來,求導(dǎo)時(shí)雅可比矩陣特征值小于1吱肌,梯度消失痘拆,特征值大于1,梯度爆炸氮墨。梯度爆炸還可以裁剪梯度纺蛆,梯度消失就很頭疼了。

綜上所述规揪,RNN無法很好地處理長程依賴桥氏。

LSTM

為了解決RNN在長序列訓(xùn)練過程中的梯度消失和梯度爆炸問題,LSTM誕生了粒褒。LSTM仍然是一種RNN,不過在計(jì)算的時(shí)候增加了一些額外的操作诚镰。RNN只有一個(gè)傳遞狀態(tài)h奕坟,LSTM有兩個(gè)傳遞狀態(tài)h和c,RNN中的h對(duì)應(yīng)LSTM中的c清笨。LSTM中c改變慢月杉,通常是上一個(gè)狀態(tài)傳遞過來加上一些數(shù)值;h在不同節(jié)點(diǎn)下改變明顯抠艾。
y_t,(h_t,c_t) = f(x_t, (h_{t-1},c_{t-1}))
來看LSTM具體的計(jì)算苛萎,同樣先將上一個(gè)時(shí)間步的h和當(dāng)前輸入x拼接,但用全連接層計(jì)算出四個(gè)狀態(tài):
\begin{cases} z = tanh(W_z [x_t,h_{t-1}] + b). \quad (-1,1) \\ z_i = sigmoid(W_i [x_t, h_{t-1}] + b) \quad (0,1) \\ z_f = sigmoid(W_f [x_t, h_{t-1}] + b) \quad (0,1) \\ z_o = sigmoid(W_o [x_t, h_{t-1}] + b) \quad (0,1) \\ \end{cases}
后面三個(gè)介于0和1之間的狀態(tài)作為一種門控狀態(tài)检号。它們的作用類似于權(quán)重系數(shù)腌歉,控制信息通過或屏蔽的程度。這些門控狀態(tài)全是x和h拼接到一起計(jì)算得到的齐苛,不是單獨(dú)看x或者h(yuǎn)翘盖,因?yàn)槟闶歉鶕?jù)當(dāng)前的輸入和前面的歷史信息一起來決定哪些重要哪些不重要的。然后凹蜂,LSTM的計(jì)算分為幾個(gè)階段:

  • 忘記階段:用zf控制上一個(gè)傳遞狀態(tài)c哪些需要留下哪些需要忘記馍驯。

  • 選擇記憶階段:用zi控制選擇,對(duì)當(dāng)前輸入x選擇記憶玛痊,哪些重要哪些不重要汰瘫。在此之前還有一件事,先把x處理成z擂煞,為啥要這么做混弥?因?yàn)槟愕冒褁弄成和zi一樣的形狀然后再選擇。

    此時(shí)对省,傳遞狀態(tài)和輸入都被處理了剑逃,你可以來更新傳遞狀態(tài)了浙宜。

  • 輸出階段:用zo控制,哪些被當(dāng)前狀態(tài)輸出蛹磺,以及用tanh縮放粟瞬。

用黑話翻譯一下上面的過程(⊙是元素級(jí)乘法):
\begin{cases} c_t = z_f \odot c_{t-1} + z_i \odot z \\ h_t = z_o \odot tanh(c_t) \\ y_t = sigmoid(W h_t + b) \end{cases}
同樣的,h到y(tǒng)的映射不是必須的萤捆,有的實(shí)現(xiàn)中直接用h作為y裙品,Pytorch中就是如此。Pytorch中LSTM的具體計(jì)算實(shí)現(xiàn)如下:


這里x和h分別和對(duì)應(yīng)的w相乘在相加俗或,實(shí)際上和上面寫的x和h拼接后再乘w的形式是一樣的市怎,但注意Pytorch每個(gè)狀態(tài)的計(jì)算用了兩個(gè)偏置項(xiàng)。所以如下形式LSTM的參數(shù)量為:
5376 = 第一層LSTM參數(shù)量3200 + 第二層LSTM參數(shù)量2176 \\ = (hidden\_size*input\_size + bias\_size + hidden\_size*hidden\_size+ bias\_size)*4 + ... \\ = (16*32+16+16*16+16)*4 + (16*16+16+16*16+16)*4

lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=False)

Pytorch中LSTM輸入是(input, (h_0, c_0)) 的形式辛慰,(h_0, c_0)默認(rèn)設(shè)置為零向量区匠。輸出是(output, (h_n, c_n))的形式,當(dāng)有多層LSTM時(shí)帅腌,這里output是最后一層LSTM每個(gè)時(shí)間步的h的集合[h_1,h_2,...,h_n]驰弄,(h_n, c_n)是每層LSTM最后一個(gè)時(shí)間步的h和c。通過下面這段簡單的代碼可以理解pytorch中LSTM的輸入輸出和參數(shù)量:

import torch
import torch.nn as nn
x = torch.randn((3, 5, 32))  # [n_seq, seq_length, n_feature]
lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=False)
(o, (h, c)) = lstm(x, )
print(sum(p.numel() for p in lstm.parameters()))
print(o.size())  # [n_seq, seq_length, hidden_size*n_direction]
print(h.size())  # [n_layer*n_direction, n_seq, hidden_size]
print(c.size())  # [n_layer*n_direction, n_seq, hidden_size]
5376
torch.Size([3, 5, 16])
torch.Size([2, 3, 16])
torch.Size([2, 3, 16])

雙向RNN

我們可以分別構(gòu)建兩個(gè)RNN分別從左向右速客,和從右向左戚篙,各自輸出自己的狀態(tài)向量,然后拼接起來:



雙向RNN比單向RNN表現(xiàn)好的原因:

  • 減輕對(duì)前面記憶的遺忘:從左向右的RNN輸出的h可能遺忘掉左端的信息溺职,從右向左的RNN輸出的h可能遺忘掉右端的信息岔擂,把兩者結(jié)合到一起就能不足對(duì)方遺忘的信息。

  • 補(bǔ)足后文的信息:對(duì)序列中某個(gè)元素的理解可能不僅僅依靠前面的信息浪耘,也需要借助后文乱灵,所以單方向的移動(dòng)可能是不夠的。

在Pytorch中實(shí)現(xiàn)雙向的LSTM只需要設(shè)置參數(shù)bidirectional=True即可:

import torch
import torch.nn as nn
x = torch.randn((3, 5, 32))  # [n_seq, seq_length, n_feature]
lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)
(o, (h, c)) = lstm(x, )
print(sum(p.numel() for p in lstm.parameters()))
print(o.size())  # [n_seq, seq_length, hidden_size*n_direction]
print(h.size())  # [n_layer*n_direction, n_seq, hidden_size]
print(c.size())  # [n_layer*n_direction, n_seq, hidden_size]
12800
torch.Size([3, 5, 32])
torch.Size([4, 3, 16])
torch.Size([4, 3, 16])

12800 = 第一層LSTM參數(shù)量6400 + 第二層LSTM參數(shù)量6400 \\ = (16*32+16+16*16+16)*4*2 + (16*32+16+16*16+16)*4*2

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末七冲,一起剝皮案震驚了整個(gè)濱河市阔蛉,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌癞埠,老刑警劉巖状原,帶你破解...
    沈念sama閱讀 211,042評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異苗踪,居然都是意外死亡颠区,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,996評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門通铲,熙熙樓的掌柜王于貴愁眉苦臉地迎上來毕莱,“玉大人,你說我怎么就攤上這事∨蠼兀” “怎么了蛹稍?”我有些...
    開封第一講書人閱讀 156,674評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵,是天一觀的道長部服。 經(jīng)常有香客問我唆姐,道長,這世上最難降的妖魔是什么廓八? 我笑而不...
    開封第一講書人閱讀 56,340評(píng)論 1 283
  • 正文 為了忘掉前任奉芦,我火速辦了婚禮,結(jié)果婚禮上剧蹂,老公的妹妹穿的比我還像新娘声功。我一直安慰自己,他們只是感情好宠叼,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,404評(píng)論 5 384
  • 文/花漫 我一把揭開白布先巴。 她就那樣靜靜地躺著,像睡著了一般冒冬。 火紅的嫁衣襯著肌膚如雪伸蚯。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,749評(píng)論 1 289
  • 那天窄驹,我揣著相機(jī)與錄音朝卒,去河邊找鬼证逻。 笑死乐埠,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的囚企。 我是一名探鬼主播丈咐,決...
    沈念sama閱讀 38,902評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢(mèng)啊……” “哼龙宏!你這毒婦竟也來了棵逊?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,662評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤银酗,失蹤者是張志新(化名)和其女友劉穎辆影,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體黍特,經(jīng)...
    沈念sama閱讀 44,110評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡蛙讥,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,451評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了灭衷。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片次慢。...
    茶點(diǎn)故事閱讀 38,577評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出迫像,到底是詐尸還是另有隱情劈愚,我是刑警寧澤,帶...
    沈念sama閱讀 34,258評(píng)論 4 328
  • 正文 年R本政府宣布闻妓,位于F島的核電站菌羽,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏纷闺。R本人自食惡果不足惜算凿,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,848評(píng)論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望犁功。 院中可真熱鬧氓轰,春花似錦、人聲如沸浸卦。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,726評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽限嫌。三九已至靴庆,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間怒医,已是汗流浹背炉抒。 一陣腳步聲響...
    開封第一講書人閱讀 31,952評(píng)論 1 264
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留稚叹,地道東北人焰薄。 一個(gè)月前我還...
    沈念sama閱讀 46,271評(píng)論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像扒袖,于是被迫代替她去往敵國和親塞茅。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,452評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容

  • 先來一篇簡單易懂的激活函數(shù)理解解解乏 RNN RNN(Recurrent Neural Networks,循環(huán)神經(jīng)...
    mugtmag閱讀 25,707評(píng)論 2 14
  • 1.1 認(rèn)識(shí)RNN模型 學(xué)習(xí)目標(biāo) 了解什么是RNN模型. 了解RNN模型的作用. 了解RNN模型的分類. 什么是R...
    遲耿耿閱讀 1,296評(píng)論 0 0
  • 1.RNN解決了什么問題? RNN主要用來解決序列問題飒泻,強(qiáng)調(diào)的是先后順序鞭光,在NLP中引申出上下文的概念,一個(gè)翻譯問...
    sudop閱讀 11,205評(píng)論 0 8
  • 總說 RNN( Recurrent Neural Network 循環(huán)(遞歸)神經(jīng)網(wǎng)絡(luò)) 跟人的大腦記憶差不多。我...
    城市中迷途小書童閱讀 1,058評(píng)論 0 9
  • 16宿命:用概率思維提高你的勝算 以前的我是風(fēng)險(xiǎn)厭惡者刹孔,不喜歡去冒險(xiǎn)啡省,但是人生放棄了冒險(xiǎn)娜睛,也就放棄了無數(shù)的可能。 ...
    yichen大刀閱讀 6,038評(píng)論 0 4