原文鏈接:http://www.cnblogs.com/sandy-t/p/6930608.html
循環(huán)神經(jīng)網(wǎng)絡(luò)RNN相比傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)在處理序列化數(shù)據(jù)時更有優(yōu)勢,因為RNN能夠?qū)⒓尤肷希ㄏ拢┪男畔⑦M行考慮狮腿。一個簡單的RNN如下圖所示:
將這個循環(huán)展開得到下圖:
上一時刻的狀態(tài)會傳遞到下一時刻钮糖。這種鏈?zhǔn)教匦詻Q定了RNN能夠很好的處理序列化的數(shù)據(jù)滚局,RNN 在語音識別坏为,語言建模,翻譯吹艇,圖片描述等問題上已經(jīng)取得了很到的結(jié)果麦牺。
根據(jù)輸入钮蛛、輸出的不同和是否有延遲等一些情況,RNN在應(yīng)用中有如下一些形態(tài):
RNN存在的問題
RNN能夠把狀態(tài)傳遞到下一時刻剖膳,好像對一部分信息有記憶能力一樣魏颓,如下圖:
h3h3的值可能會由x1x1,x2x2的值來決定。
但是吱晒,對于一些復(fù)雜場景
由于距離太遠(yuǎn)甸饱,中間間隔了太多狀態(tài),x1x1,x2x2對ht+1ht+1的值幾乎起不到任何作用仑濒。(梯度消失和梯度爆炸)
LSTM(Long Short Term Memory)
由于RNN不能很好地處理這種問題叹话,于是出現(xiàn)了LSTM(Long Short Term Memory)一種加強版的RNN(LSTM可以改善梯度消失問題)。簡單來說就是原始RNN沒有長期的記憶能力墩瞳,于是就給RNN加上了一些記憶控制器驼壶,實現(xiàn)對某些信息能夠較長期的記憶,而對某些信息只有短期記憶能力喉酌。
如上圖所示热凹,LSTM中存在Forget Gate,Input Gate,Output Gate來控制信息的流動程度泵喘。
RNN:
LSTN:
加號圓圈表示線性相加,乘號圓圈表示用gate來過濾信息般妙。
Understanding LSTM中對LSTM有非常詳細(xì)的介紹纪铺。(對應(yīng)的中文翻譯)
LSTM MNIST手寫數(shù)字辨識
實際上,圖片文字識別這類任務(wù)用CNN來做效果更好碟渺,但是這里想要強行用LSTM來做一波鲜锚。
MNIST_data中每一個image的大小是28*28,以行順序作為序列輸入苫拍,即第一行的28個像素作為$x_{0}
烹棉,第二行為,第二行為x_1怯疤,...,第28行的28個像素作為催束,...集峦,第28行的28個像素作為x_28$輸入,一個網(wǎng)絡(luò)結(jié)構(gòu)總共的輸入是28個維度為28的向量抠刺,輸出值是10維的向量塔淤,表示的是0-9個數(shù)字的概率值。這是一個many to one的RNN結(jié)構(gòu)速妖。
下面直接上代碼:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 參數(shù)設(shè)置
BATCH_SIZE = 100? ? ? ? # BATCH的大小高蜂,相當(dāng)于一次處理50個image
TIME_STEP = 28? ? ? ? ? # 一個LSTM中,輸入序列的長度罕容,image有28行
INPUT_SIZE = 28? ? ? ? # x_i 的向量長度备恤,image有28列
LR = 0.01? ? ? ? ? ? ? # 學(xué)習(xí)率
NUM_UNITS = 100? ? ? ? # 多少個LTSM單元
ITERATIONS=8000? ? ? ? # 迭代次數(shù)
N_CLASSES=10? ? ? ? ? ? # 輸出大小,0-9十個數(shù)字的概率
# 定義 placeholders 以便接收x,y
train_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])? ? ? # 維度是[BATCH_SIZE锦秒,TIME_STEP * INPUT_SIZE]
image = tf.reshape(train_x, [-1, TIME_STEP, INPUT_SIZE])? ? ? ? ? ? ? ? ? # 輸入的是二維數(shù)據(jù)露泊,將其還原為三維,維度是[BATCH_SIZE, TIME_STEP, INPUT_SIZE]
train_y = tf.placeholder(tf.int32, [None, N_CLASSES])? ? ? ? ? ? ? ? ? ?
# 定義RNN(LSTM)結(jié)構(gòu)
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=NUM_UNITS)
outputs,final_state = tf.nn.dynamic_rnn(
? ? cell=rnn_cell,? ? ? ? ? ? ? # 選擇傳入的cell
? ? inputs=image,? ? ? ? ? ? ? # 傳入的數(shù)據(jù)
? ? initial_state=None,? ? ? ? # 初始狀態(tài)
? ? dtype=tf.float32,? ? ? ? ? # 數(shù)據(jù)類型
? ? time_major=False,? ? ? ? ? # False: (batch, time step, input); True: (time step, batch, input)旅择,這里根據(jù)image結(jié)構(gòu)選擇False
)
output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES)? ? ?
這里outputs,final_state = tf.nn.dynamic_rnn(...).
final_state包含兩個量惭笑,第一個為c保存了每個LSTM任務(wù)最后一個cell中每個神經(jīng)元的狀態(tài)值,第二個量h保存了每個LSTM任務(wù)最后一個cell中每個神經(jīng)元的輸出值生真,所以c和h的維度都是[BATCH_SIZE,NUM_UNITS]沉噩。
outputs的維度是[BATCH_SIZE,TIME_STEP,NUM_UNITS],保存了每個step中cell的輸出值h。
由于這里是一個many to one的任務(wù)柱蟀,只需要最后一個step的輸出outputs[:, -1, :]川蒙,output = tf.layers.dense(inputs=outputs[:, -1, :], units=N_CLASSES) 通過一個全連接層將輸出限制為N_CLASSES。
loss = tf.losses.softmax_cross_entropy(onehot_labels=train_y, logits=output) # 計算loss
train_op = tf.train.AdamOptimizer(LR).minimize(loss)? ? ? #選擇優(yōu)化方法
correct_prediction = tf.equal(tf.argmax(train_y, axis=1),tf.argmax(output, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))? #計算正確率
sess = tf.Session()
sess.run(tf.global_variables_initializer())? ? # 初始化計算圖中的變量
for step in range(ITERATIONS):? ? # 開始訓(xùn)練
? ? x, y = mnist.train.next_batch(BATCH_SIZE)?
? ? test_x, test_y = mnist.test.next_batch(5000)
? ? _, loss_ = sess.run([train_op, loss], {train_x: x, train_y: y})
? ? if step % 500 == 0:? ? ? # test(validation)
? ? ? ? accuracy_ = sess.run(accuracy, {train_x: test_x, train_y: test_y})
? ? ? ? print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)
訓(xùn)練過程輸出:
train loss: 2.2990 | test accuracy: 0.13
train loss: 0.1347 | test accuracy: 0.96
train loss: 0.0620 | test accuracy: 0.97
train loss: 0.0788 | test accuracy: 0.98
train loss: 0.0160 | test accuracy: 0.98
train loss: 0.0084 | test accuracy: 0.99
train loss: 0.0436 | test accuracy: 0.99
train loss: 0.0104 | test accuracy: 0.98
train loss: 0.0736 | test accuracy: 0.99
train loss: 0.0154 | test accuracy: 0.98
train loss: 0.0407 | test accuracy: 0.98
train loss: 0.0109 | test accuracy: 0.98
train loss: 0.0722 | test accuracy: 0.98
train loss: 0.1133 | test accuracy: 0.98
train loss: 0.0072 | test accuracy: 0.99
train loss: 0.0352 | test accuracy: 0.98
可以看到长已,雖然RNN是擅長處理序列類的任務(wù)派歌,在MNIST手寫數(shù)字圖片辨識這個任務(wù)上弯囊,RNN同樣可以取得很高的正確率。
參考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://yjango.gitbooks.io/superorganism/content/lstmgru.html
參考代碼
https://yjango.gitbooks.io/superorganism/content/lstmgru.html