BERT代碼解讀(1)-輸入

google開源的tensorflow版本的bert 源碼見 https://github.com/google-research/bert端盆。本文主要對該官方代碼的一些關鍵部分進行解讀骑疆。

首先我們來看數(shù)據(jù)預處理部分吱雏,分析原始數(shù)據(jù)集是如何轉(zhuǎn)化成能夠送入bert模型的特征的畦木。

DataProcessor

DataProcessor這個抽象基類定義了get_train_examples, get_dev_examples, get_test_examples, get_labels這四個需要子類實現(xiàn)的方法蛙粘,還定義了一個_read_tsv函數(shù)來讀取原始數(shù)據(jù)集tsv文件。

針對文本二分類任務囱稽,我們可以通過學習繼承DataProcessor類的子類ColaProcessor的具體實現(xiàn)過程來了解數(shù)據(jù)處理的過程派昧。我們可以發(fā)現(xiàn)子類ColaProcessor處理原始數(shù)據(jù)的關鍵函數(shù)如下:

class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
    
    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            # Only the test set has a header
            if set_type == "test" and i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            if set_type == "test":
                text_a = tokenization.convert_to_unicode(line[1])
                label = "0"
            else:
                text_a = tokenization.convert_to_unicode(line[3])
                label = tokenization.convert_to_unicode(line[1])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples

class InputExample(object):
    """A single training/test example for simple sequence classification."""
    def __init__(self, guid, text_a, text_b=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label

該函數(shù)首先通過_read_tsv讀入原始數(shù)據(jù)集文件train.tsv,然后調(diào)用_create_examples將數(shù)據(jù)集中的每一行轉(zhuǎn)換成一個InputExample對象剃浇。

  • 在函數(shù)_create_examples中巾兆,如果是訓練集和驗證集,那么line[1]就是label虎囚,line[3]就是文本內(nèi)容角塑,而對于測試集,line[1]就是文本內(nèi)容淘讥,沒有l(wèi)abel圃伶,因此全部設成0。這個具體可看CoLA數(shù)據(jù)集蒲列。注意這里將所有字符串用tokenization.convert_to_unicode轉(zhuǎn)成unicode字符串窒朋,是為了兼容python2和python3。
  • 對象InputExample有四個屬性蝗岖,guid僅僅是一個唯一的id標識侥猩,text_a表示第一個句子,text_b表示第二個句子(可選抵赢,針對句子對任務)欺劳,label表示標簽(可選唧取,測試集沒有)。

tokenizer

對InputExample里的句子字符串text_atext_b進行分詞操作的主要函數(shù)是FullTokenizer划提。

tokenizer = tokenization.FullTokenizer(
        vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

class FullTokenizer(object):
  """Runs end-to-end tokenziation."""

  def __init__(self, vocab_file, do_lower_case=True):
    self.vocab = load_vocab(vocab_file)
    self.inv_vocab = {v: k for k, v in self.vocab.items()}
    self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
    self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

  def tokenize(self, text):
    split_tokens = []
    for token in self.basic_tokenizer.tokenize(text):
      for sub_token in self.wordpiece_tokenizer.tokenize(token):
        split_tokens.append(sub_token)

    return split_tokens

  def convert_tokens_to_ids(self, tokens):
    return convert_by_vocab(self.vocab, tokens)

  def convert_ids_to_tokens(self, ids):
    return convert_by_vocab(self.inv_vocab, ids)

該函數(shù)通過load_vocab加載詞典枫弟,方便將后續(xù)分詞得到的token映射到對應的id。通過調(diào)用BasicTokenizerWordpieceTokenizer進行分詞鹏往,前者根據(jù)標點符號淡诗、空格等進行普通的分詞,后者則會對前者的結果進行更細粒度的分詞伊履。

  • 注意BasicTokenizer會將中文切分成一個個的漢字韩容,也就是在中文字符(字)前后加上空格,從而后續(xù)分詞將每個中文字符當成一個詞湾碎。
  • WordpieceTokenizer基于傳入的詞典vocab,對單詞進行更細粒度的切分奠货,比如"unaffable"被進一步切分為["un", "##aff", "##able"]介褥。對于中文來說,WordpieceTokenizer什么也不干递惋,因為前一步分詞已經(jīng)是基于字符的了柔滔。注意,對于在詞典vocab中找不到的單詞萍虽,會設置為[UNK]token睛廊。
  • WordPiece是一種解決OOV問題的方法,具體可參考google/sentencepiece項目杉编。

convert_single_example

接下來我們對Processor處理后得到的InputExample進行處理超全,得到能夠送入網(wǎng)絡的特征。

    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(
            train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

## 提取特征的函數(shù)
def file_based_convert_examples_to_features(
        examples, label_list, max_seq_length, tokenizer, output_file):
    """Convert a set of `InputExample`s to a TFRecord file."""

    writer = tf.python_io.TFRecordWriter(output_file)

    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))

        feature = convert_single_example(ex_index, example, label_list,
                                         max_seq_length, tokenizer)

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        features = collections.OrderedDict()
        features["input_ids"] = create_int_feature(feature.input_ids)
        features["input_mask"] = create_int_feature(feature.input_mask)
        features["segment_ids"] = create_int_feature(feature.segment_ids)
        features["label_ids"] = create_int_feature([feature.label_id])
        features["is_real_example"] = create_int_feature(
            [int(feature.is_real_example)])

        tf_example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(tf_example.SerializeToString())
    writer.close()

對于Processor處理后得到每個InputExample對象邓馒,file_based_convert_examples_to_features函數(shù)會把這些對象轉(zhuǎn)化成能夠送入bert網(wǎng)絡的特征嘶朱,并將其保存到一個TFRecord文件中」夂ǎ可以發(fā)現(xiàn)疏遏,該過程提取特征的關鍵函數(shù)是convert_single_example

def convert_single_example(ex_index, example, label_list, max_seq_length,
                           tokenizer):
    """Converts a single `InputExample` into a single `InputFeatures`."""
    ## 將label映射為 id
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    tokens_a = tokenizer.tokenize(example.text_a)
    tokens_b = None
    if example.text_b:
        tokens_b = tokenizer.tokenize(example.text_b)

    if tokens_b:
        # Modifies `tokens_a` and `tokens_b` in place so that the total
        # length is less than the specified length.
        # Account for [CLS], [SEP], [SEP] with "- 3"
        _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
    else:
        # Account for [CLS] and [SEP] with "- 2"
        if len(tokens_a) > max_seq_length - 2:
            tokens_a = tokens_a[0:(max_seq_length - 2)]

    # The convention in BERT is:
    # (a) For sequence pairs:
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
    #  type_ids: 0     0  0    0    0     0       0 0     1  1  1  1   1 1
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0     0   0   0  0     0 0
   
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens = []
    segment_ids = []
    tokens.append("[CLS]")
    segment_ids.append(0)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(0)
    tokens.append("[SEP]")
    segment_ids.append(0)

    if tokens_b:
        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    label_id = label_map[example.label]
    
    feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_id=label_id,
        is_real_example=True)
    return feature

  • 首先調(diào)用tokenizer函數(shù)對text_atext_b進行分詞,將句子轉(zhuǎn)化為tokens救军,若分詞后的句子(一個或兩個)長度過長财异,則需要進行截斷,保證在句子首尾加了[CLS][SEP]之后的總長度在max_seq_length范圍內(nèi)唱遭。
  • segment_ids也就是type_ids用來區(qū)分單詞來自第一條句子還是第二條句子戳寸,type=0type=1對應的embedding會在模型pre-train階段學得。盡管理論上這不是必要的拷泽,因為[SEP]可以區(qū)分句子的邊界庆揩,但是加上type后模型會更容易知道這個詞屬于哪個序列俐东。
  • convert_tokens_to_ids利用詞典vocab,將句子分詞后的token映射為id订晌。
  • 當句子長度小于max_seq_length時虏辫,會進行padding,補充到固定的max_seq_length長度锈拨。input_mask=1表示該token來自于句子砌庄,input_mask=0表示該token是padding的。
  • 最后將提取的input_ids, input_mask, segment_ids封裝到InputFeatures對象中奕枢。至此娄昆,送入網(wǎng)絡前的數(shù)據(jù)處理過程完成。

下一篇會解讀模型的網(wǎng)絡結構缝彬,以及輸入的ids到詞向量的映射過程等萌焰。

參考:BERT代碼閱讀-李理的博客

?著作權歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市谷浅,隨后出現(xiàn)的幾起案子扒俯,更是在濱河造成了極大的恐慌,老刑警劉巖一疯,帶你破解...
    沈念sama閱讀 210,914評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件撼玄,死亡現(xiàn)場離奇詭異,居然都是意外死亡墩邀,警方通過查閱死者的電腦和手機掌猛,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,935評論 2 383
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來眉睹,“玉大人荔茬,你說我怎么就攤上這事≈窈#” “怎么了兔院?”我有些...
    開封第一講書人閱讀 156,531評論 0 345
  • 文/不壞的土叔 我叫張陵,是天一觀的道長站削。 經(jīng)常有香客問我坊萝,道長,這世上最難降的妖魔是什么许起? 我笑而不...
    開封第一講書人閱讀 56,309評論 1 282
  • 正文 為了忘掉前任十偶,我火速辦了婚禮,結果婚禮上园细,老公的妹妹穿的比我還像新娘惦积。我一直安慰自己,他們只是感情好猛频,可當我...
    茶點故事閱讀 65,381評論 5 384
  • 文/花漫 我一把揭開白布狮崩。 她就那樣靜靜地躺著蛛勉,像睡著了一般。 火紅的嫁衣襯著肌膚如雪睦柴。 梳的紋絲不亂的頭發(fā)上诽凌,一...
    開封第一講書人閱讀 49,730評論 1 289
  • 那天,我揣著相機與錄音坦敌,去河邊找鬼侣诵。 笑死,一個胖子當著我的面吹牛狱窘,可吹牛的內(nèi)容都是我干的杜顺。 我是一名探鬼主播,決...
    沈念sama閱讀 38,882評論 3 404
  • 文/蒼蘭香墨 我猛地睜開眼蘸炸,長吁一口氣:“原來是場噩夢啊……” “哼躬络!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起搭儒,我...
    開封第一講書人閱讀 37,643評論 0 266
  • 序言:老撾萬榮一對情侶失蹤穷当,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后仗嗦,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體膘滨,經(jīng)...
    沈念sama閱讀 44,095評論 1 303
  • 正文 獨居荒郊野嶺守林人離奇死亡甘凭,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,448評論 2 325
  • 正文 我和宋清朗相戀三年稀拐,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片丹弱。...
    茶點故事閱讀 38,566評論 1 339
  • 序言:一個原本活蹦亂跳的男人離奇死亡德撬,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出躲胳,到底是詐尸還是另有隱情蜓洪,我是刑警寧澤焦人,帶...
    沈念sama閱讀 34,253評論 4 328
  • 正文 年R本政府宣布硫兰,位于F島的核電站,受9級特大地震影響蒲讯,放射性物質(zhì)發(fā)生泄漏粹湃。R本人自食惡果不足惜恐仑,卻給世界環(huán)境...
    茶點故事閱讀 39,829評論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望为鳄。 院中可真熱鬧裳仆,春花似錦、人聲如沸孤钦。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,715評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至静袖,卻和暖如春觉鼻,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背勾徽。 一陣腳步聲響...
    開封第一講書人閱讀 31,945評論 1 264
  • 我被黑心中介騙來泰國打工滑凉, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人喘帚。 一個月前我還...
    沈念sama閱讀 46,248評論 2 360
  • 正文 我出身青樓畅姊,卻偏偏與公主長得像,于是被迫代替她去往敵國和親吹由。 傳聞我的和親對象是個殘疾皇子若未,可洞房花燭夜當晚...
    茶點故事閱讀 43,440評論 2 348

推薦閱讀更多精彩內(nèi)容