Transformer架構(gòu)詳解

Google 2017年論文Attention is all you need提出了Transformer模型缕粹,完全基于Attention mechanism稚茅,拋棄了傳統(tǒng)的CNNRNN

1. Transformer架構(gòu)

Transformer

解釋下這個(gè)結(jié)構(gòu)圖平斩。首先亚享,Transformer模型也是使用經(jīng)典的encoder-decoder架構(gòu),由encoder和decoder兩部分組成绘面。

上圖左側(cè)用Nx框出來的欺税,就是我們encoder的一層。encoder一共有6層這樣的結(jié)構(gòu)揭璃。

上圖右側(cè)用Nx框出來的晚凿,就是我們decoder的一層。decoder一共有6層這樣的結(jié)構(gòu)瘦馍。

輸入序列經(jīng)過word embeddingpositional embedding相加后歼秽,輸入到encoder中。

輸出序列經(jīng)過word embeddingpositional embedding相加后情组,輸入到decoder中燥筷。

最后,decoder輸出的結(jié)果院崇,經(jīng)過一個(gè)線性層肆氓,然后計(jì)算softmax。

2. Encoder

encoder由6層相同的層組成底瓣,每一層分別由兩部分組成:

  • 第一部分是multi-head self-attention mechanism
  • 第二部分是position-wise feed-forward network谢揪,是一個(gè)全連接層。

兩部分濒持,都有一個(gè)殘差連接(residual connection)键耕,然后接著一個(gè)Layer Normalization寺滚。

3. Decoder

與encoder類似柑营,decoder也是由6個(gè)相同層組成,每一個(gè)層包括以下3個(gè)部分:

  • 第一部分是multi-head self-attention mechanism
  • 第二部分是multi-head context-attention mechanism
  • 第三部分是position-wise feed-forward network

同樣村视,上面三部分中每一部分官套,都有一個(gè)殘差連接(residual connection),后接著一個(gè)Layer Normalization

4. Attention機(jī)制

Attention是指對(duì)于某個(gè)時(shí)刻的輸出y奶赔,它在輸入x上各個(gè)部分的注意力惋嚎。這個(gè)注意力可以理解為權(quán)重

attention機(jī)制有很多計(jì)算方式站刑,下面是一張比較全面的表格:

image.png

seq2seq模型中另伍,使用的是加性注意力(addtion attention)較多。

為什么這種attention叫做addtion attention呢绞旅?很簡(jiǎn)單摆尝,對(duì)于輸入序列隱狀態(tài)h_i和輸出序列的隱狀態(tài)s_t,它的處理方式很簡(jiǎn)單因悲,直接合并為[s_t;h_i]

但是transformer模型使用的不是這種attention機(jī)制堕汞,使用的是另一種,叫做乘性注意力(multiplicative attention)晃琳。

那么這種乘性注意力機(jī)制是怎么樣的呢讯检?從上表中的公式也可以看出來:兩個(gè)隱狀態(tài)進(jìn)行點(diǎn)積!

4.1 Self-attention是什么卫旱?

上面我們說的attention機(jī)制的時(shí)候人灼,都會(huì)提到兩個(gè)隱狀態(tài),分別是h_is_t顾翼,前者是輸入序列第i個(gè)位置產(chǎn)生的隱狀態(tài)挡毅,后者是輸出序列在第t個(gè)位置產(chǎn)生的隱狀態(tài)。

所謂self-attention實(shí)際上就是輸出序列就是輸入序列暴构,因此計(jì)算自己的attention得分跪呈,就叫做self-attention!

4.2 Context-attention是什么?

context-attention是encoder和decoder之間的attention取逾!耗绿,所以,也可以成為encoder-decoder attention砾隅!

不管是self-attention還是context-attention误阻,它們計(jì)算attention分?jǐn)?shù)的時(shí)候,可以選擇很多方式晴埂,比如上面表中提到的:

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

那么Transformer模型究反,采用的是哪種呢?答案是:scaled dot-product attention儒洛。

4.3 Scaled dot-product attention是什么精耐?

論文Attention is all you need里面對(duì)于attention機(jī)制的描述是這樣的:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

這句話描述得很清楚了。翻譯過來就是:通過確定Q和K之間的相似程度來選擇V琅锻!

用公式來描述更加清晰:
Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V\tag{4.3.1}

scaled dot-product attentiondot-product attention唯一區(qū)別是卦停,scaled dot-product attention有一個(gè)縮放因子\frac{1}{\sqrt{d_k}}向胡。

上面公式中d_k表示的是K的維度,在論文中惊完,默認(rèn)是64僵芹。

那么為什么需要加上這個(gè)縮放因子呢?論文中給出了解釋:對(duì)于d_k很大時(shí)小槐,點(diǎn)積得到的結(jié)果維度很大拇派,使得結(jié)果處理softmax函數(shù)梯度很小的區(qū)域。

我們知道凿跳,梯度很小時(shí)攀痊,這對(duì)反向傳播不利。為了克服這個(gè)負(fù)面影響拄显,除以一個(gè)縮放因子苟径,在一定程度上減緩這種情況。

為什么是\frac{1}{\sqrt{d_k}}呢躬审?論文沒有進(jìn)一步說明棘街。個(gè)人覺得你可以使用其他縮放因子,看看模型效果有沒有提升承边。

論文中也提供了一張很清晰的結(jié)果圖遭殉,供大家參考:

image.png

首先說明一下我們的K、Q博助、V是什么:

  • 在encoder的self-attention中险污,Q、K富岳、V都來自同一個(gè)地方(相等)蛔糯,他們是上一層encoder的輸出。對(duì)于第一層encoder窖式,它們就是word embeddingpositional encoding相加得到的輸入蚁飒。

  • 在decoder的self-attention中,Q萝喘、K淮逻、V都來自同一個(gè)地方(相等),他們是上一層decoder的輸出阁簸。對(duì)于第一層decoder爬早,它們就是word embeddingpositional encoding相加得到的輸入。但是對(duì)于decoder启妹,我們不希望它能獲得下一個(gè)time step筛严,因此我們需要進(jìn)行sequence masking

  • 在encoder-decoder attention中翅溺,Q來自于decoder的上一層的輸出脑漫,K和V來自于encoder的輸出,K和V是一樣的咙崎。

  • Q优幸、K、V三者的維度一樣褪猛,即d_q=d_k=d_v网杆。

4.4 Scaled dot-product attention代碼實(shí)現(xiàn)

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """
    Scaled dot-product attention mechanism.
    """

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向傳播

        args:
            q: Queries張量,形狀[B, L_q, D_q]
            k: keys張量伊滋, 形狀[B, L_k, D_k]
            v: Values張量碳却,形狀[B, L_v, D_v]
            scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
            attn_mask: Masking張量笑旺,形狀[B, L_q, L_k]
        returns:
            上下文張量和attention張量
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            # 給需要mask的地方設(shè)置一個(gè)負(fù)無窮
            attention = attention.masked_fill_(attn_mask, -np.inf)
        # 計(jì)算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和V做點(diǎn)積
        context = torch.bmm(attention, v)

        return context, attention

5. Multi-head attention是什么呢昼浦?

理解了Scaled dot-product attentionMulti-head attention也很簡(jiǎn)單了筒主。論文提到关噪,他們發(fā)現(xiàn)將Q、K乌妙、V通過一個(gè)線性映射之后使兔,分成h份,對(duì)每一份進(jìn)行scaled dot-product attention效果更好藤韵。然后虐沥,把各個(gè)部分的結(jié)果合并起來,再次經(jīng)過線性映射泽艘,得到最終的輸出欲险。這就是所謂的multi-head attention。上面的超參數(shù)h就是heads數(shù)量匹涮。論文默認(rèn)是8盯荤。

multi-head attention的結(jié)構(gòu)圖如下:


image.png

值得注意的是,上面所說的分成h份是在d_k焕盟、d_q秋秤、d_v維度上面進(jìn)行切分的。因此脚翘,進(jìn)入到scaled dot-product attention的d_k實(shí)際上等于未進(jìn)入之前的D_k/h灼卢。

Multi-head attention允許模型加入不同位置的表示子空間的信息。

Multi-head attention的公式如下:
MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O\tag{5.1}

其中来农,
head_i=Attention(QW_i^Q,KW_i^K,VW_i^V)\tag{5.2}

論文中鞋真,d_{model}=512, h=8。所以scaled dot-product attention里面的
d_q=d_k=d_v=d_{model}/h=512/8=64

5.1 Multi-head attention代碼實(shí)現(xiàn)

class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim / num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        # multi-head attention之后需要做layer norm
        self.layer_num = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        # 殘差連接
        residual = query

        batch_size = key.size(0)

        # linear projection
        query = self.linear_q(query) # [B, L, D]
        key = self.linear_k(key) # [B, L, D]
        value = self.linear_v(value) # [B, L, D]

        # split by head
        query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
        key = key.view(batch_size * num_heads, -1, dim_per_head) # 
        value = value.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
            query, key, value, scale, attn_mask
        ) 

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)
        
        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_num(residual + output)

        return output, attention

上面代碼中出現(xiàn)了 Residual connectionLayer normalization沃于。下面進(jìn)行解釋:

5.1.1 Residual connection是什么海诲?

殘差連接其實(shí)比較簡(jiǎn)單!看圖就會(huì)比較清晰:

image.png

假設(shè)網(wǎng)絡(luò)中某個(gè)層對(duì)輸入x作用后的輸出為F(x)檩互,那么增加residual connection之后特幔,變成:

F(x) + x \tag{5.2.1}

這個(gè)+x操作被稱為shotcut

殘差結(jié)構(gòu)因?yàn)樵黾恿艘豁?xiàng)x闸昨,該層網(wǎng)絡(luò)對(duì)x求偏導(dǎo)時(shí)蚯斯,為常數(shù)項(xiàng)1!所以可以在反向傳播過程中饵较,梯度連乘拍嵌,不會(huì)造成梯度消失

5.1.2 Layer normalization是什么循诉?

歸一化層横辆,主要有這幾種方法,BatchNorm(2015年)茄猫、LayerNorm(2016年)龄糊、InstanceNorm(2016年)、GroupNorm(2018年)募疮;
將輸入的圖像shape記為[N,C,H,W]炫惩,這幾個(gè)方法主要區(qū)別是:

  • BatchNorm:batch方向做歸一化,計(jì)算NHW的均值阿浓,對(duì)小batchsize效果不好他嚷;(BN主要缺點(diǎn)是對(duì)batchsize的大小比較敏感,由于每次計(jì)算均值和方差是在一個(gè)batch上芭毙,所以如果batchsize太小筋蓖,則計(jì)算的均值、方差不足以代表整個(gè)數(shù)據(jù)分布)

  • LayerNorm:channel方向做歸一化退敦,計(jì)算CHW的均值粘咖;(對(duì)RNN作用明顯)

  • InstanceNorm:一個(gè)batch,一個(gè)channel內(nèi)做歸一化侈百。計(jì)算HW的均值瓮下,用在風(fēng)格化遷移;(因?yàn)樵趫D像風(fēng)格化中钝域,生成結(jié)果主要依賴于某個(gè)圖像實(shí)例讽坏,所以對(duì)整個(gè)batch歸一化不適合圖像風(fēng)格化中,因而對(duì)HW做歸一化例证÷肺兀可以加速模型收斂,并且保持每個(gè)圖像實(shí)例之間的獨(dú)立。)

  • GroupNorm:將channel方向分group胀葱,然后每個(gè)group內(nèi)做歸一化漠秋,算(C//G)HW的均值;這樣與batchsize無關(guān)抵屿,不受其約束庆锦。

Normalization layers

6. Mask是什么?

mask顧名思義就是掩碼晌该,大概意思是對(duì)某些值進(jìn)行掩蓋肥荔,使其不產(chǎn)生效果.

需要說明的是绿渣,Transformer模型中有兩種mask朝群。分別是padding masksequence mask。其中中符,padding mask在所有的scaled dot-product attention里都需要用到姜胖,而sequence mask只在decoder的self-attention中用到。

所以淀散,我們之前的ScaledDotProductAttention的forward方法里的參數(shù)attn_mask在不同的地方有不同的含義右莱。

6.1 Padding mask

什么是padding mask呢?回想一下档插,我們的每個(gè)批次輸入序列長(zhǎng)度是不一樣的慢蜓!也就是說,我們要對(duì)輸入序列進(jìn)行對(duì)齊郭膛!具體來說晨抡,就是給較短序列后面填充0。因?yàn)檫@些填充位置则剃,其實(shí)沒有意義耘柱,所以我們的attention機(jī)制不應(yīng)該把注意力放在這些位置上,所以我們需要進(jìn)行一些處理棍现。

具體做法是:把這些位置的值加上一個(gè)非常大的負(fù)數(shù)(可以是負(fù)無窮)调煎,這樣的話,經(jīng)過softmax己肮,這些位置的概率就會(huì)接近0士袄。

而我們的padding mask實(shí)際上是一個(gè)張量,每個(gè)值都是一個(gè)Boolean谎僻,值為False的地方就是我們要進(jìn)行處理的地方窖剑。

下面是代碼實(shí)現(xiàn):

def padding_mask(seq_q, seq_k):
    # seq_k和seq_q的形狀都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]

[B,L]->[B,1,L]->[B,L,L]

F F T T
F F T T
F F T T
F F T T

6.2 Sequence mask

sequence mask是為了使得decoder不能看到未來的信息。也就是對(duì)于一個(gè)序列戈稿,在time step為t的時(shí)刻西土,我們的解碼輸出只能依賴于t時(shí)刻之前的輸出,而不能依賴t之后的輸出鞍盗。因此我們需要想一個(gè)辦法需了,把t之后的信息給隱藏起來跳昼。

那具體如何做呢?也很簡(jiǎn)單:產(chǎn)生一個(gè)上三角矩陣肋乍,上三角矩陣的值全為1鹅颊,下三角的值全為0,對(duì)角線值也為0墓造。把這個(gè)矩陣作用在每一個(gè)序列上堪伍,就可以達(dá)到我們的目的。

具體代碼如下:

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

[B,L,L]

0 1 1 1
0 0 1 1
0 0 0 1
0 0 0 0

哈佛大學(xué)的文章The Annotated Transformer有一張效果圖:

image.png

值得注意的是觅闽,本來mask只需要二維矩陣即可帝雇,但是考慮到我們的輸入序列都是批量的,所以我們需要把原本二維矩陣擴(kuò)張成3維張量蛉拙。上面代碼中尸闸,已經(jīng)做了處理。

回到本節(jié)開始的問題孕锄,attn_mask參數(shù)有幾種情況吮廉?分別是什么意思?

  • 對(duì)于decoder的self-attention畸肆,里面使用的scaled dot-product attention宦芦,同時(shí)需要padding masksequence mask作為attn_mask,具體實(shí)現(xiàn)就是兩個(gè)mask相加作為attn_mask轴脐。
  • 其它情況调卑,attn_mask都等于padding mask

7. Positional encoding是什么豁辉?

就目前而言令野,Transformer架構(gòu)似乎少了點(diǎn)東西。沒錯(cuò)徽级,那就是它對(duì)序列的順序沒有約束气破!我們知道序列的順序是一個(gè)很重要的信息,如果缺失了這個(gè)信息餐抢,可能我們的結(jié)果就是:所有詞語都對(duì)了现使,但是無法組成有意義的語句。

為了解決這個(gè)問題旷痕,論文中提出了positional encoding碳锈。一句話概括就是:對(duì)序列中的詞語出現(xiàn)的位置進(jìn)行編碼!如果對(duì)位置進(jìn)行編碼欺抗,那么我們的模型就可以捕捉順序信息售碳。

那么具體怎么做呢?論文的實(shí)現(xiàn)是使用正余弦函數(shù)。公式如下:
PF(pos,2i)=sin(pos/10000^{2i/d_{model}})\tag{7.1}

PF(pos,2i+1)=cos(pos/10000^{2i/d_{model}})\tag{7.2}

其中贸人,pos是指詞語在序列中的位置间景。可以看出艺智,在偶數(shù)位置倘要,使用正弦編碼悉稠,在奇數(shù)位置吊输,使用余弦編碼蜒蕾。

上面公式中的d_{model}是模型的維度离钝,論文默認(rèn)是512

這個(gè)編碼公式的意思就是:給定詞語的位置pos倡蝙,我們可以把它編碼成d_{model}維的向量苟穆!也就是說妓忍,位置編碼的每一個(gè)維度對(duì)應(yīng)正弦曲線甲喝,波長(zhǎng)構(gòu)成了從2\pi10000*2\pi的等比序列尝苇。

Postional encoding是對(duì)詞匯的位置編碼铛只。

7.1 Positional encoding代碼實(shí)現(xiàn)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_seq_len):
        """
        初始化

        args:
            d_model: 一個(gè)標(biāo)量埠胖。模型的維度,論文默認(rèn)是512
            max_seq_len: 一個(gè)標(biāo)量淳玩。文本序列的最大長(zhǎng)度
        """
        super(PositionalEncoding, self).__init__()

        # 根據(jù)論文給出的公式直撤,構(gòu)造出PE矩陣
        position_encoding = np.array([
            [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
            for pos in range(max_seq_len)
        ])
        # 偶數(shù)列使用sin,奇數(shù)列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩陣的一次行蜕着,加上一個(gè)全是0的向量谋竖,代表這`PAD`的positional_encoding
        # 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding承匣,兩者十分類似
        # 那么為什么需要這個(gè)額外的PAD的編碼呢蓖乘?很簡(jiǎn)單,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易韧骗,我們需要對(duì)齊嘉抒,
        # 短的序列我們使用0在結(jié)尾不全,我們也需要這些補(bǔ)全位置的編碼袍暴,也就是`PAD`對(duì)應(yīng)的位置編碼
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))

        # 嵌入操作些侍,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
        # word embedding中如果詞典增加`UNK`,我們也需要+1政模。
        self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)

    def forward(self, input_len):
        """
        神經(jīng)網(wǎng)絡(luò)前向傳播

        args:
            input_len: 一個(gè)張量岗宣,形狀為[BATCH_SIZE, 1]。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度淋样。

        returns:
            返回這一批序列的位置編碼耗式,進(jìn)行了對(duì)齊。
        """

        # 找出這一批序列的最大長(zhǎng)度
        max_len = torch.max(input_len)
        # 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
        # 這里range從1開始也是因?yàn)橐荛_PAD(0)的位置
        input_pos = torch.LongTensor(
            [list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
        )
        return self.position_encoding(input_pos)

8. Word embedding是什么刊咳?

Word embedding是對(duì)序列中的詞匯的編碼措嵌,把每一個(gè)詞匯編碼成d_{model}維的向量!它實(shí)際上就是一個(gè)二維浮點(diǎn)矩陣芦缰,里面的權(quán)重是可訓(xùn)練參數(shù)企巢,我們只需要把這個(gè)矩陣構(gòu)建出來就完成了word embedding的工作。

embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)

上面vocab_size是詞典大小让蕾,embedding_size是詞嵌入的維度大小浪规,論文里面就是等于d_{model}=512。所以word embedding矩陣就是一個(gè)vocab_size*embedding_size的二維張量探孝。

9. Position-wise Feed-Forward netword是什么笋婿?

這是一個(gè)全連接網(wǎng)絡(luò)顿颅,包含連個(gè)線性變換和一個(gè)非線性函數(shù)(ReLU)。公式如下:
FFN(x)=max(0,xW_1+b_1)W2+b2\tag{9.1}

這個(gè)線性變換在不同的位置都是一樣的粱腻,并且在不同的層之間使用不同的參數(shù)庇配。

論文提到,這個(gè)公式還可以用兩個(gè)核大小為1的一維卷積來解釋绍些,卷積的輸入輸出都是d_{model}=512捞慌,中間層維度是d_{ff}=2048

代碼如下:

class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output

10. 完整代碼

至此柬批,所有的細(xì)節(jié)都解釋完了啸澡。

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    """
    Scaled dot-product attention mechanism.
    """

    def __init__(self, attention_dropout=0.0):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, scale=None, attn_mask=None):
        """
        前向傳播

        args:
            q: Queries張量,形狀[B, L_q, D_q]
            k: keys張量氮帐, 形狀[B, L_k, D_k]
            v: Values張量嗅虏,形狀[B, L_v, D_v]
            scale: 縮放因子,一個(gè)浮點(diǎn)標(biāo)量
            attn_mask: Masking張量上沐,形狀[B, L_q, L_k]
        returns:
            上下文張量和attention張量
        """
        attention = torch.bmm(q, k.transpose(1, 2))
        if scale:
            attention = attention * scale
        if attn_mask:
            # 給需要mask的地方設(shè)置一個(gè)負(fù)無窮
            attention = attention.masked_fill_(attn_mask, -np.inf)
        # 計(jì)算softmax
        attention = self.softmax(attention)
        # 添加dropout
        attention = self.dropout(attention)
        # 和V做點(diǎn)積
        context = torch.bmm(attention, v)

        return context, attention

class MultiHeadAttention(nn.Module):

    def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
        super(MultiHeadAttention, self).__init__()

        self.dim_per_head = model_dim / num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
        self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(model_dim, model_dim)
        self.dropout = nn.Dropout(dropout)
        # multi-head attention之后需要做layer norm
        self.layer_num = nn.LayerNorm(model_dim)

    def forward(self, query, key, value, attn_mask=None):
        # 殘差連接
        residual = query

        batch_size = key.size(0)

        # linear projection
        query = self.linear_q(query) # [B, L, D]
        key = self.linear_k(key) # [B, L, D]
        value = self.linear_v(value) # [B, L, D]

        # split by head
        query = query.view(batch_size * num_heads, -1, dim_per_head) # [B * 8, , D / 8]
        key = key.view(batch_size * num_heads, -1, dim_per_head) # 
        value = value.view(batch_size * num_heads, -1, dim_per_head)

        if attn_mask:
            attn_mask = attn_mask.repeat(num_heads, 1, 1)
        # scaled dot product attention
        scale = (key.size(-1) // num_heads) ** -0.5
        context, attention = self.dot_product_attention(
            query, key, value, scale, attn_mask
        ) 

        # concat heads
        context = context.view(batch_size, -1, dim_per_head * num_heads)
        
        # final linear projection
        output = self.linear_final(context)

        # dropout
        output = self.dropout(output)

        # add residual and norm layer
        output = self.layer_num(residual + output)

        return output, attention

def padding_mask(seq_q, seq_k):
    # seq_k和seq_q的形狀都是[B,L]
    len_q = seq_q.size(1)
    # `PAD` is 0
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B,L_q,L_k]

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
                    diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_seq_len):
        """
        初始化

        args:
            d_model: 一個(gè)標(biāo)量皮服。模型的維度,論文默認(rèn)是512
            max_seq_len: 一個(gè)標(biāo)量奄容。文本序列的最大長(zhǎng)度
        """
        super(PositionalEncoding, self).__init__()

        # 根據(jù)論文給出的公式冰更,構(gòu)造出PE矩陣
        position_encoding = np.array([
            [pos / np.pow(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
            for pos in range(max_seq_len)
        ])
        # 偶數(shù)列使用sin,奇數(shù)列使用cos
        position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
        position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

        # 在PE矩陣的一次行昂勒,加上一個(gè)全是0的向量蜀细,代表這`PAD`的positional_encoding
        # 在word embedding中也會(huì)經(jīng)常加上`UNK`,代表位置單詞的word embedding戈盈,兩者十分類似
        # 那么為什么需要這個(gè)額外的PAD的編碼呢奠衔?很簡(jiǎn)單谆刨,因?yàn)槲谋拘蛄械拈L(zhǎng)度不易,我們需要對(duì)齊归斤,
        # 短的序列我們使用0在結(jié)尾不全痊夭,我們也需要這些補(bǔ)全位置的編碼,也就是`PAD`對(duì)應(yīng)的位置編碼
        pad_row = torch.zeros([1, d_model])
        position_encoding = torch.cat((pad_row, position_encoding))

        # 嵌入操作脏里,+1是因?yàn)樵黾恿薫PAD`這個(gè)補(bǔ)全位置的編碼
        # word embedding中如果詞典增加`UNK`她我,我們也需要+1。
        self.position_encoding = nn.Embedding(max_seq_len+1, d_model)
        self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False)

    def forward(self, input_len):
        """
        神經(jīng)網(wǎng)絡(luò)前向傳播

        args:
            input_len: 一個(gè)張量迫横,形狀為[BATCH_SIZE, 1]番舆。每一個(gè)張量的值代表這一批文本序列中對(duì)應(yīng)的長(zhǎng)度。

        returns:
            返回這一批序列的位置編碼矾踱,進(jìn)行了對(duì)齊恨狈。
        """

        # 找出這一批序列的最大長(zhǎng)度
        max_len = torch.max(input_len)
        # 對(duì)每一個(gè)序列的位置進(jìn)行對(duì)齊,在原序列位置的后面補(bǔ)上0
        # 這里range從1開始也是因?yàn)橐荛_PAD(0)的位置
        input_pos = torch.LongTensor(
            [list(range(1, len+1)) + [0] * (max_len-len) for len in input_len]
        )
        return self.position_encoding(input_pos)

# embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 獲得輸入的詞嵌入編碼
# seq_embedding = seq_embedding(inputs) * np.sqrt(d_model)

class PositionalWiseFeedForward(nn.Module):

    def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
        super(PositionalWiseFeedForward, self).__init__()
        self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
        self.w2 = nn.Conv2d(model_dim, ffn_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(model_dim)

    def forward(self, x):
        output = x.transpose(1, 2)
        output = self.w2(F.relu(self.w1(output)))
        output = self.dropout(output.transpose(1, 2))

        # add residual and norm layer
        output = self.layer_norm(x + output)
        return output

class EncoderLayer(nn.Module):
    """Encoder的一層呛讲。"""
    def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self, inputs, attn_mask=None):
        # self attention
        context, attention = self.attention(inputs, inputs, inputs, attn_mask)

        # feed forward network
        output = self.feed_forward(context)

        return output, attention


class Encoder(nn.Module):
    """多層EncoderLayer組成的Encoder"""
    def __init__(self,
                vocab_size,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Encoder, self).__init__()

        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
        )

        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_mask = padding_mask(inputs, inputs)

        attentions = []
        for encoder in self.encoder_layers:
            output, attention = encoder(output, self_attention_mask)
            attentions.append(attention)

        return output, attentions

class DecoderLayer(nn.Module):
    def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
        super(DecoderLayer, self).__init__()

        self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

    def forward(self,
                dec_inputs,
                enc_outputs,
                self_attn_mask=None,
                context_attn_mask=None):
        # self attention, all inputs are decoder inputs
        dec_output, self_attention = self.attention(dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

        # context attention
        # query is decoder's outputs, key and value are encoder's inputs
        dec_output, context_attention = self.attention(dec_output, enc_outputs, enc_outputs, context_attn_mask)

        # decoder's output, or context
        dec_output = self.feed_forward(dec_output)

        return dec_output, self_attention, context_attention

class Decoder(nn.Module):
    def __init__(self,
                vocab_size,
                max_seq_len,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Decoder).__init__()

        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)]
        )
        
        self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

    def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
        output = self.seq_embedding(inputs)
        output += self.pos_embedding(inputs_len)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs)
        self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

        self_attentions = []
        context_attentions = []
        for decoder in self.decoder_layers:
            output, self_attn, context_attn = decoder(
            output, enc_output, self_attn_mask, context_attn_mask)
            self_attentions.append(self_attn)
            context_attentions.append(context_attn)

        return output, self_attentions, context_attentions

    
class Transformer(nn.Module):
    def __init__(self,
                src_vocab_size,
                src_max_len,
                tgt_vocab_size,
                tgt_max_len,
                num_layers=6,
                model_dim=512,
                num_heads=8,
                ffn_dim=2048,
                dropout=0.0):
        super(Transformer).__init__()

        self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)
        self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout)

        self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
        self.softmax = nn.Softmax()

    def forward(self, src_seq, src_len, tgt_seq, tgt_len):
        context_attn_mask = padding_mask(tgt_seq, src_seq)

        output, enc_self_attn = self.encoder(src_seq, src_len)

        output, dec_self_attn, ctx_attn = self.decoder(tgt_seq, tgt_len, output, context_attn_mask)

        output = self.linear(output)
        output = self.softmax(output)

        return output, enc_self_attn, dec_self_attn, ctx_attn
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末禾怠,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子贝搁,更是在濱河造成了極大的恐慌吗氏,老刑警劉巖,帶你破解...
    沈念sama閱讀 219,427評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件徘公,死亡現(xiàn)場(chǎng)離奇詭異牲证,居然都是意外死亡哮针,警方通過查閱死者的電腦和手機(jī)关面,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,551評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來十厢,“玉大人等太,你說我怎么就攤上這事÷牛” “怎么了缩抡?”我有些...
    開封第一講書人閱讀 165,747評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)包颁。 經(jīng)常有香客問我瞻想,道長(zhǎng),這世上最難降的妖魔是什么娩嚼? 我笑而不...
    開封第一講書人閱讀 58,939評(píng)論 1 295
  • 正文 為了忘掉前任蘑险,我火速辦了婚禮,結(jié)果婚禮上岳悟,老公的妹妹穿的比我還像新娘佃迄。我一直安慰自己泼差,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,955評(píng)論 6 392
  • 文/花漫 我一把揭開白布呵俏。 她就那樣靜靜地躺著堆缘,像睡著了一般。 火紅的嫁衣襯著肌膚如雪普碎。 梳的紋絲不亂的頭發(fā)上吼肥,一...
    開封第一講書人閱讀 51,737評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音麻车,去河邊找鬼潜沦。 笑死,一個(gè)胖子當(dāng)著我的面吹牛绪氛,可吹牛的內(nèi)容都是我干的唆鸡。 我是一名探鬼主播,決...
    沈念sama閱讀 40,448評(píng)論 3 420
  • 文/蒼蘭香墨 我猛地睜開眼枣察,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼争占!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起序目,我...
    開封第一講書人閱讀 39,352評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤臂痕,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后猿涨,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體握童,經(jīng)...
    沈念sama閱讀 45,834評(píng)論 1 317
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,992評(píng)論 3 338
  • 正文 我和宋清朗相戀三年叛赚,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了澡绩。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,133評(píng)論 1 351
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡俺附,死狀恐怖肥卡,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情事镣,我是刑警寧澤步鉴,帶...
    沈念sama閱讀 35,815評(píng)論 5 346
  • 正文 年R本政府宣布,位于F島的核電站璃哟,受9級(jí)特大地震影響氛琢,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜随闪,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,477評(píng)論 3 331
  • 文/蒙蒙 一阳似、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧蕴掏,春花似錦障般、人聲如沸调鲸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,022評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)藐石。三九已至,卻和暖如春定拟,著一層夾襖步出監(jiān)牢的瞬間于微,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,147評(píng)論 1 272
  • 我被黑心中介騙來泰國(guó)打工青自, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留株依,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,398評(píng)論 3 373
  • 正文 我出身青樓延窜,卻偏偏與公主長(zhǎng)得像恋腕,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子逆瑞,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,077評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容