透過機(jī)器翻譯理解Transformer(四): 打造 Transformer:疊疊樂時(shí)間

編者按:年初疫情在家期間開始大量閱讀NLP領(lǐng)域的經(jīng)典論文芬沉,在學(xué)習(xí)《Attention Is All You Need》時(shí)發(fā)現(xiàn)了一位現(xiàn)居日本的臺(tái)灣數(shù)據(jù)科學(xué)家LeeMeng寫的Transformer詳解博客哨查,理論講解+代碼實(shí)操+動(dòng)畫演示的寫作風(fēng)格斥黑,在眾多文章中獨(dú)樹一幟,實(shí)為新手學(xué)習(xí)Transformer的上乘資料买喧,在通讀以及實(shí)操多遍之后糟秘,現(xiàn)在將其編輯整理成簡(jiǎn)體中文分享給大家男旗。由于原文實(shí)在太長(zhǎng),為了便于閱讀學(xué)習(xí)强缘,這里將其分為四個(gè)部分:

在涉及代碼部分,強(qiáng)烈推薦大家在Google的Colab Notebooks中實(shí)際操作一遍旅掂,之所以推薦Colab Notebooks是因?yàn)?).這里有免費(fèi)可以使用的GPU資源赏胚;2). 可以避免很多安裝包出錯(cuò)的問題

本節(jié)目錄

    1. 打造 Transformer:疊疊樂時(shí)間
      • 6.1 Position-wise Feed-Forward Networks
      • 6.2 Encoder layer:Encoder 小弟
      • 6.3 Decoder layer:Decoder 小弟
      • 6.4 Positional encoding:神奇數(shù)字
      • 6.5 Encoder
      • 6.6 Decoder
      • 6.7 第一個(gè) Transformer
    1. 定義損失函數(shù)與指標(biāo)
    1. 設(shè)置超參數(shù)
    1. 設(shè)置 Optimizer
    1. 實(shí)際訓(xùn)練以及定時(shí)存檔
    1. 實(shí)際進(jìn)行英翻中
    1. 可視化注意權(quán)重
    1. 在你離開之前

6. 打造 Transformer:疊疊樂時(shí)間

以前我們?cè)岬缴疃葘W(xué)習(xí)模型就是一層層的幾何運(yùn)算過程。 Transformer 也不例外商虐,剛才實(shí)現(xiàn)的 mutli-head attention layer 就是一個(gè)最明顯的例子觉阅。而它正好是 Transformer 里頭最重要的一層運(yùn)算。

在這節(jié)我們會(huì)把 Transformer 里頭除了注意力機(jī)制的其他運(yùn)算通通實(shí)現(xiàn)成一個(gè)個(gè)的 layers秘车,并將它們?nèi)俊腐B」起來典勇。

你可以通過下方的影片來了解接下來的實(shí)現(xiàn)順序:


steps-to-build-transformer.gif

影片中左側(cè)就是我們接下來會(huì)依序?qū)崿F(xiàn)的 layers。 Transformer 是一種使用自注意力機(jī)制的 Seq2Seq 模型 叮趴,里頭包含了兩個(gè)重要角色割笙,分別為 Encoder 與 Decoder:

  • 最初輸入的英文序列會(huì)通過 Encoder 中 N 個(gè) Encoder layers 并被轉(zhuǎn)換成一個(gè)相同長(zhǎng)度的序列。每個(gè) layer 都會(huì)為自己的輸入序列里頭的子詞產(chǎn)生新的 repr.眯亦,然后交給下一個(gè) layer伤溉。
  • Decoder 在生成(預(yù)測(cè))下一個(gè)中文子詞時(shí)會(huì)一邊觀察 Encoder 輸出序列里所有英文子詞的 repr.,一邊觀察自己前面已經(jīng)生成的中文子詞搔驼。

值得一提的是谈火,N = 1 (Encoder / Decoder layer 數(shù)目 = 1)時(shí)就是最陽春版的 Transformer。但在深度學(xué)習(xí)領(lǐng)域里頭我們常常想對(duì)原始數(shù)據(jù)做多層的轉(zhuǎn)換舌涨,因此會(huì)將 N 設(shè)為影片最后出現(xiàn)的 2 層或是 Transformer 論文中的 6 層 Encoder / Decoder layers糯耍。

Encoder 里頭的 Encoder layer 里又分兩個(gè) sub-layers扔字,而 Decoder 底下的 Decoder layer 則包含 3 個(gè) sub-layers。真的是 layer layer 相扣温技。將這些 layers 的階層關(guān)系簡(jiǎn)單列出來大概就長(zhǎng)這樣(位置 Encoding 等在實(shí)現(xiàn)時(shí)會(huì)做解釋):

  • Transformer
    • Encoder
      • 輸入 Embedding
      • 位置 Encoding
      • N 個(gè) Encoder layers
        • sub-layer 1: Encoder 自注意力機(jī)制
        • sub-layer 2: Feed Forward
    • Decoder
      • 輸出 Embedding
      • 位置 Encoding
      • N 個(gè) Decoder layers
        • sub-layer 1: Decoder 自注意力機(jī)制
        • sub-layer 2: Decoder-Encoder 注意力機(jī)制
        • sub-layer 3: Feed Forward
    • Final Dense Layer

不過就像影片中顯示的一樣革为,實(shí)現(xiàn)的時(shí)候我們傾向從下往上疊上去。畢竟地基打得好舵鳞,樓才蓋得高震檩,對(duì)吧?

6.1 Position-wise Feed-Forward Networks

如同影片中所看到的蜓堕, Encoder layer 跟 Decoder layer 里頭都各自有一個(gè) Feed Forward 的元件抛虏。此元件構(gòu)造簡(jiǎn)單,不用像前面的multi-head attention 建立定制化的keras layer套才,只需要寫一個(gè)Python 函數(shù)讓它在被調(diào)用的時(shí)候返回一個(gè)新的tf.keras.Sequential 模型給我們即可:

# 建立 Transformer 里 Encoder / Decoder layer 都有使用到的 Feed Forward 元件
def point_wise_feed_forward_network(d_model, dff):
  
  # 此 FFN 對(duì)輸入做兩個(gè)線性轉(zhuǎn)換迂猴,中間加了一個(gè) ReLU activation func
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])

此函數(shù)在每次被調(diào)用的時(shí)候都會(huì)返回一組新的全連接前饋神經(jīng)網(wǎng)路(Fully-connected Feed Forward Network,F(xiàn)FN)背伴,其輸入張量與輸出張量的最后一個(gè)維度皆為d_model 沸毁,而在FFN 中間層的維度則為dff。一般會(huì)讓 dff大于 d_model傻寂,讓 FFN 從輸入的d_model維度里頭擷取些有用的信息息尺。在論文中d_model為 512,dff 則為 4 倍的d_model: 2048疾掰。兩個(gè)都是可以調(diào)整的超參數(shù)搂誉。

讓我們建立一個(gè) FFN 試試:

batch_size = 64
seq_len = 10
d_model = 512
dff = 2048

x = tf.random.uniform((batch_size, seq_len, d_model))
ffn = point_wise_feed_forward_network(d_model, dff)
out = ffn(x)
print("x.shape:", x.shape)
print("out.shape:", out.shape)
x.shape: (64, 10, 512)
out.shape: (64, 10, 512)

在輸入張量的最后一維已經(jīng)是 d_model 的情況,F(xiàn)FN 的輸出張量基本上會(huì)跟輸入一模一樣:

  • 輸入:(batch_size, seq_len, d_model)
  • 輸出:(batch_size, seq_len, d_model)

FFN 輸出 / 輸入張量的 shape 相同很容易理解个绍。比較沒那么明顯的是這個(gè) FFN 事實(shí)上對(duì)序列中的所有位置做的線性轉(zhuǎn)換都是一樣的勒葱。我們可以假想一個(gè) 2 維的 duumy_sentence,里頭有 5 個(gè)以 4 維向量表示的子詞:

d_model = 4 # FFN 的輸入輸出張量的最后一維皆為 `d_model`
dff = 6

# 建立一個(gè)小 FFN
small_ffn = point_wise_feed_forward_network(d_model, dff)
# 懂子詞梗的站出來
dummy_sentence = tf.constant([[5, 5, 6, 6], 
                              [5, 5, 6, 6], 
                              [9, 5, 2, 7], 
                              [9, 5, 2, 7],
                              [9, 5, 2, 7]], dtype=tf.float32)
small_ffn(dummy_sentence)
<tf.Tensor: shape=(5, 4), dtype=float32, numpy=
array([[ 2.8674245, -2.174698 , -1.3073452, -6.4233937],
       [ 2.8674245, -2.174698 , -1.3073452, -6.4233937],
       [ 3.650207 , -0.973258 , -2.4126565, -6.5094995],
       [ 3.650207 , -0.973258 , -2.4126565, -6.5094995],
       [ 3.650207 , -0.973258 , -2.4126565, -6.5094995]], dtype=float32)>

你會(huì)發(fā)現(xiàn)同一個(gè)子詞不會(huì)因?yàn)槲恢玫母淖兌斐?FFN 的輸出結(jié)果產(chǎn)生差異巴柿。但因?yàn)槲覀儗?shí)際上會(huì)有多個(gè) Encoder / Decoder layers凛虽,而每個(gè) layers 都會(huì)有不同參數(shù)的 FFN,因此每個(gè) layer 里頭的 FFN 做的轉(zhuǎn)換都會(huì)有所不同。

值得一提的是,盡管對(duì)所有位置的子詞都做一樣的轉(zhuǎn)換辟犀,但是這個(gè)轉(zhuǎn)換是獨(dú)立進(jìn)行的,因此被稱作 Position-wise Feed-Forward Networks至非。

6.2 Encoder layer:Encoder 小弟

有了 Multi-Head Attention(MHA)以及 Feed-Forward Network(FFN),我們事實(shí)上已經(jīng)可以實(shí)現(xiàn)第一個(gè) Encoder layer 了糠聪。讓我們復(fù)習(xí)一下這 layer 里頭有什么重要元件:


Encoder layer 里的重要元件

我想上面的動(dòng)畫已經(jīng)很清楚了荒椭。一個(gè) Encoder layer 里頭會(huì)有兩個(gè) sub-layers,分別為 MHA 以及 FFN舰蟆。在 Add & Norm 步驟里頭趣惠,每個(gè) sub-layer 會(huì)有一個(gè)殘差連結(jié)(residual connection)來幫助減緩梯度消失(Gradient Vanishing)的問題狸棍。接著兩個(gè) sub-layers 都會(huì)針對(duì)最后一維 d_model 做 layer normalization,將 batch 里頭每個(gè)子詞的輸出獨(dú)立做轉(zhuǎn)換味悄,使其平均與標(biāo)準(zhǔn)差分別靠近 0 和 1 之后輸出草戈。

另外在將 sub-layer 的輸出與其輸入相加之前,我們還會(huì)做點(diǎn) regularization侍瑟,對(duì)該 sub-layer 的輸出使用 dropout唐片。

總結(jié)一下。如果輸入是 x涨颜,最后輸出寫作out的話费韭,則每個(gè) sub-layer 的處理邏輯如下:

sub_layer_out = Sublayer(x)
sub_layer_out = Dropout(sub_layer_out)
out = LayerNorm(x + sub_layer_out)

Sublayer 則可以是 MHA 或是 FFN。現(xiàn)在讓我們看看 Encoder layer 的實(shí)現(xiàn):

# Encoder 里頭會(huì)有 N 個(gè) EncoderLayers咐低,而每個(gè) EncoderLayer 里又有兩個(gè) sub-layers: MHA & FFN
class EncoderLayer(tf.keras.layers.Layer):
    
    # Transformer 論文內(nèi)預(yù)設(shè) dropout rate 為 0.1
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

       # layer norm 很常在 RNN-based 的模型被使用揽思。一個(gè) sub-layer 一個(gè) layer norm
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
        # 一樣,一個(gè) sub-layer 一個(gè) dropout layer
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    
    # 需要丟入 `training` 參數(shù)是因?yàn)?dropout 在訓(xùn)練以及測(cè)試的行為有所不同
    def call(self, x, training, mask):
        # 除了 `attn`见擦,其他張量的 shape 皆為 (batch_size, input_seq_len, d_model)
        # attn.shape == (batch_size, num_heads, input_seq_len, input_seq_len)
    
        # sub-layer 1: MHA
        # Encoder 利用注意機(jī)制關(guān)注自己當(dāng)前的序列,因此 v, k, q 全部都是自己
        # 另外別忘了我們還需要 padding mask 來遮住輸入序列中的 <pad> token
        attn_output, attn = self.mha(x, x, x, mask)  
        attn_output = self.dropout1(attn_output, training=training) 
        out1 = self.layernorm1(x + attn_output)  
    
        # sub-layer 2: FFN
        ffn_output = self.ffn(out1) 
        ffn_output = self.dropout2(ffn_output, training=training)  # 記得 training
        out2 = self.layernorm2(out1 + ffn_output)
    
        return out2

跟當(dāng)初 MHA layer 的實(shí)作比起來輕松多了羹令,對(duì)吧鲤屡?

基本上 Encoder layer 里頭就是兩個(gè)架構(gòu)一模一樣的 sub-layer,只差一個(gè)是 MHA福侈,一個(gè)是 FFN酒来。另外為了方便 residual connection 的計(jì)算,所有 sub-layers 的輸出維度都是 d_model肪凛。而 sub-layer 內(nèi)部產(chǎn)生的維度當(dāng)然就隨我們開心啦堰汉!我們可以為 FFN 設(shè)置不同的 dff 值,也能設(shè)定不同的 num_heads 來改變 MHA 內(nèi)部每個(gè) head 里頭的維度伟墙。

論文里頭的 d_model為 512翘鸭,而我們 demo 用的英文詞嵌入張量的d_model 維度則為 4:

# 之后可以調(diào)的超參數(shù)。這邊為了 demo 設(shè)小一點(diǎn)
d_model = 4
num_heads = 2
dff = 8

# 新建一個(gè)使用上述參數(shù)的 Encoder Layer
enc_layer = EncoderLayer(d_model, num_heads, dff)
padding_mask = create_padding_mask(inp)  # 建立一個(gè)當(dāng)前輸入 batch 使用的 padding mask
enc_out = enc_layer(emb_inp, training=False, mask=padding_mask)  # (batch_size, seq_len, d_model)

print("inp:", inp)
print("-" * 20)
print("padding_mask:", padding_mask)
print("-" * 20)
print("emb_inp:", emb_inp)
print("-" * 20)
print("enc_out:", enc_out)
assert emb_inp.shape == enc_out.shape

inp: tf.Tensor(
[[8113  103    9 1066 7903 8114    0    0]
 [8113   16 4111 6735   12 2750 7903 8114]], shape=(2, 8), dtype=int64)
 --------------------
padding_mask: tf.Tensor(
[[[[0. 0. 0. 0. 0. 0. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(2, 1, 1, 8), dtype=float32)
--------------------
emb_inp: tf.Tensor(
[[[ 0.0041508   0.04106052  0.00270988 -0.00628465]
  [ 0.0261193   0.04892724 -0.03637441  0.00032102]
  [-0.0315491   0.03012072 -0.03764988 -0.00832593]
  [-0.00863073  0.01537497  0.00647591  0.01622475]
  [ 0.01064278  0.02867876  0.0471475   0.02418466]
  [-0.0357633  -0.02500458  0.00584758  0.00984917]
  [ 0.02766568 -0.02055204  0.0366873  -0.04519999]
  [ 0.02766568 -0.02055204  0.0366873  -0.04519999] 
  [ 0.0041508   0.04106052  0.00270988 -0.00628465]
  [-0.03440493  0.0245572  -0.04154334  0.01249687]
  [-0.04102417 -0.04214551 -0.03087332  0.03536062]
  [ 0.00288613 -0.00550915  0.02198391 -0.02721313]
  [ 0.03594044 -0.02207484  0.00774273 -0.01938369]
  [-0.00556026  0.04242435  0.03270287 -0.00513189]
  [ 0.01064278  0.02867876  0.0471475   0.02418466]
  [-0.0357633  -0.02500458  0.00584758  0.00984917]]], shape=(2, 8, 4), dtype=float32)
 --------------------
enc_out: tf.Tensor(
[[[-0.1656846   1.4814154  -1.3332843   0.01755357]
  [ 0.05347645  1.2417278  -1.5466218   0.25141746]
  [-0.8423737   1.4621214  -1.0028969   0.3831491 ]
  [-1.1612244   0.4753281  -0.7035671   1.3894634 ]
  [-1.0288012  -0.7085241   1.5507177   0.1866076 ]
  [-0.5757953  -1.1105288   0.13135773  1.5549664 ]
  [ 1.5314106  -0.519994    0.1549343  -1.1663508 ]
  [ 1.5314106  -0.519994    0.1549343  -1.1663508 ]]

 [[-0.34800935  1.5336158  -1.234706    0.04909949]
  [-0.97635764  1.3417366  -0.9507113   0.58533245]
  [-0.53843904 -0.48348504 -0.7043885   1.7263125 ]
  [ 1.208463   -0.2577439   0.529937   -1.4806561 ]
  [ 1.6743237  -0.9758253  -0.33426592 -0.36423233]
  [-1.0195854   1.6443692  -0.13730906 -0.48747474]
  [-1.4697037  -0.00313468  1.3509609   0.12187762]
  [-0.8544105  -0.8589976   0.12724805  1.5861602 ]]], shape=(2, 8, 4), dtype=float32)

在本來的輸入維度即為 d_model 的情況下戳葵,Encoder layer 就是給我們一個(gè)一模一樣 shape 的張量就乓。當(dāng)然,實(shí)際上內(nèi)部透過 MHA 以及 FFN sub-layer 的轉(zhuǎn)換拱烁,每個(gè)子詞的 repr. 都大幅改變了生蚁。

有了 Encoder layer,接著讓我們看看 Decoder layer 的實(shí)現(xiàn)戏自。

6.3 Decoder layer:Decoder 小弟

一個(gè) Decoder layer 里頭有 3 個(gè) sub-layers:

  1. Decoder layer 自身的 Masked MHA 1
  2. Decoder layer 關(guān)注 Encoder 輸出序列的 MHA 2
  3. FFN

你也可以看一下影片來回顧它們所在的位置:

Decoder layer 中的 sub-layers

跟實(shí)現(xiàn) Encoder layer 時(shí)一樣邦投,每個(gè) sub-layer 的邏輯同下:

sub_layer_out = Sublayer(x)
sub_layer_out = Dropout(sub_layer_out)
out = LayerNorm(x + sub_layer_out)

Decoder layer 用 MHA 1 來關(guān)注輸出序列,查詢 Q擅笔、鍵值 K 以及值 V 都是自己志衣。而之所以有個(gè) masked 是因?yàn)椋ㄖ形模┹敵鲂蛄谐烁ㄓ⑽模┹斎胄蛄幸粯有枰?padding mask 以外见芹,還需要 look ahead mask 來避免 Decoder layer 關(guān)注到未來的子詞。 look ahead mask 在前面章節(jié)已經(jīng)有詳細(xì)說明了蠢涝。

MHA1 處理完的輸出序列會(huì)成為 MHA 2 的 Q玄呛,而 K 與 V 則使用 Encoder 的輸出序列。這個(gè)運(yùn)算的概念是讓一個(gè) Decoder layer 在生成新的中文子詞時(shí)先參考先前已經(jīng)產(chǎn)生的中文子詞和二,并為當(dāng)下要生成的子詞產(chǎn)生一個(gè)包含前文語義的 repr. 徘铝。接著將此 repr. 拿去跟 Encoder 那邊的英文序列做匹配,看當(dāng)下子詞的 repr. 有多好并予以修正惯吕。

用簡(jiǎn)單點(diǎn)的說法就是 Decoder 在生成中文子詞時(shí)除了參考自己已經(jīng)生成的中文子詞以外惕它,也會(huì)去關(guān)注 Encoder 輸出的英文子詞(的 repr.)。

# Decoder 里頭會(huì)有 N 個(gè) DecoderLayer废登,
# 而 DecoderLayer 又有三個(gè) sub-layers: 自注意的 MHA, 關(guān)注 Encoder 輸出的 MHA & FFN

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    # 3 個(gè) sub-layers 的主角們
    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)
 
    # 定義每個(gè) sub-layer 用的 LayerNorm
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    # 定義每個(gè) sub-layer 用的 Dropout
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)
    
    
  def call(self, x, enc_output, training, 
           combined_mask, inp_padding_mask):
    # 所有 sub-layers 的主要輸出皆為 (batch_size, target_seq_len, d_model)
    # enc_output 為 Encoder 輸出序列淹魄,shape 為 (batch_size, input_seq_len, d_model)
    # attn_weights_block_1 則為 (batch_size, num_heads, target_seq_len, target_seq_len)
    # attn_weights_block_2 則為 (batch_size, num_heads, target_seq_len, input_seq_len)

    # sub-layer 1: Decoder layer 自己對(duì)輸出序列做注意力。
    # 我們同時(shí)需要 look ahead mask 以及輸出序列的 padding mask 
    # 來避免前面已生成的子詞關(guān)注到未來的子詞以及 <pad>
    attn1, attn_weights_block1 = self.mha1(x, x, x, combined_mask)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)
    
    # sub-layer 2: Decoder layer 關(guān)注 Encoder 的最後輸出
    # 記得我們一樣需要對(duì) Encoder 的輸出套用 padding mask 避免關(guān)注到 <pad>
    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, inp_padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)
    
    # sub-layer 3: FFN 部分跟 Encoder layer 完全一樣
    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)
    
    # 除了主要輸出 `out3` 以外堡距,輸出 multi-head 注意權(quán)重方便之後理解模型內(nèi)部狀況
    return out3, attn_weights_block1, attn_weights_block2

Decoder layer 的實(shí)現(xiàn)跟 Encoder layer 大同小異甲锡,不過還是有幾點(diǎn)細(xì)節(jié)特別需要注意:

  • 在做 Masked MHA(MHA 1)的時(shí)候我們需要同時(shí)套用兩種遮罩:輸出序列的 padding mask 以及 look ahead mask。因此 Decoder layer 預(yù)期的遮罩是兩者結(jié)合的combined_mask
  • MHA 1 因?yàn)槭?Decoder layer 關(guān)注自己羽戒,multi-head attention 的參數(shù) v缤沦、k 以及 q都是 x
  • MHA 2 是 Decoder layer 關(guān)注 Encoder 輸出序列,因此易稠,multi-head attention 的參數(shù) v缸废、kenc_outputq 則為 MHA 1 sub-layer 的結(jié)果 out1

產(chǎn)生comined_mask也很簡(jiǎn)單驶社,我們只要把兩個(gè)遮罩取大的即可:

tar_padding_mask = create_padding_mask(tar)
look_ahead_mask = create_look_ahead_mask(tar.shape[-1])
combined_mask = tf.maximum(tar_padding_mask, look_ahead_mask)

print("tar:", tar)
print("-" * 20)
print("tar_padding_mask:", tar_padding_mask)
print("-" * 20)
print("look_ahead_mask:", look_ahead_mask)
print("-" * 20)
print("combined_mask:", combined_mask)
tar: tf.Tensor(
[[4205   10  241   86   27    3 4206    0    0    0]
 [4205  165  489  398  191   14    7  560    3 4206]], shape=(2, 10), dtype=int64)
--------------------
tar_padding_mask: tf.Tensor(
[[[[0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(2, 1, 1, 10), dtype=float32)
--------------------
look_ahead_mask: tf.Tensor(
 [[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)
 --------------------
combined_mask: tf.Tensor(
[[[[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]]]

[[[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(2, 1, 10, 10), dtype=float32)    

注意 combined_mask 的 shape 以及里頭遮罩所在的位置企量。利用 broadcasting 我們將 combined_mask 的 shape 也擴(kuò)充到 4 維:

(batch_size, num_heads, seq_len_tar, seq_len_tar)= (2, 1, 10, 10)

這方便之后 multi-head attention 的計(jì)算。另外因?yàn)槲覀?demo 的中文 batch 里頭的第一個(gè)句子有 <pad>亡电,combined_mask除了 look ahead 的效果以外還加了 padding mask届巩。

因?yàn)閯倓倢?shí)現(xiàn)的是 Decoder layer,這次讓我們把中文(目標(biāo)語言)的詞嵌入張量以及相關(guān)的遮罩丟進(jìn)去看看:

# 超參數(shù)
d_model = 4
num_heads = 2
dff = 8
dec_layer = DecoderLayer(d_model, num_heads, dff)

# 來源逊抡、目標(biāo)語言的序列都需要 padding mask
inp_padding_mask = create_padding_mask(inp)
tar_padding_mask = create_padding_mask(tar)

# masked MHA 用的遮罩姆泻,把 padding 跟未來子詞都蓋住
look_ahead_mask = create_look_ahead_mask(tar.shape[-1])
combined_mask = tf.maximum(tar_padding_mask, look_ahead_mask)

# 實(shí)際初始一個(gè) decoder layer 并做 3 個(gè) sub-layers 的計(jì)算
dec_out, dec_self_attn_weights, dec_enc_attn_weights = dec_layer(
    emb_tar, enc_out, False, combined_mask, inp_padding_mask)

print("emb_tar:", emb_tar)
print("-" * 20)
print("enc_out:", enc_out)
print("-" * 20)
print("dec_out:", dec_out)
assert emb_tar.shape == dec_out.shape
print("-" * 20)
print("dec_self_attn_weights.shape:", dec_self_attn_weights.shape)
print("dec_enc_attn_weights:", dec_enc_attn_weights.shape)
emb_tar: tf.Tensor(
[[[-0.00084939 -0.02029408 -0.04978932 -0.02889797]
  [-0.01320463  0.00070287  0.00797179 -0.00549082]
  [-0.01859868 -0.04142375  0.02479618 -0.00794141]
  [ 0.04030085 -0.04564189 -0.03584541 -0.04098076]
  [ 0.02629851  0.01072141 -0.01055797  0.04544314]
  [-0.00223017  0.02058548  0.01649131 -0.01385387]
  [ 0.00302396 -0.03152249  0.0396189  -0.03036447]
  [ 0.00433234  0.04481849  0.04129448  0.04720709]
  [ 0.00433234  0.04481849  0.04129448  0.04720709]
  [ 0.00433234  0.04481849  0.04129448  0.04720709]]

 [[-0.00084939 -0.02029408 -0.04978932 -0.02889797]
  [-0.04702241  0.01816512 -0.02416607 -0.01993601]
  [ 0.04391925 -0.03093947 -0.01225864 -0.03517971]
  [ 0.03755457  0.00626134  0.04324439  0.00490584]
  [ 0.00495391 -0.03399891  0.04144105  0.02539945]
  [ 0.0282723  -0.0164601  -0.00685417 -0.02280444]
  [ 0.04738505 -0.01041915 -0.02054645 -0.00066562]
  [-0.00438491  0.02117647 -0.04890387 -0.01620366]
  [-0.00223017  0.02058548  0.01649131 -0.01385387]
  [ 0.00302396 -0.03152249  0.0396189  -0.03036447]]], shape=(2, 10, 4), dtype=float32)
--------------------
enc_out: tf.Tensor(
[[[-0.1656846   1.4814154  -1.3332843   0.01755357]
  [ 0.05347645  1.2417278  -1.5466218   0.25141746]
  [-0.8423737   1.4621214  -1.0028969   0.3831491 ]
  [-1.1612244   0.4753281  -0.7035671   1.3894634 ]
  [-1.0288012  -0.7085241   1.5507177   0.1866076 ]
  [-0.5757953  -1.1105288   0.13135773  1.5549664 ]
  [ 1.5314106  -0.519994    0.1549343  -1.1663508 ]
  [ 1.5314106  -0.519994    0.1549343  -1.1663508 ]]

 [[-0.34800935  1.5336158  -1.234706    0.04909949]
  [-0.97635764  1.3417366  -0.9507113   0.58533245]
  [-0.53843904 -0.48348504 -0.7043885   1.7263125 ]
  [ 1.208463   -0.2577439   0.529937   -1.4806561 ]
  [ 1.6743237  -0.9758253  -0.33426592 -0.36423233]
  [-1.0195854   1.6443692  -0.13730906 -0.48747474]
  [-1.4697037  -0.00313468  1.3509609   0.12187762]
  [-0.8544105  -0.8589976   0.12724805  1.5861602 ]]], shape=(2, 8, 4), dtype=float32)
--------------------
dec_out: tf.Tensor(
[[[ 1.2991211   0.6467309  -0.99355525 -0.9522968 ]
  [-0.68756247 -0.44788587  1.7257465  -0.5902982 ]
  [ 0.21567897 -1.6887752   0.6456864   0.8274099 ]
  [ 1.3437784  -1.2335085  -0.6324715   0.52220154]
  [ 0.5747509  -1.1840664  -0.71563935  1.3249549 ]
  [-0.4092589   0.41854465  1.3476295  -1.3569155 ]
  [ 0.47711575 -1.7147235   0.8007993   0.43680844]
  [-1.132223   -0.82594645  1.222668    0.73550147]
  [-1.132223   -0.82594645  1.222668    0.73550147]
  [-1.132223   -0.82594645  1.222668    0.73550147]]

 [[ 1.3999227   0.49366176 -0.9038905  -0.989694  ]
  [-0.86972106  1.1954616   0.77558595 -1.1013266 ]
  [ 1.6006857  -1.068229   -0.5445589   0.01210219]
  [ 0.7155672  -1.6947896   0.750581    0.2286414 ]
  [ 0.1127052  -1.6265972   0.4442618   1.0696301 ]
  [ 1.4985088  -1.2589391  -0.38515666  0.14558706]
  [ 1.3210055  -0.90092945 -1.033153    0.6130771 ]
  [-0.0833452   1.6214814  -1.0698308  -0.4683055 ]
  [-0.4484089   0.17643274  1.5017867  -1.2298107 ]
  [ 0.44141728 -1.6816832   0.94259256  0.2976733 ]]], shape=(2, 10, 4), dtype=float32)
--------------------
dec_self_attn_weights.shape: (2, 2, 10, 10)
dec_enc_attn_weights: (2, 2, 10, 8)

跟 Encoder layer 相同,Decoder layer 輸出張量的最后一維也是 d_model冒嫡。而 dec_self_attn_weights 則代表著 Decoder layer 的自注意力權(quán)重拇勃,因此最后兩個(gè)維度皆為中文序列的長(zhǎng)度 10;而 dec_enc_attn_weights因?yàn)?Encoder 輸出序列的長(zhǎng)度為8孝凌,最后一維即為 8方咆。

都讀到這里了,判斷每一維的物理意義對(duì)你來說應(yīng)該是小菜一碟了蟀架。

6.4 Positional encoding:神奇數(shù)字

透過多層的自注意力層瓣赂,Transformer 在處理序列時(shí)里頭所有子詞都是「天涯若比鄰」:想要關(guān)注序列中任何位置的信息只要 O(1) 就能辦到榆骚。這讓 Transformer 能很好地 model 序列中長(zhǎng)距離的依賴關(guān)系(long-range dependencise)。但反過來說 Transformer 則無法 model 序列中字詞的順序關(guān)系煌集,所以我們得額外加入一些「位置信息」給 Transformer妓肢。

這個(gè)信息被稱作位置編碼(Positional Encoding),實(shí)作上是直接加到最一開始的英文 / 中文詞嵌入向量(word embedding)里頭苫纤。其直觀的想法是想辦法讓被加入位置編碼的 word embedding 在d_model維度的空間里頭不只會(huì)因?yàn)檎Z義相近而靠近碉钠,也會(huì)因?yàn)槲恢每拷谠摽臻g里頭靠近。

論文里頭使用的位置編碼的公式如下:


position-encoding-equation.jpg

論文里頭提到他們之所以這樣設(shè)計(jì)位置編碼(Positional Encoding, PE)是因?yàn)檫@個(gè)函數(shù)有個(gè)很好的特性:給定任一位置pos 的位置編碼PE(pos)卷拘,跟它距離k個(gè)單位的位置pos + k 的位置編碼PE(pos + k)可以表示為PE(pos) 的一個(gè)線性函數(shù)(linear function)喊废。

因此透過在 word embedding 里加入這樣的信息,作者們認(rèn)為可以幫助 Transformer 學(xué)會(huì) model 序列中的子詞的相對(duì)位置關(guān)系栗弟。

就算我們無法自己想出論文里頭的位置編碼公式污筷,還是可以直接把 TensorFlow 官方的實(shí)現(xiàn)搬過來使用:

# 以下直接參考 TensorFlow 官方 tutorial
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
  
  # apply sin to even indices in the array; 2i
  sines = np.sin(angle_rads[:, 0::2])
  
  # apply cos to odd indices in the array; 2i+1
  cosines = np.cos(angle_rads[:, 1::2])
  
  pos_encoding = np.concatenate([sines, cosines], axis=-1)
  
  pos_encoding = pos_encoding[np.newaxis, ...]
    
  return tf.cast(pos_encoding, dtype=tf.float32)


seq_len = 50
d_model = 512

pos_encoding = positional_encoding(seq_len, d_model)
pos_encoding
<tf.Tensor: shape=(1, 50, 512), dtype=float32, numpy=
array([[[ 0.        ,  0.        ,  0.        , ...,  1.        ,
          1.        ,  1.        ],
        [ 0.84147096,  0.8218562 ,  0.8019618 , ...,  1.        ,
          1.        ,  1.        ],
        [ 0.9092974 ,  0.9364147 ,  0.95814437, ...,  1.        ,
          1.        ,  1.        ],
        ...,
        [ 0.12357312,  0.97718984, -0.24295525, ...,  0.9999863 ,
          0.99998724,  0.99998814],
        [-0.76825464,  0.7312359 ,  0.63279754, ...,  0.9999857 ,
          0.9999867 ,  0.9999876 ],
        [-0.95375264, -0.14402692,  0.99899054, ...,  0.9999851 ,
          0.9999861 ,  0.9999871 ]]], dtype=float32)>

一路看下來你應(yīng)該也可以猜到位置編碼的每一維意義了:

  • 第 1 維代表 batch_size,之后可以 broadcasting
  • 第 2 維是序列長(zhǎng)度乍赫,我們會(huì)為每個(gè)在輸入 / 輸出序列里頭的子詞都加入位置編碼
  • 第 3 維跟詞嵌入向量同維度

因?yàn)槭且~嵌入向量相加瓣蛀,位置編碼的維度也得是 d_model。我們也可以把位置編碼畫出感受一下:

plt.pcolormesh(pos_encoding[0], cmap='RdBu')
plt.xlabel('d_model')
plt.xlim((0, 512))
plt.ylabel('Position')
plt.colorbar()
plt.show()
positional-encoding

這圖你應(yīng)該在很多教學(xué)文章以及教授的影片里都看過了耿焊。就跟我們前面看過的各種 2 維矩陣相同揪惦,x 軸代表著跟詞嵌入向量相同的維度 d_model,y 軸則代表序列中的每個(gè)位置罗侯。之后我們會(huì)看輸入 / 輸出序列有多少個(gè)子詞,就加入幾個(gè)位置編碼溪猿。

關(guān)于位置編碼我們現(xiàn)在只需要知道這些就夠了钩杰,但如果你想知道更多相關(guān)的數(shù)學(xué)計(jì)算,可以參考這個(gè)筆記本诊县。

6.5 Encoder

Encoder 里頭主要包含了 3 個(gè)元件:

  • 輸入的詞嵌入層
  • 位置編碼
  • N 個(gè) Encoder layers

大部分的工作都交給 Encoder layer 小弟做了讲弄,因此 Encoder 的實(shí)現(xiàn)很單純:

class Encoder(tf.keras.layers.Layer):
  # Encoder 的初始參數(shù)除了本來就要給 EncoderLayer 的參數(shù)還多了:
  # - num_layers: 決定要有幾個(gè) EncoderLayers, 前面影片中的 `N`
  # - input_vocab_size: 用來把索引轉(zhuǎn)成詞嵌入向量
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    
    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(input_vocab_size, self.d_model)
    
    # 建立 `num_layers` 個(gè) EncoderLayers
    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)
        
  def call(self, x, training, mask):
    # 輸入的 x.shape == (batch_size, input_seq_len)
    # 以下各 layer 的輸出皆為 (batch_size, input_seq_len, d_model)
    input_seq_len = tf.shape(x)[1]
    
    # 將 2 維的索引序列轉(zhuǎn)成 3 維的詞嵌入張量,並依照論文乘上 sqrt(d_model)
    # 再加上對(duì)應(yīng)長(zhǎng)度的位置編碼
    x = self.embedding(x)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :input_seq_len, :]

    # 對(duì) embedding 跟位置編碼的總合做 regularization
    # 這在 Decoder 也會(huì)做
    x = self.dropout(x, training=training)
    
    # 通過 N 個(gè) EncoderLayer 做編碼
    for i, enc_layer in enumerate(self.enc_layers):
      x = enc_layer(x, training, mask)
      # 以下只是用來 demo EncoderLayer outputs
      #print('-' * 20)
      #print(f"EncoderLayer {i + 1}'s output:", x)
      
    
    return x 

比較值得注意的是我們依照論文將 word embedding 乘上 sqrt(d_model)依痊,并在 embedding 跟位置編碼相加以后通過 dropout 層來達(dá)到 regularization 的效果避除。

現(xiàn)在我們可以直接將索引序列 inp 丟入 Encoder:

# 超參數(shù)
num_layers = 2 # 2 層的 Encoder
d_model = 4
num_heads = 2
dff = 8
input_vocab_size = subword_encoder_en.vocab_size + 2 # 記得加上 <start>, <end>

# 初始化一個(gè) Encoder
encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size)

# 將 2 維的索引序列丟入 Encoder 做編碼
enc_out = encoder(inp, training=False, mask=None)
print("inp:", inp)
print("-" * 20)
print("enc_out:", enc_out)
inp: tf.Tensor(
[[8113  103    9 1066 7903 8114    0    0]
 [8113   16 4111 6735   12 2750 7903 8114]], shape=(2, 8), dtype=int64)
--------------------
enc_out: tf.Tensor(
[[[-0.7849332  -0.5919684  -0.33270505  1.7096066 ]
  [-0.5070654  -0.5110136  -0.7082318   1.726311  ]
  [-0.39270183 -0.03102639 -1.158362    1.5820901 ]
  [-0.5561629   0.38050282 -1.2407898   1.4164499 ]
  [-0.90432     0.19381052 -0.8472892   1.5577985 ]
  [-0.97321564 -0.22992788 -0.4652462   1.6683896 ]
  [-0.84681976 -0.5434473  -0.31013608  1.7004032 ]
  [-0.62432766 -0.56790507 -0.539001    1.7312336 ]]

 [[-0.77423775 -0.6076471  -0.32800597  1.7098908 ]
  [-0.47978252 -0.5615605  -0.68602914  1.7273722 ]
  [-0.30068305 -0.07366991 -1.1973959   1.5717487 ]
  [-0.5147841   0.2787246  -1.2290851   1.4651446 ]
  [-0.89634496  0.2675462  -0.8954112   1.52421   ]
  [-0.97553635 -0.22618684 -0.4656965   1.6674198 ]
  [-0.87600434 -0.5448401  -0.27099532  1.6918398 ]
  [-0.60130465 -0.5993665  -0.5306774   1.7313484 ]]], shape=(2, 8, 4), dtype=float32)

注意因?yàn)?Encoder 已經(jīng)包含了詞嵌入層,因此我們不用再像調(diào)用 Encoder layer 時(shí)一樣還得自己先做 word embedding⌒剜遥現(xiàn)在的輸入及輸出張量為:

  • 輸入:(batch_size, seq_len)
  • 輸出:(batch_size, seq_len, d_model)

有了 Encoder瓶摆,我們之后就可以直接把 2 維的索引序列 inp丟入 Encoder,讓它幫我們把里頭所有的英文序列做一連串的轉(zhuǎn)換性宏。

6.6 Decoder

Decoder layer 本來就只跟 Encoder layer 差在一個(gè) MHA群井,而這邏輯被包起來以后調(diào)用它的 Decoder 做的事情就跟 Encoder 基本上沒有兩樣了。

在 Decoder 里頭我們只需要建立一個(gè)專門給中文用的詞嵌入層以及位置編碼即可毫胜。我們?cè)谡{(diào)用每個(gè) Decoder layer 的時(shí)候也順便把其注意力權(quán)重存下來书斜,方便我們了解模型訓(xùn)練完后是怎么做翻譯的诬辈。

以下則是實(shí)現(xiàn):

class Decoder(tf.keras.layers.Layer):
  # 初始參數(shù)跟 Encoder 只差在用 `target_vocab_size` 而非 `inp_vocab_size`
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, 
               rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    
    # 為中文(目標(biāo)語言)建立詞嵌入層
    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(target_vocab_size, self.d_model)
    
    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)
  
  # 呼叫時(shí)的參數(shù)跟 DecoderLayer 一模一樣
  def call(self, x, enc_output, training, 
           combined_mask, inp_padding_mask):
    
    tar_seq_len = tf.shape(x)[1]
    attention_weights = {}  # 用來存放每個(gè) Decoder layer 的注意權(quán)重
    
    # 這邊跟 Encoder 做的事情完全一樣
    x = self.embedding(x)  # (batch_size, tar_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :tar_seq_len, :]
    x = self.dropout(x, training=training)

    
    for i, dec_layer in enumerate(self.dec_layers):
      x, block1, block2 = dec_layer(x, enc_output, training,
                                    combined_mask, inp_padding_mask)
      
      # 將從每個(gè) Decoder layer 取得的注意權(quán)重全部存下來回傳,方便我們觀察
      attention_weights['decoder_layer{}_block1'.format(i + 1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i + 1)] = block2
    
    # x.shape == (batch_size, tar_seq_len, d_model)
    return x, attention_weights

接著讓我們初始并調(diào)用一個(gè) Decoder 看看:

# 超參數(shù)
num_layers = 2 # 2 層的 Decoder
d_model = 4
num_heads = 2
dff = 8
target_vocab_size = subword_encoder_zh.vocab_size + 2 # 記得加上 <start>, <end>

# 遮罩
inp_padding_mask = create_padding_mask(inp)
tar_padding_mask = create_padding_mask(tar)
look_ahead_mask = create_look_ahead_mask(tar.shape[1])
combined_mask = tf.math.maximum(tar_padding_mask, look_ahead_mask)

# 初始化一個(gè) Decoder
decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size)

# 將 2 維的索引序列以及遮罩丟入 Decoder
print("tar:", tar)
print("-" * 20)
print("combined_mask:", combined_mask)
print("-" * 20)
print("enc_out:", enc_out)
print("-" * 20)
print("inp_padding_mask:", inp_padding_mask)
print("-" * 20)
dec_out, attn = decoder(tar, enc_out, training=False, 
                        combined_mask=combined_mask,
                        inp_padding_mask=inp_padding_mask)
print("dec_out:", dec_out)
print("-" * 20)
for block_name, attn_weights in attn.items():
    print(f"{block_name}.shape: {attn_weights.shape}")
tar: tf.Tensor(
[[4205   10  241   86   27    3 4206    0    0    0]
 [4205  165  489  398  191   14    7  560    3 4206]], shape=(2, 10), dtype=int64)
--------------------
combined_mask: tf.Tensor(
[[[[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
   [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]]]

 [[[0. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
    [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
    [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
    [0. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
    [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
    [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
    [0. 0. 0. 0. 0. 0. 0. 1. 1. 1.]
    [0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]
    [0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
    [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(2, 1, 10, 10), dtype=float32)
   --------------------
enc_out: tf.Tensor(
[[[-0.7849332  -0.5919684  -0.33270505  1.7096066 ]
  [-0.5070654  -0.5110136  -0.7082318   1.726311  ]
  [-0.39270183 -0.03102639 -1.158362    1.5820901 ]
  [-0.5561629   0.38050282 -1.2407898   1.4164499 
  [-0.90432     0.19381052 -0.8472892   1.5577985 ]
  [-0.97321564 -0.22992788 -0.4652462   1.6683896 ]
  [-0.84681976 -0.5434473  -0.31013608  1.7004032 ]
  [-0.62432766 -0.56790507 -0.539001    1.7312336 ]]

 [[-0.77423775 -0.6076471  -0.32800597  1.7098908 ]
  [-0.47978252 -0.5615605  -0.68602914  1.7273722 ]
  [-0.30068305 -0.07366991 -1.1973959   1.5717487 ]
  [-0.5147841   0.2787246  -1.2290851   1.4651446 ]
  [-0.89634496  0.2675462  -0.8954112   1.52421   ]
  [-0.97553635 -0.22618684 -0.4656965   1.6674198 ]
  [-0.87600434 -0.5448401  -0.27099532  1.6918398 ]
  [-0.60130465 -0.5993665  -0.5306774   1.7313484 ]]], shape=(2, 8, 4), dtype=float32)
--------------------
inp_padding_mask: tf.Tensor(
[[[[0. 0. 0. 0. 0. 0. 1. 1.]]]
 [[[0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(2, 1, 1, 8), dtype=float32)
--------------------
dec_out: tf.Tensor(
[[[-0.5652141  -1.0581813   1.600075    0.02332011]
  [-0.34019774 -1.2377603   1.5330346   0.04492359]
  [ 0.3675252  -1.4228352   1.3287866  -0.2734765 ]
  [ 0.09472068 -1.353683    1.4559422  -0.19697984]
  [-0.38392055 -1.0940721   1.6231282  -0.14513558]
  [-0.41729763 -1.0276326   1.6514215  -0.20649135]
  [-0.3302343  -1.0454822   1.6500466  -0.27433014]
  [-0.1923209  -1.1254803   1.6149355  -0.29713422]
  [ 0.40822834 -1.3586452   1.3515034  -0.40108633]
  [ 0.19979587 -1.4183372   1.3857942  -0.1672527 ]]

 [[-0.56504554 -1.054449    1.602678    0.01681651]
  [-0.36043385 -1.2348608   1.5300139   0.0652808 ]
  [ 0.24521776 -1.4295446   1.3651297  -0.18080302]
  [-0.06483467 -1.3449186   1.4773033  -0.06755   ]
  [-0.41885287 -1.0775515   1.6267892  -0.1303851 ]
  [-0.40018192 -1.0338533   1.6504982  -0.21646297]
  [-0.3531929  -1.0375831   1.6523482  -0.26157203]
  [-0.24463172 -1.1371143   1.6107951  -0.22904922]
  [ 0.19615419 -1.362728    1.4271017  -0.2605278 ]
  [ 0.08419974 -1.3687493   1.4467624  -0.16221291]]], shape=(2, 10, 4), dtype=float32)
--------------------
decoder_layer1_block1.shape: (2, 2, 10, 10)
decoder_layer1_block2.shape: (2, 2, 10, 8)
decoder_layer2_block1.shape: (2, 2, 10, 10)
decoder_layer2_block2.shape: (2, 2, 10, 8)

麻雀雖小荐吉,五臟俱全焙糟。雖然我們是使用 demo 數(shù)據(jù),但基本上這就是你在呼叫 Decoder 時(shí)需要做的所有事情:

  • 初始時(shí)給它中文(目標(biāo)語言)的字典大小样屠、其他超參數(shù)
  • 輸入中文 batch 的索引序列
  • 也要輸入兩個(gè)遮罩以及 Encoder 輸出enc_out

Decoder 的輸出你現(xiàn)在應(yīng)該都可以很輕松地解讀才是穿撮。基本上跟 Decoder layer 一模一樣,只差在我們額外輸出一個(gè) Python dict,里頭存放所有 Decoder layers 的注意權(quán)重怎爵。

6.7 第一個(gè) Transformer

沒錯(cuò)脯燃,終于到了這個(gè)時(shí)刻。在實(shí)現(xiàn) Transformer 之前先點(diǎn)擊影片來簡(jiǎn)單回顧一下我們?cè)谶@一章實(shí)現(xiàn)了什么些玩意兒:

transformer-imple.gif

Transformer 本身只有 3 個(gè) layers

在我們前面已經(jīng)將大大小小的 layers 一一實(shí)作并組裝起來以后炊豪,真正的 Transformer 模型只需要 3 個(gè)元件:

  1. Encoder
  2. Decoder
  3. Final linear layer

馬上讓我們看看 Transformer 的實(shí)現(xiàn):

# Transformer 之上已經(jīng)沒有其他 layers 了,我們使用 tf.keras.Model 建立一個(gè)模型
class Transformer(tf.keras.Model):
    # 初始參數(shù)包含 Encoder & Decoder 都需要超參數(shù)以及中英字典數(shù)目
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, rate)
        # 這個(gè) FFN 輸出跟中文字典一樣大的 logits 數(shù),等通過 softmax 就代表每個(gè)中文字的出現(xiàn)機(jī)率
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
  
    # enc_padding_mask 跟 dec_padding_mask 都是英文序列的 padding mask傍衡,
    # 只是一個(gè)給 Encoder layer 的 MHA 用,一個(gè)是給 Decoder layer 的 MHA 2 使用
    def call(self, inp, tar, training, enc_padding_mask, combined_mask, dec_padding_mask):
        enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
    
        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.decoder(tar, enc_output, training, combined_mask, dec_padding_mask)
    
        # 將 Decoder 輸出通過最後一個(gè) linear layer
        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)
    
        return final_output, attention_weights

扣掉注解负蠕,Transformer 的實(shí)現(xiàn)本身非常簡(jiǎn)短蛙埂。

被輸入Transformer 的多個(gè)2 維英文張量inp 會(huì)一路通過Encoder 里頭的詞嵌入層,位置編碼以及N 個(gè)Encoder layers 后被轉(zhuǎn)換成Encoder 輸出enc_output遮糖,接著對(duì)應(yīng)的中文序列tar 則會(huì)在Decoder 里頭走過相似的旅程并在每一層的Decoder layer 利用MHA 2 關(guān)注Encoder 的輸出enc_output绣的,最后被Decoder 輸出。

而 Decoder 的輸出 dec_output 則會(huì)通過 Final linear layer欲账,被轉(zhuǎn)成進(jìn)入 Softmax 前的 logits final_output屡江,其 logit 的數(shù)目則跟中文字典里的子詞數(shù)相同。

因?yàn)門ransformer 把Decoder 也包起來了赛不,現(xiàn)在我們連Encoder 輸出enc_output也不用管惩嘉,只要把英文(來源)以及中文(目標(biāo))的索引序列batch 丟入Transformer,它就會(huì)輸出最后一維為中文字典大小的張量踢故。第 2 維是輸出序列文黎,里頭每一個(gè)位置的向量就代表著該位置的中文字的概率分布(事實(shí)上通過 softmax 才是,但這邊先這樣說方便你理解):

  • 輸入:
    • 英文序列:(batch_size, inp_seq_len)
    • 中文序列:(batch_size, tar_seq_len)
  • 輸出:
    • 生成序列:(batch_size, tar_seq_len, target_vocab_size)
    • 注意權(quán)重的 dict

讓我們馬上建一個(gè) Transformer殿较,并假設(shè)我們已經(jīng)準(zhǔn)備好用 demo 數(shù)據(jù)來訓(xùn)練它做英翻中:

# 超參數(shù)
num_layers = 1
d_model = 4
num_heads = 2
dff = 8

# + 2 是為了 <start> & <end> token
input_vocab_size = subword_encoder_en.vocab_size + 2
output_vocab_size = subword_encoder_zh.vocab_size + 2

# 重點(diǎn)中的重點(diǎn)耸峭。訓(xùn)練時(shí)用前一個(gè)字來預(yù)測(cè)下一個(gè)中文字
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]

# 來源 / 目標(biāo)語言用的遮罩。注意 `comined_mask` 已經(jīng)將目標(biāo)語言的兩種遮罩合而為一
inp_padding_mask = create_padding_mask(inp)
tar_padding_mask = create_padding_mask(tar_inp)
look_ahead_mask = create_look_ahead_mask(tar_inp.shape[1])
combined_mask = tf.math.maximum(tar_padding_mask, look_ahead_mask)

# 初始化我們的第一個(gè) transformer
transformer = Transformer(num_layers, d_model, num_heads, dff, 
                          input_vocab_size, output_vocab_size)

# 將英文斜脂、中文序列丟入取得 Transformer 預(yù)測(cè)下個(gè)中文字的結(jié)果
predictions, attn_weights = transformer(inp, tar_inp, False, inp_padding_mask, 
                                        combined_mask, inp_padding_mask)

print("tar:", tar)
print("-" * 20)
print("tar_inp:", tar_inp)
print("-" * 20)
print("tar_real:", tar_real)
print("-" * 20)
print("predictions:", predictions)
tar: tf.Tensor(
[[4205   10  241   86   27    3 4206    0    0    0]
 [4205  165  489  398  191   14    7  560    3 4206]], shape=(2, 10), dtype=int64)
--------------------
tar_inp: tf.Tensor(
[[4205   10  241   86   27    3 4206    0    0]
 [4205  165  489  398  191   14    7  560    3]], shape=(2, 9), dtype=int64)
--------------------
tar_real: tf.Tensor(
[[  10  241   86   27    3 4206    0    0    0]
 [ 165  489  398  191   14    7  560    3 4206]], shape=(2, 9), dtype=int64)
--------------------
predictions: tf.Tensor(
[[[ 0.01349578 -0.00199539 -0.00217387 ... -0.03862738 -0.03212879
   -0.07692747]
  [ 0.037483    0.01585471 -0.02548708 ... -0.04276202 -0.02495992
   -0.05491882]
  [ 0.05718528  0.0288353  -0.04577483 ... -0.0450176  -0.01315334
   -0.03639907]
  ...
  [ 0.01202047 -0.00400385 -0.00099438 ... -0.03859971 -0.03085513
   -0.0797975 ]
  [ 0.0235797   0.00501019 -0.01193091 ... -0.04091505 -0.02892826
   -0.06939011]
  [ 0.04867784  0.02382022 -0.03683803 ... -0.04392421 -0.01941058
   -0.04347047]]

 [[ 0.01676657 -0.00080312 -0.00556347 ... -0.03981712 -0.02937311
   -0.07665333]
  [ 0.03873826  0.01607161 -0.02685272 ... -0.04328423 -0.02345929
   -0.05522631]
  [ 0.0564083   0.02865588 -0.04492006 ... -0.04475704 -0.014088
   -0.03639095]
  ...
  [ 0.01514172 -0.00298804 -0.00426158 ... -0.03976889 -0.02800199
   -0.07974622]
  [ 0.02867933  0.00800282 -0.01704068 ... -0.04215823 -0.02618418
   -0.06638923]
  [ 0.05056309  0.02489874 -0.03880978 ... -0.04421616 -0.01803544
   -0.04204436]]], shape=(2, 9, 4207), dtype=float32)

有了前面的各種 layers抓艳,建立一個(gè) Transformer 并不難。但要輸入什么數(shù)據(jù)就是一門大學(xué)問了:

...
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
predictions, attn_weights = transformer(inp, tar_inp, False, ...)
...

為何是丟少了尾巴一個(gè)字的 tar_inp 序列進(jìn)去 Transformer,而不是直接丟 tar 呢玷或?

別忘記我們才剛初始一個(gè) Transformer儡首,里頭所有 layers 的權(quán)重都是隨機(jī)的,你可不能指望它真的會(huì)什么「黑魔法」來幫你翻譯偏友。我們得先訓(xùn)練才行蔬胯。但訓(xùn)練時(shí)如果你把整個(gè)正確的中文序列 tar都進(jìn)去給 Transformer 看,你期待它產(chǎn)生什么位他?一首新的中文詩嗎氛濒?

如果你曾經(jīng)實(shí)現(xiàn)過序列生成模型或是看過我之前的語言模型文章,就會(huì)知道在序列生成任務(wù)里頭鹅髓,模型獲得的正確答案是輸入序列往左位移一個(gè)位置的結(jié)果舞竿。

這樣講很抽象,讓我們看個(gè)影片了解序列生成是怎么運(yùn)作的:

了解序列生成以及如何訓(xùn)練一個(gè)生成模型

你現(xiàn)在應(yīng)該明白 Transformer 在訓(xùn)練的時(shí)候并不是吃進(jìn)去整個(gè)中文序列窿冯,而是吃進(jìn)去一個(gè)去掉尾巴的序列 tar_inp骗奖,然后試著去預(yù)測(cè)「左移」一個(gè)字以后的序列 tar_real。同樣概念當(dāng)然也可以運(yùn)用到以 RNN 或是 CNN-based 的模型上面醒串。

從影片中你也可以發(fā)現(xiàn)給定 tar_inp 序列中的任一位置执桌,其對(duì)應(yīng)位置的 tar_real 就是下個(gè)時(shí)間點(diǎn)模型應(yīng)該要預(yù)測(cè)的中文字。

序列生成任務(wù)可以被視為是一個(gè)分類任務(wù)(Classification)芜赌,而每一個(gè)中文字都是一個(gè)分類仰挣。而 Transformer 就是要去產(chǎn)生一個(gè)中文字的概率分布,想辦法跟正解越接近越好缠沈。

跟用已訓(xùn)練的Transformer 做預(yù)測(cè)時(shí)不同膘壶,在訓(xùn)練時(shí)為了穩(wěn)定模型表現(xiàn),我們并不會(huì)將Transformer 的輸出再度丟回去當(dāng)做其輸入(人形蜈蚣洲愤?)香椎,而是像影片中所示,給它左移一個(gè)位置后的序列tar_real 當(dāng)作正解讓它去最小化error禽篱。

這種無視模型預(yù)測(cè)結(jié)果,而將正確解答丟入的訓(xùn)練方法一般被稱作 teacher forcing馍惹。你也可以參考教授的 Sequence-to-sequence Learning 教學(xué)躺率。

7. 定義損失函數(shù)與指標(biāo)

因?yàn)楸灰暈槭且粋€(gè)分類任務(wù),我們可以使用 cross entropy 來計(jì)算序列生成任務(wù)中實(shí)際的中文字跟模型預(yù)測(cè)的中文字分布(distribution)相差有多遠(yuǎn)万矾。

這邊簡(jiǎn)單定義一個(gè)損失函數(shù):

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

# 假設(shè)我們要解的是一個(gè) binary classifcation悼吱, 0 跟 1 個(gè)代表一個(gè) label
real = tf.constant([1, 1, 0], shape=(1, 3), dtype=tf.float32)
pred = tf.constant([[0, 1], [0, 1], [0, 1]], dtype=tf.float32)
loss_object(real, pred)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.31326166, 0.31326166, 1.3132616 ], dtype=float32)>

如果你曾做過分類問題,應(yīng)該能看出預(yù)測(cè)序列pred 里頭的第 3 個(gè)預(yù)測(cè)結(jié)果出錯(cuò)因此 entropy 值上升良狈。損失函數(shù)loss_object 做的事情就是比較 2 個(gè)序列并計(jì)算 cross entropy:

  • real:一個(gè)包含 N 個(gè)正確 labels 的序列
  • pred:一個(gè)包含 N 個(gè)維度為 label 數(shù)的 logit 序列

我們?cè)谶@邊將 reduction 參數(shù)設(shè)為 none后添,請(qǐng)loss_object 不要把每個(gè)位置的 error 加總。而這是因?yàn)槲覀冎笠约喊?<pad> token 出現(xiàn)的位置的損失舍棄不計(jì)薪丁。

而將 from_logits 參數(shù)設(shè)為True 是因?yàn)閺?Transformer 得到的預(yù)測(cè)還沒有經(jīng)過 softmax遇西,因此加和還不等于 1:

print("predictions:", predictions)
print("-" * 20)
print(tf.reduce_sum(predictions, axis=-1))
predictions: tf.Tensor(
[[[ 0.01349578 -0.00199539 -0.00217387 ... -0.03862738 -0.03212879
   -0.07692747]
  [ 0.037483    0.01585471 -0.02548708 ... -0.04276202 -0.02495992
   -0.05491882]
  [ 0.05718528  0.0288353  -0.04577483 ... -0.0450176  -0.01315334
   -0.03639907]
  ...
  [ 0.01202047 -0.00400385 -0.00099438 ... -0.03859971 -0.03085513
   -0.0797975 ]
  [ 0.0235797   0.00501019 -0.01193091 ... -0.04091505 -0.02892826
   -0.06939011]
  [ 0.04867784  0.02382022 -0.03683803 ... -0.04392421 -0.01941058
   -0.04347047]]

 [[ 0.01676657 -0.00080312 -0.00556347 ... -0.03981712 -0.02937311
   -0.07665333]
  [ 0.03873826  0.01607161 -0.02685272 ... -0.04328423 -0.02345929
   -0.05522631]
  [ 0.0564083   0.02865588 -0.04492006 ... -0.04475704 -0.014088
   -0.03639095]
  ...
  [ 0.01514172 -0.00298804 -0.00426158 ... -0.03976889 -0.02800199
   -0.07974622]
  [ 0.02867933  0.00800282 -0.01704068 ... -0.04215823 -0.02618418
   -0.06638923]
  [ 0.05056309  0.02489874 -0.03880978 ... -0.04421616 -0.01803544
   -0.04204436]]], shape=(2, 9, 4207), dtype=float32)
--------------------
tf.Tensor(
[[1.3761909 2.9352095 3.8687317 3.4191105 2.608357  1.5664345 1.1489892
  1.9882674 3.5525477]
 [1.4309797 2.9219136 3.873899  3.5009165 2.6499162 1.6611676 1.1839213
  2.2150593 3.6206641]], shape=(2, 9), dtype=float32)

有了 loss_object 實(shí)際算 cross entropy 以后馅精,我們需要另外一個(gè)函數(shù)來建立遮罩并加總序列里頭不包含 token位置的損失:

def loss_function(real, pred):
  # 這次的 mask 將序列中不等于 0 的位置視為 1,其余為 0
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  # 照樣計(jì)算所有位置的 cross entropy 但不加總
  loss_ = loss_object(real, pred)
  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask  # 只計(jì)算非 <pad> 位置的損失
  
  return tf.reduce_mean(loss_)

我另外再定義兩個(gè) tf.keras.metrics粱檀,方便之后使用 TensorBoard 來追蹤模型 performance:

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

8. 設(shè)置超參數(shù)

前面實(shí)現(xiàn)了那么多 layers洲敢,你應(yīng)該還記得有哪些是你自己可以調(diào)整的超參數(shù)吧?

讓我?guī)湍闳苛谐鰜恚?/p>

  • num_layers 決定 Transfomer 里頭要有幾個(gè) Encoder / Decoder layers
  • d_model 決定我們子詞的 representation space 維度
  • num_heads 要做幾頭的自注意力運(yùn)算
  • dff 決定 FFN 的中間維度
  • dropout_rate 預(yù)設(shè) 0.1茄蚯,一般用預(yù)設(shè)值即可
  • input_vocab_size:輸入語言(英文)的字典大小
  • target_vocab_size:輸出語言(中文)的字典大小

論文里頭最基本的 Transformer 配置為:

  • num_layers=6
  • d_model=512
  • dff=2048

有大量數(shù)據(jù)以及大的 Transformer压彭,你可以在很多機(jī)器學(xué)習(xí)任務(wù)都達(dá)到不錯(cuò)的成績(jī)。為了不要讓訓(xùn)練時(shí)間太長(zhǎng)渗常,在這篇文章里頭我會(huì)把 Transformer 里頭的超參數(shù)設(shè)小一點(diǎn):

num_layers = 4 
d_model = 128
dff = 512
num_heads = 8

input_vocab_size = subword_encoder_en.vocab_size + 2
target_vocab_size = subword_encoder_zh.vocab_size + 2
dropout_rate = 0.1  # 預(yù)設(shè)值

print("input_vocab_size:", input_vocab_size)
print("target_vocab_size:", target_vocab_size)
input_vocab_size: 8115
target_vocab_size: 4207

4 層 Encoder / Decoder layers 不算貪心壮不,小巫見大巫(笑

9. 設(shè)置 Optimizer

我們?cè)谶@邊跟論文一致,使用 Adam optimizer 以及自定義的 learning rate scheduler:

lr-equation.jpg

這 schedule 讓訓(xùn)練過程的前 warmup_steps 的 learning rate 線性增加皱碘,在那之后則跟步驟數(shù) step_num的反平方根成比例下降询一。不用擔(dān)心你沒有完全理解這公式,我們一樣可以直接使用 TensorFlow 官方教學(xué)的實(shí)現(xiàn):

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    # 論文預(yù)設(shè) `warmup_steps` = 4000
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
    
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
# 將客制化 learning rate schdeule 丟入 Adam opt.
# Adam opt. 的參數(shù)都跟論文相同
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

我們可以觀察看看這個(gè) schedule 是怎么隨著訓(xùn)練步驟而改變 learning rate 的:

d_models = [128, 256, 512]
warmup_steps = [1000 * i for i in range(1, 4)]

schedules = []
labels = []
colors = ["blue", "red", "black"]
for d in d_models:
  schedules += [CustomSchedule(d, s) for s in warmup_steps]
  labels += [f"d_model: c22oaio, warm: {s}" for s in warmup_steps]

for i, (schedule, label) in enumerate(zip(schedules, labels)):
  plt.plot(schedule(tf.range(10000, dtype=tf.float32)), 
           label=label, color=colors[i // 3])

plt.legend()

plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
Text(0.5, 0, 'Train Step')
不同 d_model 以及 warmup_steps 的 learning rate 變化

你可以明顯地看到所有 schedules 都先經(jīng)過 warmup_steps 個(gè)步驟直線提升 learning rate尸执,接著逐漸平滑下降家凯。另外我們也會(huì)給比較高維的 d_model 維度比較小的 learning rate。

10. 實(shí)際訓(xùn)練以及定時(shí)存檔

好啦如失,什么都準(zhǔn)備齊全了绊诲,讓我們開始訓(xùn)練 Transformer 吧!記得使用前面已經(jīng)定義好的超參數(shù)來初始化一個(gè)全新的 Transformer:

transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, dropout_rate)

print(f"""這個(gè) Transformer 有 {num_layers} 層 Encoder / Decoder layers
d_model: {d_model}
num_heads: {num_heads}
dff: {dff}
input_vocab_size: {input_vocab_size}
target_vocab_size: {target_vocab_size}
dropout_rate: {dropout_rate}

""")

這個(gè) Transformer 有 4 層 Encoder / Decoder layers
d_model: 128
num_heads: 8
dff: 512
input_vocab_size: 8115
target_vocab_size: 4207
dropout_rate: 0.1

打游戲時(shí)你會(huì)記得要定期存檔以防任何意外發(fā)生褪贵,訓(xùn)練深度學(xué)習(xí)模型也是同樣道理掂之。設(shè)置 checkpoint 來定期儲(chǔ)存 / 讀取模型及 optimizer 是必備的。

我們?cè)诘紫聲?huì)定義一個(gè) checkpoint 路徑脆丁,此路徑包含了各種超參數(shù)的信息世舰,方便之后比較不同實(shí)驗(yàn)的結(jié)果并載入已訓(xùn)練的進(jìn)度。我們也需要一個(gè) checkpoint manager 來做所有跟存讀模型有關(guān)的雜事槽卫,并只保留最新 5 個(gè) checkpoints 以避免占用太多空間:

# 方便比較不同實(shí)驗(yàn)/ 不同超參數(shù)設(shè)定的結(jié)果
run_id = f"{num_layers}layers_{d_model}d_{num_heads}heads_{dff}dff_{train_perc}train_perc"
checkpoint_path = os.path.join(checkpoint_path, run_id)
log_dir = os.path.join(log_dir, run_id)

# tf.train.Checkpoint 可以幫我們把想要存下來的東西整合起來跟压,方便儲(chǔ)存與讀取
# 一般來說你會(huì)想存下模型以及 optimizer 的狀態(tài)
ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

# ckpt_manager 會(huì)去 checkpoint_path 看有沒有符合 ckpt 里頭定義的東西
# 存檔的時(shí)候只保留最近 5 次 checkpoints,其他自動(dòng)刪除
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# 如果在 checkpoint 路徑上有發(fā)現(xiàn)檔案就讀進(jìn)來
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
  
    # 如果在 checkpoint 路徑上有發(fā)現(xiàn)檔案就讀進(jìn)來
    last_epoch = int(ckpt_manager.latest_checkpoint.split("-")[-1])
    print(f'已讀取最新的 checkpoint歼培,模型已訓(xùn)練 {last_epoch} epochs震蒋。')
else:
    last_epoch = 0
    print("沒找到 checkpoint,從頭訓(xùn)練躲庄。")
沒找到 checkpoint查剖,從頭訓(xùn)練。

我知道你在想什么噪窘。

「誒K褡? 你不當(dāng)場(chǎng)訓(xùn)練嗎?」「直接載入已訓(xùn)練的模型太狗了吧直砂!」

拜托菌仁,我都訓(xùn)練 N 遍了,每次都重新訓(xùn)練也太沒意義了哆键。而且你能想像為了寫一個(gè)章節(jié)我就得重新訓(xùn)練一個(gè) Transformer 來 demo 嗎掘托?這樣太沒效率了。比起每次重新訓(xùn)練模型籍嘹,這才是你在真實(shí)世界中應(yīng)該做的事情:盡可能恢復(fù)之前的訓(xùn)練進(jìn)度來節(jié)省時(shí)間闪盔。

不過放心,我仍會(huì)秀出完整的訓(xùn)練代碼讓你可以執(zhí)行第一次的訓(xùn)練辱士。當(dāng)你想要依照本文訓(xùn)練自己的 Transformer 時(shí)會(huì)感謝有 checkpoint manager 的存在±嵯疲現(xiàn)在假設(shè)我們還沒有 checkpoints。

在實(shí)際訓(xùn)練 Transformer 之前還需要定義一個(gè)簡(jiǎn)單函數(shù)來產(chǎn)生所有的遮罩:

# 為 Transformer 的 Encoder / Decoder 準(zhǔn)備遮罩
def create_masks(inp, tar):
  # 英文句子的 padding mask颂碘,要交給 Encoder layer 自注意力機(jī)制用的
  enc_padding_mask = create_padding_mask(inp)
  
  # 同樣也是英文句子的 padding mask异赫,但是是要交給 Decoder layer 的 MHA 2 
  # 關(guān)注 Encoder 輸出序列用的
  dec_padding_mask = create_padding_mask(inp)
  
  # Decoder layer 的 MHA1 在做自注意力機(jī)制用的
  # `combined_mask` 是中文句子的 padding mask 跟 look ahead mask 的疊加
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
  return enc_padding_mask, combined_mask, dec_padding_mask

如果沒有本文前面針對(duì)遮罩的詳細(xì)說明,很多第一次實(shí)現(xiàn)的人得花不少時(shí)間來確實(shí)地掌握這些遮罩的用途头岔。不過對(duì)現(xiàn)在的你來說應(yīng)該也是小菜一碟塔拳。

一個(gè)數(shù)據(jù)集包含多個(gè) batch,而每次拿一個(gè) batch 來訓(xùn)練的步驟就稱作 train_step峡竣。為了讓程式碼更簡(jiǎn)潔以及容易優(yōu)化靠抑,我們會(huì)定義 Transformer 在一次訓(xùn)練步驟(處理一個(gè) batch)所需要做的所有事情。

不限于 Transformer适掰,一般來說 train_step 函數(shù)里會(huì)有幾個(gè)重要步驟:

  • 對(duì)訓(xùn)練數(shù)據(jù)做些必要的前處理
  • 將數(shù)據(jù)丟入模型颂碧,取得預(yù)測(cè)結(jié)果
  • 用預(yù)測(cè)結(jié)果跟正確解答計(jì)算 loss
  • 取出梯度并利用 optimizer 做梯度下降

有了這個(gè)概念以后看看代碼:

@tf.function  # 讓 TensorFlow 幫我們將 eager code 優(yōu)化并加快運(yùn)算
def train_step(inp, tar):
    # 前面說過的,用去尾的原始序列去預(yù)測(cè)下一個(gè)字的序列
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]
  
    # 建立 3 個(gè)遮罩
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
  
    # 紀(jì)錄 Transformer 的所有運(yùn)算過程以方便之后做梯度下降
    with tf.GradientTape() as tape:
        # 注意是丟入 `tar_inp` 而非 `tar`类浪。記得將 `training` 參數(shù)設(shè)定為 True
        predictions, _ = transformer(inp, tar_inp, 
                                     True, 
                                     enc_padding_mask, 
                                     combined_mask, 
                                     dec_padding_mask)
        # 跟影片中顯示的相同载城,計(jì)算左移一個(gè)字的序列跟模型預(yù)測(cè)分布之間的差異,當(dāng)作 loss
        loss = loss_function(tar_real, predictions)

    # 取出梯度并呼叫前面定義的 Adam optimizer 幫我們更新 Transformer 里頭可訓(xùn)練的參數(shù)
    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
  
    # 將 loss 以及訓(xùn)練 acc 記錄到 TensorBoard 上费就,非必要
    train_loss(loss)
    train_accuracy(tar_real, predictions)

如果你曾經(jīng)以TensorFlow 2 實(shí)現(xiàn)過稍微復(fù)雜一點(diǎn)的模型诉瓦,應(yīng)該就知道 train_step函數(shù)的寫法非常固定:

  • 對(duì)輸入數(shù)據(jù)做些前處理(本文中的遮罩、將輸出序列左移當(dāng)成正解 etc.)
  • 利用 tf.GradientTape 輕松記錄數(shù)據(jù)被模型做的所有轉(zhuǎn)換并計(jì)算 loss
  • 將梯度取出并讓 optimzier 對(duì)可被訓(xùn)練的權(quán)重做梯度下降(上升)

你完全可以用一模一樣的方式將任何復(fù)雜模型的處理過程包在train_step 函數(shù)力细,這樣可以讓我們之后在 iterate 數(shù)據(jù)集時(shí)非常輕松垦搬。而且最重要的是可以用 tf.function 來提高此函數(shù)里頭運(yùn)算的速度。你可以點(diǎn)擊連結(jié)來了解更多艳汽。

處理一個(gè) batch 的 train_step 函數(shù)也有了,就只差寫個(gè) for loop 將數(shù)據(jù)集跑個(gè)幾遍了对雪。我之前的模型雖然訓(xùn)練了 50 個(gè) epochs河狐,但事實(shí)上大概 30 epochs 翻譯的結(jié)果就差不多穩(wěn)定了。所以讓我們將 EPOCHS 設(shè)定為 30:

# 定義我們要看幾遍數(shù)據(jù)集
EPOCHS = 30
print(f"此超參數(shù)組合的 Transformer 已經(jīng)訓(xùn)練 {last_epoch} epochs。")
print(f"剩余 epochs:{min(0, last_epoch - EPOCHS)}")


# 用來寫資訊到 TensorBoard馋艺,非必要但十分推薦
summary_writer = tf.summary.create_file_writer(log_dir)

# 比對(duì)設(shè)定的 `EPOCHS` 以及已訓(xùn)練的 `last_epoch` 來決定還要訓(xùn)練多少 epochs
for epoch in range(last_epoch, EPOCHS):
    start = time.time()
  
    # 重置紀(jì)錄 TensorBoard 的 metrics
    train_loss.reset_states()
    train_accuracy.reset_states()
  
    # 一個(gè) epoch 就是把我們定義的訓(xùn)練資料集一個(gè)一個(gè) batch 拿出來處理栅干,直到看完整個(gè)數(shù)據(jù)集
    for (step_idx, (inp, tar)) in enumerate(train_dataset):
        # 每次 step 就是將數(shù)據(jù)丟入 Transformer,讓它生預(yù)測(cè)結(jié)果并計(jì)算梯度最小化 loss
        train_step(inp, tar)  

    # 每個(gè) epoch 完成就存一次檔
    if (epoch + 1) % 1 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
    
    # 將 loss 以及 accuracy 寫到 TensorBoard 上
    with summary_writer.as_default():
        tf.summary.scalar("train_loss", train_loss.result(), step=epoch + 1)
        tf.summary.scalar("train_acc", train_accuracy.result(), step=epoch + 1)
  
    print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                        train_loss.result(), 
                                                        train_accuracy.result()))
    print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
此超參數(shù)組合的 Transformer 已經(jīng)訓(xùn)練 0 epochs捐祠。
剩余 epochs:-30
Saving checkpoint for epoch 1 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-1
Epoch 1 Loss 5.1843 Accuracy 0.0219
Time taken for 1 epoch: 89.55020833015442 secs

Saving checkpoint for epoch 2 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-2
Epoch 2 Loss 4.2425 Accuracy 0.0604
Time taken for 1 epoch: 21.873889207839966 secs

Saving checkpoint for epoch 3 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-3
Epoch 3 Loss 3.7423 Accuracy 0.0987
Time taken for 1 epoch: 21.901566743850708 secs

Saving checkpoint for epoch 4 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-4
Epoch 4 Loss 3.2644 Accuracy 0.1512
Time taken for 1 epoch: 22.083024501800537 secs

Saving checkpoint for epoch 5 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-5
Epoch 5 Loss 2.9634 Accuracy 0.1810
Time taken for 1 epoch: 22.050684452056885 secs

Saving checkpoint for epoch 6 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-6
Epoch 6 Loss 2.7756 Accuracy 0.1988
Time taken for 1 epoch: 25.719687461853027 secs

Saving checkpoint for epoch 7 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-7
Epoch 7 Loss 2.6346 Accuracy 0.2122
Time taken for 1 epoch: 22.85287618637085 secs

Saving checkpoint for epoch 8 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-8
Epoch 8 Loss 2.5183 Accuracy 0.2242
Time taken for 1 epoch: 18.721409797668457 secs

Saving checkpoint for epoch 9 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-9
Epoch 9 Loss 2.4171 Accuracy 0.2353
Time taken for 1 epoch: 18.663178205490112 secs

Saving checkpoint for epoch 10 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-10
Epoch 10 Loss 2.3204 Accuracy 0.2458
Time taken for 1 epoch: 25.891611576080322 secs

Saving checkpoint for epoch 11 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-11
Epoch 11 Loss 2.2223 Accuracy 0.2573
Time taken for 1 epoch: 18.789816856384277 secs

Saving checkpoint for epoch 12 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-12
Epoch 12 Loss 2.1319 Accuracy 0.2685
Time taken for 1 epoch: 22.33806586265564 secs

Saving checkpoint for epoch 13 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-13
Epoch 13 Loss 2.0458 Accuracy 0.2796
Time taken for 1 epoch: 18.877813816070557 secs

Saving checkpoint for epoch 14 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-14
Epoch 14 Loss 1.9643 Accuracy 0.2912
Time taken for 1 epoch: 18.858903884887695 secs

Saving checkpoint for epoch 15 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-15
Epoch 15 Loss 1.8875 Accuracy 0.3020
Time taken for 1 epoch: 18.890562295913696 secs

Saving checkpoint for epoch 16 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-16
Epoch 16 Loss 1.8178 Accuracy 0.3120
Time taken for 1 epoch: 22.47147297859192 secs

Saving checkpoint for epoch 17 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-17
Epoch 17 Loss 1.7531 Accuracy 0.3211
Time taken for 1 epoch: 18.98854422569275 secs

Saving checkpoint for epoch 18 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-18
Epoch 18 Loss 1.6899 Accuracy 0.3305
Time taken for 1 epoch: 18.987966775894165 secs

Saving checkpoint for epoch 19 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-19
Epoch 19 Loss 1.6200 Accuracy 0.3406
Time taken for 1 epoch: 18.95727038383484 secs

Saving checkpoint for epoch 20 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-20
Epoch 20 Loss 1.5555 Accuracy 0.3499
Time taken for 1 epoch: 18.99857258796692 secs

Saving checkpoint for epoch 21 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-21
Epoch 21 Loss 1.4968 Accuracy 0.3590
Time taken for 1 epoch: 19.01795792579651 secs

Saving checkpoint for epoch 22 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-22
Epoch 22 Loss 1.4447 Accuracy 0.3668
Time taken for 1 epoch: 19.078711986541748 secs

Saving checkpoint for epoch 23 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-23
Epoch 23 Loss 1.3984 Accuracy 0.3738
Time taken for 1 epoch: 19.144370317459106 secs

Saving checkpoint for epoch 24 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-24
Epoch 24 Loss 1.3535 Accuracy 0.3805
Time taken for 1 epoch: 19.05727791786194 secs

Saving checkpoint for epoch 25 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-25
Epoch 25 Loss 1.3142 Accuracy 0.3866
Time taken for 1 epoch: 22.631419897079468 secs

Saving checkpoint for epoch 26 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-26
Epoch 26 Loss 1.2765 Accuracy 0.3926
Time taken for 1 epoch: 19.017268657684326 secs

Saving checkpoint for epoch 27 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-27
Epoch 27 Loss 1.2441 Accuracy 0.3969
Time taken for 1 epoch: 19.065359115600586 secs

Saving checkpoint for epoch 28 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-28
Epoch 28 Loss 1.2106 Accuracy 0.4023
Time taken for 1 epoch: 19.06916570663452 secs

Saving checkpoint for epoch 29 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-29
Epoch 29 Loss 1.1835 Accuracy 0.4068
Time taken for 1 epoch: 19.07039451599121 secs

Saving checkpoint for epoch 30 at nmt/checkpoints/4layers_128d_8heads_512dff_20train_perc/ckpt-30
Epoch 30 Loss 1.1560 Accuracy 0.4107
Time taken for 1 epoch: 19.10555648803711 secs

如信息所示碱鳞,當(dāng)指定的 EPOCHS「落后」于之前的訓(xùn)練進(jìn)度我們就不再訓(xùn)練了。但如果是第一次訓(xùn)練或是訓(xùn)練到指定 EPOCHS的一部分踱蛀,我們都會(huì)從正確的地方開始訓(xùn)練并存檔窿给,不會(huì)浪費(fèi)到訓(xùn)練時(shí)間或計(jì)算資源。

這邊的邏輯也很簡(jiǎn)單率拒,在每個(gè) epoch 都:

  • (非必要)重置寫到 TensorBoard 的 metrics 的值
  • 將整個(gè)數(shù)據(jù)集的 batch 取出崩泡,交給 train_step 函數(shù)處理
  • (非必要)存 checkpoints
  • (非必要)將當(dāng)前 epoch 結(jié)果寫到 TensorBoard
  • (非必要)在標(biāo)準(zhǔn)輸出顯示當(dāng)前 epoch 結(jié)果

是的,如果你真的只是想要訓(xùn)練個(gè)模型猬膨,什么其他事情都不想考慮的話那你可以:

# 87 分角撞,不能再高了。
for epoch in range(EPOCHS):
  for inp, tar in train_dataset:
    train_step(inp, tar)

嗯 ... 話是這么說勃痴,但我仍然建議你至少要記得存檔并將訓(xùn)練過程顯示出來谒所。

編者按:我是在Google的Colab Notebooks中進(jìn)行的訓(xùn)練,在這個(gè)計(jì)算能力下沛申,我們定義的 4 層 Transformer 大約每 19 秒就可以看完一遍有 3 萬筆訓(xùn)練例子的數(shù)據(jù)集劣领,而且你從上面的 loss 以及 accuracy 可以看出來 Transformer 至少在訓(xùn)練集里頭進(jìn)步地挺快的。

而就我自己的觀察大約經(jīng)過 30 個(gè) epochs 翻譯結(jié)果就很穩(wěn)定了污它。所以你大約只需半個(gè)小時(shí)就能有一個(gè)非常簡(jiǎn)單剖踊,有點(diǎn)水準(zhǔn)的英翻中 Transformer(在至少有個(gè)一般 GPU 的情況)。

但跟看上面的 log 比起來衫贬,我個(gè)人還是比較推薦使用 TensorBoard德澈。在 TensorFlow 2 里頭,你甚至能直接在 Jupyter Notebook 或是 Colab 里頭開啟它:

%load_ext tensorboard
%tensorboard --logdir {log_dir}
<IPython.core.display.Javascript object>
使用 TensorBoard 可以讓你輕松比較不同超參數(shù)的訓(xùn)練結(jié)果

透過 TensorBoard固惯,你能非常清楚地比較不同實(shí)驗(yàn)以及不同點(diǎn)子的效果梆造,知道什么 work 什么不 work,進(jìn)而修正之后嘗試的方向葬毫。如果只是簡(jiǎn)單寫個(gè)print镇辉,那你永遠(yuǎn)只會(huì)看到最新一次訓(xùn)練過程的 log仆葡,然后忘記之前到底發(fā)生過什么事习劫。

11. 實(shí)際進(jìn)行英翻中

有了已經(jīng)訓(xùn)練一陣子的 Transformer,當(dāng)然得拿它來實(shí)際做做翻譯损敷。

跟訓(xùn)練的時(shí)候不同烂斋,在做預(yù)測(cè)時(shí)我們不需做 teacher forcing 來穩(wěn)定 Transformer 的訓(xùn)練過程屹逛。反之础废,我們將 Transformer 在每個(gè)時(shí)間點(diǎn)生成的中文索引加到之前已經(jīng)生成的序列尾巴,并以此新序列作為其下一次的輸入罕模。這是因?yàn)?Transformer 事實(shí)上是一個(gè)自回歸模型(Auto-regressive model):依據(jù)自己生成的結(jié)果預(yù)測(cè)下次輸出评腺。

利用 Transformer 進(jìn)行翻譯(預(yù)測(cè))的邏輯如下:

  • 將輸入的英文句子利用 Subword Tokenizer 轉(zhuǎn)換成子詞索引序列(還記得 inp 吧?)

  • 在該英文索引序列前后加上代表英文 BOS / EOS 的tokens

  • 在 Transformer 輸出序列長(zhǎng)度達(dá)到 MAX_LENGTH 之前重復(fù)以下步驟:

    • 為目前已經(jīng)生成的中文索引序列產(chǎn)生新的遮罩
    • 將剛剛的英文序列淑掌、當(dāng)前的中文序列以及各種遮罩放入 Transformer
    • 將 Transformer 輸出序列的最后一個(gè)位置的向量取出蒿讥,并取 argmax 取得新的預(yù)測(cè)中文索引
    • 將此索引加到目前的中文索引序列里頭作為 Transformer 到此為止的輸出結(jié)果
    • 如果新生成的中文索引為 <end> 則代表中文翻譯已全部生成完畢,直接回傳
  • 將最后得到的中文索引序列回傳作為翻譯結(jié)果

是的抛腕,一個(gè)時(shí)間點(diǎn)生成一個(gè)中文字芋绸,而在第一個(gè)時(shí)間點(diǎn)因?yàn)?Transformer 還沒有任何輸出,我們會(huì)丟中文字的 <start> token 進(jìn)去兽埃。你可能會(huì)想:

為何每次翻譯開頭都是 start token侥钳,Transformer 還能產(chǎn)生不一樣且正確的結(jié)果?

答案也很簡(jiǎn)單柄错,因?yàn)?Decoder 可以透過「關(guān)注」 Encoder 處理完不同英文句子的輸出來獲得語義信息舷夺,了解它在當(dāng)下該生成什么中文字作為第一個(gè)輸出。

現(xiàn)在讓我們定義一個(gè) evaluate函數(shù)實(shí)現(xiàn)上述邏輯售貌。此函數(shù)的輸入是一個(gè)完全沒有經(jīng)過處理的英文句子(以字串表示)给猾,輸出則是一個(gè)索引序列,里頭的每個(gè)索引就代表著 Transformer 預(yù)測(cè)的中文字颂跨。

讓我們實(shí)際看看 evaluate 函數(shù):

# 給定一個(gè)英文句子敢伸,輸出預(yù)測(cè)的中文索引數(shù)字序列以及注意權(quán)重 dict
def evaluate(inp_sentence):
  
  # 準(zhǔn)備英文句子前後會(huì)加上的 <start>, <end>
  start_token = [subword_encoder_en.vocab_size]
  end_token = [subword_encoder_en.vocab_size + 1]
  
  # inp_sentence 是字串,我們用 Subword Tokenizer 將其變成子詞的索引序列
  # 並在前後加上 BOS / EOS
  inp_sentence = start_token + subword_encoder_en.encode(inp_sentence) + end_token
  encoder_input = tf.expand_dims(inp_sentence, 0)
  
  # 跟我們?cè)谟捌e看到的一樣恒削,Decoder 在第一個(gè)時(shí)間點(diǎn)吃進(jìn)去的輸入
  # 是一個(gè)只包含一個(gè)中文 <start> token 的序列
  decoder_input = [subword_encoder_zh.vocab_size]
  output = tf.expand_dims(decoder_input, 0)  # 增加 batch 維度
  
  # auto-regressive池颈,一次生成一個(gè)中文字並將預(yù)測(cè)加到輸入再度餵進(jìn) Transformer
  for i in range(MAX_LENGTH):
    # 每多一個(gè)生成的字就得產(chǎn)生新的遮罩
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)
  
    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)
    

    # 將序列中最後一個(gè) distribution 取出,並將裡頭值最大的當(dāng)作模型最新的預(yù)測(cè)字
    predictions = predictions[: , -1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
    
    # 遇到 <end> token 就停止回傳钓丰,代表模型已經(jīng)產(chǎn)生完結(jié)果
    if tf.equal(predicted_id, subword_encoder_zh.vocab_size + 1):
      return tf.squeeze(output, axis=0), attention_weights
    
    #將 Transformer 新預(yù)測(cè)的中文索引加到輸出序列中躯砰,讓 Decoder 可以在產(chǎn)生
    # 下個(gè)中文字的時(shí)候關(guān)注到最新的 `predicted_id`
    output = tf.concat([output, predicted_id], axis=-1)

  # 將 batch 的維度去掉後回傳預(yù)測(cè)的中文索引序列
  return tf.squeeze(output, axis=0), attention_weights

我知道這章代碼很多很長(zhǎng),但搭配注解后你會(huì)發(fā)現(xiàn)它們實(shí)際上都不難携丁,而且這也是你看這篇文章的主要目的:實(shí)際了解 Transformer 是怎么做英中翻譯的琢歇。你不想只是紙上談兵,對(duì)吧梦鉴?

有了 evaluate 函數(shù)李茫,要透過 Transformer 做翻譯非常容易:

# 要被翻譯的英文句子
sentence = "China, India, and others have enjoyed continuing economic growth."

# 取得預(yù)測(cè)的中文索引序列
predicted_seq, _ = evaluate(sentence)

# 過濾掉 <start> & <end> tokens 并用中文的 subword tokenizer 幫我們將索引序列還原回中文句子
target_vocab_size = subword_encoder_zh.vocab_size
predicted_seq_without_bos_eos = [idx for idx in predicted_seq if idx < target_vocab_size]
predicted_sentence = subword_encoder_zh.decode(predicted_seq_without_bos_eos)

print("sentence:", sentence)
print("-" * 20)
print("predicted_seq:", predicted_seq)
print("-" * 20)
print("predicted_sentence:", predicted_sentence)
sentence: China, India, and others have enjoyed continuing economic growth.
--------------------
predicted_seq: tf.Tensor(
[4205   16    4   36  378  100    8   35   32    4   33  111  945  189
   22   49  105   83    3], shape=(19,), dtype=int32)
--------------------
predicted_sentence: 中國(guó)、印度和其他國(guó)家都享受經(jīng)濟(jì)增長(zhǎng)肥橙。

考慮到這個(gè) Transformer 不算巨大(約 400 萬個(gè)參數(shù))魄宏,且模型訓(xùn)練時(shí)用的數(shù)據(jù)集不大的情況下,我們達(dá)到相當(dāng)不錯(cuò)的結(jié)果存筏,你說是吧娜庇?在這個(gè)例子里頭該翻的詞匯都翻了出來塔次,句子本身也還算自然。

transformer.summary()
Model: "transformer_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_2 (Encoder)          multiple                  1831808   
_________________________________________________________________
decoder_2 (Decoder)          multiple                  1596800   
_________________________________________________________________
dense_137 (Dense)            multiple                  542703    
=================================================================
Total params: 3,971,311
Trainable params: 3,971,311
Non-trainable params: 0
_________________________________________________________________

12. 可視化注意權(quán)重

除了其運(yùn)算高度并行以及表現(xiàn)不錯(cuò)以外名秀,Transformer 另外一個(gè)優(yōu)點(diǎn)在于我們可以透過可視化注意權(quán)重(attention weights)來了解模型實(shí)際在生成序列的時(shí)候放「注意力」在哪里。別忘記我們當(dāng)初在 Decoder layers 做完 multi-head attention 之后都將注意權(quán)重輸出∨航Γ現(xiàn)在正是它們派上用場(chǎng)的時(shí)候了匕得。

先讓我們看看有什么注意權(quán)重可以拿來可視化:

predicted_seq, attention_weights = evaluate(sentence)

# 在這邊我們自動(dòng)選擇最后一個(gè) Decoder layer 的 MHA 2,也就是 Decoder 關(guān)注 Encoder 的 MHA
layer_name = f"decoder_layer{num_layers}_block2"

print("sentence:", sentence)
print("-" * 20)
print("predicted_seq:", predicted_seq)
print("-" * 20)
print("attention_weights.keys():")
for layer_name, attn in attention_weights.items():
  print(f"{layer_name}.shape: {attn.shape}")
print("-" * 20)
print("layer_name:", layer_name)
sentence: China, India, and others have enjoyed continuing economic growth.
--------------------
predicted_seq: tf.Tensor(
[4205   16    4   36  378  100    8   35   32    4   33  111  945  189
   22   49  105   83    3], shape=(19,), dtype=int32)
--------------------
attention_weights.keys():
decoder_layer1_block1.shape: (1, 8, 19, 19)
decoder_layer1_block2.shape: (1, 8, 19, 15)
decoder_layer2_block1.shape: (1, 8, 19, 19)
decoder_layer2_block2.shape: (1, 8, 19, 15)
decoder_layer3_block1.shape: (1, 8, 19, 19)
decoder_layer3_block2.shape: (1, 8, 19, 15)
decoder_layer4_block1.shape: (1, 8, 19, 19)
decoder_layer4_block2.shape: (1, 8, 19, 15)
--------------------
layer_name: decoder_layer4_block2
  • block1 代表是Decoder layer 自己關(guān)注自己的MHA 1巾表,因此倒數(shù)兩個(gè)維度都跟中文序列長(zhǎng)度相同汁掠;

  • block2 則是Decoder layer 用來關(guān)注Encoder 輸出的MHA 2 ,在這邊我們選擇最后一個(gè)Decoder layer 的MHA 2來看Transformer 在生成中文序列時(shí)關(guān)注在英文句子的那些位置集币。

但首先考阱,我們得要有一個(gè)繪圖的函數(shù)才行:

import matplotlib as mpl
# 你可能會(huì)需要自行下載一個(gè)中文字體檔案以讓 matplotlib 正確顯示中文
zhfont = mpl.font_manager.FontProperties(fname='tensorflow-datasets/SimHei.ttf')
plt.style.use("seaborn-whitegrid")

# 這個(gè)函數(shù)將英 -> 中翻譯的注意權(quán)重視覺化(注意:我們將注意權(quán)重 transpose 以最佳化渲染結(jié)果
def plot_attention_weights(attention_weights, sentence, predicted_seq, layer_name, max_len_tar=None):
    
    fig = plt.figure(figsize=(17, 7))
    sentence = subword_encoder_en.encode(sentence)
  
    # 只顯示中文序列前 `max_len_tar` 個(gè)字以避免畫面太過壅擠
    if max_len_tar:
        predicted_seq = predicted_seq[:max_len_tar]
    else:
        max_len_tar = len(predicted_seq)
  
    # 將某一個(gè)特定 Decoder layer 里頭的 MHA 1 或 MHA2 的注意權(quán)重拿出來并去掉 batch 維度
    attention_weights = tf.squeeze(attention_weights[layer_name], axis=0)  
    # (num_heads, tar_seq_len, inp_seq_len)
  
    # 將每個(gè) head 的注意權(quán)重畫出
    for head in range(attention_weights.shape[0]):
        ax = fig.add_subplot(2, 4, head + 1)

        # [注意]我為了將長(zhǎng)度不短的英文子詞顯示在 y 軸,將注意權(quán)重做了 transpose
        attn_map = np.transpose(attention_weights[head][:max_len_tar, :])
        ax.matshow(attn_map, cmap='viridis')  # (inp_seq_len, tar_seq_len)
    
        fontdict = {"fontproperties": zhfont}
    
        ax.set_xticks(range(max(max_len_tar, len(predicted_seq))))
        ax.set_xlim(-0.5, max_len_tar -1.5)
    
        ax.set_yticks(range(len(sentence) + 2))
        ax.set_xticklabels([subword_encoder_zh.decode([i]) for i in predicted_seq 
                            if i < subword_encoder_zh.vocab_size], 
                           fontdict=fontdict, fontsize=18)    
    
        ax.set_yticklabels(
            ['<start>'] + [subword_encoder_en.decode([i]) for i in sentence] + ['<end>'], 
            fontdict=fontdict)
    
        ax.set_xlabel('Head {}'.format(head + 1))
        ax.tick_params(axis="x", labelsize=12)
        ax.tick_params(axis="y", labelsize=12)
        
      
    plt.tight_layout()
    plt.show()
    plt.close(fig)

這個(gè)函數(shù)不難鞠苟,且里頭不少是調(diào)整圖片的細(xì)節(jié)設(shè)定因此我將它留給你自行參考乞榨。

比較值得注意的是因?yàn)槲覀冊(cè)谶@篇文章是做英文(來源)到中文(目標(biāo))的翻譯,注意權(quán)重的 shape 為:

(batch_size, num_heads, zh_seq_len, en_seq_len)

如果你直接把注意權(quán)重繪出的話 y 軸就會(huì)是每個(gè)中文字当娱,而 x 軸則會(huì)是每個(gè)英文子詞吃既。而英文子詞繪在 x 軸太占空間,我將每個(gè)注意權(quán)重都做 transpose 并呈現(xiàn)結(jié)果跨细,這點(diǎn)你得注意一下鹦倚。

讓我們馬上畫出剛剛翻譯的注意權(quán)重看看:

plot_attention_weights(attention_weights, sentence, 
                       predicted_seq, layer_name, max_len_tar=18)
注意力權(quán)重可視化.png

盡管其運(yùn)算機(jī)制十分錯(cuò)綜復(fù)雜,閱讀本文后 Transformer 對(duì)你來說不再是黑魔法冀惭,也不再是遙不可及的存在震叙。如果你現(xiàn)在覺得「Transformer 也不過就這樣嘛!」那就達(dá)成我寫這篇文章的目的了散休。

自注意力機(jī)制以及Transformer 在推出之后就被非常廣泛地使用并改進(jìn)媒楼,但在我自己開始接觸相關(guān)知識(shí)以后一直沒有發(fā)現(xiàn)完整的繁中教學(xué),因此寫了這篇當(dāng)初的我殷殷期盼的文章溃槐,也希望能幫助到更多人學(xué)習(xí)匣砖。

在進(jìn)入結(jié)語之前,讓我們看看文中的 Transformer 是怎么逐漸學(xué)會(huì)做好翻譯的:

attention_weights_change_by_time.gif

13. 在你離開之前

這篇是當(dāng)初在學(xué)習(xí) Transformer 的我希望有人分享給自己的文章昏滴。

我相信人類之所以強(qiáng)大是因?yàn)榧w知識(shí):我們能透過書籍猴鲫、影片以及語言將一個(gè)人腦中的知識(shí)與思想共享給其他人,讓寶貴的知識(shí)能夠「scale」谣殊,在更多人的腦袋中發(fā)光發(fā)熱拂共,創(chuàng)造更多價(jià)值。

我希望你有從本文中學(xué)到一點(diǎn)東西姻几,并幫助我將本文的這些知識(shí)「scale」宜狐,把文章分享給更多有興趣的人势告,并利用所學(xué)應(yīng)用在一些你一直想要完成的任務(wù)上面。

最后一點(diǎn)提醒抚恒,就算Transformer比古早時(shí)代的方法好再多多終究也只是個(gè)工具咱台,其最大價(jià)值不會(huì)超過于被你拿來應(yīng)用的問題之上。就好像現(xiàn)在已有很多超越基本Transformer的翻譯 方法俭驮,但我們?nèi)匀怀掷m(xù)在追尋更好的機(jī)器翻譯系統(tǒng)回溺。

工具會(huì)被淘汰,需求一直都在混萝。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末遗遵,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子逸嘀,更是在濱河造成了極大的恐慌车要,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,324評(píng)論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件崭倘,死亡現(xiàn)場(chǎng)離奇詭異翼岁,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)绳姨,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,356評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門登澜,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人飘庄,你說我怎么就攤上這事脑蠕。” “怎么了跪削?”我有些...
    開封第一講書人閱讀 162,328評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵谴仙,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我碾盐,道長(zhǎng)晃跺,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,147評(píng)論 1 292
  • 正文 為了忘掉前任毫玖,我火速辦了婚禮掀虎,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘付枫。我一直安慰自己烹玉,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,160評(píng)論 6 388
  • 文/花漫 我一把揭開白布阐滩。 她就那樣靜靜地躺著二打,像睡著了一般。 火紅的嫁衣襯著肌膚如雪掂榔。 梳的紋絲不亂的頭發(fā)上继效,一...
    開封第一講書人閱讀 51,115評(píng)論 1 296
  • 那天症杏,我揣著相機(jī)與錄音,去河邊找鬼瑞信。 笑死厉颤,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的凡简。 我是一名探鬼主播走芋,決...
    沈念sama閱讀 40,025評(píng)論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼潘鲫!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起肋杖,我...
    開封第一講書人閱讀 38,867評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤溉仑,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后状植,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體浊竟,經(jīng)...
    沈念sama閱讀 45,307評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,528評(píng)論 2 332
  • 正文 我和宋清朗相戀三年津畸,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了振定。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 39,688評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡肉拓,死狀恐怖后频,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情暖途,我是刑警寧澤卑惜,帶...
    沈念sama閱讀 35,409評(píng)論 5 343
  • 正文 年R本政府宣布,位于F島的核電站驻售,受9級(jí)特大地震影響露久,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜欺栗,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,001評(píng)論 3 325
  • 文/蒙蒙 一毫痕、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧迟几,春花似錦消请、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,657評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至存哲,卻和暖如春因宇,著一層夾襖步出監(jiān)牢的瞬間七婴,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,811評(píng)論 1 268
  • 我被黑心中介騙來泰國(guó)打工察滑, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留打厘,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,685評(píng)論 2 368
  • 正文 我出身青樓贺辰,卻偏偏與公主長(zhǎng)得像户盯,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子饲化,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,573評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容