Pytorch學(xué)習(xí)記錄-Transformer(數(shù)據(jù)預(yù)處理和模型結(jié)構(gòu))

Pytorch學(xué)習(xí)記錄-torchtext和Pytorch的實(shí)例6

0. PyTorch Seq2Seq項(xiàng)目介紹

在完成基本的torchtext之后绍昂,找到了這個(gè)教程洽糟,《基于Pytorch和torchtext來(lái)理解和實(shí)現(xiàn)seq2seq模型》撩匕。
這個(gè)項(xiàng)目主要包括了6個(gè)子項(xiàng)目

  1. 使用神經(jīng)網(wǎng)絡(luò)訓(xùn)練Seq2Seq
  2. 使用RNN encoder-decoder訓(xùn)練短語(yǔ)表示用于統(tǒng)計(jì)機(jī)器翻譯
  3. 使用共同學(xué)習(xí)完成NMT的堆砌和翻譯
  4. 打包填充序列锦亦、掩碼和推理
  5. 卷積Seq2Seq
  6. Transformer

6. Transformer

OK宇弛,來(lái)到最后一章曙痘,Transformer阳准,又回到這個(gè)模型啦氛堕,繞不開(kāi)的,依舊沒(méi)有講解野蝇,只能看看代碼讼稚。
來(lái)源不用說(shuō)了,《Attention is all you need》绕沈。Transformer在之前復(fù)習(xí)了多次锐想,這次也一樣,不知道教程會(huì)如何實(shí)現(xiàn)乍狐,反正之前學(xué)得挺痛苦的赠摇。

6.1 準(zhǔn)備數(shù)據(jù)

這里使用了一個(gè)新的數(shù)據(jù)集TranslationDataset,機(jī)器翻譯數(shù)據(jù)集是 TranslationDataset 類的子類。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchtext
#機(jī)器翻譯數(shù)據(jù)集是 TranslationDataset 類的子類藕帜。
from torchtext.datasets import TranslationDataset, Multi30k
from torchtext.data import Field, BucketIterator

import spacy

import random
import math
import os
import time

SEED=1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic=True

spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
def tokenize_de(text):
    return [tok.text for tok in spacy_de.tokenizer(text)]
def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]
SRC=Field(tokenize=tokenize_de,
         init_token='<sos>',
         eos_token='<eos>',
         lower=True,
         batch_first=True)
TRG=Field(tokenize=tokenize_en,
         init_token='<sos>',
         eos_token='<eos>',
         lower=True,
         batch_first=True)
train_data,valid_data,test_data=Multi30k.splits(
    exts=('.de','.en'),
    fields=(SRC, TRG)
)
SRC.build_vocab(train_data,min_freq=2)
TRG.build_vocab(train_data,min_freq=2)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
BATCH_SIZE=128
train_iterator, valid_iterator, test_iterator=BucketIterator.splits(
    (train_data,valid_data,test_data),
    batch_size=BATCH_SIZE,
    device=device
)

6.2 構(gòu)建模型

Transformer結(jié)構(gòu)圖

6.2.1 encoder和decoder

Transformer模型使用經(jīng)典的encoer-decoder架構(gòu)烫罩,由encoder和decoder兩部分組成。
可以看到兩側(cè)的N_x表示encoer和decoder各有多少層洽故。

encoder和decoder.png

在原始論文中贝攒,encoder和decoder都包含有6層。
encoder的每一層是由一個(gè)Multi head self-attention和一個(gè)FeedForward構(gòu)成收津,兩個(gè)部分都會(huì)使用殘差連接(residual connection)和Layer Normalization饿这。
decoder的每一層比encoder多了一個(gè)multi-head context-attention,即是說(shuō)撞秋,每一層包括multi-head context-attention长捧、Multi head self-attention和一個(gè)FeedForward,同樣三個(gè)部分都會(huì)使用殘差連接(residual connection)和Layer Normalization吻贿。
encoder和decoder通過(guò)context-attention進(jìn)行連接串结。
對(duì)比會(huì)發(fā)現(xiàn),紅框是相同的舅列,藍(lán)框是多出來(lái)的那個(gè)block肌割。

6.2.2 使用多種attention機(jī)制(multi-head context-attention、multi-head self-attention)

在圖中帐要,每塊就是一個(gè)block把敞,可以看到里面所使用的機(jī)制。似乎都有一個(gè)***-attention
attention對(duì)于某個(gè)時(shí)刻的輸出y榨惠,它在輸入x上各個(gè)部分的注意力奋早。這個(gè)注意力實(shí)際上可以理解為權(quán)重。

看到了encoder和decoder里面的block赠橙,就自然會(huì)考慮“什么是Attention”耽装。
前面已經(jīng)說(shuō)了,Attention實(shí)際上可以理解為權(quán)重期揪。attention機(jī)制也可以分成很多種掉奄。Attention? Attention!一文有一張比較全面的表格

image.png

6.2.2.1 multi head self-attention

attention機(jī)制有兩個(gè)隱狀態(tài),分別是輸入序列隱狀態(tài)h_i和輸出序列隱狀態(tài)s_t凤薛,前者是輸入序列第i個(gè)位置產(chǎn)生的隱狀態(tài)姓建,后者是輸出序列在第t個(gè)位置產(chǎn)生的隱狀態(tài)。
所謂multi head self-attention實(shí)際上就是輸出序列就是輸入序列枉侧,即是說(shuō)計(jì)算自己的attention得分引瀑,就叫做self-attention。

6.2.2.2 multi head context-attention

multi head context-attention是encoder和decoder之間的attention榨馁,是兩個(gè)不同序列之間的attention,與self-attention相區(qū)別帜矾。

6.2.2.3 如何實(shí)現(xiàn)Attention翼虫?

Attention的實(shí)現(xiàn)有很多種方式屑柔,上面的表列出了7種attention,在Transformer中珍剑,使用的是scaled dot-product attention掸宛。
為什么使用scaled dot-product attention。Google給出的解答就是Q(Query)招拙、V(Value)唧瘾、K(Key),注意看看這里的描述别凤。通過(guò)query和key的相似性程度來(lái)確定value的權(quán)重分布饰序。
Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V
scaled dot-product attention 和 dot-product attention 唯一的區(qū)別就是,scaled dot-product attention 有一個(gè)縮放因子规哪, 叫\frac{1}{\sqrt{d_k}}求豫。d_k 表示 Key 的維度,默認(rèn)用 64诉稍。
使用縮放因子的原因是蝠嘉,對(duì)于d_k很大的時(shí)候,點(diǎn)積得到的結(jié)果維度很大杯巨,使得結(jié)果處于softmax函數(shù)梯度很小的區(qū)域蚤告。而在梯度很小的情況時(shí),對(duì)反向傳播不利服爷。為了克服這個(gè)負(fù)面影響杜恰,除以一個(gè)縮放因子,可以一定程度上減緩這種情況层扶。

我們對(duì)比一下公式和論文中的圖示箫章。
Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V

image.png

以下為實(shí)現(xiàn)scaled dot-product attention的算法和圖示對(duì)比,我盡量搞得清楚點(diǎn)镜会。

QK.png

可以發(fā)現(xiàn)scale就是比例的意思檬寂,所以多了scale就是多了一個(gè)\frac{1}{\sqrt{d_k}}

scaled.png

接下來(lái)就是一個(gè)softmax戳表,然后與V相乘


final.png

在decoder的self-attention中桶至,Q、K匾旭、V都來(lái)自于同一個(gè)地方(相等)镣屹,它們是上一層decoder的輸出。對(duì)于第一層decoder价涝,它們就是word embedding和positional encoding相加得到的輸入女蜈。但是對(duì)于decoder,我們不希望它能獲得下一個(gè)time step(即將來(lái)的信息),因此我們需要進(jìn)行sequence masking伪窖∫菰ⅲ可以看到里面還有一個(gè)Mask,這個(gè)在下面會(huì)詳細(xì)介紹覆山。

6.2.2.4 如何實(shí)現(xiàn)multi-heads attention竹伸?

理解了Scaled dot-product attention,Multi-head attention也很簡(jiǎn)單了簇宽。論文提到勋篓,他們發(fā)現(xiàn)將Q、K魏割、V通過(guò)一個(gè)線性映射之后譬嚣,分成h份,對(duì)每一份進(jìn)行scaled dot-product attention效果更好见妒。然后孤荣,把各個(gè)部分的結(jié)果合并起來(lái),再次經(jīng)過(guò)線性映射须揣,得到最終的輸出盐股。這就是所謂的multi-head attention。上面的超參數(shù)h就是heads數(shù)量耻卡。論文默認(rèn)是8疯汁。
Multi-head attention允許模型加入不同位置的表示子空間的信息。
我們對(duì)比一下公式和論文中的圖示卵酪。

multi-heads attention.png


其中

6.2.3 使用Layer-Normalization機(jī)制

Normalization的一種幌蚊,把輸入轉(zhuǎn)化成均值為0方差為1的數(shù)據(jù)。是在每一個(gè)樣本上計(jì)算均值和方差溃卡。
層歸一對(duì)應(yīng)的是每個(gè)block中的Norm


image.png

層歸一(Layer Normalization)與批歸一(Batch Normalization)的區(qū)別就在于:

  • BN在每一層的每一批數(shù)據(jù)上進(jìn)行歸一化(計(jì)算均值和方差)
  • LN在每一個(gè)樣本上計(jì)算均值和方差
    層歸一公式
    LN(x_i)=\alpha\times\frac{x_i-u_L}{\sqrt{\sigma_L^2+\epsilon}}+\beta
    其中u_L是x最后一個(gè)維度的均值(看實(shí)現(xiàn)的源碼是這樣解釋溢豆,但是為什么是最后一個(gè)維度呢)
    層歸一示意圖
    LayerNormalization.png

6.2.4 使用Mask機(jī)制

用于對(duì)輸入序列進(jìn)行對(duì)齊。這里使用的是padding mask和sequence mask瘸羡。
mask掩碼漩仙,在Transformer中就是對(duì)某些值進(jìn)行掩蓋,使其在參數(shù)更新時(shí)不產(chǎn)生效果犹赖。
Transformer模型涉及兩種mask队他。

  • padding mask
  • sequence mask,這個(gè)在之前decoder中已經(jīng)見(jiàn)過(guò)峻村,使用在multi-heads context-attention中麸折。
    其中,padding mask在所有的scaled dot-product attention里面都需要用到粘昨,而sequence mask只有在decoder的multi-heads context-attention里面用到垢啼。


    兩種mask使用的位置對(duì)比.png

    所以窜锯,我們之前ScaledDotProductAttention的forward方法里面的參數(shù)attn_mask在不同的地方會(huì)有不同的含義。這一點(diǎn)我們會(huì)在后面說(shuō)明膊夹。

6.2.4.1 padding mask

說(shuō)白了就是對(duì)齊每一句話衬浑,每個(gè)批次輸入序列長(zhǎng)度是不一樣的捌浩。就是說(shuō)要以最長(zhǎng)的那句話為標(biāo)準(zhǔn)放刨,其他句子少一個(gè)詞就填充一個(gè)0。因?yàn)檫@些填充的位置是沒(méi)有意義的尸饺,attention機(jī)制不應(yīng)該把注意力放在這些位置上进统,所以我們需要進(jìn)行一些處理。
操作方法就是把這些位置的值加上一個(gè)非常大的負(fù)數(shù)(可以是負(fù)無(wú)窮)浪听,這樣的話螟碎,經(jīng)過(guò)softmax,這些位置的概率就會(huì)接近0迹栓。
padding mask是一個(gè)張量掉分,每個(gè)值都是一個(gè)Boolen,值為False的地方就是我們要進(jìn)行處理的地方克伊。

6.2.4.2 sequence mask

sequence mask是為了使得decoder不能看見(jiàn)未來(lái)的信息酥郭。也就是對(duì)于一個(gè)序列,在time_step為t的時(shí)刻愿吹,我們的解碼輸出應(yīng)該只能依賴于t時(shí)刻之前的輸出不从,而不能依賴t之后的輸出。因此我們需要想一個(gè)辦法犁跪,把t之后的信息給隱藏起來(lái)椿息。
這部分具體操作:產(chǎn)生一個(gè)上三角矩陣,上三角的值全為1坷衍,下三角的值全為0寝优,對(duì)角線也是0。把這個(gè)矩陣作用在每一個(gè)序列上枫耳。
沒(méi)看懂乏矾,好像第一次看這部分也是沒(méi)看懂,等下實(shí)現(xiàn)的時(shí)候看看吧嘉涌,不知道這句話的意思妻熊。

6.2.5 使用殘差residual connection

避免梯度消失。
殘差連接對(duì)應(yīng)的是每個(gè)block中的add


image.png

殘差連接示意圖如下仑最。

殘差連接

假設(shè)網(wǎng)絡(luò)中某個(gè)層對(duì)輸入x作用(比如使用Relu作用)后的輸出是扔役,那么增加residual connection之后,就變成了:

這個(gè)+x操作就是一個(gè)shortcut警医。
那么殘差結(jié)構(gòu)有什么好處呢亿胸?顯而易見(jiàn):因?yàn)樵黾恿艘豁?xiàng)x坯钦,那么該層網(wǎng)絡(luò)對(duì)x求偏導(dǎo)的時(shí)候,多了一個(gè)常數(shù)項(xiàng)1侈玄。在反向傳播過(guò)程中婉刀,梯度連乘,也不會(huì)造成梯度消失序仙。

6.2.6 使用Positional-encoding

對(duì)序列中的詞語(yǔ)出現(xiàn)的位置進(jìn)行編碼突颊。這樣模型就可以捕捉順序信息。
在處理完模型的各個(gè)模塊后潘悼,開(kāi)始關(guān)注數(shù)據(jù)的輸入部分律秃,在這里重點(diǎn)是位置編碼。與CNN和RNN不同治唤,Transformer模型對(duì)于序列沒(méi)有編碼棒动,這就導(dǎo)致無(wú)法獲取每個(gè)詞之間的關(guān)系,也就是無(wú)法構(gòu)成有意義的語(yǔ)句宾添。
為了解決這個(gè)問(wèn)題船惨。論文提出了Positional encoding。核心就是對(duì)序列中的詞語(yǔ)出現(xiàn)的位置進(jìn)行編碼缕陕。如果對(duì)位置進(jìn)行編碼粱锐,那么我們的模型就可以捕捉順序信息。
論文使用正余弦函數(shù)實(shí)現(xiàn)位置編碼榄檬。這樣做的好處就是不僅可以獲取詞的絕對(duì)位置信息卜范,還可以獲取相對(duì)位置信息。
PE(pos,2i) = sin(pos/10000^{2i/d_{model}})
PE(pos,2i+1) = cos(pos/10000^{2i/d_{model}})
其中鹿榜,pos是指詞語(yǔ)在序列中的位置海雪。可以看出舱殿,在偶數(shù)位置奥裸,使用正弦編碼,在奇數(shù)位置沪袭,使用余弦編碼湾宙。

相對(duì)位置信息通過(guò)以下公式實(shí)現(xiàn)
sin(\alpha+\beta) = sin\alpha cos\beta + cos\alpha sin\beta
cos(\alpha+\beta) = cos\alpha cos\beta - sin\alpha sin\beta
上面的公式說(shuō)明,對(duì)于詞匯之間的位置偏移k冈绊,PE(pos+k)可以表示成PE(pos)PE(k)的組合形式侠鳄,這就是表達(dá)相對(duì)位置的能力。

6.2.7 Position-wise Feed-Forward network

除了attention子層之外死宣,encoder和decoder中的每個(gè)層都包含一個(gè)完全連接的前饋網(wǎng)絡(luò)(Feed-forward network)伟恶,該網(wǎng)絡(luò)分別和相同地應(yīng)用于每個(gè)位置。這包括兩個(gè)線性變換和一個(gè)ReLU激活毅该。
FFN(x)=max(0,xW1+b1)W2+b2
雖然線性變換在不同位置上是相同的博秫,但它們?cè)趯优c層之間使用不同的參數(shù)潦牛。另一種描述這種情況的方法是兩個(gè)內(nèi)核大小為1的卷積。輸入和輸出的維數(shù)是512挡育,而中間層維度為2048巴碗。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市即寒,隨后出現(xiàn)的幾起案子橡淆,更是在濱河造成了極大的恐慌,老刑警劉巖蒿叠,帶你破解...
    沈念sama閱讀 222,104評(píng)論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件明垢,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡市咽,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,816評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門(mén)抵蚊,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)施绎,“玉大人,你說(shuō)我怎么就攤上這事贞绳」茸恚” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 168,697評(píng)論 0 360
  • 文/不壞的土叔 我叫張陵冈闭,是天一觀的道長(zhǎng)俱尼。 經(jīng)常有香客問(wèn)我,道長(zhǎng)萎攒,這世上最難降的妖魔是什么遇八? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 59,836評(píng)論 1 298
  • 正文 為了忘掉前任,我火速辦了婚禮耍休,結(jié)果婚禮上刃永,老公的妹妹穿的比我還像新娘。我一直安慰自己羊精,他們只是感情好斯够,可當(dāng)我...
    茶點(diǎn)故事閱讀 68,851評(píng)論 6 397
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著喧锦,像睡著了一般读规。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上燃少,一...
    開(kāi)封第一講書(shū)人閱讀 52,441評(píng)論 1 310
  • 那天束亏,我揣著相機(jī)與錄音,去河邊找鬼供汛。 笑死枪汪,一個(gè)胖子當(dāng)著我的面吹牛涌穆,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播雀久,決...
    沈念sama閱讀 40,992評(píng)論 3 421
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼宿稀,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了赖捌?” 一聲冷哼從身側(cè)響起祝沸,我...
    開(kāi)封第一講書(shū)人閱讀 39,899評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎越庇,沒(méi)想到半個(gè)月后罩锐,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,457評(píng)論 1 318
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡卤唉,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,529評(píng)論 3 341
  • 正文 我和宋清朗相戀三年涩惑,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片桑驱。...
    茶點(diǎn)故事閱讀 40,664評(píng)論 1 352
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡竭恬,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出熬的,到底是詐尸還是另有隱情痊硕,我是刑警寧澤,帶...
    沈念sama閱讀 36,346評(píng)論 5 350
  • 正文 年R本政府宣布押框,位于F島的核電站岔绸,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏橡伞。R本人自食惡果不足惜盒揉,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,025評(píng)論 3 334
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望骑歹。 院中可真熱鬧预烙,春花似錦、人聲如沸道媚。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,511評(píng)論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)最域。三九已至谴分,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間镀脂,已是汗流浹背牺蹄。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,611評(píng)論 1 272
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留薄翅,地道東北人沙兰。 一個(gè)月前我還...
    沈念sama閱讀 49,081評(píng)論 3 377
  • 正文 我出身青樓氓奈,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親鼎天。 傳聞我的和親對(duì)象是個(gè)殘疾皇子舀奶,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,675評(píng)論 2 359

推薦閱讀更多精彩內(nèi)容