hands on machine learning with scikit-learn and tensorflow
reading note
CHAPTER 14: RNN
分析任意長度的序列化(sequences)數據
循環(huán)神經元
在傳統(tǒng)神經元的基礎上, 添加了上階段輸出作為本階段的一個輸入
因此, 循環(huán)神經元的輸入包含兩部分, 1)特征輸入; 2)上階段的輸出
循環(huán)神經元數據的計算公式如下(也就是多了一份輸入)
當由一層循環(huán)神經元構成時, 輸出的y為一個向量(一個神經元對應一個輸出)
其中, 黃色的框為一個cell(單元)
對應的計算公式為
注1: RNN輸入X的形式
形象一點的輸入形式
在訓練模型時, 既是對參數的更新:
- 使用一個或多個樣本去更新參數(SGD, MB-SGD)
- 每個樣本包含多個時間的輸入
- 每個時間的輸入為一個向量(Word Enbedding)
注2: NN, LR與樹模型的區(qū)別
- NN和LR模型的結構是確定的, 通過樣本(一個或多個)去更新模型的參數, 來訓練模型
- 樹模型的結構是不確定的, 因此需要全部的樣本來確定結構
輸入與輸出序列
- seq2seq(輸入為序列, 輸出為序列)
- 時間序列的預測, 序列生成
- 機器翻譯(Encoder2Decoder)
- 語音2文字
- seq2vec(輸入為序列, 輸出為向量(僅保留了最后的一個(輸出)狀態(tài)))
- 分類任務(音樂分類, 情感分類)
- 預測用戶下次可能觀看的電影(協(xié)同過濾)
- vec2seq(輸入為向量, 輸出為序列)
- 給圖片添加描述
- 輸入歌手, 創(chuàng)建播放列表
- encoder2decoder(輸入為序列, 輸出為序列)
- seq2seq的一種特殊形式
- 機器翻譯
訓練RNN模型
RNN的訓練技巧(BPTT)
- 按時間展開
- 反向傳播
虛線為正向預測過程, 實線為反向訓練過程(反向傳播), 每個時刻向損失函數的負梯度更新模型, 比如計算了$Y_{(2)}$的梯度, 只更新$Y_{(2)}$, 不會更新$Y_{(1)}, Y_{(0)}$; 另外, $W, b$在每個階段都是一致的(參數共享, 這也是梯度爆炸和梯度消失的原因)
參考: RNN訓練詳解
當RNN用作分類時, 直接輸出最后一個的狀態(tài)向量, 然后連接一個全連接層, 轉換為一個普通的NN模型
RNN對于長序列的訓練困難
問題: 當序列較長時, 會出現
- 訓練時間長(收斂困難)
- 序列越長, 相當于展開的RNN更深, 又因為RNN權值共享, 因此容易造成梯度爆炸/消失問題
- 解決方法: 限制序列長度, 但會丟失長期記憶
- 序列越長, 相當于展開的RNN更深, 又因為RNN權值共享, 因此容易造成梯度爆炸/消失問題
- 長期記憶退化, 僅保留了短期記憶
引出: 如何保存長期記憶? -> LSTM, GRU
[1] 訓練時間慢: 1. 初始化參數方法; 2. 不飽和的激活函數; 3. BN; 4. 梯度修建; 5. 更快的優(yōu)化方法.
LSTM Cell
LSTM對RNN的提升: 收斂更快, 能夠檢測出長期依賴信息
LSTM的關鍵思想: 網絡有能力學習到哪些長期信息應該被丟棄, 哪些應該被記憶
LSTM管理了兩條狀態(tài)向量, 一條為長期記憶, 一條為短期記憶
LSTM包含了四個全連接層(一個輸出, 三條控制), 三個門(遺忘門, 輸入門, 輸出門), 兩條狀態(tài)向量(長期記憶, 短期記憶)
[注1] 這里可以對logistic和tanh兩個激活函數的作用做一個思考:
- logistic的取值范圍[0, 1], 用于gate(控制), 控制輸出的量, 相當于過濾
- tanh的取值范圍[-1, 1], 用于計算, 計算出輸出的值
GRU Cell
LSTM的精簡版本
- 合并兩條狀態(tài)向量為一條狀態(tài)向量
- 合并了遺忘門和輸入門的計算(遺忘與輸入的對立)
- 沒有輸出門, 但是多了一個對狀態(tài)輸入進行過濾的門
[注] 書上的最后一個公式存在錯誤, 付上的為修改后的公式
Word Embeddings
降低維度, 使相同的詞語有相同的表示, 表示具有泛化能力, 有距離的性質
機器翻譯過程(Encoder-Decoder 網絡)
訓練過程:
- 對單詞進行Embedding, 轉換每個單詞為向量
- 訓練時輸入包含兩個部分, 一個是原始的輸入, 一個是翻譯的輸入
- 翻譯的輸入比原始的輸入延后一步
- 原始的輸入逆序(這里并不是絕對的, 逆序輸入是為了讓翻譯有總結的能力, 雙向的網絡則有更多的信息)
- 使用softmax計算概率, 選擇概率最高的詞語(因此這里有一個詞語個數的問題, 詞語過多會造成計算復雜(解決方法, 抽樣))
測試過程: 不再有翻譯的輸入