RNN
現(xiàn)在我們的數(shù)據(jù)樣本是這樣的序列的形式:
序列中的每個(gè)元素都不是獨(dú)立的,而是和其他元素存在著一定聯(lián)系虹蓄,例如一個(gè)句字就是這種形式,句子中的每個(gè)詞和其他詞是存在上下文關(guān)系的。現(xiàn)在我們要對(duì)這個(gè)句子建模,捕捉它整體的含義艾栋,那我們的模型就必須要考慮這種數(shù)據(jù)結(jié)構(gòu)中元素和元素間的關(guān)聯(lián)切诀。為了處理這樣的序列數(shù)據(jù),RNN應(yīng)運(yùn)而生缨睡,它的一般形式如下:
輸入有兩個(gè),x和隱藏狀態(tài)h陈辱;輸出有兩個(gè)奖年,y和下一個(gè)時(shí)間步的隱藏狀態(tài)h。RNN并不是一次性把整個(gè)序列輸入沛贪,而是每次按時(shí)間步t把序列中元素依次輸入陋守,這樣的做法可以讓模型處理每個(gè)樣本序列長度不一的情況。RNN的關(guān)鍵就是這個(gè)隱藏狀態(tài)h利赋,每個(gè)時(shí)間步結(jié)合上一步的h和當(dāng)前的x更新h水评,當(dāng)前x的信息和前文的信息被整合進(jìn)h,然后傳入下一時(shí)間步媚送,實(shí)現(xiàn)了對(duì)前文的記憶中燥。
通常來說,RNN可以用下面這樣的形式實(shí)現(xiàn)塘偎,將來自上個(gè)時(shí)間步的h和x做拼接然后用全連接層映射疗涉。注意h到y(tǒng)的映射不是必須的,有的實(shí)現(xiàn)中直接用h作為y吟秩,這一點(diǎn)很好理解咱扣,你確實(shí)需要的話,對(duì)h用一層全連接映射到y(tǒng)就是了涵防,這種額外的操作是容易實(shí)現(xiàn)的偏窝。
盡管RNN被設(shè)計(jì)用來記憶信息,但它的記憶力實(shí)在是不怎么樣武学,觀察RNN的輸出祭往,它經(jīng)過激活函數(shù)映射到(0,1),在多步運(yùn)算后數(shù)值衰減很快火窒,也就是學(xué)習(xí)的信息在幾步運(yùn)算后就趨于零了硼补。那我們不用激活函數(shù)呢?且不提激活函數(shù)引入非線性性的事熏矿,激活函數(shù)可以把輸出映射到一定的區(qū)間范圍內(nèi)已骇,假設(shè)撇掉x,計(jì)算h時(shí)沒有激活函數(shù)票编,h在多步運(yùn)算后像下面這樣褪储,若w的特征值小于1,h100幾乎全為0慧域,w特征值大于1鲤竹,h100數(shù)值爆炸。
從反向傳播的角度來看RNN也是不行昔榴,同一層神經(jīng)元在求導(dǎo)時(shí)共享權(quán)重矩陣辛藻,多個(gè)時(shí)間步的計(jì)算后碘橘,梯度從后向前一通鏈?zhǔn)椒▌t連乘過來,求導(dǎo)時(shí)雅可比矩陣特征值小于1吱肌,梯度消失痘拆,特征值大于1,梯度爆炸氮墨。梯度爆炸還可以裁剪梯度纺蛆,梯度消失就很頭疼了。
綜上所述规揪,RNN無法很好地處理長程依賴桥氏。
LSTM
為了解決RNN在長序列訓(xùn)練過程中的梯度消失和梯度爆炸問題,LSTM誕生了粒褒。LSTM仍然是一種RNN,不過在計(jì)算的時(shí)候增加了一些額外的操作诚镰。RNN只有一個(gè)傳遞狀態(tài)h奕坟,LSTM有兩個(gè)傳遞狀態(tài)h和c,RNN中的h對(duì)應(yīng)LSTM中的c清笨。LSTM中c改變慢月杉,通常是上一個(gè)狀態(tài)傳遞過來加上一些數(shù)值;h在不同節(jié)點(diǎn)下改變明顯抠艾。
來看LSTM具體的計(jì)算苛萎,同樣先將上一個(gè)時(shí)間步的h和當(dāng)前輸入x拼接,但用全連接層計(jì)算出四個(gè)狀態(tài):
后面三個(gè)介于0和1之間的狀態(tài)作為一種門控狀態(tài)检号。它們的作用類似于權(quán)重系數(shù)腌歉,控制信息通過或屏蔽的程度。這些門控狀態(tài)全是x和h拼接到一起計(jì)算得到的齐苛,不是單獨(dú)看x或者h(yuǎn)翘盖,因?yàn)槟闶歉鶕?jù)當(dāng)前的輸入和前面的歷史信息一起來決定哪些重要哪些不重要的。然后凹蜂,LSTM的計(jì)算分為幾個(gè)階段:
忘記階段:用zf控制上一個(gè)傳遞狀態(tài)c哪些需要留下哪些需要忘記馍驯。
-
選擇記憶階段:用zi控制選擇,對(duì)當(dāng)前輸入x選擇記憶玛痊,哪些重要哪些不重要汰瘫。在此之前還有一件事,先把x處理成z擂煞,為啥要這么做混弥?因?yàn)槟愕冒褁弄成和zi一樣的形狀然后再選擇。
此時(shí)对省,傳遞狀態(tài)和輸入都被處理了剑逃,你可以來更新傳遞狀態(tài)了浙宜。
輸出階段:用zo控制,哪些被當(dāng)前狀態(tài)輸出蛹磺,以及用tanh縮放粟瞬。
用黑話翻譯一下上面的過程(⊙是元素級(jí)乘法):
同樣的,h到y(tǒng)的映射不是必須的萤捆,有的實(shí)現(xiàn)中直接用h作為y裙品,Pytorch中就是如此。Pytorch中LSTM的具體計(jì)算實(shí)現(xiàn)如下:
這里x和h分別和對(duì)應(yīng)的w相乘在相加俗或,實(shí)際上和上面寫的x和h拼接后再乘w的形式是一樣的市怎,但注意Pytorch每個(gè)狀態(tài)的計(jì)算用了兩個(gè)偏置項(xiàng)。所以如下形式LSTM的參數(shù)量為:
lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=False)
Pytorch中LSTM輸入是(input, (h_0, c_0))
的形式辛慰,(h_0, c_0)
默認(rèn)設(shè)置為零向量区匠。輸出是(output, (h_n, c_n))
的形式,當(dāng)有多層LSTM時(shí)帅腌,這里output
是最后一層LSTM每個(gè)時(shí)間步的h的集合[h_1,h_2,...,h_n]
驰弄,(h_n, c_n)
是每層LSTM最后一個(gè)時(shí)間步的h和c。通過下面這段簡單的代碼可以理解pytorch中LSTM的輸入輸出和參數(shù)量:
import torch
import torch.nn as nn
x = torch.randn((3, 5, 32)) # [n_seq, seq_length, n_feature]
lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=False)
(o, (h, c)) = lstm(x, )
print(sum(p.numel() for p in lstm.parameters()))
print(o.size()) # [n_seq, seq_length, hidden_size*n_direction]
print(h.size()) # [n_layer*n_direction, n_seq, hidden_size]
print(c.size()) # [n_layer*n_direction, n_seq, hidden_size]
5376
torch.Size([3, 5, 16])
torch.Size([2, 3, 16])
torch.Size([2, 3, 16])
雙向RNN
我們可以分別構(gòu)建兩個(gè)RNN分別從左向右速客,和從右向左戚篙,各自輸出自己的狀態(tài)向量,然后拼接起來:
雙向RNN比單向RNN表現(xiàn)好的原因:
減輕對(duì)前面記憶的遺忘:從左向右的RNN輸出的h可能遺忘掉左端的信息溺职,從右向左的RNN輸出的h可能遺忘掉右端的信息岔擂,把兩者結(jié)合到一起就能不足對(duì)方遺忘的信息。
補(bǔ)足后文的信息:對(duì)序列中某個(gè)元素的理解可能不僅僅依靠前面的信息浪耘,也需要借助后文乱灵,所以單方向的移動(dòng)可能是不夠的。
在Pytorch中實(shí)現(xiàn)雙向的LSTM只需要設(shè)置參數(shù)bidirectional=True
即可:
import torch
import torch.nn as nn
x = torch.randn((3, 5, 32)) # [n_seq, seq_length, n_feature]
lstm = nn.LSTM(input_size=32, hidden_size=16, num_layers=2, batch_first=True, bidirectional=True)
(o, (h, c)) = lstm(x, )
print(sum(p.numel() for p in lstm.parameters()))
print(o.size()) # [n_seq, seq_length, hidden_size*n_direction]
print(h.size()) # [n_layer*n_direction, n_seq, hidden_size]
print(c.size()) # [n_layer*n_direction, n_seq, hidden_size]
12800
torch.Size([3, 5, 32])
torch.Size([4, 3, 16])
torch.Size([4, 3, 16])