一步步解析Attention is All You Need!

本文將通過細(xì)節(jié)剖析以及代碼相結(jié)合的方式,來一步步解析Attention is all you need這篇文章。

這篇文章的下載地址為:https://arxiv.org/abs/1706.03762

本文的部分圖片來自文章:https://mp.weixin.qq.com/s/RLxWevVWHXgX-UcoxDS70w筋讨,寫的非常好!

本文邊講細(xì)節(jié)邊配合代碼實(shí)戰(zhàn)摸恍,代碼地址為:https://github.com/princewen/tensorflow_practice/tree/master/basic/Basic-Transformer-Demo

數(shù)據(jù)地址為:https://pan.baidu.com/s/14XfprCqjmBKde9NmNZeCNg 密碼:lfwu

好了悉罕,廢話不多說,我們進(jìn)入正題误墓!我們從簡單到復(fù)雜蛮粮,一步步介紹該模型的結(jié)構(gòu)!

1谜慌、整體架構(gòu)

模型的整體框架如下:

整體架構(gòu)看似復(fù)雜然想,其實(shí)就是一個(gè)Seq2Seq結(jié)構(gòu),簡化一下欣范,就是這樣的:

Encoder的輸出和decoder的結(jié)合如下变泄,即最后一個(gè)encoder的輸出將和每一層的decoder進(jìn)行結(jié)合:

好了,我們主要關(guān)注的是每一層Encoder和每一層Decoder的內(nèi)部結(jié)構(gòu)恼琼。如下圖所示:

可以看到妨蛹,Encoder的每一層有兩個(gè)操作,分別是Self-Attention和Feed Forward晴竞;而Decoder的每一層有三個(gè)操作蛙卤,分別是Self-Attention、Encoder-Decoder Attention以及Feed Forward操作噩死。這里的Self-Attention和Encoder-Decoder Attention都是用的是Multi-Head Attention機(jī)制颤难,這也是我們本文重點(diǎn)講解的地方。

在介紹之前已维,我們先介紹下我們的數(shù)據(jù)行嗤,經(jīng)過處理之后,數(shù)據(jù)如下:

很簡單垛耳,上面部分是我們的x栅屏,也就是encoder的輸入飘千,下面部分是y,也就是decoder的輸入栈雳,這是一個(gè)機(jī)器翻譯的數(shù)據(jù)护奈,x中的每一個(gè)id代表一個(gè)語言中的單詞id,y中的每一個(gè)id代表另一種語言中的單詞id甫恩。后面為0的部分是填充部分逆济,代表這個(gè)句子的長度沒有達(dá)到我們設(shè)置的最大長度,進(jìn)行補(bǔ)齊磺箕。

2、Position Embedding

給定我們的輸入數(shù)據(jù)抛虫,我們首先要轉(zhuǎn)換成對(duì)應(yīng)的embedding松靡,由于我們后面要在計(jì)算attention時(shí)屏蔽掉填充的部分,所以這里我們對(duì)于填充的部分的embedding直接賦予0值建椰。Embedding的函數(shù)如下:

def embedding(inputs,
              vocab_size,
              num_units,
              zero_pad=True,
              scale=True,
              scope="embedding",
              reuse=None):
    with tf.variable_scope(scope, reuse=reuse):
        lookup_table = tf.get_variable('lookup_table',
                                       dtype=tf.float32,
                                       shape=[vocab_size, num_units],
                                       initializer=tf.contrib.layers.xavier_initializer())
        if zero_pad:
            lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
                                      lookup_table[1:, :]), 0)
        outputs = tf.nn.embedding_lookup(lookup_table, inputs)

        if scale:
            outputs = outputs * (num_units ** 0.5)

    return outputs

在本文中雕欺,Embedding操作不是普通的Embedding而是加入了位置信息的Embedding,我們稱之為Position Embedding棉姐。因?yàn)樵诒疚牡哪P椭型懒校呀?jīng)沒有了循環(huán)神經(jīng)網(wǎng)絡(luò)這樣的結(jié)構(gòu),因此序列信息已經(jīng)無法捕捉伞矩。但是序列信息非常重要笛洛,代表著全局的結(jié)構(gòu),因此必須將序列的分詞相對(duì)或者絕對(duì)position信息利用起來乃坤。位置信息的計(jì)算公式如下:

其中pos代表的是第幾個(gè)詞苛让,i代表embedding中的第幾維。這部分的代碼如下湿诊,對(duì)于padding的部分狱杰,我們還是使用全0處理。

def positional_encoding(inputs,
                        num_units,
                        zero_pad = True,
                        scale = True,
                        scope = "positional_encoding",
                        reuse=None):

    N,T = inputs.get_shape().as_list()
    with tf.variable_scope(scope,reuse=True):
        position_ind = tf.tile(tf.expand_dims(tf.range(T),0),[N,1])

        position_enc = np.array([
            [pos / np.power(10000, 2.*i / num_units) for i in range(num_units)]
            for pos in range(T)])

        position_enc[:,0::2] = np.sin(position_enc[:,0::2]) # dim 2i
        position_enc[:,1::2] = np.cos(position_enc[:,1::2]) # dim 2i+1

        lookup_table = tf.convert_to_tensor(position_enc)

        if zero_pad:
            lookup_table = tf.concat((tf.zeros(shape=[1,num_units]),lookup_table[1:,:]),0)

        outputs = tf.nn.embedding_lookup(lookup_table,position_ind)

        if scale:
            outputs = outputs * num_units ** 0.5

        return outputs

所以對(duì)于輸入厅须,我們調(diào)用上面兩個(gè)函數(shù)仿畸,并將結(jié)果相加就能得到最終Position Embedding的結(jié)果:

self.enc = embedding(self.x,
                     vocab_size=len(de2idx),
                     num_units = hp.hidden_units,
                     zero_pad=True, # 讓padding一直是0
                     scale=True,
                     scope="enc_embed")
self.enc += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.x)[1]),0),[tf.shape(self.x)[0],1]),
                      vocab_size = hp.maxlen,
                      num_units = hp.hidden_units,
                      zero_pad = False,
                      scale = False,
                      scope = "enc_pe")

3、Multi-Head Attention

3.1 Attention簡單回顧

Attention其實(shí)就是計(jì)算一種相關(guān)程度朗和,看下面的例子:

Attention通炒砉粒可以進(jìn)行如下描述,表示為將query(Q)和key-value pairs映射到輸出上例隆,其中query甥捺、每個(gè)key、每個(gè)value都是向量镀层,輸出是V中所有values的加權(quán)镰禾,其中權(quán)重是由Query和每個(gè)key計(jì)算出來的皿曲,計(jì)算方法分為三步:

1)計(jì)算比較Q和K的相似度,用f來表示:

2)將得到的相似度進(jìn)行softmax歸一化:

3)針對(duì)計(jì)算出來的權(quán)重吴侦,對(duì)所有的values進(jìn)行加權(quán)求和屋休,得到Attention向量:

計(jì)算相似度的方法有以下4種:

在本文中,我們計(jì)算相似度的方式是第一種备韧,本文提出的Attention機(jī)制稱為Multi-Head Attention劫樟,不過在這之前,我們要先介紹它的簡單版本 Scaled Dot-Product Attention织堂。

計(jì)算Attention首先要有query叠艳,key和value。我們前面提到了易阳,Encoder的attention是self-attention附较,Decoder里面的attention首先是self-attention,然后是encoder-decoder attention潦俺。這里的兩種attention是針對(duì)query和key-value來說的拒课,對(duì)于self-attention來說,計(jì)算得到query和key-value的過程都是使用的同樣的輸入事示,因?yàn)橐阕约焊约旱腶ttention嘛早像;而對(duì)encoder-decoder attention來說,query的計(jì)算使用的是decoder的輸入肖爵,而key-value的計(jì)算使用的是encoder的輸出卢鹦,因?yàn)槲覀円?jì)算decoder的輸入跟encoder里面每一個(gè)的相似度嘛。

因此本文下面對(duì)于attention的講解遏匆,都是基于self-attention來說的法挨,如果是encoder-decoder attention,只要改一下輸入即可幅聘,其余過程都是一樣的凡纳。

3.2 Scaled Dot-Product Attention

Scaled Dot-Product Attention的圖示如下:

接下來,我們對(duì)上述過程進(jìn)行一步步的拆解:

First Step-得到embedding

給定我們的輸入數(shù)據(jù)帝蒿,我們首先要轉(zhuǎn)換成對(duì)應(yīng)的position embedding荐糜,效果圖如下,綠色部分代表填充部分葛超,全0值:

得到Embedding的過程我們上文中已經(jīng)介紹過了暴氏,這里不再贅述。

Second Step-得到Q绣张,K答渔,V

計(jì)算Attention首先要有Query,Key和Value侥涵,我們通過一個(gè)線性變換來得到三者沼撕。我們的輸入是position embedding宋雏,過程如下:

代碼也很簡單怠肋,下面的代碼中膏执,如果是self-attention的話,query和key-value輸入的embedding是一樣的朋贬。padding的部分由于都是0笼沥,結(jié)果中該部分還是0蚪燕,所以仍然用綠色表示

# Linear projection
Q = tf.layers.dense(queries,num_units,activation=tf.nn.relu) #
K = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #
V = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #

Third-Step-計(jì)算相似度

接下來就是計(jì)算相似度了,我們之前說過了奔浅,本文中使用的是點(diǎn)乘的方式馆纳,所以將Q和K進(jìn)行點(diǎn)乘即可,過程如下:

文中對(duì)于相似度還除以了dk的平方根汹桦,這里dk是key的embedding長度厕诡。

這一部分的代碼如下:

outputs = tf.matmul(Q,tf.transpose(K,[0,2,1]))
outputs = outputs / (K.get_shape().as_list()[-1] ** 0.5)

你可能注意到了,這樣做其實(shí)是得到了一個(gè)注意力的矩陣营勤,每一行都是一個(gè)query和所有key的相似性,對(duì)self-attention來說壹罚,其效果如下:

不過我們還沒有進(jìn)行softmax歸一化操作葛作,因?yàn)槲覀冞€需要進(jìn)行一些處理。

Forth-Step-增加mask

剛剛得到的注意力矩陣猖凛,我們還需要做一下處理赂蠢,主要有:

  1. query和key有些部分是填充的,這些需要用mask屏蔽辨泳,一個(gè)簡單的方法就是賦予一個(gè)很小很小的值或者直接變?yōu)?值虱岂。
  2. 對(duì)于decoder的來說,我們是不能看到未來的信息的菠红,所以對(duì)于decoder的輸入第岖,我們只能計(jì)算它和它之前輸入的信息的相似度。

我們首先對(duì)key中填充的部分進(jìn)行屏蔽试溯,我們之前介紹了蔑滓,在進(jìn)行embedding時(shí),填充的部分的embedding 直接設(shè)置為全0遇绞,所以我們直接根據(jù)這個(gè)來進(jìn)行屏蔽键袱,即對(duì)embedding的向量所有維度相加得到一個(gè)標(biāo)量,如果標(biāo)量是0摹闽,那就代表是填充的部分蹄咖,否則不是:

這部分的代碼如下:

key_masks = tf.sign(tf.abs(tf.reduce_sum(keys,axis=-1)))
key_masks = tf.tile(tf.expand_dims(key_masks,1),[1,tf.shape(queries)[1],1])
paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(key_masks,0),paddings,outputs)

經(jīng)過這一步處理,效果如下付鹿,我們下圖中用深灰色代表屏蔽掉的部分:

接下來的操作只針對(duì)Decoder的self-attention來說澜汤,我們首先得到一個(gè)下三角矩陣蚜迅,這個(gè)矩陣主對(duì)角線以及下方的部分是1,其余部分是0银亲,然后根據(jù)1或者0來選擇使用output還是很小的數(shù)進(jìn)行填充:

diag_vals = tf.ones_like(outputs[0,:,:])
tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense()
masks = tf.tile(tf.expand_dims(tril,0),[tf.shape(outputs)[0],1,1])

paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(masks,0),paddings,outputs)

得到的效果如下圖所示:

接下來慢叨,我們對(duì)query的部分進(jìn)行屏蔽,與屏蔽key的思路大致相同务蝠,不過我們這里不是用很小的值替換了拍谐,而是直接把填充的部分變?yōu)?:

query_masks = tf.sign(tf.abs(tf.reduce_sum(queries,axis=-1)))
query_masks = tf.tile(tf.expand_dims(query_masks,-1),[1,1,tf.shape(keys)[1]])
outputs *= query_masks

經(jīng)過這一步,Encoder和Decoder得到的最終的相似度矩陣如下馏段,上邊是Encoder的結(jié)果轩拨,下邊是Decoder的結(jié)果:

接下來,我們就可以進(jìn)行softmax操作了:

outputs = tf.nn.softmax(outputs)

Fifth-Step-得到最終結(jié)果

得到了Attention的相似度矩陣院喜,我們就可以和Value進(jìn)行相乘亡蓉,得到經(jīng)過attention加權(quán)的結(jié)果:

這一部分是一個(gè)簡單的矩陣相乘運(yùn)算,代碼如下:

outputs = tf.matmul(outputs,V)

不過這并不是最終的結(jié)果,這里文中還加入了殘差網(wǎng)絡(luò)的結(jié)構(gòu)喷舀,即將最終的結(jié)果和queries的輸入進(jìn)行相加:

outputs += queries

所以一個(gè)完整的Scaled Dot-Product Attention的代碼如下:

def scaled_dotproduct_attention(queries,keys,num_units=None,
                        num_heads = 0,
                        dropout_rate = 0,
                        is_training = True,
                        causality = False,
                        scope = "mulithead_attention",
                        reuse = None):
    with tf.variable_scope(scope,reuse=reuse):
        if num_units is None:
            num_units = queries.get_shape().as_list[-1]

        # Linear projection
        Q = tf.layers.dense(queries,num_units,activation=tf.nn.relu) #
        K = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #
        V = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #

        outputs = tf.matmul(Q,tf.transpose(K,[0,2,1]))
        outputs = outputs / (K.get_shape().as_list()[-1] ** 0.5)

        # 這里是對(duì)填充的部分進(jìn)行一個(gè)mask砍濒,這些位置的attention score變?yōu)闃O小,我們的embedding操作中是有一個(gè)padding操作的硫麻,
        # 填充的部分其embedding都是0爸邢,加起來也是0,我們就會(huì)填充一個(gè)很小的數(shù)拿愧。
        key_masks = tf.sign(tf.abs(tf.reduce_sum(keys,axis=-1)))
        key_masks = tf.tile(tf.expand_dims(key_masks,1),[1,tf.shape(queries)[1],1])

        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(key_masks,0),paddings,outputs)

        # 這里其實(shí)就是進(jìn)行一個(gè)mask操作杠河,不給模型看到未來的信息。
        if causality:
            diag_vals = tf.ones_like(outputs[0,:,:])
            tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense()
            masks = tf.tile(tf.expand_dims(tril,0),[tf.shape(outputs)[0],1,1])

            paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
            outputs = tf.where(tf.equal(masks,0),paddings,outputs)

        outputs = tf.nn.softmax(outputs)
        # Query Mask
        query_masks = tf.sign(tf.abs(tf.reduce_sum(queries,axis=-1)))
        query_masks = tf.tile(tf.expand_dims(query_masks,-1),[1,1,tf.shape(keys)[1]])
        outputs *= query_masks
        # Dropout
        outputs = tf.layers.dropout(outputs,rate = dropout_rate,training = tf.convert_to_tensor(is_training))
        # Weighted sum
        outputs = tf.matmul(outputs,V)
        # Residual connection
        outputs += queries
        # Normalize
        outputs = normalize(outputs)

    return outputs

3.3 Multi-Head Attention

Multi-Head Attention就是把Scaled Dot-Product Attention的過程做H次浇辜,然后把輸出合起來券敌。論文中,它的結(jié)構(gòu)圖如下:

這部分的示意圖如下所示柳洋,我們重復(fù)做3次相似的操作待诅,得到每一個(gè)的結(jié)果矩陣,隨后將結(jié)果矩陣進(jìn)行拼接膳灶,再經(jīng)過一次的線性操作咱士,得到最終的結(jié)果:

Scaled Dot-Product Attention可以看作是只有一個(gè)Head的Multi-Head Attention,這部分的代碼跟Scaled Dot-Product Attention大同小異轧钓,我們直接貼出:

def multihead_attention(queries,keys,num_units=None,
                        num_heads = 0,
                        dropout_rate = 0,
                        is_training = True,
                        causality = False,
                        scope = "mulithead_attention",
                        reuse = None):
    with tf.variable_scope(scope,reuse=reuse):
        if num_units is None:
            num_units = queries.get_shape().as_list[-1]

        # Linear projection
        Q = tf.layers.dense(queries,num_units,activation=tf.nn.relu) #
        K = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #
        V = tf.layers.dense(keys,num_units,activation=tf.nn.relu) #

        # Split and Concat
        Q_ = tf.concat(tf.split(Q,num_heads,axis=2),axis=0) #
        K_ = tf.concat(tf.split(K,num_heads,axis=2),axis=0)
        V_ = tf.concat(tf.split(V,num_heads,axis=2),axis=0)

        outputs = tf.matmul(Q_,tf.transpose(K_,[0,2,1]))
        outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)

        # 這里是對(duì)填充的部分進(jìn)行一個(gè)mask序厉,這些位置的attention score變?yōu)闃O小,我們的embedding操作中是有一個(gè)padding操作的毕箍,
        # 填充的部分其embedding都是0弛房,加起來也是0,我們就會(huì)填充一個(gè)很小的數(shù)而柑。
        key_masks = tf.sign(tf.abs(tf.reduce_sum(keys,axis=-1)))
        key_masks = tf.tile(key_masks,[num_heads,1])
        key_masks = tf.tile(tf.expand_dims(key_masks,1),[1,tf.shape(queries)[1],1])

        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
        outputs = tf.where(tf.equal(key_masks,0),paddings,outputs)

        # 這里其實(shí)就是進(jìn)行一個(gè)mask操作文捶,不給模型看到未來的信息荷逞。
        if causality:
            diag_vals = tf.ones_like(outputs[0,:,:])
            tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense()
            masks = tf.tile(tf.expand_dims(tril,0),[tf.shape(outputs)[0],1,1])

            paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
            outputs = tf.where(tf.equal(masks,0),paddings,outputs)

        outputs = tf.nn.softmax(outputs)

        # Query Mask
        query_masks = tf.sign(tf.abs(tf.reduce_sum(queries,axis=-1)))
        query_masks = tf.tile(query_masks,[num_heads,1])
        query_masks = tf.tile(tf.expand_dims(query_masks,-1),[1,1,tf.shape(keys)[1]])
        outputs *= query_masks

        # Dropout
        outputs = tf.layers.dropout(outputs,rate = dropout_rate,training = tf.convert_to_tensor(is_training))

        # Weighted sum
        outputs = tf.matmul(outputs,V_)
        # restore shape
        outputs = tf.concat(tf.split(outputs,num_heads,axis=0),axis=2)
        # Residual connection
        outputs += queries
        # Normalize
        outputs = normalize(outputs)
    return outputs

4、Position-wise Feed-forward Networks

在進(jìn)行了Attention操作之后粹排,encoder和decoder中的每一層都包含了一個(gè)全連接前向網(wǎng)絡(luò)种远,對(duì)每個(gè)position的向量分別進(jìn)行相同的操作,包括兩個(gè)線性變換和一個(gè)ReLU激活輸出:

代碼如下:

def feedforward(inputs,
                num_units=[2048, 512],
                scope="multihead_attention",
                reuse=None):
    with tf.variable_scope(scope, reuse=reuse):
        # Inner layer
        params = {"inputs": inputs, "filters": num_units[0], "kernel_size": 1,
                  "activation": tf.nn.relu, "use_bias": True}
        outputs = tf.layers.conv1d(**params)

        # Readout layer
        params = {"inputs": outputs, "filters": num_units[1], "kernel_size": 1,
                  "activation": None, "use_bias": True}
        outputs = tf.layers.conv1d(**params)
        # Residual connection
        outputs += inputs
        # Normalize
        outputs = normalize(outputs)
    return outputs

5顽耳、Encoder的結(jié)構(gòu)

Encoder有N(默認(rèn)是6)層坠敷,每層包括兩個(gè)sub-layers:
1 )第一個(gè)sub-layer是multi-head self-attention mechanism,用來計(jì)算輸入的self-attention;
2 )第二個(gè)sub-layer是簡單的全連接網(wǎng)絡(luò)射富。
每一個(gè)sub-layer都模擬了殘差網(wǎng)絡(luò)的結(jié)構(gòu)膝迎,其網(wǎng)絡(luò)示意圖如下:

根據(jù)我們剛才定義的函數(shù),其完整的代碼如下:

with tf.variable_scope("encoder"):
    # Embedding
    self.enc = embedding(self.x,
                         vocab_size=len(de2idx),
                         num_units = hp.hidden_units,
                         zero_pad=True, # 讓padding一直是0
                         scale=True,
                         scope="enc_embed")

    ## Positional Encoding
    if hp.sinusoid:
        self.enc += positional_encoding(self.x,
                                        num_units = hp.hidden_units,
                                        zero_pad = False,
                                        scale = False,
                                        scope='enc_pe')

    else:
        self.enc += embedding(tf.tile(tf.expand_dims(tf.range(tf.shape(self.x)[1]),0),[tf.shape(self.x)[0],1]),
                              vocab_size = hp.maxlen,
                              num_units = hp.hidden_units,
                              zero_pad = False,
                              scale = False,
                              scope = "enc_pe")

    ##Drop out
    self.enc = tf.layers.dropout(self.enc,rate = hp.dropout_rate,
                                 training = tf.convert_to_tensor(is_training))

    ## Blocks
    for i in range(hp.num_blocks):
        with tf.variable_scope("num_blocks_{}".format(i)):
            ### MultiHead Attention
            self.enc = multihead_attention(queries = self.enc,
                                           keys = self.enc,
                                           num_units = hp.hidden_units,
                                           num_heads = hp.num_heads,
                                           dropout_rate = hp.dropout_rate,
                                           is_training = is_training,
                                           causality = False
                                           )
            self.enc = feedforward(self.enc,num_units = [4 * hp.hidden_units,hp.hidden_units])

6胰耗、Decoder的結(jié)構(gòu)

Decoder有N(默認(rèn)是6)層限次,每層包括三個(gè)sub-layers:
1 )第一個(gè)是Masked multi-head self-attention,也是計(jì)算輸入的self-attention柴灯,但是因?yàn)槭巧蛇^程卖漫,因此在時(shí)刻 i 的時(shí)候,大于 i 的時(shí)刻都沒有結(jié)果赠群,只有小于 i 的時(shí)刻有結(jié)果懊亡,因此需要做Mask.
2 )第二個(gè)sub-layer是對(duì)encoder的輸入進(jìn)行attention計(jì)算,這里仍然是multi-head的attention結(jié)構(gòu)乎串,只不過輸入的分別是decoder的輸入和encoder的輸出。
3 )第三個(gè)sub-layer是全連接網(wǎng)絡(luò)速警,與Encoder相同叹誉。

其網(wǎng)絡(luò)示意圖如下:

其代碼如下:

with tf.variable_scope("decoder"):
    # Embedding
    self.dec = embedding(self.decoder_inputs,
                         vocab_size=len(en2idx),
                         num_units = hp.hidden_units,
                         scale=True,
                         scope="dec_embed")

    ## Positional Encoding
    if hp.sinusoid:
        self.dec += positional_encoding(self.decoder_inputs,
                                        vocab_size = hp.maxlen,
                                        num_units = hp.hidden_units,
                                        zero_pad = False,
                                        scale = False,
                                        scope = "dec_pe")
    else:
        self.dec += embedding(
            tf.tile(tf.expand_dims(tf.range(tf.shape(self.decoder_inputs)[1]), 0), [tf.shape(self.decoder_inputs)[0], 1]),
            vocab_size=hp.maxlen,
            num_units=hp.hidden_units,
            zero_pad=False,
            scale=False,
            scope="dec_pe")

    # Dropout
    self.dec = tf.layers.dropout(self.dec,
                                rate = hp.dropout_rate,
                                training = tf.convert_to_tensor(is_training))

    ## Blocks
    for i in range(hp.num_blocks):
        with tf.variable_scope("num_blocks_{}".format(i)):
            ## Multihead Attention ( self-attention)
            self.dec = multihead_attention(queries=self.dec,
                                           keys=self.dec,
                                           num_units=hp.hidden_units,
                                           num_heads=hp.num_heads,
                                           dropout_rate=hp.dropout_rate,
                                           is_training=is_training,
                                           causality=True,
                                           scope="self_attention")

            ## Multihead Attention ( vanilla attention)
            self.dec = multihead_attention(queries=self.dec,
                                           keys=self.enc,
                                           num_units=hp.hidden_units,
                                           num_heads=hp.num_heads,
                                           dropout_rate=hp.dropout_rate,
                                           is_training=is_training,
                                           causality=False,
                                           scope="vanilla_attention")

            ## Feed Forward
            self.dec = feedforward(self.dec, num_units=[4 * hp.hidden_units, hp.hidden_units])

7、模型輸出

decoder的輸出會(huì)經(jīng)過一層全聯(lián)接網(wǎng)絡(luò)和softmax得到最終的結(jié)果闷旧,示意圖如下:

這樣长豁,一個(gè)完整的Transformer Architecture我們就介紹完了,對(duì)于文中寫的不清楚或者不到位的地方忙灼,歡迎各位留言指正匠襟!

參考文獻(xiàn)

1、原文:https://arxiv.org/abs/1706.03762
2该园、https://mp.weixin.qq.com/s/RLxWevVWHXgX-UcoxDS70w
3酸舍、https://github.com/princewen/tensorflow_practice/tree/master/basic/Basic-Transformer-Demo

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市里初,隨后出現(xiàn)的幾起案子啃勉,更是在濱河造成了極大的恐慌,老刑警劉巖双妨,帶你破解...
    沈念sama閱讀 218,941評(píng)論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件淮阐,死亡現(xiàn)場離奇詭異叮阅,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)泣特,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,397評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門浩姥,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人状您,你說我怎么就攤上這事勒叠。” “怎么了竞阐?”我有些...
    開封第一講書人閱讀 165,345評(píng)論 0 356
  • 文/不壞的土叔 我叫張陵缴饭,是天一觀的道長。 經(jīng)常有香客問我骆莹,道長颗搂,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,851評(píng)論 1 295
  • 正文 為了忘掉前任幕垦,我火速辦了婚禮丢氢,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘先改。我一直安慰自己疚察,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,868評(píng)論 6 392
  • 文/花漫 我一把揭開白布仇奶。 她就那樣靜靜地躺著貌嫡,像睡著了一般。 火紅的嫁衣襯著肌膚如雪该溯。 梳的紋絲不亂的頭發(fā)上岛抄,一...
    開封第一講書人閱讀 51,688評(píng)論 1 305
  • 那天,我揣著相機(jī)與錄音狈茉,去河邊找鬼夫椭。 笑死,一個(gè)胖子當(dāng)著我的面吹牛氯庆,可吹牛的內(nèi)容都是我干的蹭秋。 我是一名探鬼主播,決...
    沈念sama閱讀 40,414評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼堤撵,長吁一口氣:“原來是場噩夢啊……” “哼仁讨!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起实昨,我...
    開封第一講書人閱讀 39,319評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤陪竿,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體族跛,經(jīng)...
    沈念sama閱讀 45,775評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡闰挡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,945評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了礁哄。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片长酗。...
    茶點(diǎn)故事閱讀 40,096評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖桐绒,靈堂內(nèi)的尸體忽然破棺而出夺脾,到底是詐尸還是另有隱情,我是刑警寧澤茉继,帶...
    沈念sama閱讀 35,789評(píng)論 5 346
  • 正文 年R本政府宣布咧叭,位于F島的核電站,受9級(jí)特大地震影響烁竭,放射性物質(zhì)發(fā)生泄漏菲茬。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,437評(píng)論 3 331
  • 文/蒙蒙 一派撕、第九天 我趴在偏房一處隱蔽的房頂上張望婉弹。 院中可真熱鬧,春花似錦终吼、人聲如沸镀赌。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,993評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽商佛。三九已至,卻和暖如春姆打,著一層夾襖步出監(jiān)牢的瞬間威彰,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,107評(píng)論 1 271
  • 我被黑心中介騙來泰國打工穴肘, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人舔痕。 一個(gè)月前我還...
    沈念sama閱讀 48,308評(píng)論 3 372
  • 正文 我出身青樓评抚,卻偏偏與公主長得像,于是被迫代替她去往敵國和親伯复。 傳聞我的和親對(duì)象是個(gè)殘疾皇子慨代,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,037評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容