如何使用雙向RNN
在《深度學(xué)習(xí)之TensorFlow入門演痒、原理與進階實戰(zhàn)》一書的9.4.2中的第4小節(jié)中,介紹過變長動態(tài)RNN的實現(xiàn)。這里在來延伸的講解一下雙向動態(tài)rnn在處理變長序列時的應(yīng)用。其實雙向RNN的使用中铅匹,有一個隱含的注意事項,非常容易犯錯饺藤。?
本文就在介紹下雙向RNN的常用函數(shù)包斑、用法及注意事項。
動態(tài)雙向rnn有兩個函數(shù):
stack_bidirectional_dynamic_rnn
bidirectional_dynamic_rnn?
二者的實現(xiàn)上大同小異涕俗,放置的位置也不一樣罗丰,前者放在contrib下面,而后者顯得更加根紅苗正再姑,放在了tf的核心庫下面萌抵。在使用時二者的返回值也有所區(qū)別。下面就來一一介紹元镀。
先以GRU的cell代碼為例:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()# 創(chuàng)建輸入數(shù)據(jù)X = np.random.randn(2, 4, 5)# 批次 绍填、序列長度、樣本維度# 第二個樣本長度為3X[1,2:] = 0seq_lengths = [4, 2]
Gstacked_rnn = []
Gstacked_bw_rnn = []
for i in range(3):
? ? Gstacked_rnn.append(tf.contrib.rnn.GRUCell(3))
? ? Gstacked_bw_rnn.append(tf.contrib.rnn.GRUCell(3))#建立前向和后向的三層RNNGmcell = tf.contrib.rnn.MultiRNNCell(Gstacked_rnn)
Gmcell_bw = tf.contrib.rnn.MultiRNNCell(Gstacked_bw_rnn)
sGbioutputs, sGoutput_state_fw, sGoutput_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([Gmcell],[Gmcell_bw], X,sequence_length=seq_lengths,
dtype=tf.float64)
Gbioutputs, Goutput_state_fw = tf.nn.bidirectional_dynamic_rnn(Gmcell,Gmcell_bw, X,sequence_length=seq_lengths,dtype=tf.float64)
是創(chuàng)建雙向RNN的方法示例栖疑√钟溃可以看到帶有stack的雙向RNN會輸出3個返回值,而不帶有stack的雙向RNN會輸出2個返回值遇革。?
這里面還要注意的是卿闹,在沒有未cell初始化時必須要將dtype參數(shù)賦值。不然會報錯萝快。
下面添加代碼锻霎,將輸出的值打印出來,看一下揪漩,這兩個函數(shù)到底是輸出的是啥量窘?
#建立一個會話sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sgbresult,sgstate_fw,sgstate_bw=sess.run([sGbioutputs,sGoutput_state_fw,sGoutput_state_bw])
print("全序列:\n", sgbresult[0])
print("短序列:\n", sgbresult[1])
print('Gru的狀態(tài):',len(sgstate_fw[0]),'\n',sgstate_fw[0][0],'\n',sgstate_fw[0][1],'\n',sgstate_fw[0][2])
print('Gru的狀態(tài):',len(sgstate_bw[0]),'\n',sgstate_bw[0][0],'\n',sgstate_bw[0][1],'\n',sgstate_bw[0][2])
先看一下帶有stack的雙向RNN輸出的內(nèi)容:?
我們輸入的數(shù)據(jù)的批次是2,第一個序列長度是4氢拥,第二個序列長度是2.?
圖中共有4部分輸出,可以看到锨侯,第一部分(全序列)就是序列長度為4的結(jié)果,第二部分(短序列)就是序列長度為2的結(jié)果囚痴。由于沒一層都是由3個RNN的GRU cell組成,所以每個序列的輸出都為3.很顯然深滚,對于這樣的結(jié)果輸出涣觉,必須要將短序列后面的0去掉才可以用。?
好在該函數(shù)還有第二個輸出值血柳,GRU的狀態(tài)官册∧寻疲可以直接使用狀態(tài)里的值,而不需要對原始結(jié)果進行去0的變化根吁。
由于單個GRU本來就是沒有狀態(tài)的员淫。所以該函數(shù)將最后的輸出作為狀態(tài)返回。該函數(shù)有兩個狀態(tài)返回击敌,分別代表前向和后向介返。每一個方向的狀態(tài)都會返回3個元素沃斤。這是因為每個方向的網(wǎng)絡(luò)都有3層GRU組成。在使用時轰枝,一般都會取最后一個狀態(tài)。圖中紅色部分為前向中鞍陨,兩個樣本對應(yīng)的輸出,這個很好理解缭裆。
重點要看藍色的部分,即反向的狀態(tài)值對應(yīng)的是原始數(shù)據(jù)中最其實的序列輸入澈驼。因為是反向RNN筛武,在反向循環(huán)時,是會把序列中最后的放在最前面徘六,所以反向網(wǎng)絡(luò)的生成結(jié)果就會與最開始的序列相對應(yīng)。?
對于特征提取任務(wù)處理時漠其,正向與反向的最后值都為該序列的特征,需要合并起來統(tǒng)一處理和屎。但是對于下一個序列預(yù)測任務(wù)時,建議直接使用正向的RNN網(wǎng)絡(luò)就可以了套啤。?
如果要獲取雙向RNN的結(jié)果颠印,尤其是變長情況下,通過狀態(tài)拿到值直接拼接起來才是正確的做法线罕。即便不是變長。直接使用輸出值來拼接钞楼,會損失掉反向的一部分特征結(jié)果。這是需要值得注意的地方燃乍。
好了宛琅。在接著看下不帶stack的函數(shù)輸出是什么樣子的
gbresult,state_fw=sess.run([Gbioutputs,Goutput_state_fw])print("正向:\n", gbresult[0])print("反向:\n", gbresult[1])print('狀態(tài):',len(state_fw),'\n',state_fw[0],'\n',state_fw[1])? #state_fw[0]:【層,批次嘿辟,cell個數(shù)】 重頭到最后一個序列print(state_fw[0][-1],state_fw[1][-1])out? = np.concatenate((state_fw[0][-1],state_fw[1][-1]),axis = 1)print("拼接",out)
這次舆瘪,在輸出基本內(nèi)容基礎(chǔ)上红伦,直接將結(jié)果拼接起來。上面代碼運行后會輸出如下內(nèi)容昙读。
同樣正向用紅色蛮浑,反向用藍色。改函數(shù)返回的輸出值沮稚,沒有將正反向拼接。輸出的狀態(tài)雖然是一個值壮虫,但是里面有兩個元素,一個代表正向狀態(tài)剩拢,一個代表反向狀態(tài).?
從輸出中可以看到饶唤,最后一行實現(xiàn)了最終結(jié)果的真正拼接。在使用雙向rnn時可以按照上面的例子代碼將其狀態(tài)拼接成一條完整輸出募狂,然后在進行處理。
類似的如果想使用LSTM cell性穿。將前面的GRU部分替換即可雷滚,代碼如下:
stacked_rnn = []stacked_bw_rnn = []for iinrange(3):? ? stacked_rnn.append(tf.contrib.rnn.LSTMCell(3))? ? stacked_bw_rnn.append(tf.contrib.rnn.LSTMCell(3))mcell = tf.contrib.rnn.MultiRNNCell(stacked_rnn)mcell_bw = tf.contrib.rnn.MultiRNNCell(stacked_bw_rnn)? ? bioutputs, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn([mcell],[mcell_bw],X,sequence_length=seq_lengths,? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
dtype=tf.float64)bioutputs, output_state_fw = tf.nn.bidirectional_dynamic_rnn(mcell,mcell_bw,X,sequence_length=seq_lengths,? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? dtype=tf.float64)
至于輸出的內(nèi)容是什么,可以按照前面GRU的輸出部分顯示出來自己觀察呆万。如何拼接车份,也可以參照GRU的例子來做。
通過將正反向的狀態(tài)拼接起來才可以獲得雙向RNN的最終輸出特征扫沼。千萬不要直接拿著輸出不加處理的來進行后續(xù)的運算,這會損失一大部分的運算特征以政。
該部分內(nèi)容屬于《深度學(xué)習(xí)之TensorFlow入門伴找、原理與進階實戰(zhàn)》一書的內(nèi)容補充。關(guān)于RNN的更多介紹可以參看書中第九章的詳細內(nèi)容技矮。