谷歌推出的Bert署咽,最近有多火,估計(jì)做自然語(yǔ)言處理的都知道生音。據(jù)稱在SQuAD等11項(xiàng)任務(wù)當(dāng)中達(dá)到了state of the art宁否。bert的原理可參考論文,或者網(wǎng)上其他人翻譯的資料缀遍。谷歌已經(jīng)在github上開(kāi)源了代碼慕匠,相信每一個(gè)從事NLP的都應(yīng)該和我一樣摩拳擦掌,迫不及待地想要學(xué)習(xí)它了吧域醇。
就我個(gè)人而言學(xué)習(xí)一個(gè)開(kāi)源項(xiàng)目台谊,第一步是安裝蓉媳,第二步是跑下demo,第三步才是閱讀源碼锅铅。安裝bert簡(jiǎn)單酪呻,直接github上拉下來(lái)就可以了,跑demo其實(shí)也不難盐须,參照README.md一步步操作就行了玩荠,但是經(jīng)我實(shí)操過(guò)后,發(fā)現(xiàn)里面有個(gè)小坑贼邓,所以用這篇文章記錄下來(lái)阶冈,供讀者參考。
閑言少敘塑径,書(shū)歸正傳女坑。本次介紹的demo只有兩個(gè),一個(gè)是基于MRPC(Microsoft Research Paraphrase Corpus )的句子對(duì)分類任務(wù)统舀,一個(gè)是基于SQuAD語(yǔ)料的閱讀理解任務(wù)堂飞。run demo分為以下幾步:
1、下載bert源碼
這沒(méi)什么好說(shuō)的绑咱,直接clone
git clone https://github.com/google-research/bert.git
2、下載預(yù)訓(xùn)練模型
為什么選擇BERT-Base, Uncased
這個(gè)模型呢枢泰?原因有三:1描融、訓(xùn)練語(yǔ)料為英文,所以不選擇中文或者多語(yǔ)種衡蚂;2窿克、設(shè)備條件有限,如果您的顯卡內(nèi)存小于16個(gè)G毛甲,那就請(qǐng)乖乖選擇base,不要折騰large了年叮;3、cased表示區(qū)分大小寫玻募,uncased表示不區(qū)分大小寫只损。除非你明確知道你的任務(wù)對(duì)大小寫敏感(比如命名實(shí)體識(shí)別、詞性標(biāo)注等)那么通常情況下uncased效果更好七咧。
3跃惫、下載訓(xùn)練數(shù)據(jù):
(1)下載MRPC語(yǔ)料:
官網(wǎng)上指定的方式是通過(guò)跑腳本download_glue_data.py來(lái)下載 GLUE data 。指定數(shù)據(jù)存放地址為:glue_data艾栋, 下載任務(wù)為:MRPC爆存,執(zhí)行(本篇中所有python3的命令同樣適用于python):
python3 download_glue_data.py --data_dir glue_data --tasks MRPC
執(zhí)行后發(fā)現(xiàn)下載失敗,究其原因是下面這兩個(gè)鏈接訪問(wèn)不上蝗砾,幾天后試了一次又能下載了先较,可能對(duì)方服務(wù)端不穩(wěn)定携冤。
MRPC_TRAIN = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://s3.amazonaws.com/senteval/senteval_data/msr_paraphrase_test.txt'
如果不能下載,可以參考我當(dāng)時(shí)的做法:
1闲勺、手動(dòng)下載dev_ids.tsv映射表保存在glue_data/MRPC文件夾下
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
2曾棕、因?yàn)?GLUE data官網(wǎng)也訪問(wèn)不了,所以只能去微軟官網(wǎng)下載:https://www.microsoft.com/en-ca/download/details.aspx?id=52398
將 msr_paraphrase_test.txt霉翔, msr_paraphrase_train.txt兩個(gè)解壓后的文件放在mrpc_ori_corpus文件夾下
3睁蕾、注釋掉腳本download_glue_data.py里下載dev_ids.tsv文件的語(yǔ)句(如果你的服務(wù)器能下載可以不注釋,使用代碼下載不必手動(dòng)下載):
65 # urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
執(zhí)行
python3 download_glue_data.py --data_dir glue_data --tasks MRPC --path_to_mrpc mrpc_ori_corpus
如果在glue_data/MRPC文件下出現(xiàn) dev.tsv债朵,test.tsv子眶,train.tsv這三個(gè)文件,說(shuō)明MRPC語(yǔ)料下載成功序芦。
(2)下載SQuAD語(yǔ)料:
基本上沒(méi)什么波折臭杰,可以使用下面三個(gè)鏈接直接下載,放置于$SQUAD_DIR路徑下
4谚中、run demo
(1) 基于MRPC語(yǔ)料的句子對(duì)分類任務(wù)
訓(xùn)練:
設(shè)置環(huán)境變量渴杆,指定預(yù)訓(xùn)練模型文件和語(yǔ)料地址
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue_data
在bert源碼文件里執(zhí)行run_classifier.py,基于預(yù)訓(xùn)練模型進(jìn)行fine-tune
python run_classifier.py \
--task_name=MRPC \
--do_train=true \
--do_eval=true \
--data_dir=$GLUE_DIR/MRPC \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=3.0 \
--output_dir=/tmp/mrpc_output/
模型保存在output_dir宪塔, 驗(yàn)證結(jié)果為:
***** Eval results *****
eval_accuracy = 0.845588
eval_loss = 0.505248
global_step = 343
loss = 0.505248
預(yù)測(cè):
指定fine-tune之后模型文件所在地址
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifier
執(zhí)行以下語(yǔ)句完成預(yù)測(cè)任務(wù)磁奖,預(yù)測(cè)結(jié)果輸出在output_dir文件夾中
python run_classifier.py \
--task_name=MRPC \
--do_predict=true \
--data_dir=$GLUE_DIR/MRPC \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER \
--max_seq_length=128 \
--output_dir=/tmp/mrpc_output/
(2)基于SQuAD語(yǔ)料的閱讀理解任務(wù)
設(shè)置為語(yǔ)料所在文件夾為$SQUAD_DIR
python run_squad.py \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v1.1.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v1.1.json \
--train_batch_size=12 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=/tmp/squad_base/
在output_dir文件夾下會(huì)輸出一個(gè)predictions.json文件,執(zhí)行:
python3 $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json predictions.json
看到以下結(jié)果某筐,說(shuō)明執(zhí)行無(wú)誤:
{"f1": 88.41249612335034, "exact_match": 81.2488174077578}
5比搭、總結(jié):
本篇內(nèi)容主要解決了以下兩個(gè)問(wèn)題:
(1) 基于MRPC語(yǔ)料的句子對(duì)分類任務(wù)和基于SQuAD語(yǔ)料的閱讀理解任務(wù)的demo執(zhí)行,主要是翻譯源碼中README.md的部分內(nèi)容南誊;
(2) 對(duì)于部分語(yǔ)料無(wú)法下載的情況身诺,提供了其他的搜集方式。
系列后續(xù)將對(duì)bert源碼進(jìn)行解讀抄囚,敬請(qǐng)關(guān)注
系列文章
Bert系列(二)——模型主體源碼解讀
Bert系列(三)——源碼解讀之Pre-train
Bert系列(四)——源碼解讀之Fine-tune
Bert系列(五)——中文分詞實(shí)踐 F1 97.8%(附代碼)
Reference
1.https://github.com/google-research/bert