2019-01-16 解析bert代碼

代碼文件為bert_lstm_ner.py,下面進(jìn)行逐行解析:

tf.logging.set_verbosity(tf.logging.INFO)#運行代碼時烘浦,將會看到info日志輸出INFO:tensorflow:loss = 1.18812, step = 1INFO:tensorflow:loss = #0.210323, step = 101INFO:tensorflow:loss = 0.109025, step = 201

processors = {

? ? ? ? "ner": NerProcessor

? ? }#定義一個ner:NerProcessor的字典

bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)#將bert參數(shù)傳到bert_config中

if FLAGS.max_seq_length > bert_config.max_position_embeddings:#假如最大總輸入序列長度大于bert最大的wordembedding長度勒魔,報錯

? ? raise ValueError(

? ? ? ? "Cannot use sequence length %d because the BERT model "

? ? ? ? "was only trained up to sequence length %d" %

? ? ? ? (FLAGS.max_seq_length, bert_config.max_position_embeddings))

# 在train 的時候,才刪除上一輪產(chǎn)出的文件,在predicted 的時候不做clean

if FLAGS.clean and FLAGS.do_train:#默認(rèn)是兩個ture

? ? if os.path.exists(FLAGS.output_dir):#假如輸出文件位置存在

? ? ? ? def del_file(path):#設(shè)置個刪文件的函數(shù)

? ? ? ? ? ? ls = os.listdir(path)#listdir函數(shù)返回文件夾中的所有文件名字

? ? ? ? ? ? for i in ls:

? ? ? ? ? ? ? ? c_path = os.path.join(path, i)#os.path.join()函數(shù)用于路徑拼接文件路徑

? ? ? ? ? ? ? ? if os.path.isdir(c_path):#如果該文件存在

? ? ? ? ? ? ? ? ? ? del_file(c_path)#刪除文件

? ? ? ? ? ? ? ? else:

? ? ? ? ? ? ? ? ? ? os.remove(c_path)#刪除文件

? ? ? ? try:

? ? ? ? ? ? del_file(FLAGS.output_dir)#嘗試刪除文件忽冻,否則報錯

? ? ? ? except Exception as e:

? ? ? ? ? ? print(e)

? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')

? ? ? ? ? ? exit(-1)

? ? if os.path.exists(FLAGS.data_config_path):#如果保存數(shù)據(jù)的位置存在

? ? ? ? try:

? ? ? ? ? ? os.remove(FLAGS.data_config_path)#嘗試刪除

? ? ? ? except Exception as e:

? ? ? ? ? ? print(e)

? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')

? ? ? ? ? ? exit(-1)

task_name = FLAGS.task_name.lower()#task_name是要訓(xùn)練的任務(wù)的名稱馅而,值為ner

if task_name not in processors:#如果processor里面沒有ner溉贿,報錯

? ? raise ValueError("Task not found: %s" % (task_name))

processor = processors[task_name]()#返回NerProcessor()函數(shù)

label_list = processor.get_labels()#label_list值為["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]

tokenizer = tokenization.FullTokenizer(

? ? vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)#輸出函數(shù)(對bert的詞匯文件在進(jìn)行變小寫后進(jìn)行fulltokenizer)

tpu_cluster_resolver = None#不使用tpu集群

if FLAGS.use_tpu and FLAGS.tpu_name:#不考慮

? ? tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(

? ? ? ? FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2#如果為PER_HOST_V1或PER_HOST_V2仑嗅,則在每個主機(jī)上調(diào)用一次input_fn。 #使用每核心輸入管道配置痹束,每個核心調(diào)用一次检疫。 具有全局批量大小

run_config = tf.contrib.tpu.RunConfig(#定義tpu函數(shù)

? ? cluster=tpu_cluster_resolver,#false

? ? master=FLAGS.master,#none‘TensorFlow master URL.’

? ? model_dir=FLAGS.output_dir,#輸出位置

? ? save_checkpoints_steps=FLAGS.save_checkpoints_steps,#" 保存模型checkpoint的頻率."為1000

? ? tpu_config=tf.contrib.tpu.TPUConfig(#定義tpu函數(shù)2

? ? ? ? iterations_per_loop=FLAGS.iterations_per_loop,#"在每個評估單元調(diào)用中要執(zhí)行多少步驟."1000

? ? ? ? num_shards=FLAGS.num_tpu_cores,#tpu核數(shù),8

? ? ? ? per_host_input_for_training=is_per_host))#PER_HOST_V2

train_examples = None#none

num_train_steps = None#none

num_warmup_steps = None#none

if os.path.exists(FLAGS.data_config_path):#如果data config 文件祷嘶,保存訓(xùn)練和dev config存在

? ? with codecs.open(FLAGS.data_config_path) as fd:#打開文件路徑

? ? ? ? data_config = json.load(fd)#加載數(shù)據(jù)到data_config中

else:

? ? data_config = {}#否則設(shè)為空

if FLAGS.do_train:

? ? ? ? # 加載訓(xùn)練數(shù)據(jù)

? ? if len(data_config) == 0:#如果為空

? ? ? ? train_examples = processor.get_train_examples(FLAGS.data_dir)#將訓(xùn)練樣本輸入到變量中

? ? ? ? num_train_steps = int(

? ? ? ? ? ? len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)#訓(xùn)練執(zhí)行總批次數(shù)為樣本長度/訓(xùn)練總批次*訓(xùn)練總次數(shù)

? ? ? ? num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)#上面數(shù)值*進(jìn)行線性學(xué)習(xí)率熱身訓(xùn)練的比例屎媳。

? ? ? ? data_config['num_train_steps'] = num_train_steps#數(shù)據(jù)參數(shù)設(shè)定1

? ? ? ? data_config['num_warmup_steps'] = num_warmup_steps#數(shù)據(jù)參數(shù)設(shè)定2

? ? ? ? data_config['num_train_size'] = len(train_examples)#數(shù)據(jù)參數(shù)設(shè)定3(數(shù)據(jù)長度)

? ? else:

? ? ? ? num_train_steps = int(data_config['num_train_steps'])#直接調(diào)用1

? ? ? ? num_warmup_steps = int(data_config['num_warmup_steps'])#直接調(diào)用2

? ? # 返回的model_dn 是一個函數(shù)夺溢,其定義了模型,訓(xùn)練烛谊,評測方法风响,并且使用鉤子參數(shù),加載了BERT模型的參數(shù)進(jìn)行了自己模型的參數(shù)初始化過程

? ? # tf 新的架構(gòu)方法丹禀,通過定義model_fn 函數(shù)状勤,定義模型,然后通過EstimatorAPI進(jìn)行模型的其他工作湃崩,Es就可以控制模型的訓(xùn)練荧降,預(yù)測,評估工作等攒读。

model_fn = model_fn_builder(

? ? bert_config=bert_config,#從bert文件中獲得

? ? num_labels=len(label_list) + 1,#標(biāo)簽數(shù)量

? ? init_checkpoint=FLAGS.init_checkpoint,#r'D:\bert\chinese_L-12_H-768_A-12\bert_model.ckpt?"初始檢查點(通常來自預(yù)先訓(xùn)練的bert模型)."

? ? learning_rate=FLAGS.learning_rate,#學(xué)習(xí)率?5e-5,

? ? num_train_steps=num_train_steps,#總批次

? ? num_warmup_steps=num_warmup_steps,#warmup數(shù)

#warmup就是先采用小的學(xué)習(xí)率(0.01)進(jìn)行訓(xùn)練,訓(xùn)練了400iterations之后將學(xué)習(xí)率調(diào)整至0.1開始正式訓(xùn)練

? ? use_tpu=FLAGS.use_tpu,#none

? ? use_one_hot_embeddings=FLAGS.use_tpu)#none

print(model_fn)

estimator = tf.contrib.tpu.TPUEstimator(#定義評估器

? ? use_tpu=FLAGS.use_tpu,#none

? ? model_fn=model_fn,#將上面定義的model加入

? ? config=run_config,#將上面定義的runconfig參數(shù)加入

? ? train_batch_size=FLAGS.train_batch_size,#訓(xùn)練批次 64

? ? eval_batch_size=FLAGS.eval_batch_size,#評估批次 8

? ? predict_batch_size=FLAGS.predict_batch_size)# 預(yù)測批次 8

train_file =r'C:\Users\dell\Desktop\Name-Entity-Recognition-master\BERT-BiLSTM-CRF-NER\train.tf_record'

filed_based_convert_examples_to_features(

? ? train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)#將數(shù)據(jù)轉(zhuǎn)化為TF_Record 結(jié)構(gòu)辛友,作為模型數(shù)據(jù)輸入:樣本薄扁,標(biāo)簽,最#大長度废累,tokenizer邓梅,數(shù)據(jù)

num_train_size = num_train_size = int(data_config['num_train_size'])

tf.logging.info("***** Running training *****")

tf.logging.info("? Num examples = %d", num_train_size)#20864

tf.logging.info("? Batch size = %d", FLAGS.train_batch_size)#64

tf.logging.info("? Num steps = %d", num_train_steps)#978

train_input_fn = file_based_input_fn_builder(

? ? input_file=train_file,#訓(xùn)練文件

? ? seq_length=FLAGS.max_seq_length,#最大序列長度 128

? ? is_training=True,#確定訓(xùn)練

? ? drop_remainder=True)#沒查到。邑滨。日缨。

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)#進(jìn)行訓(xùn)練

if FLAGS.do_eval:#進(jìn)行評估

? ? if data_config.get('eval.tf_record_path', '') == '':#如果字典中沒有評估路徑

? ? ? ? eval_examples = processor.get_dev_examples(FLAGS.data_dir)#讀到data_dir的dev文件

? ? ? ? eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")#獲得輸出位置的eval.tf_record文件

? ? ? ? filed_based_convert_examples_to_features(

? ? ? ? ? ? eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)#將評估文件轉(zhuǎn)換

? ? ? ? data_config['eval.tf_record_path'] = eval_file#將評估文件加入數(shù)據(jù)

? ? ? ? data_config['num_eval_size'] = len(eval_examples)#將評估文件長度加入數(shù)據(jù)

? ? else:

? ? ? ? eval_file = data_config['eval.tf_record_path']#將評估數(shù)據(jù)文件讀出

? ? ? ? # 打印驗證集數(shù)據(jù)信息

? ? num_eval_size = data_config.get('num_eval_size', 0)#將評估文件長度讀出

? ? tf.logging.info("***** Running evaluation *****")

? ? tf.logging.info("? Num examples = %d", num_eval_size)#2318

? ? tf.logging.info("? Batch size = %d", FLAGS.eval_batch_size)#8

? ? eval_steps = None

? ? if FLAGS.use_tpu:#none

? ? ? ? eval_steps = int(num_eval_size / FLAGS.eval_batch_size)#不管

? ? eval_drop_remainder = True if FLAGS.use_tpu else False#false

? ? eval_input_fn = file_based_input_fn_builder(

? ? ? ? input_file=eval_file,#評估文件

? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列長度

? ? ? ? is_training=False,#不訓(xùn)練

? ? ? ? drop_remainder=eval_drop_remainder)#none

? ? result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)#step=none(這里報錯)

? ? output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")#輸出文件

? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

? ? ? ? tf.logging.info("***** Eval results *****")

? ? ? ? for key in sorted(result.keys()):

? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(result[key]))#報出文件

? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(result[key])))#寫入文件

# 保存數(shù)據(jù)的配置文件,避免在以后的訓(xùn)練過程中多次讀取訓(xùn)練以及測試數(shù)據(jù)集掖看,消耗時間

if not os.path.exists(FLAGS.data_config_path):

? ? with codecs.open(FLAGS.data_config_path, 'a', encoding='utf-8') as fd:

? ? ? ? json.dump(data_config, fd)#把a作為data_config_path存入data_config

if FLAGS.do_predict:#開始預(yù)測

? ? token_path = os.path.join(FLAGS.output_dir, "token_test.txt")#導(dǎo)入測試集輸出位置

? ? if os.path.exists(token_path):#如果測試集存在

? ? ? ? os.remove(token_path)#刪了

? ? with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:#打開label2id的文件

? ? ? ? label2id = pickle.load(rf)

? ? ? ? id2label = {value: key for key, value in label2id.items()}#轉(zhuǎn)成字典

? ? predict_examples = processor.get_test_examples(FLAGS.data_dir)#得到test文件

? ? predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")#得到預(yù)測的tf_record文件

? ? filed_based_convert_examples_to_features(predict_examples, label_list,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? FLAGS.max_seq_length, tokenizer,

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? predict_file, mode="test")#建立測試的tf_record文件

? ? tf.logging.info("***** Running prediction*****")

? ? tf.logging.info("? Num examples = %d", len(predict_examples))#4636

? ? tf.logging.info("? Batch size = %d", FLAGS.predict_batch_size)#8

? ? if FLAGS.use_tpu:

? ? ? ? ? ? # Warning: According to tpu_estimator.py Prediction on TPU is an

? ? ? ? ? ? # experimental feature and hence not supported here

? ? ? ? ? ? raise ValueError("Prediction in TPU not supported")

? ? predict_drop_remainder = True if FLAGS.use_tpu else False#false

? ? predict_input_fn = file_based_input_fn_builder(

? ? ? ? input_file=predict_file,#輸入文件

? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列

? ? ? ? is_training=False,#不訓(xùn)練

? ? ? ? drop_remainder=predict_drop_remainder)#none

? ? predicted_result = estimator.evaluate(input_fn=predict_input_fn)#報錯匣距。。哎壳。

? ? output_eval_file = os.path.join(FLAGS.output_dir, "predicted_results.txt")#輸出預(yù)測結(jié)果

? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:

? ? ? ? tf.logging.info("***** Predict results *****")

? ? ? ? for key in sorted(predicted_result.keys()):

? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(predicted_result[key]))

? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(predicted_result[key])))#寫入文件

? ? result = estimator.predict(input_fn=predict_input_fn)#預(yù)測

? ? output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")#輸出文件

? ? def result_to_pair(writer):#這里是寫入函數(shù)

? ? ? ? for predict_line, prediction in zip(predict_examples, result):

? ? ? ? ? ? idx = 0

? ? ? ? ? ? line = ''

? ? ? ? ? ? line_token = str(predict_line.text).split(' ')

? ? ? ? ? ? label_token = str(predict_line.label).split(' ')

? ? ? ? ? ? if len(line_token) != len(label_token):

? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)

? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)

? ? ? ? ? ? for id in prediction:

? ? ? ? ? ? ? ? if id == 0:

? ? ? ? ? ? ? ? ? ? continue

? ? ? ? ? ? ? ? curr_labels = id2label[id]

? ? ? ? ? ? ? ? if curr_labels in ['[CLS]', '[SEP]']:

? ? ? ? ? ? ? ? ? ? continue

? ? ? ? ? ? ? ? ? ? # 不知道為什么毅待,這里會出現(xiàn)idx out of range 的錯誤。归榕。尸红。do not know why here cache list out of range exception!

? ? ? ? ? ? ? ? try:

? ? ? ? ? ? ? ? ? ? line += line_token[idx] + ' ' + label_token[idx] + ' ' + curr_labels + '\n'

? ? ? ? ? ? ? ? except Exception as e:

? ? ? ? ? ? ? ? ? ? tf.logging.info(e)

? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)

? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)

? ? ? ? ? ? ? ? ? ? line = ''

? ? ? ? ? ? ? ? ? ? break

? ? ? ? ? ? ? ? idx += 1

? ? ? ? ? ? writer.write(line + '\n')

? ? with codecs.open(output_predict_file, 'w', encoding='utf-8') as writer:

? ? ? ? result_to_pair(writer)#寫入文件

? ? from conlleval import return_report

? ? eval_result = return_report(output_predict_file)#百度找不到,猜測是得到評估結(jié)果的函數(shù)

? ? print(eval_result)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末刹泄,一起剝皮案震驚了整個濱河市外里,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌特石,老刑警劉巖盅蝗,帶你破解...
    沈念sama閱讀 221,548評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異县匠,居然都是意外死亡风科,警方通過查閱死者的電腦和手機(jī)撒轮,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,497評論 3 399
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來贼穆,“玉大人题山,你說我怎么就攤上這事」嗜” “怎么了顶瞳?”我有些...
    開封第一講書人閱讀 167,990評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長愕秫。 經(jīng)常有香客問我慨菱,道長,這世上最難降的妖魔是什么戴甩? 我笑而不...
    開封第一講書人閱讀 59,618評論 1 296
  • 正文 為了忘掉前任符喝,我火速辦了婚禮,結(jié)果婚禮上甜孤,老公的妹妹穿的比我還像新娘协饲。我一直安慰自己,他們只是感情好缴川,可當(dāng)我...
    茶點故事閱讀 68,618評論 6 397
  • 文/花漫 我一把揭開白布茉稠。 她就那樣靜靜地躺著,像睡著了一般把夸。 火紅的嫁衣襯著肌膚如雪而线。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,246評論 1 308
  • 那天恋日,我揣著相機(jī)與錄音膀篮,去河邊找鬼。 笑死谚鄙,一個胖子當(dāng)著我的面吹牛各拷,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播闷营,決...
    沈念sama閱讀 40,819評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼烤黍,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了傻盟?” 一聲冷哼從身側(cè)響起速蕊,我...
    開封第一講書人閱讀 39,725評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎娘赴,沒想到半個月后规哲,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,268評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡诽表,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,356評論 3 340
  • 正文 我和宋清朗相戀三年唉锌,在試婚紗的時候發(fā)現(xiàn)自己被綠了隅肥。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,488評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡袄简,死狀恐怖腥放,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情绿语,我是刑警寧澤秃症,帶...
    沈念sama閱讀 36,181評論 5 350
  • 正文 年R本政府宣布,位于F島的核電站吕粹,受9級特大地震影響种柑,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜匹耕,卻給世界環(huán)境...
    茶點故事閱讀 41,862評論 3 333
  • 文/蒙蒙 一聚请、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧泌神,春花似錦良漱、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,331評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽矾兜。三九已至损趋,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間椅寺,已是汗流浹背浑槽。 一陣腳步聲響...
    開封第一講書人閱讀 33,445評論 1 272
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留返帕,地道東北人桐玻。 一個月前我還...
    沈念sama閱讀 48,897評論 3 376
  • 正文 我出身青樓,卻偏偏與公主長得像荆萤,于是被迫代替她去往敵國和親镊靴。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,500評論 2 359