??在 PyTorch 中黎做,GRU / LSTM
模塊的調(diào)用十分方便别厘,以 GRU 為例种蘸,如下:
import torch
from torch.nn import LSTM, GRU
from torch.autograd import Variable
import numpy as np
# [batch_size, seq_len, input_feature_size]
random_input = Variable(torch.FloatTensor(1, 5, 1).normal_(), requires_grad=False)
gru = GRU(
input_size=1, hidden_size=1, num_layers=1,
batch_first=True, bidirectional=False
)
# output: [batch_size, seq_len, num_direction * hidden_size]
# hidden: [num_layers * num_directions, batch, hidden_size]
output, hidden = gru(random_input)
??其中鞠值,output[:, -1, :] 即為 hidden媚创。LSTM 只是比 GRU 多了一個返回值 cell_state,其余不變彤恶。
??當(dāng)我們將 bidirectional
參數(shù)設(shè)置為 True 的時候钞钙,GRU/LSTM 會自動地將兩個方向的狀態(tài)拼接起來。遇到一些序列分類問題声离,我們常常會將 Bi-GRU/LSTM 的最后一個隱狀態(tài)輸出到分類層中芒炼,也即使用 output[:, -1, :],那么這樣做是否正確呢术徊?
??考慮這樣一個問題:當(dāng)模型正向遍歷序列1, 2, 3, 4, 5
的時候本刽,output[:, -1, :] 是依次計算節(jié)點 1~5
之后的隱狀態(tài);當(dāng)模型反向遍歷序列1, 2, 3, 4, 5
的時候赠涮,t = 5 位置對應(yīng)的隱狀態(tài)僅僅是計算了節(jié)點 5
之后的隱狀態(tài)子寓。output[:, -1, :] 就是拼接了上述兩個向量的特征,但我們想要放入分類層的逆序特征應(yīng)該是 t=1 位置對應(yīng)的隱狀態(tài)笋除,也即依次遍歷 5~1
節(jié)點斜友、編碼整個序列信息的特征。
??下面通過具體的代碼佐證上述結(jié)論垃它,樣例主要參考 Understanding Bidirectional RNN in PyTorch:
1) 數(shù)據(jù) & 模型準(zhǔn)備
# import 如上
random_input = Variable(torch.FloatTensor(1, 5, 1).normal_(), requires_grad=False)
# random_input[0, :, 0]
# tensor([ 0.0929, 0.6335, 0.6090, -0.0992, 0.7811])
# 分別建立一個 雙向 和 單向 GRU
bi_gru = GRU(input_size=1, hidden_size=1, num_layers=1, batch_first=True, bidirectional=True)
reverse_gru = GRU(input_size=1, hidden_size=1, num_layers=1, batch_first=True, bidirectional=False)
# 使 reverse_gru 的參數(shù)與 bi_gru 中逆序計算的部分保持一致
# 這樣 reverse_gru 就可以等價于 bi_gru 的逆序部分
reverse_gru.weight_ih_l0 = bi_gru.weight_ih_l0_reverse
reverse_gru.weight_hh_l0 = bi_gru.weight_hh_l0_reverse
reverse_gru.bias_ih_l0 = bi_gru.bias_ih_l0_reverse
reverse_gru.bias_hh_l0 = bi_gru.bias_hh_l0_reverse
# random_input 正序輸入 bi_gru鲜屏,逆序輸入 reverse_gru
bi_output, bi_hidden = bi_gru(random_input)
reverse_output, reverse_hidden = reverse_gru(random_input[:, np.arange(4, -1, -1), :])
2)結(jié)果對比
bi_output
'''
# shape = [1, 5, 2]
tensor([[[0.0867, 0.7053],
[0.2305, 0.6983],
[0.3245, 0.5996],
[0.2290, 0.4437],
[0.3471, 0.3395]]], grad_fn=<TransposeBackward1>)
'''
reverse_output
# shape = [1, 5, 1]
'''
tensor([[[0.3395],
[0.4437],
[0.5996],
[0.6983],
[0.7053]]], grad_fn=<TransposeBackward1>)
'''
??捋一捋,先只看 reverse_gru国拇,這是個單向gru洛史,我們輸入了一個序列,那么編碼了真格序列信息的隱狀態(tài)自然是最后一個隱狀態(tài)酱吝,也即 0.7053
是序列 [0.7811, -0.0992, 0.609, 0.6335, 0.0929] 的最后一個隱狀態(tài)(序列向量)也殖;bi_output 的第二列代表著逆向編碼的結(jié)果,剛好是 reverse_output 的倒序务热,如果我們直接把 bi_output[:, -1, :] 作為序列向量毕源,顯然是不符合期望的。正確的做法是:
Method 1:
seq_vec = torch.cat(bi_output[:, -1, 0], bi_output[:, 0, 1])
'''
tensor([0.3471, 0.7053], grad_fn=<CatBackward>)
'''
Method 2:
seq_vec = bi_hidden.reshape([bi_hidden.shape[0], -1])
'''
tensor([[0.3471],
[0.7053]], grad_fn=<ViewBackward>)
'''
??也即 hidden 這個變量是返回了 序列編碼
的信息陕习,滿足了我們的要求,可以放心用址愿,也推薦使用第二種方法该镣,少做不必要折騰。
bi_hidden
'''
tensor([[[0.3471]],
[[0.7053]]], grad_fn=<StackBackward>)
'''
reverse_hidden
'''
tensor([[[0.7053]]], grad_fn=<StackBackward>)
'''