內容包括:
- attention機制
- 為什么需要attention
- 簡介
- Seq2Seq
attention
為什么需要attention
由于之前的輸入文本無論長度,最后都變成了一個隱藏狀態(tài)輸入到decoder解碼都伪,因為encoder最后的隱藏狀態(tài)一定程度上包含了所有文本的信息,所以理論上是可以有效果的撤卢,但是實際上一旦輸入文本很長,那么指望這么一個隱藏狀態(tài)變量清楚的記錄文本的所有信息并且能夠解碼出正確的輸出是不現(xiàn)實的,也就是導致了輸出的隱藏狀態(tài)(contenxt vector)"遺忘"了一部分信息驰怎,因此我們想到是否可以給解碼器提供更多的信息钞馁,也就是編碼器每個時間步的信息虑省,attention為解碼器提供編碼器每個隱藏狀態(tài)的信息斗搞,通過這些信息,模型可以有重點的關注編碼器中需要關注的部分
簡介
總的來說慷妙,是為編碼器每個時間步分配一個權重(注意力)僻焚,利用softmax對編碼器隱藏狀態(tài)加權求和,得到context vector
步驟:
- 獲取每個編碼器隱藏狀態(tài)的分數(shù)
-分數(shù)(標量)是通過評分函數(shù)(alignment)獲得的膝擂,其中評分函數(shù)并不是唯一的虑啤,假設為點積函數(shù)
*decoder_hidden *= [10, 5, 10]##其中一個解碼器的隱藏狀態(tài)
*encoder_hidden score*
---------------------
[0, 1, 1] 15 (= 10×0 + 5×1 + 10×1, the dot product)
[5, 0, 1] 60
[1, 1, 0] 15
[0, 5, 1] 35
以上就獲得了4個編碼器隱藏狀態(tài)關于一個解碼器隱藏狀態(tài)的分數(shù),其中[5, 0, 1]的注意力分數(shù)為60是最高的架馋,說明接下來翻譯出的這個詞將受到這個編碼器隱藏狀態(tài)的影響
- 通過softmax層狞山,歸一化
*encoder_hidden score score^*
-----------------------------
[0, 1, 1] 15 0
[5, 0, 1] 60 1
[1, 1, 0] 15 0
[0, 5, 1] 35 0
- 通過softmax獲得的權重與每個編碼器隱藏狀態(tài)相乘,也就是加權叉寂,獲得alignment向量
*encoder_hidden score score^ alignment*
----------------------------------------
[0, 1, 1] 15 0 [0, 0, 0]
[5, 0, 1] 60 1 [5, 0, 1]
[1, 1, 0] 15 0 [0, 0, 0]
[0, 5, 1] 35 0 [0, 0, 0]
- 對alignment向量求和
*encoder_hidden score score^ alignment*
----------------------------------------
[0, 1, 1] 15 0 [0, 0, 0]
[5, 0, 1] 60 1 [5, 0, 1]
[1, 1, 0] 15 0 [0, 0, 0]
[0, 5, 1] 35 0 [0, 0, 0]
*context *= [0+5+0+0, 0+0+0+0, 0+1+0+0] = [5, 0, 1]
- 將上下文向量輸入解碼器
輸入方式取決于架構萍启,一般是和解碼器的輸入一起喂進模型
'''
1. 計算評分函數(shù)
2. 求出對應的softmax歸一化后的值
3. 用上面得到的值加權value(在這里k和v都是編碼器的隱藏狀態(tài))
'''
def SequenceMask(X, X_len,value=-1e6):
maxlen = X.size(1)
#print(X.size(),torch.arange((maxlen),dtype=torch.float)[None, :],'\n',X_len[:, None] )
mask = torch.arange((maxlen),dtype=torch.float)[None, :] >= X_len[:, None] ##這是一種reshape操作,將X_reshape為()
#print(mask)
X[mask]=value
return X
def masked_softmax(X, valid_length):
'''
1. valid_length可能是(batch_size_1, )屏鳍,也可能是(batch_size, x)
2. X(batch_size_1, 1,num_hidden_size)
'''
# X: 3-D tensor, valid_length: 1-D or 2-D tensor
softmax = nn.Softmax(dim=-1)
if valid_length is None:
return softmax(X)
else:
shape = X.shape
if valid_length.dim() == 1:
try:
valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))#[2,2,3,3]
except:
valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))#[2,2,3,3]
else:
valid_length = valid_length.reshape((-1,))
# fill masked elements with a large negative, whose exp is 0
X = SequenceMask(X.reshape((-1, shape[-1])), valid_length)
return softmax(X).reshape(shape)
class DotProductAttention(nn.Module):
'''
1. 計算評分
2. 計算歸一化
3. 計算加權
'''
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# query: (batch_size, #queries, d)
# key: (batch_size, #kv_pairs, d)
# value: (batch_size, #kv_pairs, dim_v)
# valid_length: either (batch_size, ) or (batch_size, xx)
def forward(self, query, key, value, valid_length=None):
d = query.shape[-1]
# set transpose_b=True to swap the last two dimensions of key
scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
attention_weights = self.dropout(masked_softmax(scores, valid_length))
print("attention_weight\n",attention_weights)
return torch.bmm(attention_weights, value)