Pytorch學(xué)習(xí)記錄-torchtext和Pytorch的實(shí)例6
0. PyTorch Seq2Seq項(xiàng)目介紹
在完成基本的torchtext之后绍昂,找到了這個(gè)教程洽糟,《基于Pytorch和torchtext來(lái)理解和實(shí)現(xiàn)seq2seq模型》撩匕。
這個(gè)項(xiàng)目主要包括了6個(gè)子項(xiàng)目
使用神經(jīng)網(wǎng)絡(luò)訓(xùn)練Seq2Seq使用RNN encoder-decoder訓(xùn)練短語(yǔ)表示用于統(tǒng)計(jì)機(jī)器翻譯使用共同學(xué)習(xí)完成NMT的堆砌和翻譯打包填充序列锦亦、掩碼和推理卷積Seq2Seq- Transformer
6. Transformer
OK宇弛,來(lái)到最后一章曙痘,Transformer阳准,又回到這個(gè)模型啦氛堕,繞不開(kāi)的,依舊沒(méi)有講解野蝇,只能看看代碼讼稚。
來(lái)源不用說(shuō)了,《Attention is all you need》绕沈。Transformer在之前復(fù)習(xí)了多次锐想,這次也一樣,不知道教程會(huì)如何實(shí)現(xiàn)乍狐,反正之前學(xué)得挺痛苦的赠摇。
6.1 準(zhǔn)備數(shù)據(jù)
這里使用了一個(gè)新的數(shù)據(jù)集TranslationDataset,機(jī)器翻譯數(shù)據(jù)集是 TranslationDataset 類的子類。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchtext
#機(jī)器翻譯數(shù)據(jù)集是 TranslationDataset 類的子類藕帜。
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator
import spacy
import random
import math
import os
import time
SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
def tokenize_de(text):
return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(tokenize=tokenize_de,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
TRG=Field(tokenize=tokenize_en,
init_token='<sos>',
eos_token='<eos>',
lower=True,
batch_first=True)
train_data,valid_data,test_data=Multi30k.splits(
exts=('.de','.en'),
fields=(SRC, TRG)
)
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator=BucketIterator.splits(
(train_data,valid_data,test_data),
batch_size=BATCH_SIZE,
device=device
)
6.2 構(gòu)建模型
6.2.1 encoder和decoder
Transformer模型使用經(jīng)典的encoer-decoder架構(gòu)烫罩,由encoder和decoder兩部分組成。
可以看到兩側(cè)的N_x表示encoer和decoder各有多少層洽故。
在原始論文中贝攒,encoder和decoder都包含有6層。
encoder的每一層是由一個(gè)Multi head self-attention和一個(gè)FeedForward構(gòu)成收津,兩個(gè)部分都會(huì)使用殘差連接(residual connection)和Layer Normalization饿这。
decoder的每一層比encoder多了一個(gè)multi-head context-attention,即是說(shuō)撞秋,每一層包括multi-head context-attention长捧、Multi head self-attention和一個(gè)FeedForward,同樣三個(gè)部分都會(huì)使用殘差連接(residual connection)和Layer Normalization吻贿。
encoder和decoder通過(guò)context-attention進(jìn)行連接串结。
對(duì)比會(huì)發(fā)現(xiàn),紅框是相同的舅列,藍(lán)框是多出來(lái)的那個(gè)block肌割。
6.2.2 使用多種attention機(jī)制(multi-head context-attention、multi-head self-attention)
在圖中帐要,每塊就是一個(gè)block把敞,可以看到里面所使用的機(jī)制。似乎都有一個(gè)***-attention
attention對(duì)于某個(gè)時(shí)刻的輸出y榨惠,它在輸入x上各個(gè)部分的注意力奋早。這個(gè)注意力實(shí)際上可以理解為權(quán)重。
看到了encoder和decoder里面的block赠橙,就自然會(huì)考慮“什么是Attention”耽装。
前面已經(jīng)說(shuō)了,Attention實(shí)際上可以理解為權(quán)重期揪。attention機(jī)制也可以分成很多種掉奄。Attention? Attention!一文有一張比較全面的表格
6.2.2.1 multi head self-attention
attention機(jī)制有兩個(gè)隱狀態(tài),分別是輸入序列隱狀態(tài)和輸出序列隱狀態(tài)
凤薛,前者是輸入序列第i個(gè)位置產(chǎn)生的隱狀態(tài)姓建,后者是輸出序列在第t個(gè)位置產(chǎn)生的隱狀態(tài)。
所謂multi head self-attention實(shí)際上就是輸出序列就是輸入序列枉侧,即是說(shuō)計(jì)算自己的attention得分引瀑,就叫做self-attention。
6.2.2.2 multi head context-attention
multi head context-attention是encoder和decoder之間的attention榨馁,是兩個(gè)不同序列之間的attention,與self-attention相區(qū)別帜矾。
6.2.2.3 如何實(shí)現(xiàn)Attention翼虫?
Attention的實(shí)現(xiàn)有很多種方式屑柔,上面的表列出了7種attention,在Transformer中珍剑,使用的是scaled dot-product attention掸宛。
為什么使用scaled dot-product attention。Google給出的解答就是Q(Query)招拙、V(Value)唧瘾、K(Key),注意看看這里的描述别凤。通過(guò)query和key的相似性程度來(lái)確定value的權(quán)重分布饰序。
scaled dot-product attention 和 dot-product attention 唯一的區(qū)別就是,scaled dot-product attention 有一個(gè)縮放因子规哪, 叫求豫。
表示 Key 的維度,默認(rèn)用 64诉稍。
使用縮放因子的原因是蝠嘉,對(duì)于d_k很大的時(shí)候,點(diǎn)積得到的結(jié)果維度很大杯巨,使得結(jié)果處于softmax函數(shù)梯度很小的區(qū)域蚤告。而在梯度很小的情況時(shí),對(duì)反向傳播不利服爷。為了克服這個(gè)負(fù)面影響杜恰,除以一個(gè)縮放因子,可以一定程度上減緩這種情況层扶。
我們對(duì)比一下公式和論文中的圖示箫章。
以下為實(shí)現(xiàn)scaled dot-product attention的算法和圖示對(duì)比,我盡量搞得清楚點(diǎn)镜会。
可以發(fā)現(xiàn)scale就是比例的意思檬寂,所以多了scale就是多了一個(gè)。
接下來(lái)就是一個(gè)softmax戳表,然后與V相乘
在decoder的self-attention中桶至,Q、K匾旭、V都來(lái)自于同一個(gè)地方(相等)镣屹,它們是上一層decoder的輸出。對(duì)于第一層decoder价涝,它們就是word embedding和positional encoding相加得到的輸入女蜈。但是對(duì)于decoder,我們不希望它能獲得下一個(gè)time step(即將來(lái)的信息),因此我們需要進(jìn)行sequence masking伪窖∫菰ⅲ可以看到里面還有一個(gè)Mask,這個(gè)在下面會(huì)詳細(xì)介紹覆山。
6.2.2.4 如何實(shí)現(xiàn)multi-heads attention竹伸?
理解了Scaled dot-product attention,Multi-head attention也很簡(jiǎn)單了簇宽。論文提到勋篓,他們發(fā)現(xiàn)將Q、K魏割、V通過(guò)一個(gè)線性映射之后譬嚣,分成h份,對(duì)每一份進(jìn)行scaled dot-product attention效果更好见妒。然后孤荣,把各個(gè)部分的結(jié)果合并起來(lái),再次經(jīng)過(guò)線性映射须揣,得到最終的輸出盐股。這就是所謂的multi-head attention。上面的超參數(shù)h就是heads數(shù)量耻卡。論文默認(rèn)是8疯汁。
Multi-head attention允許模型加入不同位置的表示子空間的信息。
我們對(duì)比一下公式和論文中的圖示卵酪。
其中
6.2.3 使用Layer-Normalization機(jī)制
Normalization的一種幌蚊,把輸入轉(zhuǎn)化成均值為0方差為1的數(shù)據(jù)。是在每一個(gè)樣本上計(jì)算均值和方差溃卡。
層歸一對(duì)應(yīng)的是每個(gè)block中的Norm
層歸一(Layer Normalization)與批歸一(Batch Normalization)的區(qū)別就在于:
- BN在每一層的每一批數(shù)據(jù)上進(jìn)行歸一化(計(jì)算均值和方差)
- LN在每一個(gè)樣本上計(jì)算均值和方差
層歸一公式
其中是x最后一個(gè)維度的均值(看實(shí)現(xiàn)的源碼是這樣解釋溢豆,但是為什么是最后一個(gè)維度呢)
層歸一示意圖
LayerNormalization.png
6.2.4 使用Mask機(jī)制
用于對(duì)輸入序列進(jìn)行對(duì)齊。這里使用的是padding mask和sequence mask瘸羡。
mask掩碼漩仙,在Transformer中就是對(duì)某些值進(jìn)行掩蓋,使其在參數(shù)更新時(shí)不產(chǎn)生效果犹赖。
Transformer模型涉及兩種mask队他。
- padding mask
-
sequence mask,這個(gè)在之前decoder中已經(jīng)見(jiàn)過(guò)峻村,使用在multi-heads context-attention中麸折。
其中,padding mask在所有的scaled dot-product attention里面都需要用到粘昨,而sequence mask只有在decoder的multi-heads context-attention里面用到垢啼。
兩種mask使用的位置對(duì)比.png
所以窜锯,我們之前ScaledDotProductAttention的forward方法里面的參數(shù)attn_mask在不同的地方會(huì)有不同的含義。這一點(diǎn)我們會(huì)在后面說(shuō)明膊夹。
6.2.4.1 padding mask
說(shuō)白了就是對(duì)齊每一句話衬浑,每個(gè)批次輸入序列長(zhǎng)度是不一樣的捌浩。就是說(shuō)要以最長(zhǎng)的那句話為標(biāo)準(zhǔn)放刨,其他句子少一個(gè)詞就填充一個(gè)0。因?yàn)檫@些填充的位置是沒(méi)有意義的尸饺,attention機(jī)制不應(yīng)該把注意力放在這些位置上进统,所以我們需要進(jìn)行一些處理。
操作方法就是把這些位置的值加上一個(gè)非常大的負(fù)數(shù)(可以是負(fù)無(wú)窮)浪听,這樣的話螟碎,經(jīng)過(guò)softmax,這些位置的概率就會(huì)接近0迹栓。
padding mask是一個(gè)張量掉分,每個(gè)值都是一個(gè)Boolen,值為False的地方就是我們要進(jìn)行處理的地方克伊。
6.2.4.2 sequence mask
sequence mask是為了使得decoder不能看見(jiàn)未來(lái)的信息酥郭。也就是對(duì)于一個(gè)序列,在time_step為t的時(shí)刻愿吹,我們的解碼輸出應(yīng)該只能依賴于t時(shí)刻之前的輸出不从,而不能依賴t之后的輸出。因此我們需要想一個(gè)辦法犁跪,把t之后的信息給隱藏起來(lái)椿息。
這部分具體操作:產(chǎn)生一個(gè)上三角矩陣,上三角的值全為1坷衍,下三角的值全為0寝优,對(duì)角線也是0。把這個(gè)矩陣作用在每一個(gè)序列上枫耳。
沒(méi)看懂乏矾,好像第一次看這部分也是沒(méi)看懂,等下實(shí)現(xiàn)的時(shí)候看看吧嘉涌,不知道這句話的意思妻熊。
6.2.5 使用殘差residual connection
避免梯度消失。
殘差連接對(duì)應(yīng)的是每個(gè)block中的add
殘差連接示意圖如下仑最。
假設(shè)網(wǎng)絡(luò)中某個(gè)層對(duì)輸入x作用(比如使用Relu作用)后的輸出是扔役,那么增加residual connection之后,就變成了:
這個(gè)+x操作就是一個(gè)shortcut警医。
那么殘差結(jié)構(gòu)有什么好處呢亿胸?顯而易見(jiàn):因?yàn)樵黾恿艘豁?xiàng)x坯钦,那么該層網(wǎng)絡(luò)對(duì)x求偏導(dǎo)的時(shí)候,多了一個(gè)常數(shù)項(xiàng)1侈玄。在反向傳播過(guò)程中婉刀,梯度連乘,也不會(huì)造成梯度消失序仙。
6.2.6 使用Positional-encoding
對(duì)序列中的詞語(yǔ)出現(xiàn)的位置進(jìn)行編碼突颊。這樣模型就可以捕捉順序信息。
在處理完模型的各個(gè)模塊后潘悼,開(kāi)始關(guān)注數(shù)據(jù)的輸入部分律秃,在這里重點(diǎn)是位置編碼。與CNN和RNN不同治唤,Transformer模型對(duì)于序列沒(méi)有編碼棒动,這就導(dǎo)致無(wú)法獲取每個(gè)詞之間的關(guān)系,也就是無(wú)法構(gòu)成有意義的語(yǔ)句宾添。
為了解決這個(gè)問(wèn)題船惨。論文提出了Positional encoding。核心就是對(duì)序列中的詞語(yǔ)出現(xiàn)的位置進(jìn)行編碼缕陕。如果對(duì)位置進(jìn)行編碼粱锐,那么我們的模型就可以捕捉順序信息。
論文使用正余弦函數(shù)實(shí)現(xiàn)位置編碼榄檬。這樣做的好處就是不僅可以獲取詞的絕對(duì)位置信息卜范,還可以獲取相對(duì)位置信息。
其中鹿榜,pos是指詞語(yǔ)在序列中的位置海雪。可以看出舱殿,在偶數(shù)位置奥裸,使用正弦編碼,在奇數(shù)位置沪袭,使用余弦編碼湾宙。
相對(duì)位置信息通過(guò)以下公式實(shí)現(xiàn)
上面的公式說(shuō)明,對(duì)于詞匯之間的位置偏移k冈绊,可以表示成
和
的組合形式侠鳄,這就是表達(dá)相對(duì)位置的能力。
6.2.7 Position-wise Feed-Forward network
除了attention子層之外死宣,encoder和decoder中的每個(gè)層都包含一個(gè)完全連接的前饋網(wǎng)絡(luò)(Feed-forward network)伟恶,該網(wǎng)絡(luò)分別和相同地應(yīng)用于每個(gè)位置。這包括兩個(gè)線性變換和一個(gè)ReLU激活毅该。
雖然線性變換在不同位置上是相同的博秫,但它們?cè)趯优c層之間使用不同的參數(shù)潦牛。另一種描述這種情況的方法是兩個(gè)內(nèi)核大小為1的卷積。輸入和輸出的維數(shù)是512挡育,而中間層維度為2048巴碗。