編者按:年初疫情在家期間開(kāi)始大量閱讀NLP領(lǐng)域的經(jīng)典論文对扶,在學(xué)習(xí)《Attention Is All You Need》時(shí)發(fā)現(xiàn)了一位現(xiàn)居日本的數(shù)據(jù)科學(xué)家LeeMeng寫(xiě)的Transformer詳解博客月而,理論講解+代碼實(shí)操+動(dòng)畫(huà)演示的寫(xiě)作風(fēng)格,在眾多文章中獨(dú)樹(shù)一幟挖滤,實(shí)為新手學(xué)習(xí)Transformer的上乘資料崩溪,在通讀以及實(shí)操多遍之后,現(xiàn)在將其編輯整理成簡(jiǎn)體中文分享給大家斩松。由于原文實(shí)在太長(zhǎng)伶唯,為了便于閱讀學(xué)習(xí),這里將其分為四個(gè)部分:
- 透過(guò)機(jī)器翻譯理解Transformer(一):關(guān)于機(jī)器翻譯
- 透過(guò)機(jī)器翻譯理解Transformer(二):師傅引進(jìn)門(mén)惧盹,修行在個(gè)人—建立輸入管道
- 透過(guò)機(jī)器翻譯理解Transformer(三):理解 Transformer 之旅:跟著多維向量去冒險(xiǎn)
- 透過(guò)機(jī)器翻譯理解Transformer(四):打造 Transformer:疊疊樂(lè)時(shí)間
在涉及代碼部分乳幸,強(qiáng)烈推薦大家在Google的Colab Notebooks中實(shí)際操作一遍,之所以推薦Colab Notebooks是因?yàn)?).這里有免費(fèi)可以使用的GPU資源钧椰;2). 可以避免很多安裝包出錯(cuò)的問(wèn)題
本文用到的數(shù)據(jù):
鏈接:https://pan.baidu.com/s/1Ku1GH8a_NqHUxYs-uf0htg
提取碼:tbor
本節(jié)目錄
- 師傅引進(jìn)門(mén)粹断,修行在個(gè)人
- Transformer 11 個(gè)重要概念回顧
- 安裝Python庫(kù)并設(shè)置環(huán)境
- 建立輸入管道
- 4.1 下載并準(zhǔn)備數(shù)據(jù)集
- 4.2 切割數(shù)據(jù)集
- 4.3 建立中文與英文字典
- 4.4 數(shù)據(jù)預(yù)處理
- 建立輸入管道
1. 師傅引進(jìn)門(mén),修行在個(gè)人
你回來(lái)了嗎嫡霞?還是等不及待地想繼續(xù)往下閱讀瓶埋?
接下來(lái)我們會(huì)進(jìn)入實(shí)際的代碼實(shí)現(xiàn)。但跟前半段相比難度呈指數(shù)型上升诊沪,因此我只推薦符合以下條件的讀者閱讀:
- 想透過(guò)實(shí)現(xiàn) Transformer 來(lái)徹底了解其內(nèi)部運(yùn)作原理的人
- 愿意先花 1 小時(shí)了解 Transformer 的細(xì)節(jié)概念與理論的人
你馬上就會(huì)知道 1 個(gè)小時(shí)代表什么意思养筒。如果你覺(jué)得這聽(tīng)起來(lái)很 ok,那可以繼續(xù)閱讀端姚。
在機(jī)器翻譯近代史一章我們已經(jīng)花了不少篇幅講解了許多在實(shí)現(xiàn) Transformer 時(shí)會(huì)有幫助的重要概念晕粪,其中包含:
- Seq2Seq 模型的運(yùn)作原理
- 注意力機(jī)制的概念與計(jì)算過(guò)程
- 自注意力機(jī)制與 Transformer 的精神
壞消息是,深度學(xué)習(xí)里頭理論跟實(shí)現(xiàn)的差異常常是很大的渐裸。盡管這些背景知識(shí)對(duì)理解Transformer 的精神非常有幫助兵多,對(duì)從來(lái)沒(méi)有用過(guò)RNN 實(shí)現(xiàn)文本生成或是以Seq2Seq 模型+ 注意力機(jī)制實(shí)現(xiàn)過(guò)NMT 的人來(lái)說(shuō),要在第一次就正確實(shí)現(xiàn)Transformer 仍是一個(gè)巨大的挑戰(zhàn)橄仆。
就算不說(shuō)理論跟實(shí)現(xiàn)的差異剩膘,讓我們看看 TensorFlow 官方釋出的最新 Transformer 教學(xué)里頭有多少內(nèi)容:
上面是我用這輩子最快的速度卷動(dòng)該頁(yè)面再加速后的結(jié)果,可以看出內(nèi)容還真不少盆顾。盡管中文化很重要怠褐,我在這篇文章里不會(huì)幫你把其中的敘述翻成中文(畢竟你的英文可能比我好)
反之,我將利用 TensorFlow 官方的代碼您宪,以最適合「初學(xué)者」理解的實(shí)現(xiàn)順序來(lái)講述 Transformer 的重要技術(shù)細(xì)節(jié)及概念奈懒。在閱讀本文之后,你將有能力自行理解 TensorFlow 官方教學(xué)以及其他網(wǎng)絡(luò)上的實(shí)現(xiàn)(比方說(shuō) HarvardNLP 以 Pytorch 實(shí)現(xiàn)的 The Annotated Transformer宪巨。
李宏毅教授前陣子才在他 2019 年的臺(tái)大機(jī)器學(xué)習(xí)課程發(fā)布了 Transformer 的教學(xué)影片磷杏,而這可以說(shuō)是世界上最好的中文教學(xué)影片。如果你真的想要深入理解 Transformer捏卓,在實(shí)現(xiàn)前至少把上面的影片看完吧极祸!你可以少走很多彎路。
實(shí)現(xiàn)時(shí)我會(huì)盡量重述關(guān)鍵概念,但如果有先看影片你會(huì)比較容易理解我在碎碎念什么遥金。如果看完影片你的小宇宙開(kāi)始發(fā)光發(fā)熱浴捆,也可以先讀讀 Transformer 的原始論文,跟很多學(xué)術(shù)論文比起來(lái)相當(dāng)好讀稿械,真心不騙选泻。
重申一次,除非你已經(jīng)了解基本注意力機(jī)制的運(yùn)算以及 Transformer 的整體架構(gòu)美莫,否則我不建議繼續(xù)閱讀页眯。
2. Transformer 11 個(gè)重要概念回顧
怎么樣?你應(yīng)該已經(jīng)從教授的課程中學(xué)到不少重要概念了吧厢呵?我不知道你還記得多少窝撵,但讓我非常簡(jiǎn)單地幫你復(fù)習(xí)一下。
自注意力層(Self-Attention Layer)跟 RNN 一樣述吸,輸入是一個(gè)序列,輸出一個(gè)序列锣笨。但是該層可以并行計(jì)算蝌矛,且輸出序列中的每個(gè)向量都已經(jīng)看了整個(gè)序列的信息。
自注意力層將輸入序列
I
里頭的每個(gè)位置的向量i
透過(guò)3 個(gè)線(xiàn)性轉(zhuǎn)換分別變成3 個(gè)向量:q
错英、k
和v
入撒,并將每個(gè)位置的q
拿去跟序列中其他位置的k
做匹配茅逮,算出匹配程度后利用softmax 層取得介于0 到1 之間的權(quán)重值挺身,并以此權(quán)重跟每個(gè)位置的v
作加權(quán)平均贱傀,最后取得該位置的輸出向量o
椰棘。全部位置的輸出向量可以同時(shí)并行計(jì)算巨朦,最后輸出序列O
棚蓄。計(jì)算匹配程度(注意)的方法不只一種,只要能吃進(jìn) 2 個(gè)向量并吐出一個(gè)數(shù)值即可褥紫。但在 Transformer 論文原文是將 2 向量做 dot product 算匹配程度绳军。
我們可以透過(guò)大量矩陣運(yùn)算以及 GPU 將概念 2 提到的注意力機(jī)制的計(jì)算全部并行化奶是,加快訓(xùn)練效率(也是本文實(shí)現(xiàn)的重點(diǎn))沮趣。
多頭注意力機(jī)制(Multi-head Attention)是將輸入序列中的每個(gè)位置的
q
翁狐、k
和v
切割成多個(gè)qi
懈词、ki
和vi
再分別各自進(jìn)行注意力機(jī)制。各自處理完以后把所有結(jié)果串接并視情況降維计贰。這樣的好處是能讓各個(gè) head 各司其職钦睡,學(xué)會(huì)關(guān)注序列中不同位置在不同 representaton spaces 的信息蒂窒。自注意力機(jī)制這樣的計(jì)算的好處是「天涯若比鄰」:序列中每個(gè)位置都可以在 O(1) 的距離內(nèi)關(guān)注任一其他位置的信息躁倒,運(yùn)算效率較雙向 RNN 優(yōu)秀。
自注意力層可以取代 Seq2Seq 模型里頭以 RNN 為基礎(chǔ)的 Encoder / Decoder洒琢,而實(shí)際上全部替換掉后就(大致上)是 Transformer秧秉。
自注意力機(jī)制預(yù)設(shè)沒(méi)有「先后順序」的概念,而這也是為何其可以快速并行運(yùn)算的原因衰抑。在進(jìn)行如機(jī)器翻譯等序列生成任務(wù)時(shí)象迎,我們需要額外加入位置編碼(Positioning Encoding)來(lái)加入順序信息。而在 Transformer 原論文中此值為手設(shè)而非訓(xùn)練出來(lái)的模型權(quán)重呛踊。
Transformer 是一個(gè) Seq2Seq 模型砾淌,自然包含了 Encoder / Decoder,而 Encoder 及 Decoder 可以包含多層結(jié)構(gòu)相同的 blocks谭网,里頭每層都會(huì)有 multi-head attention 以及 Feed Forward Network汪厨。
在每個(gè) Encoder / Decoder block 里頭,我們還會(huì)使用殘差連結(jié)(Residual Connection)以及 Layer Normalization愉择。這些能幫助模型穩(wěn)定訓(xùn)練劫乱。
Decoder 在關(guān)注 Encoder 輸出時(shí)會(huì)需要遮罩(mask)來(lái)避免看到未來(lái)信息织中。我們后面會(huì)看到,事實(shí)上還會(huì)需要其他遮罩衷戈。
這些應(yīng)該是你在看完影片后學(xué)到的東西狭吼。如果你想要快速?gòu)?fù)習(xí),這里則是教授課程的 PDF 殖妇。
另外你之后也可以隨時(shí)透過(guò)左側(cè)導(dǎo)覽的圖片 icon 來(lái)快速回顧 Transformer 的整體架構(gòu)以及教授添加的注解刁笙。我相信在實(shí)現(xiàn)的時(shí)候它可以幫得上點(diǎn)忙:
有了這些背景知識(shí)以后,在理解代碼時(shí)會(huì)輕松許多拉一。你也可以一邊執(zhí)行 TensorFlow 官方的 Colab 筆記本一邊參考底下實(shí)現(xiàn)采盒。
好戲登場(chǎng)!
3. 安裝Python庫(kù)并設(shè)置環(huán)境
在這邊我們導(dǎo)入一些常用的 Python 庫(kù)蔚润,這應(yīng)該不需要特別說(shuō)明磅氨。
from google.colab import drive
drive.mount('/content/drive/')
Mounted at /content/drive/
import os
import time
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from pprint import pprint
from IPython.display import clear_output
比較值得注意的是我們將以最新的 TensorFlow 2 Beta 版本(編者注:在這里我使用的是TF==2.3)來(lái)實(shí)現(xiàn)本文的 Transformer。另外也會(huì)透過(guò) TensorFlow Datasets 來(lái)使用前人幫我們準(zhǔn)備好的英中翻譯資料集:
!pip install tensorflow # stable
clear_output()
!pip show tensorflow
Name: tensorflow
Version: 2.3.0
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: packages@tensorflow.org
License: Apache 2.0
Location: /usr/local/lib/python3.6/dist-packages
Requires: google-pasta, scipy, gast, absl-py, astunparse, numpy, protobuf, termcolor, wrapt, grpcio, wheel, tensorboard, six, tensorflow-estimator, opt-einsum, h5py, keras-preprocessing
Required-by: fancyimpute
!pip install -q tensorflow-datasets
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
print(tf.__version__)
2.3.0
另外為了避免 TensorFlow 吐給我們太多不必要的信息嫡纠,在此文中我也將改變 logging 等級(jí)烦租。在 TensorFlow 2 里頭因?yàn)?tf.logging 被 deprecated,我們可以直接用 logging
模組來(lái)做到這件事情:
import logging
# logging.basicConfig(level="error")
np.set_printoptions(suppress=True)
我們同時(shí)也讓 numpy 不要顯示科學(xué)記號(hào)除盏。這樣可以讓我們之后在做一些 Tensor 運(yùn)算的時(shí)候版面能干凈一點(diǎn)叉橱。
接著定義一些之后在儲(chǔ)存檔案時(shí)會(huì)用到的路徑變量:
output_dir = "nmt"
en_vocab_file = os.path.join(output_dir, "en_vocab")
zh_vocab_file = os.path.join(output_dir, "zh_vocab")
checkpoint_path = os.path.join(output_dir, "checkpoints")
log_dir = os.path.join(output_dir, 'logs')
download_dir = "tensorflow-datasets/downloads"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
4. 建立輸入管道
現(xiàn)行的 GPU 以及 TPU 能透過(guò)并行運(yùn)算幫我們顯著地縮短訓(xùn)練一個(gè) step 所需的時(shí)間。而為了讓并行計(jì)算能發(fā)揮最佳性能者蠕,我們需要最佳化輸入管道(Input pipeline)窃祝,以在當(dāng)前訓(xùn)練步驟完成之前就準(zhǔn)備好下一個(gè)時(shí)間點(diǎn) GPU 要用的數(shù)據(jù)。
而我們將透過(guò) tf.data API以及前面導(dǎo)入的 TensorFlow Datasets 來(lái)建立高效的輸入管道踱侣,并將機(jī)器翻譯競(jìng)賽 WMT 2019 的中英數(shù)據(jù)集準(zhǔn)備好粪小。
4.1 下載并準(zhǔn)備數(shù)據(jù)集
首先看看 tfds
里頭 WMT 2019 的中英翻譯有哪些資料來(lái)源:
tmp_builder = tfds.builder("wmt19_translate/zh-en")
pprint(tmp_builder.subsets)
{NamedSplit('train'): ['newscommentary_v14',
'wikititles_v1',
'uncorpus_v1',
'casia2015',
'casict2011',
'casict2015',
'datum2015',
'datum2017',
'neu2017'],
NamedSplit('validation'): ['newstest2018']}
可以看到在 WMT 2019 里中英對(duì)照的數(shù)據(jù)來(lái)源還算不少。其中幾個(gè)很好猜到其性質(zhì):
- 聯(lián)合國(guó)數(shù)據(jù):
uncorpus_v1
- 維基百科標(biāo)題:
wikititles_v1
- 新聞評(píng)論:
newscommentary_v14
雖然大量數(shù)據(jù)對(duì)訓(xùn)練神經(jīng)網(wǎng)路很有幫助抡句,本文為了節(jié)省訓(xùn)練 Transformer 所需的時(shí)間探膊,在這里我們就只選擇一個(gè)數(shù)據(jù)來(lái)源當(dāng)作數(shù)據(jù)集。至于要選哪個(gè)數(shù)據(jù)來(lái)源呢待榔?
聯(lián)合國(guó)的數(shù)據(jù)非常龐大逞壁,而維基百科標(biāo)題通常內(nèi)容很短,新聞評(píng)論感覺(jué)是一個(gè)相對(duì)適合的選擇锐锣。我們可以在設(shè)定檔 config
里頭指定新聞評(píng)論這個(gè)數(shù)據(jù)來(lái)源并請(qǐng) TensorFlow Datasets 下載:
config = tfds.translate.wmt.WmtConfig(
version=tfds.core.Version("1.0.0"),
language_pair=("zh", "en"),
subsets={
tfds.Split.TRAIN: ["newscommentary_v14"]
}
)
builder = tfds.builder("wmt_translate", config=config)
builder.download_and_prepare(download_dir=download_dir)
clear_output()
上面的指令約需 2 分鐘完成腌闯,而在過(guò)程中tfds
幫我們完成不少工作:
- 下載包含原始數(shù)據(jù)的壓縮文檔
- 解壓縮得到 CSV 文件
- 逐行讀取該 CSV 里頭所有中英句子
- 將不符合格式的 row 自動(dòng)過(guò)濾
- Shuffle 數(shù)據(jù)
- 將原數(shù)據(jù)轉(zhuǎn)換成 TFRecord 數(shù)據(jù)以加速讀取
多花點(diǎn)時(shí)間把相關(guān) API 文件看熟,你就能把清理雕憔、準(zhǔn)備數(shù)據(jù)的時(shí)間花在建構(gòu)模型以及跑實(shí)驗(yàn)上面姿骏。
4.2 切割數(shù)據(jù)集
雖然我們只下載了一個(gè)新聞評(píng)論的數(shù)據(jù)集,里頭還是有超過(guò) 30 萬(wàn)筆的中英平行句子橘茉。為了減少訓(xùn)練所需的時(shí)間工腋,將剛剛處理好的新聞評(píng)論數(shù)據(jù)集再進(jìn)一步切成 3 個(gè)部分姨丈,數(shù)據(jù)量分布如下:
- Split 1:20% 數(shù)據(jù)
- Split 2:1% 數(shù)據(jù)
- Split 3:79% 數(shù)據(jù)
我們將前兩個(gè) splits 拿來(lái)當(dāng)作訓(xùn)練以及驗(yàn)證集,剩余的部分(第 3 個(gè) split)舍棄不用:
train_examples = builder.as_dataset(split='train[:20%]', as_supervised=True)
val_examples = builder.as_dataset(split='train[20%:21%]', as_supervised=True)
print(train_examples)
print(val_examples)
<DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.string)>
<DatasetV1Adapter shapes: ((), ()), types: (tf.string, tf.string)>
你可以在這里找到更多跟 [split] 相關(guān)的用法擅腰。
這時(shí)候 train_examples
跟 val_examples
都已經(jīng)是 tf.data.Dataset蟋恬。我們?cè)?strong>數(shù)據(jù)預(yù)處理一節(jié)會(huì)看到這些數(shù)據(jù)在被丟入神經(jīng)網(wǎng)絡(luò)前需要經(jīng)過(guò)什么轉(zhuǎn)換,不過(guò)現(xiàn)在先讓我們簡(jiǎn)單讀幾筆數(shù)據(jù)出來(lái)看看:
for en, zh in train_examples.take(3):
print(en)
print(zh)
print('-' * 10)
tf.Tensor(b'The fear is real and visceral, and politicians ignore it at their peril.', shape=(), dtype=string)
tf.Tensor(b'\xe8\xbf\x99\xe7\xa7\x8d\xe6\x81\x90\xe6\x83\xa7\xe6\x98\xaf\xe7\x9c\x9f\xe5\xae\x9e\xe8\x80\x8c\xe5\x86\x85\xe5\x9c\xa8\xe7\x9a\x84\xe3\x80\x82 \xe5\xbf\xbd\xe8\xa7\x86\xe5\xae\x83\xe7\x9a\x84\xe6\x94\xbf\xe6\xb2\xbb\xe5\xae\xb6\xe4\xbb\xac\xe5\x89\x8d\xe9\x80\x94\xe5\xa0\xaa\xe5\xbf\xa7\xe3\x80\x82', shape=(), dtype=string)
----------
tf.Tensor(b'In fact, the German political landscape needs nothing more than a truly liberal party, in the US sense of the word \xe2\x80\x9cliberal\xe2\x80\x9d \xe2\x80\x93 a champion of the cause of individual freedom.', shape=(), dtype=string)
tf.Tensor(b'\xe4\xba\x8b\xe5\xae\x9e\xe4\xb8\x8a\xef\xbc\x8c\xe5\xbe\xb7\xe5\x9b\xbd\xe6\x94\xbf\xe6\xb2\xbb\xe5\xb1\x80\xe5\x8a\xbf\xe9\x9c\x80\xe8\xa6\x81\xe7\x9a\x84\xe4\xb8\x8d\xe8\xbf\x87\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe7\xac\xa6\xe5\x90\x88\xe7\xbe\x8e\xe5\x9b\xbd\xe6\x89\x80\xe8\xb0\x93\xe2\x80\x9c\xe8\x87\xaa\xe7\x94\xb1\xe2\x80\x9d\xe5\xae\x9a\xe4\xb9\x89\xe7\x9a\x84\xe7\x9c\x9f\xe6\xad\xa3\xe7\x9a\x84\xe8\x87\xaa\xe7\x94\xb1\xe5\x85\x9a\xe6\xb4\xbe\xef\xbc\x8c\xe4\xb9\x9f\xe5\xb0\xb1\xe6\x98\xaf\xe4\xb8\xaa\xe4\xba\xba\xe8\x87\xaa\xe7\x94\xb1\xe4\xba\x8b\xe4\xb8\x9a\xe7\x9a\x84\xe5\x80\xa1\xe5\xaf\xbc\xe8\x80\x85\xe3\x80\x82', shape=(), dtype=string)
----------
tf.Tensor(b'Shifting to renewable-energy sources will require enormous effort and major infrastructure investment.', shape=(), dtype=string)
tf.Tensor(b'\xe5\xbf\x85\xe9\xa1\xbb\xe4\xbb\x98\xe5\x87\xba\xe5\xb7\xa8\xe5\xa4\xa7\xe7\x9a\x84\xe5\x8a\xaa\xe5\x8a\x9b\xe5\x92\x8c\xe5\x9f\xba\xe7\xa1\x80\xe8\xae\xbe\xe6\x96\xbd\xe6\x8a\x95\xe8\xb5\x84\xe6\x89\x8d\xe8\x83\xbd\xe5\xae\x8c\xe6\x88\x90\xe5\x90\x91\xe5\x8f\xaf\xe5\x86\x8d\xe7\x94\x9f\xe8\x83\xbd\xe6\xba\x90\xe7\x9a\x84\xe8\xbf\x87\xe6\xb8\xa1\xe3\x80\x82', shape=(), dtype=string)
----------
跟預(yù)期一樣趁冈,每一個(gè)例子(每一次的 take
)都包含了 2 個(gè)以 unicode 呈現(xiàn)的tf.Tensor
歼争。它們有一樣的語(yǔ)義,只是一個(gè)是英文渗勘,一個(gè)是中文沐绒。
讓我們將這些 Tensors 實(shí)際儲(chǔ)存的字串利用 numpy()
取出并解碼看看:
sample_examples = []
num_samples = 10
for en_t, zh_t in train_examples.take(num_samples):
en = en_t.numpy().decode("utf-8")
zh = zh_t.numpy().decode("utf-8")
print(en)
print(zh)
print('-' * 10)
# 之後用來(lái)簡(jiǎn)單評(píng)估模型的訓(xùn)練情況
sample_examples.append((en, zh))
The fear is real and visceral, and politicians ignore it at their peril.
這種恐懼是真實(shí)而內(nèi)在的。 忽視它的政治家們前途堪憂(yōu)旺坠。
----------
In fact, the German political landscape needs nothing more than a truly liberal party, in the US sense of the word “l(fā)iberal” – a champion of the cause of individual freedom.
事實(shí)上乔遮,德國(guó)政治局勢(shì)需要的不過(guò)是一個(gè)符合美國(guó)所謂“自由”定義的真正的自由黨派,也就是個(gè)人自由事業(yè)的倡導(dǎo)者取刃。
----------
Shifting to renewable-energy sources will require enormous effort and major infrastructure investment.
必須付出巨大的努力和基礎(chǔ)設(shè)施投資才能完成向可再生能源的過(guò)渡蹋肮。
----------
In this sense, it is critical to recognize the fundamental difference between “urban villages” and their rural counterparts.
在這方面,關(guān)鍵在于認(rèn)識(shí)到“城市村落”和農(nóng)村村落之間的根本區(qū)別璧疗。
----------
A strong European voice, such as Nicolas Sarkozy’s during the French presidency of the EU, may make a difference, but only for six months, and at the cost of reinforcing other European countries’ nationalist feelings in reaction to the expression of “Gallic pride.”
法國(guó)擔(dān)任輪值主席國(guó)期間尼古拉·薩科奇統(tǒng)一的歐洲聲音可能讓人耳目一新坯辩,但這種聲音卻只持續(xù)了短短六個(gè)月,而且付出了讓其他歐洲國(guó)家在面對(duì)“高盧人的驕傲”時(shí)民族主義情感進(jìn)一步被激發(fā)的代價(jià)崩侠。
----------
Most of Japan’s bondholders are nationals (if not the central bank) and have an interest in political stability.
日本債券持有人大多為本國(guó)國(guó)民(甚至中央銀行 ) 漆魔, 政治穩(wěn)定符合他們的利益。
----------
Paul Romer, one of the originators of new growth theory, has accused some leading names, including the Nobel laureate Robert Lucas, of what he calls “mathiness” – using math to obfuscate rather than clarify.
新增長(zhǎng)理論創(chuàng)始人之一的保羅·羅默(Paul Romer)也批評(píng)一些著名經(jīng)濟(jì)學(xué)家却音,包括諾貝爾獎(jiǎng)獲得者羅伯特·盧卡斯(Robert Lucas)在內(nèi)改抡,說(shuō)他們“數(shù)學(xué)性 ” ( 羅默的用語(yǔ))太重,結(jié)果是讓問(wèn)題變得更加模糊而不是更加清晰僧家。
----------
It is, in fact, a capsule depiction of the United States Federal Reserve and the European Central Bank.
事實(shí)上雀摘,這就是對(duì)美聯(lián)儲(chǔ)和歐洲央行的簡(jiǎn)略描述裸删。
----------
Given these variables, the degree to which migration is affected by asylum-seekers will not be easy to predict or control.
考慮到這些變量八拱,移民受尋求庇護(hù)者的影響程度很難預(yù)測(cè)或控制。
----------
WASHINGTON, DC – In the 2016 American presidential election, Hillary Clinton and Donald Trump agreed that the US economy is suffering from dilapidated infrastructure, and both called for greater investment in renovating and upgrading the country’s public capital stock.
華盛頓—在2016年美國(guó)總統(tǒng)選舉中涯塔,希拉里·克林頓和唐納德·特朗普都認(rèn)為美國(guó)經(jīng)濟(jì)飽受基礎(chǔ)設(shè)施陳舊的拖累肌稻,兩人都要求加大投資用于修繕和升級(jí)美國(guó)公共資本存量。
----------
想像一下沒(méi)有對(duì)應(yīng)的中文匕荸,要閱讀這些英文得花多少時(shí)間爹谭。你可以試著消化其中幾句中文與其對(duì)應(yīng)的英文句子,并比較一下所需要的時(shí)間差異榛搔。
雖然只是隨意列出的 10 個(gè)中英句子诺凡,你應(yīng)該跟我一樣也能感受到機(jī)器翻譯研究的重要以及其能帶給我們的價(jià)值东揣。
4.3 建立中文與英文字典
就跟大多數(shù) NLP項(xiàng)目相同,有了原始的中英句子以后我們得分別為其建立字典來(lái)將每個(gè)詞匯轉(zhuǎn)成索引(Index)腹泌。 tfds.features.text
底下的 SubwordTextEncoder
提供非常方便的 API 讓我們掃過(guò)整個(gè)訓(xùn)練資料集并建立字典嘶卧。
首先為英文語(yǔ)料建立字典:
%%time
try:
subword_encoder_en =tfds.deprecated.text.SubwordTextEncoder.load_from_file(en_vocab_file)
print(f"載入已建立的字典: {en_vocab_file}")
except:
print("沒(méi)有已建立的字典,從頭建立凉袱。")
subword_encoder_en = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
(en.numpy() for en, _ in train_examples),
target_vocab_size=2**13) # 有需要可以調(diào)整字典大小
# 將創(chuàng)建的字典存下以方便下次 warmstart
subword_encoder_en.save_to_file(en_vocab_file)
print(f"字典大薪嬉鳌:{subword_encoder_en.vocab_size}")
print(f"前 10 個(gè) subwords:{subword_encoder_en.subwords[:10]}")
print()
沒(méi)有已建立的字典,從頭建立专甩。
字典大兄油摇:8113
前 10 個(gè) subwords:[', ', 'the_', 'of_', 'to_', 'and_', 's_', 'in_', 'a_', 'is_', 'that_']
CPU times: user 1min 11s, sys: 3.32 s, total: 1min 14s
Wall time: 1min 6s
如果你的語(yǔ)料庫(kù)(corpus) 不小,要掃過(guò)整數(shù)據(jù)集并建立一個(gè)字典得花不少時(shí)間涤躲。因此實(shí)現(xiàn)上我們會(huì)先使用load_from_file
函式嘗試讀取之前已經(jīng)建好的字典檔案棺耍,失敗才 build_from_corpus
。
這招很基本种樱,但在你需要重復(fù)處理巨大語(yǔ)料庫(kù)時(shí)非常重要烈掠。
subword_encoder_en
則是利用 GNMT 當(dāng)初推出的 wordpieces 來(lái)進(jìn)行分詞,而簡(jiǎn)單來(lái)說(shuō)其產(chǎn)生的子詞(subword)介于這兩者之間:
- 用英文字母分隔的斷詞(character-delimited)
- 用空白分隔的斷詞(word-delimited)
在掃過(guò)所有英文句子以后缸托,subword_encoder_en
建立一個(gè)有 8135 個(gè)子詞的字典左敌。我們可以用該字典來(lái)幫我們將一個(gè)英文句子轉(zhuǎn)成對(duì)應(yīng)的索引序列(index sequence):
sample_string = 'Taiwan is beautiful.'
indices = subword_encoder_en.encode(sample_string)
indices
[3461, 7889, 9, 3502, 4379, 1134, 7903]
這樣的索引序列你應(yīng)該已經(jīng)見(jiàn)怪不怪了。我們?cè)谝郧暗?NLP 入門(mén)文章也使用 tf.keras
里頭的 Tokenizer
做過(guò)類(lèi)似的事情俐镐。
接著讓我們將這些索引還原矫限,看看它們的長(zhǎng)相:
print("{0:10}{1:6}".format("Index", "Subword"))
print("-" * 15)
for idx in indices:
subword = subword_encoder_en.decode([idx])
print('{0:5}{1:6}'.format(idx, ' ' * 5 + subword))
Index Subword
---------------
3461 Taiwan
7889
9 is
3502 bea
4379 uti
1134 ful
7903 .
當(dāng) subword tokenizer 遇到從沒(méi)出現(xiàn)過(guò)在字典里的詞匯,會(huì)將該詞拆成多個(gè)子詞(subwords)佩抹。比方說(shuō)上面句中的beautiful
就被拆成bea
uti
ful
叼风。這也是為何這種分詞方法比較不怕沒(méi)有出現(xiàn)過(guò)在字典里的字(out-of-vocabulary words)。
另外別在意我為了對(duì)齊寫(xiě)的 print
語(yǔ)法棍苹。重點(diǎn)是我們可以用 subword_encoder_en
的 decode
函數(shù)再度將索引數(shù)字轉(zhuǎn)回其對(duì)應(yīng)的子詞无宿。編碼與解碼是 2 個(gè)完全可逆(invertable)的操作:
sample_string = 'Beijing is beautiful.'
indices = subword_encoder_en.encode(sample_string)
decoded_string = subword_encoder_en.decode(indices)
assert decoded_string == sample_string
pprint((sample_string, decoded_string))
('Beijing is beautiful.', 'Beijing is beautiful.')
接著讓我們?nèi)绶ㄅ谥疲瑸橹形囊步⒁粋€(gè)字典:
%%time
try:
subword_encoder_zh = tfds.deprecated.text.SubwordTextEncoder.load_from_file(zh_vocab_file)
print(f"載入已建立的字典: {zh_vocab_file}")
except:
print("沒(méi)有已建立的字典枢里,從頭建立孽鸡。")
subword_encoder_zh = tfds.deprecated.text.SubwordTextEncoder.build_from_corpus(
(zh.numpy() for _, zh in train_examples),
target_vocab_size=2**13, # 有需要可以調(diào)整字典大小
max_subword_length=1) # 每一個(gè)中文字就是字典里的一個(gè)單位
# 將字典檔案存下以方便下次 warmstart
subword_encoder_zh.save_to_file(zh_vocab_file)
print(f"字典大小:{subword_encoder_zh.vocab_size}")
print(f"前 10 個(gè) subwords:{subword_encoder_zh.subwords[:10]}")
print()
沒(méi)有已建立的字典栏豺,從頭建立彬碱。
字典大小:4205
前 10 個(gè) subwords:['的', '奥洼,', '巷疼。', '國(guó)', '在', '是', '一', '和', '不', '這']
CPU times: user 6min 10s, sys: 1.6 s, total: 6min 12s
Wall time: 6min 9s
在使用 build_from_corpus
函數(shù)掃過(guò)整個(gè)中文數(shù)據(jù)集時(shí),我們將 max_subword_length
參數(shù)設(shè)置為 1灵奖。這樣可以讓每個(gè)漢字都會(huì)被視為字典里頭的一個(gè)單位嚼沿。畢竟跟英文的 abc 字母不同估盘,一個(gè)漢字代表的意思可多得多了。而且如果使用 n-gram 的話(huà)可能的詞匯組合太多骡尽,在小數(shù)據(jù)集的情況非常容易遇到不存在字典里頭的字忿檩。
另外所有漢字也就大約 4000 ~ 5000 個(gè)可能,作為一個(gè)分類(lèi)問(wèn)題(classification problem)還是可以接受的爆阶。
讓我們挑個(gè)中文句子來(lái)測(cè)試看看:
sample_string = sample_examples[0][1]
indices = subword_encoder_zh.encode(sample_string)
print(sample_string)
print(indices)
這種恐懼是真實(shí)而內(nèi)在的燥透。 忽視它的政治家們前途堪憂(yōu)。
[10, 151, 574, 1298, 6, 374, 55, 29, 193, 5, 1, 3, 3981, 931, 431, 125, 1, 17, 124, 33, 20, 97, 1089, 1247, 861, 3]
好的辨图,我們把中英文分詞及字典的部分都搞定了“嗵祝現(xiàn)在給定一個(gè)例子(example,在這邊以及后文指的都是一組包含同語(yǔ)義的中英平行句子)故河,我們都能將其轉(zhuǎn)換成對(duì)應(yīng)的索引序列了:
en = "The eurozone’s collapse forces a major realignment of European politics."
zh = "歐元區(qū)的瓦解強(qiáng)迫歐洲政治進(jìn)行一次重大改組吱韭。"
# 將文字轉(zhuǎn)成為 subword indices
en_indices = subword_encoder_en.encode(en)
zh_indices = subword_encoder_zh.encode(zh)
print("[英中原文](轉(zhuǎn)換前)")
print(en)
print(zh)
print()
print('-' * 20)
print()
print("[英中序列](轉(zhuǎn)換后)")
print(en_indices)
print(zh_indices)
[英中原文](轉(zhuǎn)換前)
The eurozone’s collapse forces a major realignment of European politics.
歐元區(qū)的瓦解強(qiáng)迫歐洲政治進(jìn)行一次重大改組。
--------------------
[英中序列](轉(zhuǎn)換后)
[16, 900, 11, 6, 1527, 874, 8, 230, 2259, 2728, 239, 3, 89, 1236, 7903]
[44, 202, 168, 1, 852, 201, 231, 592, 44, 87, 17, 124, 106, 38, 7, 279, 86, 18, 212, 265, 3]
接著讓我們針對(duì)這些索引序列(index sequence)做一些預(yù)處理鱼的。
4.4 數(shù)據(jù)預(yù)處理
在處理序列數(shù)據(jù)時(shí)我們時(shí)常會(huì)在一個(gè)序列的前后各加入一個(gè)特殊的 token理盆,以標(biāo)記該序列的開(kāi)始與完結(jié),而它們常有許多不同的稱(chēng)呼:
- 開(kāi)始 token凑阶、Begin of Sentence猿规、BOS、
<start>
- 結(jié)束 token宙橱、End of Sentence姨俩、EOS、
<end>
這邊我們定義了一個(gè)將被 tf.data.Dataset
使用的 encode
函數(shù)师郑,它的輸入是一筆包含 2 個(gè)string
Tensors 的例子环葵,輸出則是 2 個(gè)包含 BOS / EOS 的索引序列:
def encode(en_t, zh_t):
# 因?yàn)樽值涞乃饕龔?0 開(kāi)始,
# 我們可以使用 subword_encoder_en.vocab_size 這個(gè)值作為 BOS 的索引值
# 用 subword_encoder_en.vocab_size + 1 作為 EOS 的索引值
en_indices = [subword_encoder_en.vocab_size] + subword_encoder_en.encode(
en_t.numpy()) + [subword_encoder_en.vocab_size + 1]
# 作為 EOS 的索引值
zh_indices = [subword_encoder_zh.vocab_size] + subword_encoder_zh.encode(
zh_t.numpy()) + [subword_encoder_zh.vocab_size + 1]
return en_indices, zh_indices
因?yàn)?tf.data.Dataset
里頭都是在操作 Tensors(而非 Python 字串)宝冕,所以這個(gè)encode
函數(shù)預(yù)期的輸入也是 TensorFlow 里的 Eager Tensors张遭。但只要我們使用 numpy()
將 Tensor 里的實(shí)際字串取出以后,做的事情就跟上一節(jié)完全相同地梨。
讓我們從訓(xùn)練集里隨意取一組中英的 Tensors 來(lái)看看這個(gè)函數(shù)的實(shí)際輸出:
en_t, zh_t = next(iter(train_examples))
en_indices, zh_indices = encode(en_t, zh_t)
print('英文 BOS 的 index:', subword_encoder_en.vocab_size)
print('英文 EOS 的 index:', subword_encoder_en.vocab_size + 1)
print('中文 BOS 的 index:', subword_encoder_zh.vocab_size)
print('中文 EOS 的 index:', subword_encoder_zh.vocab_size + 1)
print('\n輸入為 2 個(gè) Tensors:')
pprint((en_t, zh_t))
print('-' * 15)
print('輸出為 2 個(gè)索引序列:')
print((en_indices))
print((zh_indices))
英文 BOS 的 index: 8113
英文 EOS 的 index: 8114
中文 BOS 的 index: 4205
中文 EOS 的 index: 4206
輸入為 2 個(gè) Tensors:
(<tf.Tensor: shape=(), dtype=string, numpy=b'The fear is real and visceral, and politicians ignore it at their peril.'>,
<tf.Tensor: shape=(), dtype=string, numpy=b'\xe8\xbf\x99\xe7\xa7\x8d\xe6\x81\x90\xe6\x83\xa7\xe6\x98\xaf\xe7\x9c\x9f\xe5\xae\x9e\xe8\x80\x8c\xe5\x86\x85\xe5\x9c\xa8\xe7\x9a\x84\xe3\x80\x82 \xe5\xbf\xbd\xe8\xa7\x86\xe5\xae\x83\xe7\x9a\x84\xe6\x94\xbf\xe6\xb2\xbb\xe5\xae\xb6\xe4\xbb\xac\xe5\x89\x8d\xe9\x80\x94\xe5\xa0\xaa\xe5\xbf\xa7\xe3\x80\x82'>)
---------------
輸出為 2 個(gè)索引序列:
[8113, 16, 1284, 9, 243, 5, 1275, 1756, 156, 1, 5, 1016, 5566, 21, 38, 33, 2982, 7965, 7903, 8114]
[4205, 10, 151, 574, 1298, 6, 374, 55, 29, 193, 5, 1, 3, 3981, 931, 431, 125, 1, 17, 124, 33, 20, 97, 1089, 1247, 861, 3, 4206]
你可以看到不管是英文還是中文的索引序列菊卷,前面都加了一個(gè)代表 BOS 的索引(分別為 8113 與 4205),最后一個(gè)索引則代表 EOS(分別為 8114 與 4206)
但如果我們將encode
函數(shù)直接套用到整個(gè)訓(xùn)練資料集時(shí)會(huì)產(chǎn)生以下的錯(cuò)誤信息:
train_dataset = train_examples.map(encode)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-26-81637afc26ae> in <module>()
----> 1 train_dataset = train_examples.map(encode)
......
/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
263 except Exception as e: # pylint:disable=broad-except
264 if hasattr(e, 'ag_error_metadata'):
--> 265 raise e.ag_error_metadata.to_exception(e)
266 else:
267 raise
AttributeError: in user code:
<ipython-input-24-cb8784cd1930>:5 encode *
en_indices = [subword_encoder_en.vocab_size] + subword_encoder_en.encode(
AttributeError: 'Tensor' object has no attribute 'numpy'
這是因?yàn)槟壳?tf.data.Dataset.map
函數(shù)里頭的計(jì)算是在計(jì)算圖模式(Graph mode)下執(zhí)行湿刽,所以里頭的 Tensors 并不會(huì)有 Eager Execution 下才有的 numpy
屬性的烁。
解法是使用 tf.py_function將我們剛剛定義的 encode
函數(shù)包成一個(gè)以 eager 模式執(zhí)行的 TensorFlow Op:
def tf_encode(en_t, zh_t):
# 在 `tf_encode` 函數(shù)里頭的 `en_t` 與 `zh_t` 都不是 Eager Tensors
# 要到 `tf.py_funtion` 里頭才是
# 另外因?yàn)樗饕际钦麛?shù)褐耳,所以使用 `tf.int64`
return tf.py_function(encode, [en_t, zh_t], [tf.int64, tf.int64])
# `tmp_dataset` 為說(shuō)明用資料集诈闺,說(shuō)明完所有重要的 func,
# 我們會(huì)從頭建立一個(gè)正式的 `train_dataset`
tmp_dataset = train_examples.map(tf_encode)
en_indices, zh_indices = next(iter(tmp_dataset))
print(en_indices)
print(zh_indices)
tf.Tensor(
[8113 16 1284 9 243 5 1275 1756 156 1 5 1016 5566 21
38 33 2982 7965 7903 8114], shape=(20,), dtype=int64)
tf.Tensor(
[4205 10 151 574 1298 6 374 55 29 193 5 1 3 3981
931 431 125 1 17 124 33 20 97 1089 1247 861 3 4206], shape=(28,), dtype=int64)
有點(diǎn) tricky 但任務(wù)完成铃芦!注意在套用map
函數(shù)以后雅镊,tmp_dataset
的輸出已經(jīng)是兩個(gè)索引序列襟雷,而非原文字串。
為了讓 Transformer 快點(diǎn)完成訓(xùn)練仁烹,讓我們將長(zhǎng)度超過(guò) 40 個(gè) tokens 的序列都去掉吧耸弄!我們?cè)诘紫露x了一個(gè)布林(boolean)函數(shù),其輸入為一個(gè)包含兩個(gè)英中序列en, zh
的例子卓缰,并在只有這2 個(gè)序列的長(zhǎng)度都小于40 的時(shí)候回傳真值(True) :
MAX_LENGTH = 40
def filter_max_length(en, zh, max_length=MAX_LENGTH):
# en, zh 分別代表英文與中文的索引序列
return tf.logical_and(tf.size(en) <= max_length,
tf.size(zh) <= max_length)
# tf.data.Dataset.filter(func) 只會(huì)回傳 func 為真的例子
tmp_dataset = tmp_dataset.filter(filter_max_length)
簡(jiǎn)單檢查是否有序列超過(guò)我們指定的長(zhǎng)度计呈,順便計(jì)算過(guò)濾掉過(guò)長(zhǎng)序列后剩余的訓(xùn)練集筆數(shù):
# 因?yàn)槲覀償?shù)據(jù)量小可以這樣 count
num_examples = 0
for en_indices, zh_indices in tmp_dataset:
cond1 = len(en_indices) <= MAX_LENGTH
cond2 = len(zh_indices) <= MAX_LENGTH
assert cond1 and cond2
num_examples += 1
print(f"所有英文與中文序列長(zhǎng)度都不超過(guò) {MAX_LENGTH} 個(gè) tokens")
print(f"訓(xùn)練資料集里總共有 {num_examples} 筆數(shù)據(jù)")
所有英文與中文序列長(zhǎng)度都不超過(guò) 40 個(gè) tokens
訓(xùn)練資料集里總共有 29784 筆數(shù)據(jù)
過(guò)濾掉較長(zhǎng)句子后還有接近 3 萬(wàn)筆的訓(xùn)練例子,看來(lái)不用擔(dān)心數(shù)據(jù)太少征唬。
最后值得注意的是每個(gè)例子里的索引序列長(zhǎng)度不一捌显,這在建立 batch 時(shí)可能會(huì)發(fā)生問(wèn)題。不過(guò)別擔(dān)心总寒,輪到padded_batch
函數(shù)出場(chǎng)了:
BATCH_SIZE = 64
# 將 batch 里的所有序列都 pad 到同樣長(zhǎng)度
tmp_dataset = tmp_dataset.padded_batch(BATCH_SIZE, padded_shapes=([-1], [-1]))
en_batch, zh_batch = next(iter(tmp_dataset))
print("英文索引序列的 batch")
print(en_batch)
print('-' * 20)
print("中文索引序列的 batch")
print(zh_batch)
英文索引序列的 batch
tf.Tensor(
[[8113 16 1284 ... 0 0 0]
[8113 1894 1302 ... 0 0 0]
[8113 44 40 ... 0 0 0]
...
[8113 122 506 ... 0 0 0]
[8113 16 215 ... 0 0 0]
[8113 7443 7889 ... 0 0 0]], shape=(64, 39), dtype=int64)
--------------------
中文索引序列的 batch
tf.Tensor(
[[4205 10 151 ... 0 0 0]
[4205 206 275 ... 0 0 0]
[4205 5 10 ... 0 0 0]
...
[4205 34 6 ... 0 0 0]
[4205 317 256 ... 0 0 0]
[4205 167 326 ... 0 0 0]], shape=(64, 40), dtype=int64)
padded_batch
函數(shù)能幫我們將每個(gè) batch 里頭的序列都補(bǔ) 0 到跟當(dāng)下 batch 里頭最長(zhǎng)的序列一樣長(zhǎng)扶歪。
比方說(shuō)英文 batch 里最長(zhǎng)的序列為 34;而中文 batch 里最長(zhǎng)的序列則長(zhǎng)達(dá) 40 個(gè) tokens摄闸,剛好是我們前面設(shè)定過(guò)的序列長(zhǎng)度上限善镰。
好啦,現(xiàn)在讓我們從頭建立訓(xùn)練集與驗(yàn)證集年枕,順便看看這些中英句子是如何被轉(zhuǎn)換成它們的最終形態(tài)的:
MAX_LENGTH = 40
BATCH_SIZE = 128
BUFFER_SIZE = 15000
# 訓(xùn)練集
train_dataset = (train_examples # 輸出:(英文句子, 中文句子)
.map(tf_encode) # 輸出:(英文索引序列, 中文索引序列)
.filter(filter_max_length) # 同上炫欺,且序列長(zhǎng)度都不超過(guò) 40
.cache() # 加快讀取數(shù)據(jù)
.shuffle(BUFFER_SIZE) # 將例子洗牌確保隨機(jī)性
.padded_batch(BATCH_SIZE, # 將 batch 里的序列都 pad 到一樣長(zhǎng)度
padded_shapes=([-1], [-1]))
.prefetch(tf.data.experimental.AUTOTUNE)) # 加速
# 驗(yàn)證集
val_dataset = (val_examples
.map(tf_encode)
.filter(filter_max_length)
.padded_batch(BATCH_SIZE,
padded_shapes=([-1], [-1])))
建構(gòu)訓(xùn)練數(shù)據(jù)集時(shí)我們還添加了些沒(méi)提過(guò)的函數(shù)。它們的用途大都是用來(lái)提高輸入效率熏兄,并不會(huì)影響到輸出格式竣稽。如果你想深入了解這些函數(shù)的運(yùn)作方式,可以參考 tf.data 的官方教學(xué)霍弹。
現(xiàn)在讓我們看看最后建立出來(lái)的資料集長(zhǎng)什么樣子:
en_batch, zh_batch = next(iter(train_dataset))
print("英文索引序列的 batch")
print(en_batch)
print('-' * 20)
print("中文索引序列的 batch")
print(zh_batch)
英文索引序列的 batch
tf.Tensor(
[[8113 571 91 ... 0 0 0]
[8113 246 4266 ... 0 0 0]
[8113 4077 3168 ... 0 0 0]
...
[8113 367 693 ... 0 0 0]
[8113 435 1062 ... 0 0 0]
[8113 122 2 ... 0 0 0]], shape=(128, 37), dtype=int64)
--------------------
中文索引序列的 batch
tf.Tensor(
[[4205 378 100 ... 0 0 0]
[4205 826 97 ... 0 0 0]
[4205 1275 154 ... 0 0 0]
...
[4205 7 28 ... 4206 0 0]
[4205 52 11 ... 0 0 0]
[4205 29 305 ... 0 0 0]], shape=(128, 40), dtype=int64)
我們建立了一個(gè)可供訓(xùn)練的輸入管道(Input pipeline)毫别!
你會(huì)發(fā)現(xiàn)訓(xùn)練集:
- 一次回傳大小為 128 的 2 個(gè) batch,分別包含 128 個(gè)英文典格、中文的索引序列
- 每個(gè)序列開(kāi)頭皆為 BOS岛宦,英文的 BOS 索引是 8113;中文的 BOS 索引則為 4205
- 兩語(yǔ)言 batch 里的序列都被「拉長(zhǎng)」到我們先前定義的最長(zhǎng)序列長(zhǎng)度:40
- 驗(yàn)證集也是相同的輸出形式耍缴。
現(xiàn)在你應(yīng)該可以想像我們?cè)诿總€(gè)訓(xùn)練步驟會(huì)拿出來(lái)的數(shù)據(jù)長(zhǎng)什么樣子了:2 個(gè)shape 為(batch_size, seq_len) 的Tensors砾肺,而里頭的每一個(gè)索引數(shù)字都代表著一個(gè)中/ 英文子詞(或是BOS / EOS)。
在這一節(jié)我們建立了一個(gè)通用數(shù)據(jù)集防嗡。 「通用」代表不限于 Transformer变汪,你也能用一般搭配注意力機(jī)制的 Seq2Seq 模型來(lái)處理這個(gè)數(shù)據(jù)集并做中英翻譯。
但從下節(jié)開(kāi)始讓我們把這個(gè)數(shù)據(jù)集先擺一邊蚁趁,將注意力全部放到 Transformer 身上并逐一實(shí)現(xiàn)其架構(gòu)里頭的各個(gè)元件裙盾。