版權(quán)聲明:本文為博主原創(chuàng)文章,轉(zhuǎn)載請(qǐng)注明出處.
上篇文章介紹了如何安裝和使用BERT進(jìn)行文本相似度任務(wù),包括如何修改代碼進(jìn)行訓(xùn)練和測(cè)試。本文在此基礎(chǔ)上介紹如何進(jìn)行文本分類任務(wù)剑肯。
文本相似度任務(wù)具體見(jiàn): BERT介紹及中文文本相似度任務(wù)實(shí)踐
文本相似度任務(wù)和文本分類任務(wù)的區(qū)別在于數(shù)據(jù)集的準(zhǔn)備以及run_classifier.py
中數(shù)據(jù)類的構(gòu)造部分。
0. 準(zhǔn)備工作
如果想要根據(jù)我們準(zhǔn)備的數(shù)據(jù)集進(jìn)行fine-tuning
观堂,則需要先下載預(yù)訓(xùn)練模型让网。由于是處理中文文本,因此下載對(duì)應(yīng)的中文預(yù)訓(xùn)練模型师痕。
BERTgit
地址: google-research/bert
- BERT-Base, Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
文件名為 chinese_L-12_H-768_A-12.zip
溃睹。將其解壓至bert
文件夾,包含以下三種文件:
- 配置文件(bert_config.json):用于指定模型的超參數(shù)
- 詞典文件(vocab.txt):用于WordPiece 到 Word id的映射
- Tensorflow checkpoint(bert_model.ckpt):包含了預(yù)訓(xùn)練模型的權(quán)重(實(shí)際包含三個(gè)文件)
1. 數(shù)據(jù)集的準(zhǔn)備
對(duì)于文本分類任務(wù)胰坟,需要準(zhǔn)備的數(shù)據(jù)集的格式如下:
label, 文本
因篇,其中標(biāo)簽可以是中文字符串,也可以是數(shù)字腕铸。
如: 天氣, 一會(huì)好像要下雨了
或者0, 一會(huì)好像要下雨了
將準(zhǔn)備好的數(shù)據(jù)存放于文本文件中惜犀,如.txt
, .csv
等狠裹。至于用什么名字和后綴虽界,只要與數(shù)據(jù)類中的名稱一致即可。
如涛菠,在run_classifier.py
中的數(shù)據(jù)類get_train_examples
方法中莉御,默認(rèn)訓(xùn)練集文件是train.csv
,可以修改為自己命名的文件名即可俗冻。
def get_train_examples(self, data_dir):
"""See base class."""
file_path = os.path.join(data_dir, 'train.csv')
2. 增加自定義數(shù)據(jù)類
將新增的用于文本分類的數(shù)據(jù)類命名為 TextClassifierProcessor
礁叔,如下
class TextClassifierProcessor(DataProcessor):
重寫其父類的四個(gè)方法,從而實(shí)現(xiàn)數(shù)據(jù)的獲取過(guò)程迄薄。
-
get_train_examples
:對(duì)訓(xùn)練集獲取InputExample
的集合 -
get_dev_examples
:對(duì)驗(yàn)證集... -
get_test_examples
:對(duì)測(cè)試集... -
get_labels
:獲取數(shù)據(jù)集分類標(biāo)簽列表
InputExample
類的作用是對(duì)于單個(gè)分類序列的訓(xùn)練/測(cè)試樣例琅关。構(gòu)建了一個(gè)InputExample
,包含id, text_a, text_b, label
讥蔽。
其定義如下:
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
重寫get_train_examples
方法涣易, 對(duì)于文本分類任務(wù)画机,只需要label
和一個(gè)文本即可,因此新症,只需要賦值給text_a
步氏。
因?yàn)闇?zhǔn)備的數(shù)據(jù)集 標(biāo)簽 和 文本 是以逗號(hào)隔開(kāi)的,因此先將每行數(shù)據(jù)以逗號(hào)隔開(kāi)徒爹,則split_line[0]
為標(biāo)簽賦值給label
荚醒,split_line[1]
為文本賦值給text_a
。
此處隆嗅,準(zhǔn)備的數(shù)據(jù)集標(biāo)簽和文本是以逗號(hào)隔開(kāi)的界阁,難免文本中沒(méi)有同樣的英文逗號(hào),為了避免獲取到不完整的文本數(shù)據(jù)榛瓮,建議使用
str.find(',')
找到第一個(gè)逗號(hào)出現(xiàn)的位置铺董,則label = line[:line.find(',')].strip()
對(duì)于測(cè)試集和驗(yàn)證集的處理相同。
def get_train_examples(self, data_dir):
"""See base class."""
file_path = os.path.join(data_dir, 'train.csv')
examples = []
with open(file_path, encoding='utf-8') as f:
reader = f.readlines()
for (i, line) in enumerate(reader):
guid = "train-%d" % (i)
split_line = line.strip().split(",")
text_a = tokenization.convert_to_unicode(split_line[1])
text_b = None
label = str(split_line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
get_labels
方法用于獲取數(shù)據(jù)集所有的類別標(biāo)簽禀晓,此處使用數(shù)字1,2,3.... 來(lái)表示精续,如有66個(gè)類別(1—66),則實(shí)現(xiàn)方法如下:
def get_labels(self):
"""See base class."""
labels = [str(i) for i in range(1,67)]
return labels
<注意>
為了方便粹懒,可以構(gòu)建一個(gè)字典類型的變量重付,存放數(shù)字類別和文本標(biāo)簽中間的對(duì)應(yīng)關(guān)系。當(dāng)然也可以直接使用文本標(biāo)簽凫乖,想用哪種用哪種确垫。
定義完TextClassifierProcessor
類之后,還需要將其加入到main
函數(shù)中的processors
變量中去帽芽。
找到main()函數(shù)删掀,增加新定義數(shù)據(jù)類,如下所示:
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"sim": SimProcessor,
"classifier":TextClassifierProcessor, # 增加此行
}
3. 修改predict輸出
在run_classifier.py
文件中导街,預(yù)測(cè)部分的會(huì)輸出兩個(gè)文件披泪,分別是 predict.tf_record
和test_results.tsv
。其中test_results.tsv
中存放的是每個(gè)測(cè)試數(shù)據(jù)得到的屬于所有類別的概率值搬瑰,維度為[n*num_labels]款票。
但這個(gè)結(jié)果并不能直接反應(yīng)得到的預(yù)測(cè)結(jié)果,因此增加處理代碼泽论,直接獲取得到的預(yù)測(cè)類別艾少。
原始代碼如下:
if FLAGS.do_predict:
print('*'*30,'do_predict', '*'*30)
predict_examples = processor.get_test_examples(FLAGS.data_dir)
num_actual_predict_examples = len(predict_examples)
if FLAGS.use_tpu:
# TPU requires a fixed batch size for all batches, therefore the number
# of examples must be a multiple of the batch size, or else examples
# will get dropped. So we pad with fake examples which are ignored
# later on.
while len(predict_examples) % FLAGS.predict_batch_size != 0:
predict_examples.append(PaddingInputExample())
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer,
predict_file)
tf.logging.info("***** Running prediction*****")
tf.logging.info(" Num examples = %d (%d actual, %d padding)",
len(predict_examples), num_actual_predict_examples,
len(predict_examples) - num_actual_predict_examples)
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
predict_drop_remainder = True if FLAGS.use_tpu else False
predict_input_fn = file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
output_predict_file = os.path.join(
FLAGS.output_dir, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
num_written_lines = 0
tf.logging.info("***** Predict results *****")
for (i, prediction) in enumerate(result):
probabilities = prediction["probabilities"]
if i >= num_actual_predict_examples:
break
output_line = "\t".join(
str(class_probability)
for class_probability in probabilities) + "\n"
writer.write(output_line)
num_written_lines += 1
assert num_written_lines == num_actual_predict_examples
修改后的代碼如下:
result_predict_file = os.path.join(
FLAGS.output_dir, "test_labels_out.txt")
right = 0 # 預(yù)測(cè)正確的個(gè)數(shù)
f_res = open(result_predict_file, 'w') #將結(jié)果保存到此文件中
with tf.gfile.GFile(output_predict_file, "w") as writer:
num_written_lines = 0
tf.logging.info("***** Predict results *****")
for (i, prediction) in enumerate(result):
probabilities = prediction["probabilities"] #預(yù)測(cè)結(jié)果
if i >= num_actual_predict_examples:
break
output_line = "\t".join(
str(class_probability)
for class_probability in probabilities) + "\n"
# 獲取概率值最大的類別的下標(biāo)Index
index = np.argmax(probabilities, axis = 0)
# 將真實(shí)標(biāo)簽和預(yù)測(cè)標(biāo)簽及對(duì)應(yīng)的概率值寫入到結(jié)果文件中
res_line = 'real: %s, \tpred:%s, \tscore = %.2f\n' \
%(lable_to_cate[real_label[i]], lable_to_cate[index+1], probabilities[index])
f_res.write(res_line)
writer.write(output_line)
num_written_lines += 1
if real_label[i] == (index+1):
right += 1
print('precision = %.2f' %(right / len(real_label)))
4.fine-tuning模型
準(zhǔn)備好數(shù)據(jù)集,修改完數(shù)據(jù)類后翼悴,接下來(lái)就是如何fine-tuning
模型缚够。
查看 run_classifier.py
文件的入口部分,包含了fine-tuning模型所需的必要參數(shù),如下:
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("bert_config_file")
flags.mark_flag_as_required("output_dir")
tf.app.run()
部分參數(shù)說(shuō)明
data_dir
:數(shù)據(jù)存放路徑
task_mask
:processor的名字潮瓶,對(duì)于文本分類任務(wù)陶冷,則為classifier
vocab_file
:字典文件的地址
bert_config_file
:配置文件
output_dir
:模型輸出地址
由于需要設(shè)置的參數(shù)較多钙姊,因此將其統(tǒng)一放置到sh腳本中毯辅,名稱fine-tuning_classifier.sh
,如下所示:
#!/usr/bin/env bash
export BERT_BASE_DIR=/**/NLP/bert/chinese_L-12_H-768_A-12 #全局變量 下載的預(yù)訓(xùn)練bert地址
export MY_DATASET=/**/NLP/bert/data/text_classifition #全局變量 數(shù)據(jù)集所在地址
python run_classifier.py \
--task_name=classifier \
--do_train=true \
--do_eval=true \
--do_predict=true \
--data_dir=$MY_DATASET \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=32 \
--train_batch_size=64 \
--learning_rate=5e-5 \
--num_train_epochs=10.0 \
--output_dir=./fine_tuning_out/text_classifier_64_epoch10_5e5
執(zhí)行命令
sh ./fine-tuning_classifier.sh
生成的模型文件煞额,在output_dir
目錄中思恐,如下:
得到的測(cè)試結(jié)果文件
test_labels_out.txt
內(nèi)容如下:
real: 天氣, pred:天氣, score = 1.00
使用tensorboard
查看loss
走勢(shì),如下所示:
文本相似度任務(wù)具體見(jiàn): BERT介紹及中文文本相似度任務(wù)實(shí)踐