Github上BERT的README里面已經(jīng)給出了相當(dāng)詳細(xì)的使用說明,GOOGLE BERT地址。
Fine-tuning就是載入預(yù)訓(xùn)練好的Bert模型,在自己的語料上再訓(xùn)練一段時間。載入模型和使用模型繼續(xù)訓(xùn)練這部分github上代碼已經(jīng)幫忙做好了,我們fine-tuning需要做的工作就是在官方代碼的run_classifier.py這個文件里面添加本地任務(wù)的Processor
贪壳。
仿照官方代碼中的XnliProcessor
,添加一個文本分類的fine-tuning處理類TextProcessor
蚜退,它需要繼承DataProcessor
闰靴,實現(xiàn)get_train_examples
,get_labels
方法钻注。
使用的文本數(shù)據(jù)不需要分詞(bert代碼會幫你做掉)蚂且,以'\t'
為分隔分為兩列。第一列是文本的標(biāo)簽幅恋,第二列是輸入文本杏死。
train.tsv:
2 比奔馳寶馬比不了比相應(yīng)的的車輛小五你是第一
0 美國車就是費(fèi)油既然你買的起就能燒得起油
class TextProcessor(DataProcessor):
"""Processor for the Text classification"""
# 讀入訓(xùn)練集語料 (.tsv即使用'\t'分隔的csv文件)
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
# 讀入evalation使用的語料, 如果--do_eval=False, 可以不實現(xiàn)
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
"dev_matched")
# 讀入測試集數(shù)據(jù), --do_predict的時候使用
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
# 文本分類的標(biāo)簽
def get_labels(self):
return ["0", "1", "2"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
# lines是_read_tsv()方法讀入的數(shù)據(jù), 每一行樣本按照"\t"切成一個list
for (i, line) in enumerate(lines):
# guid是樣本的id, 保證唯一性就可以了
guid = "%s-%s" % (set_type, i)
# 輸入的文本數(shù)據(jù) (不需要分詞)
text_a = tokenization.convert_to_unicode(line[1])
# 只有一個輸出文本, text_b置為None
text_b = None
if set_type == "test":
# 這邊假設(shè)test數(shù)據(jù)沒有標(biāo)簽, 隨便給個標(biāo)簽就行了
label = "0"
else:
# 訓(xùn)練和驗證使用的數(shù)據(jù)標(biāo)簽
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
還需要在def main()
函數(shù)中添加任務(wù)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"text": TextProcessor
}
最后在shell腳本中添加以下內(nèi)容,運(yùn)行:
# 下載的預(yù)訓(xùn)練模型文件的目錄
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# fine-tuning使用的文件目錄 (目錄包含train.tsv, dev.tsv, test.tsv文件)
export TEXT_DIR=/path/to/text
CUDA_VISIBLE_DEVICES=0 python run_classifier.py \
--task_name=text \
--do_train=true \
--do_eval=true \
--data_dir=$TEXT_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=/tmp/text_output/
run_classifier.py
使用的參數(shù)含義在文件開頭都有解釋捆交,這里就不再贅述了淑翼。
如果想利用fine-tuning好的模型來對test數(shù)據(jù)進(jìn)行預(yù)測,可以參考以下shell腳本
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export DATA_DIR=/path/to/text
# 前面fine-tuning模型的輸出目錄
export TRAINED_CLASSIFIER=/tmp/text_output/
python run_classifier.py \
--task_name=text \
--do_predict=true \
--data_dir=$DATA_DIR \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=128 \
--output_dir=/tmp/text_output/pred/
預(yù)測完成后會在/tmp/text_output/pred/目錄下生成一個test_results.tsv文件品追。文件每行代表模型對每個類別預(yù)測的分?jǐn)?shù)玄括,對應(yīng)順序為TextProcessor
中get_labels
返回的標(biāo)簽順序。