BERT 文本分類 fine-tuning

版權(quán)聲明:本文為博主原創(chuàng)文章,轉(zhuǎn)載請(qǐng)注明出處.

上篇文章介紹了如何安裝和使用BERT進(jìn)行文本相似度任務(wù),包括如何修改代碼進(jìn)行訓(xùn)練和測(cè)試。本文在此基礎(chǔ)上介紹如何進(jìn)行文本分類任務(wù)剑肯。

文本相似度任務(wù)具體見(jiàn): BERT介紹及中文文本相似度任務(wù)實(shí)踐

文本相似度任務(wù)和文本分類任務(wù)的區(qū)別在于數(shù)據(jù)集的準(zhǔn)備以及run_classifier.py中數(shù)據(jù)類的構(gòu)造部分。

0. 準(zhǔn)備工作

如果想要根據(jù)我們準(zhǔn)備的數(shù)據(jù)集進(jìn)行fine-tuning观堂,則需要先下載預(yù)訓(xùn)練模型让网。由于是處理中文文本,因此下載對(duì)應(yīng)的中文預(yù)訓(xùn)練模型师痕。

BERTgit地址: google-research/bert

  • BERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters

文件名為 chinese_L-12_H-768_A-12.zip溃睹。將其解壓至bert文件夾,包含以下三種文件:

  • 配置文件(bert_config.json):用于指定模型的超參數(shù)
  • 詞典文件(vocab.txt):用于WordPiece 到 Word id的映射
  • Tensorflow checkpoint(bert_model.ckpt):包含了預(yù)訓(xùn)練模型的權(quán)重(實(shí)際包含三個(gè)文件)

1. 數(shù)據(jù)集的準(zhǔn)備

對(duì)于文本分類任務(wù)胰坟,需要準(zhǔn)備的數(shù)據(jù)集的格式如下:
label, 文本 因篇,其中標(biāo)簽可以是中文字符串,也可以是數(shù)字腕铸。
如: 天氣, 一會(huì)好像要下雨了 或者0, 一會(huì)好像要下雨了

將準(zhǔn)備好的數(shù)據(jù)存放于文本文件中惜犀,如.txt.csv等狠裹。至于用什么名字和后綴虽界,只要與數(shù)據(jù)類中的名稱一致即可。
如涛菠,在run_classifier.py中的數(shù)據(jù)類get_train_examples方法中莉御,默認(rèn)訓(xùn)練集文件是train.csv,可以修改為自己命名的文件名即可俗冻。

    def get_train_examples(self, data_dir):
        """See base class."""
        file_path = os.path.join(data_dir, 'train.csv')

2. 增加自定義數(shù)據(jù)類

將新增的用于文本分類的數(shù)據(jù)類命名為 TextClassifierProcessor礁叔,如下

class TextClassifierProcessor(DataProcessor):

重寫其父類的四個(gè)方法,從而實(shí)現(xiàn)數(shù)據(jù)的獲取過(guò)程迄薄。

  • get_train_examples:對(duì)訓(xùn)練集獲取InputExample的集合
  • get_dev_examples:對(duì)驗(yàn)證集...
  • get_test_examples:對(duì)測(cè)試集...
  • get_labels:獲取數(shù)據(jù)集分類標(biāo)簽列表

InputExample類的作用是對(duì)于單個(gè)分類序列的訓(xùn)練/測(cè)試樣例琅关。構(gòu)建了一個(gè)InputExample,包含id, text_a, text_b, label讥蔽。
其定義如下:

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.

        Args:
          guid: Unique id for the example.
          text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
          text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
          label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label

重寫get_train_examples方法涣易, 對(duì)于文本分類任務(wù)画机,只需要label和一個(gè)文本即可,因此新症,只需要賦值給text_a步氏。

因?yàn)闇?zhǔn)備的數(shù)據(jù)集 標(biāo)簽文本以逗號(hào)隔開(kāi)的,因此先將每行數(shù)據(jù)以逗號(hào)隔開(kāi)徒爹,則split_line[0]為標(biāo)簽賦值給label荚醒,split_line[1]為文本賦值給text_a

此處隆嗅,準(zhǔn)備的數(shù)據(jù)集標(biāo)簽和文本是以逗號(hào)隔開(kāi)的界阁,難免文本中沒(méi)有同樣的英文逗號(hào),為了避免獲取到不完整的文本數(shù)據(jù)榛瓮,建議使用 str.find(',')找到第一個(gè)逗號(hào)出現(xiàn)的位置铺董,則 label = line[:line.find(',')].strip()

對(duì)于測(cè)試集和驗(yàn)證集的處理相同。

    def get_train_examples(self, data_dir):
        """See base class."""
        file_path = os.path.join(data_dir, 'train.csv')
        examples = []
        with open(file_path, encoding='utf-8') as f:
            reader = f.readlines()
        for (i, line) in enumerate(reader):
            guid = "train-%d" % (i)
            split_line = line.strip().split(",")
            text_a = tokenization.convert_to_unicode(split_line[1])
            text_b = None
            label = str(split_line[0])
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
        return examples

get_labels方法用于獲取數(shù)據(jù)集所有的類別標(biāo)簽禀晓,此處使用數(shù)字1,2,3.... 來(lái)表示精续,如有66個(gè)類別(1—66),則實(shí)現(xiàn)方法如下:

   def get_labels(self):
        """See base class."""
        labels = [str(i) for i in range(1,67)]
        return labels

<注意>

為了方便粹懒,可以構(gòu)建一個(gè)字典類型的變量重付,存放數(shù)字類別和文本標(biāo)簽中間的對(duì)應(yīng)關(guān)系。當(dāng)然也可以直接使用文本標(biāo)簽凫乖,想用哪種用哪種确垫。

定義完TextClassifierProcessor類之后,還需要將其加入到main函數(shù)中的processors變量中去帽芽。

找到main()函數(shù)删掀,增加新定義數(shù)據(jù)類,如下所示:

def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
        "sim": SimProcessor,
        "classifier":TextClassifierProcessor,  # 增加此行
    }

3. 修改predict輸出

run_classifier.py文件中导街,預(yù)測(cè)部分的會(huì)輸出兩個(gè)文件披泪,分別是 predict.tf_recordtest_results.tsv。其中test_results.tsv中存放的是每個(gè)測(cè)試數(shù)據(jù)得到的屬于所有類別的概率值搬瑰,維度為[n*num_labels]款票。

但這個(gè)結(jié)果并不能直接反應(yīng)得到的預(yù)測(cè)結(jié)果,因此增加處理代碼泽论,直接獲取得到的預(yù)測(cè)類別艾少。

原始代碼如下:

    if FLAGS.do_predict:
        print('*'*30,'do_predict', '*'*30)
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length, tokenizer,
                                                predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(
            FLAGS.output_dir, "test_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                writer.write(output_line)
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples

修改后的代碼如下:

        result_predict_file = os.path.join(
            FLAGS.output_dir, "test_labels_out.txt")

        right = 0 # 預(yù)測(cè)正確的個(gè)數(shù)
        f_res = open(result_predict_file, 'w') #將結(jié)果保存到此文件中
        with tf.gfile.GFile(output_predict_file, "w") as writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, prediction) in enumerate(result):
                probabilities = prediction["probabilities"] #預(yù)測(cè)結(jié)果
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                # 獲取概率值最大的類別的下標(biāo)Index
                index = np.argmax(probabilities, axis = 0)
                # 將真實(shí)標(biāo)簽和預(yù)測(cè)標(biāo)簽及對(duì)應(yīng)的概率值寫入到結(jié)果文件中
                res_line = 'real: %s, \tpred:%s, \tscore = %.2f\n' \
                        %(lable_to_cate[real_label[i]], lable_to_cate[index+1], probabilities[index])
                f_res.write(res_line)
                writer.write(output_line)
                num_written_lines += 1

                if real_label[i] == (index+1):
                    right += 1

            print('precision = %.2f' %(right / len(real_label)))

4.fine-tuning模型

準(zhǔn)備好數(shù)據(jù)集,修改完數(shù)據(jù)類后翼悴,接下來(lái)就是如何fine-tuning模型缚够。
查看 run_classifier.py文件的入口部分,包含了fine-tuning模型所需的必要參數(shù),如下:

if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("task_name")
    flags.mark_flag_as_required("vocab_file")
    flags.mark_flag_as_required("bert_config_file")
    flags.mark_flag_as_required("output_dir")
    tf.app.run()

部分參數(shù)說(shuō)明
data_dir :數(shù)據(jù)存放路徑
task_mask :processor的名字潮瓶,對(duì)于文本分類任務(wù)陶冷,則為classifier
vocab_file :字典文件的地址
bert_config_file :配置文件
output_dir :模型輸出地址

由于需要設(shè)置的參數(shù)較多钙姊,因此將其統(tǒng)一放置到sh腳本中毯辅,名稱fine-tuning_classifier.sh,如下所示:

#!/usr/bin/env bash
export BERT_BASE_DIR=/**/NLP/bert/chinese_L-12_H-768_A-12 #全局變量 下載的預(yù)訓(xùn)練bert地址
export MY_DATASET=/**/NLP/bert/data/text_classifition #全局變量 數(shù)據(jù)集所在地址

python run_classifier.py \
  --task_name=classifier  \
  --do_train=true \
  --do_eval=true \
  --do_predict=true \
  --data_dir=$MY_DATASET \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=32  \
  --train_batch_size=64 \
  --learning_rate=5e-5 \
  --num_train_epochs=10.0 \
  --output_dir=./fine_tuning_out/text_classifier_64_epoch10_5e5

執(zhí)行命令

sh ./fine-tuning_classifier.sh

生成的模型文件煞额,在output_dir目錄中思恐,如下:

在這里插入圖片描述

得到的測(cè)試結(jié)果文件test_labels_out.txt內(nèi)容如下:

real: 天氣, pred:天氣, score = 1.00

使用tensorboard查看loss走勢(shì),如下所示:

在這里插入圖片描述

文本相似度任務(wù)具體見(jiàn): BERT介紹及中文文本相似度任務(wù)實(shí)踐

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末膊毁,一起剝皮案震驚了整個(gè)濱河市胀莹,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌婚温,老刑警劉巖描焰,帶你破解...
    沈念sama閱讀 216,591評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異栅螟,居然都是意外死亡荆秦,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,448評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門力图,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)步绸,“玉大人,你說(shuō)我怎么就攤上這事吃媒∪拷椋” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 162,823評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵赘那,是天一觀的道長(zhǎng)刑桑。 經(jīng)常有香客問(wèn)我,道長(zhǎng)募舟,這世上最難降的妖魔是什么祠斧? 我笑而不...
    開(kāi)封第一講書人閱讀 58,204評(píng)論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮胃珍,結(jié)果婚禮上梁肿,老公的妹妹穿的比我還像新娘。我一直安慰自己觅彰,他們只是感情好吩蔑,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,228評(píng)論 6 388
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著填抬,像睡著了一般烛芬。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書人閱讀 51,190評(píng)論 1 299
  • 那天赘娄,我揣著相機(jī)與錄音仆潮,去河邊找鬼。 笑死遣臼,一個(gè)胖子當(dāng)著我的面吹牛性置,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播揍堰,決...
    沈念sama閱讀 40,078評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼鹏浅,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了屏歹?” 一聲冷哼從身側(cè)響起隐砸,我...
    開(kāi)封第一講書人閱讀 38,923評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎蝙眶,沒(méi)想到半個(gè)月后季希,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,334評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡幽纷,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,550評(píng)論 2 333
  • 正文 我和宋清朗相戀三年式塌,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片霹崎。...
    茶點(diǎn)故事閱讀 39,727評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡珊搀,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出尾菇,到底是詐尸還是另有隱情境析,我是刑警寧澤,帶...
    沈念sama閱讀 35,428評(píng)論 5 343
  • 正文 年R本政府宣布派诬,位于F島的核電站劳淆,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏默赂。R本人自食惡果不足惜沛鸵,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,022評(píng)論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望缆八。 院中可真熱鬧曲掰,春花似錦、人聲如沸奈辰。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 31,672評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)奖恰。三九已至吊趾,卻和暖如春宛裕,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背论泛。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 32,826評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工揩尸, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人屁奏。 一個(gè)月前我還...
    沈念sama閱讀 47,734評(píng)論 2 368
  • 正文 我出身青樓岩榆,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親了袁。 傳聞我的和親對(duì)象是個(gè)殘疾皇子朗恳,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,619評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容