什么是LSTM
如果還不知道什么是LSTM ,請移步
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
我第一眼看到LSTM時,還在感概這個網(wǎng)絡(luò)怎么這多參數(shù)市咽。其實(shí)接觸多了雇盖,發(fā)現(xiàn)LSTM的精髓就在于3個門,forget肆糕,input和output藻懒,圍繞這3個門的公式也是基本相似却邓,所以記憶LSTM的公式其實(shí)相當(dāng)簡單镀琉。
為什么要用LSTM
因為簡單的RNN很容易就發(fā)生梯度消失和梯度爆炸峦嗤,其中主要的原因是RNN中求導(dǎo),引起的鏈?zhǔn)椒▌t屋摔,對時間上的追溯烁设,很容易發(fā)生系數(shù)矩陣的累乘,矩陣元素大于1钓试,那么就會發(fā)生梯度爆炸装黑;矩陣元素小于1,就會發(fā)生梯度消失弓熏。
LSTM通過門的控制恋谭,可以有效的防止梯度消失,(敲黑板M炀稀>渭铡!)但是依舊可能出現(xiàn)梯度爆炸的問題滞谢,所以訓(xùn)練LSTM會加入梯度裁剪(Gradient Clipping)串稀。在Pytorch中梯度裁剪可以使用
import torch.nn as nn
nn.utils.clip_grad_norm(filter(lambda p:p.requires_grad,model.parameters()),max_norm=max_norm)
在以下的代碼中我不會使用梯度裁剪操作,大家如果有需要可以自己添加以上代碼狮杨。關(guān)于梯度消失和梯度爆炸的具體原因分析可以移步
http://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/readings/L15%20Exploding%20and%20Vanishing%20Gradients.pdf
為什么要用BiLSTM
Bi代表雙向母截。其實(shí)使用BiLSTM還是蠻有爭議,因為人類理解時序信號的默認(rèn)順序其實(shí)是時間流逝的順序橄教,那么將時間倒敘的信號還有沒有意義清寇?有人說有,譬如說看一個人寫一個字的具體筆畫順序其實(shí)不影響我們猜測這個字(這個例子其實(shí)是我瞎舉的)护蝶;有人說沒有华烟,倒著聽一個人說話就不行。不管有什么爭議持灰,但是架不住BiLSTM在實(shí)際應(yīng)用中效果十有八九好于LSTM盔夜,所以就用吧。
具體雙向LSTM的結(jié)構(gòu)其實(shí)相當(dāng)簡單堤魁,就是兩個單向LSTM各自沿著時間和網(wǎng)絡(luò)層向前傳播喂链,然后最后的輸出拼接在一起。
不如先搭建一個BiLSTM妥泉,為了分類任務(wù)
先定義幾個符號
- B代表batch size椭微,
- L_i代表在batch中第i個序列的長度,L\in R^B是一個長度為B的向量
- x(i,0:L_i,0:d_{input})代表在batch中第i個序列盲链,其長度為L_i蝇率,每一幀的維度是d_{input}迟杂;每一個batch的數(shù)據(jù)x的矩陣大小為x\in R^{B\times L_{max}\times d_{input}},其中L_{max}是序列L中的最大值本慕,對于長度不足L_{max}事先應(yīng)進(jìn)行補(bǔ)0操作
- y(i,0:L_i)代表在batch中第i個序列的類別排拷,每一個batch的數(shù)據(jù)y的矩陣大小為y\in R^{B\times L_{max}},其中L_{max}是序列L中的最大值间狂,對于長度不足L_{max}事先應(yīng)進(jìn)行補(bǔ)-1操作(避免和0混淆攻泼,其實(shí)補(bǔ)什么都無所謂,這里只是為了區(qū)分)
在這里鉴象,我將先使用Pytorch的原生API忙菠,搭建一個BiLSTM。先吐槽一下Pytorch對可變長序列處理的復(fù)雜程度纺弊。處理序列的基本步驟如下:
- 準(zhǔn)備torch.Tensor格式的data=x牛欢,label=y,length=L淆游,等等
- 數(shù)據(jù)根據(jù)length排序傍睹,由函數(shù)sort_batch完成
- pack_padded_sequence操作
- 輸入到lstm中進(jìn)行訓(xùn)練
函數(shù)sort_batch
def sort_batch(data,label,length):
batch_size=data.size(0)
# 先將數(shù)據(jù)轉(zhuǎn)化為numpy(),再得到排序的index
inx=torch.from_numpy(np.argsort(length.numpy())[::-1].copy())
data=data[inx]
label=label[inx]
length=length[inx]
# length轉(zhuǎn)化為了list格式犹菱,不再使用torch.Tensor格式
length=list(length.numpy())
return (data,label,length)
網(wǎng)絡(luò)
class Net(nn.Module):
def __init__(self,input_dim,hidden_dim,output_dim,num_layers,biFlag,dropout=0.5):
# input_dim 輸入特征維度d_input
# hidden_dim 隱藏層的大小
# output_dim 輸出層的大惺拔取(分類的類別數(shù))
# num_layers LSTM隱藏層的層數(shù)
# biFlag 是否使用雙向
super(Net,self).__init__()
self.input_dim=input_dim
self.hidden_dim=hidden_dim
self.output_dim=output_dim
self.num_layers=num_layers
if(biFlag):self.bi_num=2
else:self.bi_num=1
self.biFlag=biFlag
# 根據(jù)需要修改device
self.device=torch.device("cuda")
# 定義LSTM網(wǎng)絡(luò)的輸入,輸出腊脱,層數(shù)访得,是否batch_first,dropout比例陕凹,是否雙向
self.layer1=nn.LSTM(input_size=input_dim,hidden_size=hidden_dim, \
num_layers=num_layers,batch_first=True, \
dropout=dropout,bidirectional=biFlag)
# 定義線性分類層悍抑,使用logsoftmax輸出
self.layer2=nn.Sequential(
nn.Linear(hidden_dim*self.bi_num,output_dim),
nn.LogSoftmax(dim=2)
)
self.to(self.device)
def init_hidden(self,batch_size):
# 定義初始的hidden state
return (torch.zeros(self.num_layers*self.bi_num,batch_size,self.hidden_dim).to(self.device),
torch.zeros(self.num_layers*self.bi_num,batch_size,self.hidden_dim).to(self.device))
def forward(self,x,y,length):
# 輸入原始數(shù)據(jù)x,標(biāo)簽y杜耙,以及長度length
# 準(zhǔn)備
batch_size=x.size(0)
max_length=torch.max(length)
# 根據(jù)最大長度截斷
x=x[:,0:max_length,:];y=y[:,0:max_length]
x,y,length=sort_batch(x,y,length)
x,y=x.to(self.device),y.to(self.device)
# pack sequence
x=pack_padded_sequence(x,length,batch_first=True)
# run the network
hidden1=self.init_hidden(batch_size)
out,hidden1=self.layer1(x,hidden1)
# out,_=self.layerLSTM(x) is also ok if you don't want to refer to hidden state
# unpack sequence
out,length=pad_packed_sequence(out,batch_first=True)
out=self.layer2(out)
# 返回正確的標(biāo)簽搜骡,預(yù)測標(biāo)簽,以及長度向量
return y,out,length
官方的BiLSTM有缺陷
以上的代碼看似沒問題了佑女,實(shí)際上卻有一個無法容忍的問題就是non-reproducible记靡。也就是這個雙向LSTM,每次出現(xiàn)的結(jié)果會有不同(在固定所有隨機(jī)種子后)团驱。老實(shí)說簸呈,這對科研狗是致命的。所以reproducible其實(shí)是我對模型最最基本的要求店茶。
根據(jù)實(shí)驗,以下情況下LSTM是non-reproducible劫恒,
- 使用nn.LSTM中的bidirectional=True贩幻,且dropout>0
根據(jù)實(shí)驗轿腺,以下情況下LSTM是reproducible,
- 使用nn.LSTM中的bidirectional=True丛楚,且dropout=0
- 使用nn.LSTM中的bidirectional=False
也就是說雙向LSTM在加上dropout操作后族壳,會導(dǎo)致non-reproducible,據(jù)說這是Cudnn的一個問題趣些,Pytorch無法解決仿荆,具體可見
https://discuss.pytorch.org/t/non-deterministic-result-on-multi-layer-lstm-with-dropout/9700
https://github.com/soumith/cudnn.torch/issues/197
作為一個強(qiáng)迫癥,顯然無法容忍non-reproducible坏平。所幸單向的LSTM是reproducible拢操,所以只能自己搭建一個雙向的LSTM
自己動手豐衣足食
這里要引入一個新的函數(shù)reverse_padded_sequence,作用是將序列反向(可以理解為將batch x\in R^{B\times L_{max}\times d_{input}}的第二個維度L反向舶替,但是補(bǔ)零的地方不反向令境,作用同tensorflow中的tf.reverse_sequence函數(shù)一致)
import torch
from torch.autograd import Variable
def reverse_padded_sequence(inputs, lengths, batch_first=True):
'''這個函數(shù)輸入是Variable,在Pytorch0.4.0中取消了Variable顾瞪,輸入tensor即可
'''
"""Reverses sequences according to their lengths.
Inputs should have size ``T x B x *`` if ``batch_first`` is False, or
``B x T x *`` if True. T is the length of the longest sequence (or larger),
B is the batch size, and * is any number of dimensions (including 0).
Arguments:
inputs (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequence lengths
batch_first (bool, optional): if True, inputs should be B x T x *.
Returns:
A Variable with the same size as inputs, but with each sequence
reversed according to its length.
"""
if batch_first:
inputs = inputs.transpose(0, 1)
max_length, batch_size = inputs.size(0), inputs.size(1)
if len(lengths) != batch_size:
raise ValueError("inputs is incompatible with lengths.")
ind = [list(reversed(range(0, length))) + list(range(length, max_length))
for length in lengths]
ind = torch.LongTensor(ind).transpose(0, 1)
for dim in range(2, inputs.dim()):
ind = ind.unsqueeze(dim)
ind = Variable(ind.expand_as(inputs))
if inputs.is_cuda:
ind = ind.cuda(inputs.get_device())
reversed_inputs = torch.gather(inputs, 0, ind)
if batch_first:
reversed_inputs = reversed_inputs.transpose(0, 1)
return reversed_inputs
接下來就是手動搭建雙向LSTM的網(wǎng)絡(luò)舔庶,和之前基本類似
class Net(nn.Module):
def __init__(self,input_dim,hidden_dim,output_dim,num_layers,biFlag,dropout=0.5):
super(Net,self).__init__()
self.input_dim=input_dim
self.hidden_dim=hidden_dim
self.output_dim=output_dim
self.num_layers=num_layers
if(biFlag):self.bi_num=2
else:self.bi_num=1
self.biFlag=biFlag
self.layer1=nn.ModuleList()
self.layer1.append(nn.LSTM(input_size=input_dim,hidden_size=hidden_dim, \
num_layers=num_layers,batch_first=True, \
dropout=dropout,bidirectional=0))
if(biFlag):
# 如果是雙向,額外加入逆向?qū)? self.layer1.append(nn.LSTM(input_size=input_dim,hidden_size=hidden_dim, \
num_layers=num_layers,batch_first=True, \
dropout=dropout,bidirectional=0))
self.layer2=nn.Sequential(
nn.Linear(hidden_dim*self.bi_num,output_dim),
nn.LogSoftmax(dim=2)
)
self.to(self.device)
def init_hidden(self,batch_size):
return (torch.zeros(self.num_layers*self.bi_num,batch_size,self.hidden_dim).to(self.device),
torch.zeros(self.num_layers*self.bi_num,batch_size,self.hidden_dim).to(self.device))
def forward(self,x,y,length):
batch_size=x.size(0)
max_length=torch.max(length)
x=x[:,0:max_length,:];y=y[:,0:max_length]
x,y,length=sort_batch(x,y,length)
x,y=x.to(self.device),y.to(self.device)
hidden=[ self.init_hidden(batch_size) for l in range(self.bi_num)]
out=[x,reverse_padded_sequence(x,length,batch_first=True)]
for l in range(self.bi_num):
# pack sequence
out[l]=pack_padded_sequence(out[l],length,batch_first=True)
out[l],hidden[l]=self.layer1[l](out[l],hidden[l])
# unpack
out[l],_=pad_packed_sequence(out[l],batch_first=True)
# 如果是逆向?qū)映滦眩枰~外將輸出翻過來
if(l==1):out[l]=reverse_padded_sequence(out[l],length,batch_first=True)
if(self.bi_num==1):out=out[0]
else:out=torch.cat(out,2)
out=self.layer2(out)
out=torch.squeeze(out)
return y,out,length
大功告成惕橙,實(shí)測此網(wǎng)絡(luò)reproducible
Appendix
固定Pytorch中的隨機(jī)種子
import torch
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)