Google 2017年論文Attention is all you need
提出了Transformer模型缕粹,完全基于Attention mechanism稚茅,拋棄了傳統(tǒng)的CNN和RNN。
1. Transformer架構(gòu)
解釋下這個(gè)結(jié)構(gòu)圖平斩。首先亚享,Transformer模型也是使用經(jīng)典的encoder-decoder架構(gòu),由encoder和decoder兩部分組成绘面。
上圖左側(cè)用Nx
框出來的欺税,就是我們encoder的一層。encoder一共有6層這樣的結(jié)構(gòu)揭璃。
上圖右側(cè)用Nx
框出來的晚凿,就是我們decoder的一層。decoder一共有6層這樣的結(jié)構(gòu)瘦馍。
輸入序列經(jīng)過word embedding
和positional embedding
相加后歼秽,輸入到encoder中。
輸出序列經(jīng)過word embedding
和positional embedding
相加后情组,輸入到decoder中燥筷。
最后,decoder輸出的結(jié)果院崇,經(jīng)過一個(gè)線性層肆氓,然后計(jì)算softmax。
2. Encoder
encoder由6層相同的層組成底瓣,每一層分別由兩部分組成:
- 第一部分是
multi-head self-attention mechanism
- 第二部分是
position-wise feed-forward network
谢揪,是一個(gè)全連接層。
兩部分濒持,都有一個(gè)殘差連接(residual connection)
键耕,然后接著一個(gè)Layer Normalization
寺滚。
3. Decoder
與encoder類似柑营,decoder也是由6個(gè)相同層組成,每一個(gè)層包括以下3個(gè)部分:
- 第一部分是
multi-head self-attention mechanism
- 第二部分是
multi-head context-attention mechanism
- 第三部分是
position-wise feed-forward network
同樣村视,上面三部分中每一部分官套,都有一個(gè)殘差連接(residual connection)
,后接著一個(gè)Layer Normalization
。
4. Attention機(jī)制
Attention是指對(duì)于某個(gè)時(shí)刻的輸出y
奶赔,它在輸入x
上各個(gè)部分的注意力惋嚎。這個(gè)注意力可以理解為權(quán)重。
attention機(jī)制有很多計(jì)算方式站刑,下面是一張比較全面的表格:
seq2seq模型中另伍,使用的是加性注意力(addtion attention)
較多。
為什么這種attention叫做addtion attention呢绞旅?很簡(jiǎn)單摆尝,對(duì)于輸入序列隱狀態(tài)和輸出序列的隱狀態(tài)
,它的處理方式很簡(jiǎn)單因悲,直接合并為
但是transformer模型使用的不是這種attention機(jī)制堕汞,使用的是另一種,叫做乘性注意力(multiplicative attention)
晃琳。
那么這種乘性注意力機(jī)制是怎么樣的呢讯检?從上表中的公式也可以看出來:兩個(gè)隱狀態(tài)進(jìn)行點(diǎn)積!
4.1 Self-attention是什么卫旱?
上面我們說的attention機(jī)制的時(shí)候人灼,都會(huì)提到兩個(gè)隱狀態(tài),分別是和
顾翼,前者是輸入序列第
個(gè)位置產(chǎn)生的隱狀態(tài)挡毅,后者是輸出序列在第
個(gè)位置產(chǎn)生的隱狀態(tài)。
所謂self-attention實(shí)際上就是輸出序列就是輸入序列暴构,因此計(jì)算自己的attention得分跪呈,就叫做self-attention!
4.2 Context-attention是什么?
context-attention
是encoder和decoder之間的attention取逾!耗绿,所以,也可以成為encoder-decoder attention砾隅!
不管是self-attention還是context-attention误阻,它們計(jì)算attention分?jǐn)?shù)的時(shí)候,可以選擇很多方式晴埂,比如上面表中提到的:
- additive attention
- local-base
- general
- dot-product
- scaled dot-product
那么Transformer模型究反,采用的是哪種呢?答案是:scaled dot-product attention
儒洛。
4.3 Scaled dot-product attention是什么精耐?
論文Attention is all you need里面對(duì)于attention機(jī)制的描述是這樣的:
An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.
這句話描述得很清楚了。翻譯過來就是:通過確定Q和K之間的相似程度來選擇V琅锻!
用公式來描述更加清晰:
scaled dot-product attention和dot-product attention唯一區(qū)別是卦停,scaled dot-product attention有一個(gè)縮放因子向胡。
上面公式中表示的是
的維度,在論文中惊完,默認(rèn)是
64
僵芹。
那么為什么需要加上這個(gè)縮放因子呢?論文中給出了解釋:對(duì)于很大時(shí)小槐,點(diǎn)積得到的結(jié)果維度很大拇派,使得結(jié)果處理softmax函數(shù)梯度很小的區(qū)域。
我們知道凿跳,梯度很小時(shí)攀痊,這對(duì)反向傳播
不利。為了克服這個(gè)負(fù)面影響拄显,除以一個(gè)縮放因子苟径,在一定程度上減緩這種情況。
為什么是呢躬审?論文沒有進(jìn)一步說明棘街。個(gè)人覺得你可以使用其他縮放因子,看看模型效果有沒有提升承边。
論文中也提供了一張很清晰的結(jié)果圖遭殉,供大家參考:
首先說明一下我們的是什么:
在encoder的self-attention中险污,Q、K富岳、V都來自同一個(gè)地方(相等)蛔糯,他們是上一層encoder的輸出。對(duì)于第一層encoder窖式,它們就是
word embedding
和positional encoding
相加得到的輸入蚁飒。在decoder的self-attention中,Q萝喘、K淮逻、V都來自同一個(gè)地方(相等),他們是上一層decoder的輸出阁簸。對(duì)于第一層decoder爬早,它們就是
word embedding
和positional encoding
相加得到的輸入。但是對(duì)于decoder启妹,我們不希望它能獲得下一個(gè)time step筛严,因此我們需要進(jìn)行sequence masking。在encoder-decoder attention中翅溺,Q來自于decoder的上一層的輸出脑漫,K和V來自于encoder的輸出,K和V是一樣的咙崎。
三者的維度一樣褪猛,即
网杆。
4.4 Scaled dot-product attention代碼實(shí)現(xiàn)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
"""
Scaled dot-product attention mechanism.
"""
def __init__(self, attention_dropout=0.0):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, scale=None, attn_mask=None):
"""
前向傳播
args:
q: Queries張量,形狀[B, L_q, D_q]
k: keys張量伊滋, 形狀[B, L_k, D_k]
v: Values張量碳却,形狀[B, L_v, D_v]
scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
attn_mask: Masking張量笑旺,形狀[B, L_q, L_k]
returns:
上下文張量和attention張量
"""
attention = torch.bmm(q, k.transpose(1, 2))
if scale:
attention = attention * scale
if attn_mask:
# 給需要mask的地方設(shè)置一個(gè)負(fù)無窮
attention = attention.masked_fill_(attn_mask, -np.inf)
# 計(jì)算softmax
attention = self.softmax(attention)
# 添加dropout
attention = self.dropout(attention)
# 和V做點(diǎn)積
context = torch.bmm(attention, v)
return context, attention
5. Multi-head attention是什么呢昼浦?
理解了Scaled dot-product attention
,Multi-head attention
也很簡(jiǎn)單了筒主。論文提到关噪,他們發(fā)現(xiàn)將Q、K乌妙、V通過一個(gè)線性映射之后使兔,分成份,對(duì)每一份進(jìn)行
scaled dot-product attention
效果更好藤韵。然后虐沥,把各個(gè)部分的結(jié)果合并起來,再次經(jīng)過線性映射泽艘,得到最終的輸出欲险。這就是所謂的multi-head attention
。上面的超參數(shù)就是heads數(shù)量匹涮。論文默認(rèn)是
8
盯荤。
multi-head attention的結(jié)構(gòu)圖如下:
值得注意的是,上面所說的分成份是在
維度上面進(jìn)行切分的。因此脚翘,進(jìn)入到scaled dot-product attention的
實(shí)際上等于未進(jìn)入之前的
灼卢。
Multi-head attention允許模型加入不同位置的表示子空間的信息。
Multi-head attention的公式如下:
其中来农,
論文中鞋真,所以scaled dot-product attention里面的
5.1 Multi-head attention代碼實(shí)現(xiàn)
class MultiHeadAttention(nn.Module):
def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
super(MultiHeadAttention, self).__init__()
self.dim_per_head = model_dim / num_heads
self.num_heads = num_heads
self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.dot_product_attention = ScaledDotProductAttention(dropout)
self.linear_final = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout)
# multi-head attention之后需要做layer norm
self.layer_num = nn.LayerNorm(model_dim)
def forward(self, query, key, value, attn_mask=None):
# 殘差連接
residual = query
batch_size = key.size(0)
# linear projection
query = self.linear_q(query) # [B, L, D]
key = self.linear_k(key) # [B, L, D]
value = self.linear_v(value) # [B, L, D]
# split by head
query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
key = key.view(batch_size * num_heads, -1, dim_per_head) #
value = value.view(batch_size * num_heads, -1, dim_per_head)
if attn_mask:
attn_mask = attn_mask.repeat(num_heads, 1, 1)
# scaled dot product attention
scale = (key.size(-1) // num_heads) ** -0.5
context, attention = self.dot_product_attention(
query, key, value, scale, attn_mask
)
# concat heads
context = context.view(batch_size, -1, dim_per_head * num_heads)
# final linear projection
output = self.linear_final(context)
# dropout
output = self.dropout(output)
# add residual and norm layer
output = self.layer_num(residual + output)
return output, attention
上面代碼中出現(xiàn)了 Residual connection
和Layer normalization
沃于。下面進(jìn)行解釋:
5.1.1 Residual connection是什么海诲?
殘差連接其實(shí)比較簡(jiǎn)單!看圖就會(huì)比較清晰:
假設(shè)網(wǎng)絡(luò)中某個(gè)層對(duì)輸入x
作用后的輸出為檩互,那么增加
residual connection
之后特幔,變成:
這個(gè)操作被稱為
shotcut
。
殘差結(jié)構(gòu)
因?yàn)樵黾恿艘豁?xiàng)闸昨,該層網(wǎng)絡(luò)對(duì)
求偏導(dǎo)時(shí)蚯斯,為常數(shù)項(xiàng)1!所以可以在反向傳播過程中饵较,梯度連乘拍嵌,不會(huì)造成
梯度消失
!
5.1.2 Layer normalization是什么循诉?
歸一化層横辆,主要有這幾種方法,BatchNorm(2015年)茄猫、LayerNorm(2016年)龄糊、InstanceNorm(2016年)、GroupNorm(2018年)募疮;
將輸入的圖像shape記為[N,C,H,W]炫惩,這幾個(gè)方法主要區(qū)別是:
BatchNorm:batch方向做歸一化,計(jì)算NHW的均值阿浓,對(duì)小batchsize效果不好他嚷;
(BN主要缺點(diǎn)是對(duì)batchsize的大小比較敏感,由于每次計(jì)算均值和方差是在一個(gè)batch上芭毙,所以如果batchsize太小筋蓖,則計(jì)算的均值、方差不足以代表整個(gè)數(shù)據(jù)分布)
LayerNorm:channel方向做歸一化退敦,計(jì)算CHW的均值粘咖;
(對(duì)RNN作用明顯)
InstanceNorm:一個(gè)batch,一個(gè)channel內(nèi)做歸一化侈百。計(jì)算HW的均值瓮下,用在風(fēng)格化遷移;
(因?yàn)樵趫D像風(fēng)格化中钝域,生成結(jié)果主要依賴于某個(gè)圖像實(shí)例讽坏,所以對(duì)整個(gè)batch歸一化不適合圖像風(fēng)格化中,因而對(duì)HW做歸一化例证÷肺兀可以加速模型收斂,并且保持每個(gè)圖像實(shí)例之間的獨(dú)立。)
GroupNorm:將channel方向分group胀葱,然后每個(gè)group內(nèi)做歸一化漠秋,算(C//G)HW的均值;這樣與batchsize無關(guān)抵屿,不受其約束庆锦。
6. Mask是什么?
mask顧名思義就是掩碼晌该,大概意思是對(duì)某些值進(jìn)行掩蓋肥荔,使其不產(chǎn)生效果.
需要說明的是绿渣,Transformer模型中有兩種mask朝群。分別是padding mask
和sequence mask
。其中中符,padding mask
在所有的scaled dot-product attention里都需要用到姜胖,而sequence mask
只在decoder的self-attention中用到。
所以淀散,我們之前的ScaledDotProductAttention的forward
方法里的參數(shù)attn_mask
在不同的地方有不同的含義右莱。
6.1 Padding mask
什么是padding mask
呢?回想一下档插,我們的每個(gè)批次輸入序列長(zhǎng)度是不一樣的慢蜓!也就是說,我們要對(duì)輸入序列進(jìn)行對(duì)齊
郭膛!具體來說晨抡,就是給較短序列后面填充0
。因?yàn)檫@些填充位置则剃,其實(shí)沒有意義耘柱,所以我們的attention機(jī)制不應(yīng)該把注意力放在這些位置上,所以我們需要進(jìn)行一些處理棍现。
具體做法是:把這些位置的值加上一個(gè)非常大的負(fù)數(shù)(可以是負(fù)無窮)调煎,這樣的話,經(jīng)過softmax己肮,這些位置的概率就會(huì)接近0士袄。
而我們的padding mask實(shí)際上是一個(gè)張量,每個(gè)值都是一個(gè)Boolean谎僻,值為False
的地方就是我們要進(jìn)行處理的地方窖剑。
下面是代碼實(shí)現(xiàn):
def padding_mask(seq_q, seq_k):
# seq_k和seq_q的形狀都是[B,L]
len_q = seq_q.size(1)
# `PAD` is 0
pad_mask = seq_k.eq(0)
pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]
[B,L]->[B,1,L]->[B,L,L]
F | F | T | T |
---|---|---|---|
F | F | T | T |
F | F | T | T |
F | F | T | T |
6.2 Sequence mask
sequence mask
是為了使得decoder不能看到未來的信息。也就是對(duì)于一個(gè)序列戈稿,在time step為t的時(shí)刻西土,我們的解碼輸出只能依賴于t時(shí)刻之前的輸出,而不能依賴t之后的輸出鞍盗。因此我們需要想一個(gè)辦法需了,把t之后的信息給隱藏起來跳昼。
那具體如何做呢?也很簡(jiǎn)單:產(chǎn)生一個(gè)上三角矩陣肋乍,上三角矩陣的值全為1鹅颊,下三角的值全為0,對(duì)角線值也為0墓造。把這個(gè)矩陣作用在每一個(gè)序列上堪伍,就可以達(dá)到我們的目的。
具體代碼如下:
def sequence_mask(seq):
batch_size, seq_len = seq.size()
mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
diagonal=1)
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
return mask
[B,L,L]
0 | 1 | 1 | 1 |
---|---|---|---|
0 | 0 | 1 | 1 |
0 | 0 | 0 | 1 |
0 | 0 | 0 | 0 |
哈佛大學(xué)的文章The Annotated Transformer有一張效果圖:
值得注意的是觅闽,本來mask只需要二維矩陣即可帝雇,但是考慮到我們的輸入序列都是批量的,所以我們需要把原本二維矩陣擴(kuò)張成3維張量蛉拙。上面代碼中尸闸,已經(jīng)做了處理。
回到本節(jié)開始的問題孕锄,attn_mask
參數(shù)有幾種情況吮廉?分別是什么意思?
- 對(duì)于decoder的self-attention畸肆,里面使用的scaled dot-product attention宦芦,同時(shí)需要
padding mask
和sequence mask
作為attn_mask
,具體實(shí)現(xiàn)就是兩個(gè)mask相加作為attn_mask
轴脐。 - 其它情況调卑,
attn_mask
都等于padding mask
。
7. Positional encoding是什么豁辉?
就目前而言令野,Transformer架構(gòu)似乎少了點(diǎn)東西。沒錯(cuò)徽级,那就是它對(duì)序列的順序沒有約束气破!我們知道序列的順序是一個(gè)很重要的信息,如果缺失了這個(gè)信息餐抢,可能我們的結(jié)果就是:所有詞語都對(duì)了现使,但是無法組成有意義的語句。
為了解決這個(gè)問題旷痕,論文中提出了positional encoding
碳锈。一句話概括就是:對(duì)序列中的詞語出現(xiàn)的位置進(jìn)行編碼!如果對(duì)位置進(jìn)行編碼欺抗,那么我們的模型就可以捕捉順序信息售碳。
那么具體怎么做呢?論文的實(shí)現(xiàn)是使用正余弦函數(shù)。公式如下:
其中贸人,pos
是指詞語在序列中的位置间景。可以看出艺智,在偶數(shù)位置倘要,使用正弦編碼悉稠,在奇數(shù)位置吊输,使用余弦編碼蜒蕾。
上面公式中的是模型的維度离钝,論文默認(rèn)是
512
。
這個(gè)編碼公式的意思就是:給定詞語的位置pos倡蝙,我們可以把它編碼成維的向量苟穆!也就是說妓忍,位置編碼的每一個(gè)維度對(duì)應(yīng)正弦曲線甲喝,波長(zhǎng)構(gòu)成了從
到
的等比序列尝苇。
Postional encoding是對(duì)詞匯的位置編碼铛只。
7.1 Positional encoding代碼實(shí)現(xiàn)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len):
"""
初始化
args:
d_model: 一個(gè)標(biāo)量埠胖。模型的維度,論文默認(rèn)是512
max_seq_len: 一個(gè)標(biāo)量淳玩。文本序列的最大長(zhǎng)度
"""
super(PositionalEncoding, self).__init__()
# 根據(jù)論文給出的公式直撤,構(gòu)造出PE矩陣
position_encoding = np.array([
[pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
for pos in range(max_seq_len)
])
# 偶數(shù)列使用sin,奇數(shù)列使用cos
position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])
# 在PE矩陣的一次行蜕着,加上一個(gè)全是0的向量谋竖,代表這`PAD`的positional_encoding
# 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding承匣,兩者十分類似
# 那么為什么需要這個(gè)額外的PAD的編碼呢蓖乘?很簡(jiǎn)單,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易韧骗,我們需要對(duì)齊嘉抒,
# 短的序列我們使用0在結(jié)尾不全,我們也需要這些補(bǔ)全位置的編碼袍暴,也就是`PAD`對(duì)應(yīng)的位置編碼
pad_row = torch.zeros([1, d_model])
position_encoding = torch.cat((pad_row, position_encoding))
# 嵌入操作些侍,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
# word embedding中如果詞典增加`UNK`,我們也需要+1政模。
self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)
def forward(self, input_len):
"""
神經(jīng)網(wǎng)絡(luò)前向傳播
args:
input_len: 一個(gè)張量岗宣,形狀為[BATCH_SIZE, 1]。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度淋样。
returns:
返回這一批序列的位置編碼耗式,進(jìn)行了對(duì)齊。
"""
# 找出這一批序列的最大長(zhǎng)度
max_len = torch.max(input_len)
# 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
# 這里range從1開始也是因?yàn)橐荛_PAD(0)的位置
input_pos = torch.LongTensor(
[list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
)
return self.position_encoding(input_pos)
8. Word embedding是什么刊咳?
Word embedding是對(duì)序列中的詞匯的編碼措嵌,把每一個(gè)詞匯編碼成維的向量!它實(shí)際上就是一個(gè)二維浮點(diǎn)矩陣芦缰,里面的權(quán)重是可訓(xùn)練參數(shù)企巢,我們只需要把這個(gè)矩陣構(gòu)建出來就完成了word embedding的工作。
embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
上面vocab_size
是詞典大小让蕾,embedding_size
是詞嵌入的維度大小浪规,論文里面就是等于。所以word embedding矩陣就是一個(gè)
vocab_size*embedding_size
的二維張量探孝。
9. Position-wise Feed-Forward netword是什么笋婿?
這是一個(gè)全連接網(wǎng)絡(luò)顿颅,包含連個(gè)線性變換和一個(gè)非線性函數(shù)(ReLU)。公式如下:
這個(gè)線性變換在不同的位置都是一樣的粱腻,并且在不同的層之間使用不同的參數(shù)庇配。
論文提到,這個(gè)公式還可以用兩個(gè)核大小為1的一維卷積來解釋绍些,卷積的輸入輸出都是捞慌,中間層維度是
。
代碼如下:
class PositionalWiseFeedForward(nn.Module):
def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
super(PositionalWiseFeedForward, self).__init__()
self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(model_dim)
def forward(self, x):
output = x.transpose(1, 2)
output = self.w2(F.relu(self.w1(output)))
output = self.dropout(output.transpose(1, 2))
# add residual and norm layer
output = self.layer_norm(x + output)
return output
10. 完整代碼
至此柬批,所有的細(xì)節(jié)都解釋完了啸澡。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
"""
Scaled dot-product attention mechanism.
"""
def __init__(self, attention_dropout=0.0):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, scale=None, attn_mask=None):
"""
前向傳播
args:
q: Queries張量,形狀[B, L_q, D_q]
k: keys張量氮帐, 形狀[B, L_k, D_k]
v: Values張量嗅虏,形狀[B, L_v, D_v]
scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
attn_mask: Masking張量上沐,形狀[B, L_q, L_k]
returns:
上下文張量和attention張量
"""
attention = torch.bmm(q, k.transpose(1, 2))
if scale:
attention = attention * scale
if attn_mask:
# 給需要mask的地方設(shè)置一個(gè)負(fù)無窮
attention = attention.masked_fill_(attn_mask, -np.inf)
# 計(jì)算softmax
attention = self.softmax(attention)
# 添加dropout
attention = self.dropout(attention)
# 和V做點(diǎn)積
context = torch.bmm(attention, v)
return context, attention
class MultiHeadAttention(nn.Module):
def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
super(MultiHeadAttention, self).__init__()
self.dim_per_head = model_dim / num_heads
self.num_heads = num_heads
self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.dot_product_attention = ScaledDotProductAttention(dropout)
self.linear_final = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout)
# multi-head attention之后需要做layer norm
self.layer_num = nn.LayerNorm(model_dim)
def forward(self, query, key, value, attn_mask=None):
# 殘差連接
residual = query
batch_size = key.size(0)
# linear projection
query = self.linear_q(query) # [B, L, D]
key = self.linear_k(key) # [B, L, D]
value = self.linear_v(value) # [B, L, D]
# split by head
query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
key = key.view(batch_size * num_heads, -1, dim_per_head) #
value = value.view(batch_size * num_heads, -1, dim_per_head)
if attn_mask:
attn_mask = attn_mask.repeat(num_heads, 1, 1)
# scaled dot product attention
scale = (key.size(-1) // num_heads) ** -0.5
context, attention = self.dot_product_attention(
query, key, value, scale, attn_mask
)
# concat heads
context = context.view(batch_size, -1, dim_per_head * num_heads)
# final linear projection
output = self.linear_final(context)
# dropout
output = self.dropout(output)
# add residual and norm layer
output = self.layer_num(residual + output)
return output, attention
def padding_mask(seq_q, seq_k):
# seq_k和seq_q的形狀都是[B,L]
len_q = seq_q.size(1)
# `PAD` is 0
pad_mask = seq_k.eq(0)
pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]
def sequence_mask(seq):
batch_size, seq_len = seq.size()
mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
diagonal=1)
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
return mask
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len):
"""
初始化
args:
d_model: 一個(gè)標(biāo)量皮服。模型的維度,論文默認(rèn)是512
max_seq_len: 一個(gè)標(biāo)量奄容。文本序列的最大長(zhǎng)度
"""
super(PositionalEncoding, self).__init__()
# 根據(jù)論文給出的公式冰更,構(gòu)造出PE矩陣
position_encoding = np.array([
[pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
for pos in range(max_seq_len)
])
# 偶數(shù)列使用sin,奇數(shù)列使用cos
position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])
# 在PE矩陣的一次行昂勒,加上一個(gè)全是0的向量蜀细,代表這`PAD`的positional_encoding
# 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding戈盈,兩者十分類似
# 那么為什么需要這個(gè)額外的PAD的編碼呢奠衔?很簡(jiǎn)單谆刨,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易,我們需要對(duì)齊归斤,
# 短的序列我們使用0在結(jié)尾不全痊夭,我們也需要這些補(bǔ)全位置的編碼,也就是`PAD`對(duì)應(yīng)的位置編碼
pad_row = torch.zeros([1, d_model])
position_encoding = torch.cat((pad_row, position_encoding))
# 嵌入操作脏里,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
# word embedding中如果詞典增加`UNK`她我,我們也需要+1。
self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)
def forward(self, input_len):
"""
神經(jīng)網(wǎng)絡(luò)前向傳播
args:
input_len: 一個(gè)張量迫横,形狀為[BATCH_SIZE, 1]番舆。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度。
returns:
返回這一批序列的位置編碼矾踱,進(jìn)行了對(duì)齊恨狈。
"""
# 找出這一批序列的最大長(zhǎng)度
max_len = torch.max(input_len)
# 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
# 這里range從1開始也是因?yàn)橐荛_PAD(0)的位置
input_pos = torch.LongTensor(
[list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
)
return self.position_encoding(input_pos)
# embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 獲得輸入的詞嵌入編碼
# seq_embedding = seq_embedding(inputs) * np.sqrt(d_model)
class PositionalWiseFeedForward(nn.Module):
def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
super(PositionalWiseFeedForward, self).__init__()
self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(model_dim)
def forward(self, x):
output = x.transpose(1, 2)
output = self.w2(F.relu(self.w1(output)))
output = self.dropout(output.transpose(1, 2))
# add residual and norm layer
output = self.layer_norm(x + output)
return output
class EncoderLayer(nn.Module):
"""Encoder的一層呛讲。"""
def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)
def forward(self, inputs, attn_mask=None):
# self attention
context, attention = self.attention(inputs, inputs, inputs, attn_mask)
# feed forward network
output = self.feed_forward(context)
return output, attention
class Encoder(nn.Module):
"""多層EncoderLayer組成的Encoder"""
def __init__(self,
vocab_size,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.0):
super(Encoder, self).__init__()
self.encoder_layers = nn.ModuleList(
[EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
)
self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)
def forward(self, inputs, inputs_len):
output = self.seq_embedding(inputs)
output += self.pos_embedding(inputs_len)
self_attention_mask = padding_mask(inputs, inputs)
attentions = []
for encoder in self.encoder_layers:
output, attention = encoder(output, self_attention_mask)
attentions.append(attention)
return output, attentions
class DecoderLayer(nn.Module):
def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
super(DecoderLayer, self).__init__()
self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)
def forward(self,
dec_inputs,
enc_outputs,
self_attn_mask=None,
context_attn_mask=None):
# self attention, all inputs are decoder inputs
dec_output, self_attention = self.attention(dec_inputs, dec_inputs, dec_inputs, self_attn_mask)
# context attention
# query is decoder's outputs, key and value are encoder's inputs
dec_output, context_attention = self.attention(dec_output, enc_outputs, enc_outputs, context_attn_mask)
# decoder's output, or context
dec_output = self.feed_forward(dec_output)
return dec_output, self_attention, context_attention
class Decoder(nn.Module):
def __init__(self,
vocab_size,
max_seq_len,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.0):
super(Decoder).__init__()
self.num_layers = num_layers
self.decoder_layers = nn.ModuleList(
[DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
)
self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)
def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
output = self.seq_embedding(inputs)
output += self.pos_embedding(inputs_len)
self_attention_padding_mask = padding_mask(inputs, inputs)
seq_mask = sequence_mask(inputs)
self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)
self_attentions = []
context_attentions = []
for decoder in self.decoder_layers:
output, self_attn, context_attn = decoder(
output, enc_output, self_attn_mask, context_attn_mask)
self_attentions.append(self_attn)
context_attentions.append(context_attn)
return output, self_attentions, context_attentions
class Transformer(nn.Module):
def __init__(self,
src_vocab_size,
src_max_len,
tgt_vocab_size,
tgt_max_len,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.0):
super(Transformer).__init__()
self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)
self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)
self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
self.softmax = nn.Softmax()
def forward(self, src_seq, src_len, tgt_seq, tgt_len):
context_attn_mask = padding_mask(tgt_seq, src_seq)
output, enc_self_attn = self.encoder(src_seq, src_len)
output, dec_self_attn, ctx_attn = self.decoder(tgt_seq, tgt_len, output, context_attn_mask)
output = self.linear(output)
output = self.softmax(output)
return output, enc_self_attn, dec_self_attn, ctx_attn