前因
在尋找pytorch版本的英文版Bert預(yù)訓(xùn)練模型時嘴脾,發(fā)現(xiàn)只有中文版的預(yù)訓(xùn)練模型,而且因為Tensorflow和Pytorch在讀取預(yù)訓(xùn)練模型時,讀取文件的格式不同票从,所以不能直接拿來使用吹零。在讀取這篇 Pytorch | BERT模型實現(xiàn)罩抗,提供轉(zhuǎn)換腳本【橫掃NLP】 文章后,發(fā)現(xiàn)有可以將預(yù)訓(xùn)練模型轉(zhuǎn)換至Pytorch可以讀取的文件形式的方法灿椅。
介紹
一個名為 Hugging Face ?? 的團隊公開了BERT模型的谷歌官方TensorFlow庫的 op-for-op PyTorch 重新實現(xiàn)套蒂,其中有腳本可以將Tensorflow預(yù)訓(xùn)練模型轉(zhuǎn)換為Pytorch可以讀取的形式
這個實現(xiàn)可以為BERT加載任何預(yù)訓(xùn)練的TensorFlow checkpoint(特別是谷歌的官方預(yù)訓(xùn)練模型),并提供一個轉(zhuǎn)換腳本茫蛹。
使用說明
下載 Transformers操刀,使用convert_bert_original_tf_checkpoint_to_pytorch.py腳本,你可以在PyTorch保存文件中轉(zhuǎn)換BERT的任何TensorFlow檢查點(尤其是谷歌發(fā)布的官方預(yù)訓(xùn)練模型)婴洼。
這個腳本將TensorFlow checkpoint(以bert_model.ckpt開頭的三個文件)和相關(guān)的配置文件(bert_config.json)作為輸入骨坑,并為此配置創(chuàng)建PyTorch模型,從PyTorch模型的TensorFlow checkpoint加載權(quán)重并保存生成的模型在一個標(biāo)準(zhǔn)PyTorch保存文件中柬采,可以使用 torch.load() 導(dǎo)入(請參閱extract_features.py欢唾,run_classifier.py和run_squad.py中的示例)。
只需要運行一次這個轉(zhuǎn)換腳本粉捻,在原文件夾下就可以得到一個PyTorch模型匈辱。然后,你可以忽略TensorFlow checkpoint(以bert_model.ckpt開頭的三個文件)杀迹,但是一定要保留配置文件(bert_config.json)和詞匯表文件(vocab.txt)亡脸,因為PyTorch模型也需要這些文件押搪。
要運行這個特定的轉(zhuǎn)換腳本,你需要安裝TensorFlow和PyTorch浅碾。該庫的其余部分只需要PyTorch大州。
使用方法
linux 下執(zhí)行或者使用 git 執(zhí)行 sh 指令
下面是一個預(yù)訓(xùn)練的BERT-Base Uncased 模型的轉(zhuǎn)換過程示例:
# export BERT_BASE_DIR = '絕對路徑'
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
# windows下
export BERT_BASE_DIR = F:/program/uncased_L-12_H-768_A-12
python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin \
Google的預(yù)訓(xùn)練模型下載地址:https://github.com/google-research/bert#pre-trained-models