Keras中的stateful LSTM可以說是所有學(xué)習(xí)者的夢魘丐枉,令人混淆的機(jī)制,說明不到位的文檔掘托,中文資料的匱乏瘦锹。
通過此文,旨在幫助有困惑的人理解statefulness這一狀態(tài)闪盔。
警告: 永遠(yuǎn)不要在不熟悉stateful LSTM的情況下使用它
參考目錄:
- 官方文檔
- Stateful LSTM in Keras (必讀圣經(jīng))
- 案例靈感來自此GitHub
- Stateful and Stateless LSTM for Time Series Forecasting with Python (這篇可以看完本文再看)
官方文檔簡介
stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
使 RNN 具有狀態(tài)意味著每批樣品的狀態(tài)將被重新用作下一批樣品的初始狀態(tài)弯院。
注意,此處的狀態(tài)表示的是原論文公式里的c泪掀,h听绳,即LSTM特有的一些記憶參數(shù),并非w權(quán)重异赫。
當(dāng)使用有狀態(tài) RNN 時椅挣,假定:
- 所有的批次都有相同數(shù)量的樣本
- 如果
x1
和x2
是連續(xù)批次的樣本,則x2[i]
是x1[i]
的后續(xù)序列祝辣,對于每個i
贴妻。
要在 RNN 中使用狀態(tài)切油,你需要:
- 通過將
batch_size
參數(shù)傳遞給模型的第一層來顯式指定你正在使用的批大小蝙斜。例如,對于 10 個時間步長的 32 樣本的batch
澎胡,每個時間步長具有 16 個特征孕荠,batch_size = 32
娩鹉。 - 在 RNN 層中設(shè)置
stateful = True
。 - 在調(diào)用
fit()
時指定shuffle = False
稚伍。
重置累積狀態(tài):
- 使用
model.reset_states()
來重置模型中所有層的狀態(tài) - 使用
layer.reset_states()
來重置指定有狀態(tài) RNN 層的狀態(tài)
疑問解答:
將一個很長的序列(例如時間序列)分成小序列來構(gòu)建我的輸入矩陣弯予。那LSTM網(wǎng)絡(luò)會發(fā)現(xiàn)我這些小序列之間的關(guān)聯(lián)依賴嗎?
不會个曙,除非你使用 stateful LSTM 锈嫩。大多數(shù)問題使用stateless LSTM即可解決,所以如果你想使用stateful LSTM垦搬,請確保自己是真的需要它呼寸。在stateless時,長期記憶網(wǎng)絡(luò)并不意味著你的LSTM將記住之前batch
的內(nèi)容猴贰。在Keras中stateless LSTM中的stateless指的是?
注意对雪,此文所說的stateful是指的在Keras中特有的,是batch之間的記憶cell狀態(tài)傳遞米绕。而非說的是LSTM論文模型中表示那些記憶門瑟捣,遺忘門,c栅干,h
等等在同一sequence
中不同timesteps
時間步之間的狀態(tài)傳遞迈套。
假定我們的輸入X
是一個三維矩陣,shape = (nb_samples, timesteps, input_dim)
碱鳞,每一個row
代表一個sample
交汤,每個sample
都是一個sequence
小序列。X[i]
表示輸入矩陣中第i
個sample
劫笙。步長啥的我們先不用管芙扎。
當(dāng)我們在默認(rèn)狀態(tài)stateless
下,Keras會在訓(xùn)練每個sequence小序列(=sample)開始時填大,將LSTM網(wǎng)絡(luò)中的記憶狀態(tài)參數(shù)reset初始化(指的是c戒洼,h
而并非權(quán)重w
),即調(diào)用model.reset_states()
允华。為啥stateless LSTM每次訓(xùn)練都要初始化記憶參數(shù)?
因為Keras在訓(xùn)練時會默認(rèn)地shuffle samples
圈浇,所以導(dǎo)致sequence
之間的依賴性消失,sample
和sample
之間就沒有時序關(guān)系靴寂,順序被打亂磷蜀,這時記憶參數(shù)在batch
、小序列之間進(jìn)行傳遞就沒意義了百炬,所以Keras要把記憶參數(shù)初始化褐隆。那stateful LSTM到底怎么傳遞記憶參數(shù)?
首先要明確一點剖踊,LSTM作為有記憶的網(wǎng)絡(luò)庶弃,它的有記憶指的是在一個sequence中衫贬,記憶在不同的timesteps中傳播。舉個例子歇攻,就是你有一篇文章X固惯,分解,然后把每個句子作為一個sample訓(xùn)練對象(sequence)缴守,X[i]就代表一句話葬毫,而一句話里的每個word各自代表一個timestep時間步,LSTM的有記憶即指的是在一句話里屡穗,X[i][0]
第一個單詞(時間步)的信息可以被記憶供常,傳遞到第5個單詞(時間步)X[i][5]
中。
而我們突然覺得鸡捐,這還遠(yuǎn)遠(yuǎn)不夠栈暇,因為句子和句子之間沒有任何的記憶啊,假設(shè)文章一共1000句話箍镜,我們想預(yù)測出第1001句是什么源祈,不想丟棄前1000句里的一些時序性特征(stateless時這1000句訓(xùn)練時會被打亂,時序性特征丟失)色迂。那么,stateful LSTM就可以做到歇僧。
在stateful = True
時诈悍,我們要在fit中手動使得shuffle = False
侥钳。隨后,在X[i]
(表示輸入矩陣中第i
個sample
)這個小序列訓(xùn)練完之后苦酱,Keras會將將訓(xùn)練完的記憶參數(shù)傳遞給X[i+bs]
(表示第i+bs個sample),作為其初始的記憶參數(shù)疫萤。bs = batch_size
敢伸。這樣一來,我們的記憶參數(shù)就能順利地在sample
和sample
之間傳遞,X[i+n*bs]
也能知道X[i]
的信息饶辙。
用圖片可以更好地展示,如下圖弃揽,藍(lán)色箭頭就代表了記憶參數(shù)的傳遞矿微,如果
stateful = False
涌矢,則沒有這些藍(lán)色箭頭。
stateful LSTM中為何一定要提供batch_size參數(shù)娜庇?
我們可以發(fā)現(xiàn),記憶參數(shù)(state)是在每個batch
對應(yīng)的位置跳躍著傳播的励负,所以batch_size
參數(shù)至關(guān)重要继榆,在stateful lstm層中必須提供略吨。那stateful時考阱,對權(quán)重參數(shù)w有影響嗎羔砾?
我們上面所說的一切記憶參數(shù)都是LSTM模型的特有記憶c姜凄,h
參數(shù),和權(quán)重參數(shù)w沒有任何關(guān)系董虱。無論是stateful還是stateless愤诱,都是在模型接受一個batch
后淫半,計算每個sequence的輸出,然后平均它們的梯度昏滴,反向傳播更新所有的各種參數(shù)谣殊。
總結(jié)
如果你還是不理解牺弄,沒關(guān)系势告,簡單的說:
- stateful LSTM:能讓模型學(xué)習(xí)到你輸入的samples之間的時序特征培慌,適合一些長序列的預(yù)測,哪個sample在前盒音,那個sample在后對模型是有影響的祥诽。
- stateless LSTM:輸入samples后瓮恭,默認(rèn)就會shuffle屯蹦,可以說是每個sample獨(dú)立登澜,之間無前后關(guān)系,適合輸入一些沒有關(guān)系的樣本购撼。
如果你還是不理解迂求,沒關(guān)系……舉個例子:
stateful LSTM:我想根據(jù)一篇1000句的文章預(yù)測第1001句,每一句是一個sample毫玖。我會選用stateful涩盾,因為這文章里的1000句是有前后關(guān)聯(lián)的,是有時序的特征的砸西,我不想丟棄這個特征芹枷。利用這個時序性能讓第一句的特征傳遞到我們預(yù)測的第1001句鸳慈。(
batch_size = 10
時)stateless LSTM:我想訓(xùn)練LSTM自動寫詩句走芋,我想訓(xùn)練1000首詩翁逞,每一首是一個sample溉仑,我會選用stateless LSTM浊竟,因為這1000首詩是獨(dú)立的振定,不存在關(guān)聯(lián),哪怕打亂它們的順序棚赔,對于模型訓(xùn)練來說也沒區(qū)別丧肴。
實戰(zhàn)
如果感興趣芋浮,可以看看官方的example——lstm_stateful.py纸巷,個人不推薦眶痰,用例繁瑣存哲,還沒畫圖七婴,講的不清楚修肠。
本實戰(zhàn)代碼地址:GitHub
具體代碼里面可以自己看户盯,我就不多說細(xì)節(jié)了嵌施,這里主要來展示下結(jié)果。
目標(biāo):
很簡單先舷,就是用LSTM去預(yù)測一個cos曲線艰管。-
訓(xùn)練集:
訓(xùn)練集如下圖:
生產(chǎn)訓(xùn)練集數(shù)據(jù):
類似滑動窗口,假設(shè)我們有1000組數(shù)據(jù)蒋川,若滑動窗口大小為20牲芋,則第i
組數(shù)據(jù)trainX =Y[i:i+20]
, trainY =Y[i+20]
,一共980組訓(xùn)練數(shù)據(jù)。-
普通多層神經(jīng)網(wǎng)絡(luò)預(yù)測結(jié)果:
-
stateless LSTM預(yù)測結(jié)果:
-
單層Stateful LSTM預(yù)測結(jié)果:
-
雙層stacked Stateful LSTM預(yù)測結(jié)果:
- 注意:訓(xùn)練存在不穩(wěn)定性捺球,若預(yù)測結(jié)果偏差過大缸浦,請重新訓(xùn)練氮兵。另外弥姻,不要迷信GPU薪缆,LSTM用CPU訓(xùn)練效率可能更高。