1. 簡介
Tensor2Tensor
是google出品的一個神仙級工具包辣卒,能大大簡化類似模型的開發(fā)調試時間掷贾。在眾多的深度學習工具中,個人認為這貨屬于那種門檻還比較高的工具荣茫。并且google家的文檔一向做得很辣雞想帅,都是直接看源碼注釋摸索怎么使用。
Tensor2Tensor的版本和Tensorflow版本是對應的啡莉,我電腦上是tensorflow 1.14.0博脑,就這樣安裝了pip install tensor2tensor==1.14.1
。
2. 基礎模塊
import os
import tensorflow as tf
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators.text_problems import Text2TextProblem
from tensor2tensor.data_generators.text_problems import VocabType
from tensor2tensor.models import transformer
text_encoder
中預定義了一些把string轉化為ids的類型票罐。
problem
不知道怎么說叉趣,看官方解釋就行了,反正新增加任務都需要自己寫個Problem類型该押。
Problems consist of features such as inputs and targets, and metadata such
as each feature's modality (e.g. symbol, image, audio) and vocabularies. Problem
features are given by a dataset, which is stored as a TFRecord
file with tensorflow.Example
protocol buffers. All
problems are imported in all_problems.py
or are registered with @registry.register_problem
.
這里是直接調用了models里面的transformers
疗杉,如果自己該模型,還需要使用@registry.register_model
注冊模型。
3. Problems寫法
最好看下Text2TextProblem
模塊的源碼看下google的結構思路烟具,以文本生成任務為例梢什。
# 數(shù)據(jù)格式
想看你的美照<TAB>親我一口就給你看
我親兩口<TAB>討厭人家拿小拳拳捶你胸口
......
新建文件my_task.py
@registry.register_problem
class Seq2SeqDemo(Text2TextProblem):
TRAIN_FILES = "train.txt"
EVAL_FILES = "dev.txt"
@property
def vocab_type(self):
# 見父類的說明
return VocabType.TOKEN
@property
def oov_token(self):
return "<UNK>"
def _generate_vocab(self, tmp_dir):
vocab_list = [self.oov_token]
user_vocab_file = os.path.join(tmp_dir, "vocab.txt")
with tf.gfile.GFile(user_vocab_file, "r") as vocab_file:
for line in vocab_file:
token = line.strip().split("\t")[0]
vocab_list.append(token)
token_encoder = text_encoder.TokenTextEncoder(None, vocab_list=vocab_list)
return token_encoder
def _generate_samples(self, data_dir, tmp_dir, dataset_split):
del data_dir
is_training = dataset_split == problem.DatasetSplit.TRAIN
files = self.TRAIN_FILES if is_training else self.EVAL_FILES
files = os.path.join(tmp_dir, files)
with tf.gfile.GFile(files, "r") as fin:
for line in fin:
inputs, targets = line.strip().split("\t")
yield {"inputs": inputs, "targets": targets}
def generator_samples(self, data_dir, tmp_dir, dataset_split):
vocab_filepath = os.path.join(data_dir, self.vocab_filename)
if not tf.gfile.Exists(vocab_filepath):
token_encoder = self._generate_vocab(tmp_dir)
token_encoder.store_to_file(vocab_filepath)
return self._generate_samples(data_dir, tmp_dir, dataset_split)
tmp_dir
是真實的訓練文本和字典存放的地方,data_dir
是處理后的字典和TFRcord存在的地方朝聋。
關鍵就一個方法generator_samples
嗡午,它有兩個作用,讀入字典和轉換數(shù)據(jù)文件方便后面轉化為TFRecord
的形式冀痕。
其中有個天坑荔睹,generator_samples
和_generator_samples
我是故意拆開寫的。如果合并了言蛇,因為生成器的性質僻他,在沒有遍歷之前generator_samples
return之前的代碼都不會執(zhí)行。但是注意到父類中有個有個方法generate_encoded_samples
其中有兩行:
generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
encoder = self.get_or_create_vocab(data_dir, tmp_dir)
generator_samples
如果沒有執(zhí)行到token_encoder = self._generate_vocab(tmp_dir)
腊尚,self.get_or_create_vocab
這邊就直接炸了吨拗,找不到字典。因此婿斥,這兩個函數(shù)不能拆開劝篷,problem調用結構就這樣,暫時只知道這么改民宿。
最后娇妓,還要添加transformer
的模型參數(shù),想了解參數(shù)請看transformer.transformer_base
源碼(°ー°〃)勘高。
@registry.register_hparams
def my_param():
hparams = transformer.transformer_base()
hparams.summarize_vars = True
hparams.num_hidden_layers = 4
hparams.batch_size = 64
hparams.max_length = 40
hparams.hidden_size = 512
hparams.num_heads = 8
hparams.dropout = 0.1
hparams.attention_dropout = 0.1
hparams.filter_size = 1024
hparams.layer_prepostprocess_dropout = 0.1
hparams.learning_rate_warmup_steps = 1000
hparams.learning_rate_decay_steps = 800
hparams.learning_rate = 3e-5
return hparams
4. 運行
首先生成TFRecod文件峡蟋,執(zhí)行命令
t2t_datagen \
--t2t_usr_dir=/code_path (to my_task.py)
--data_dir=/record_data_path
--tmp_dir=/data_path
--problem=Seq2SeqDemo
然后訓練
t2t_trainer \
--data_dir=/same_as_above
--problem=Seq2SeqDemo
--model=transformer
--hparams_set=my_param
--output_dir=~/output_dir
--job-dir=~/output_dir
--train_steps=8000
--eval_steps=2000
訓練好的模型進行預測(decode)
t2t_decoder \
--data_dir=/same_as_above
--problem=Seq2SeqDemo
--model=transformer
--hparams_set=my_param
--output_dir=~/output_dir
--decode_from_file=/dev_file_path
--decode_to_file=/file_save_path