代碼文件為bert_lstm_ner.py,下面進(jìn)行逐行解析:
tf.logging.set_verbosity(tf.logging.INFO)#運行代碼時烘浦,將會看到info日志輸出INFO:tensorflow:loss = 1.18812, step = 1INFO:tensorflow:loss = #0.210323, step = 101INFO:tensorflow:loss = 0.109025, step = 201
processors = {
? ? ? ? "ner": NerProcessor
? ? }#定義一個ner:NerProcessor的字典
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)#將bert參數(shù)傳到bert_config中
if FLAGS.max_seq_length > bert_config.max_position_embeddings:#假如最大總輸入序列長度大于bert最大的wordembedding長度勒魔,報錯
? ? raise ValueError(
? ? ? ? "Cannot use sequence length %d because the BERT model "
? ? ? ? "was only trained up to sequence length %d" %
? ? ? ? (FLAGS.max_seq_length, bert_config.max_position_embeddings))
# 在train 的時候,才刪除上一輪產(chǎn)出的文件,在predicted 的時候不做clean
if FLAGS.clean and FLAGS.do_train:#默認(rèn)是兩個ture
? ? if os.path.exists(FLAGS.output_dir):#假如輸出文件位置存在
? ? ? ? def del_file(path):#設(shè)置個刪文件的函數(shù)
? ? ? ? ? ? ls = os.listdir(path)#listdir函數(shù)返回文件夾中的所有文件名字
? ? ? ? ? ? for i in ls:
? ? ? ? ? ? ? ? c_path = os.path.join(path, i)#os.path.join()函數(shù)用于路徑拼接文件路徑
? ? ? ? ? ? ? ? if os.path.isdir(c_path):#如果該文件存在
? ? ? ? ? ? ? ? ? ? del_file(c_path)#刪除文件
? ? ? ? ? ? ? ? else:
? ? ? ? ? ? ? ? ? ? os.remove(c_path)#刪除文件
? ? ? ? try:
? ? ? ? ? ? del_file(FLAGS.output_dir)#嘗試刪除文件忽冻,否則報錯
? ? ? ? except Exception as e:
? ? ? ? ? ? print(e)
? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')
? ? ? ? ? ? exit(-1)
? ? if os.path.exists(FLAGS.data_config_path):#如果保存數(shù)據(jù)的位置存在
? ? ? ? try:
? ? ? ? ? ? os.remove(FLAGS.data_config_path)#嘗試刪除
? ? ? ? except Exception as e:
? ? ? ? ? ? print(e)
? ? ? ? ? ? print('pleace remove the files of output dir and data.conf')
? ? ? ? ? ? exit(-1)
task_name = FLAGS.task_name.lower()#task_name是要訓(xùn)練的任務(wù)的名稱馅而,值為ner
if task_name not in processors:#如果processor里面沒有ner溉贿,報錯
? ? raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()#返回NerProcessor()函數(shù)
label_list = processor.get_labels()#label_list值為["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X", "[CLS]", "[SEP]"]
tokenizer = tokenization.FullTokenizer(
? ? vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)#輸出函數(shù)(對bert的詞匯文件在進(jìn)行變小寫后進(jìn)行fulltokenizer)
tpu_cluster_resolver = None#不使用tpu集群
if FLAGS.use_tpu and FLAGS.tpu_name:#不考慮
? ? tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
? ? ? ? FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2#如果為PER_HOST_V1或PER_HOST_V2仑嗅,則在每個主機(jī)上調(diào)用一次input_fn。 #使用每核心輸入管道配置痹束,每個核心調(diào)用一次检疫。 具有全局批量大小
run_config = tf.contrib.tpu.RunConfig(#定義tpu函數(shù)
? ? cluster=tpu_cluster_resolver,#false
? ? master=FLAGS.master,#none‘TensorFlow master URL.’
? ? model_dir=FLAGS.output_dir,#輸出位置
? ? save_checkpoints_steps=FLAGS.save_checkpoints_steps,#" 保存模型checkpoint的頻率."為1000
? ? tpu_config=tf.contrib.tpu.TPUConfig(#定義tpu函數(shù)2
? ? ? ? iterations_per_loop=FLAGS.iterations_per_loop,#"在每個評估單元調(diào)用中要執(zhí)行多少步驟."1000
? ? ? ? num_shards=FLAGS.num_tpu_cores,#tpu核數(shù),8
? ? ? ? per_host_input_for_training=is_per_host))#PER_HOST_V2
train_examples = None#none
num_train_steps = None#none
num_warmup_steps = None#none
if os.path.exists(FLAGS.data_config_path):#如果data config 文件祷嘶,保存訓(xùn)練和dev config存在
? ? with codecs.open(FLAGS.data_config_path) as fd:#打開文件路徑
? ? ? ? data_config = json.load(fd)#加載數(shù)據(jù)到data_config中
else:
? ? data_config = {}#否則設(shè)為空
if FLAGS.do_train:
? ? ? ? # 加載訓(xùn)練數(shù)據(jù)
? ? if len(data_config) == 0:#如果為空
? ? ? ? train_examples = processor.get_train_examples(FLAGS.data_dir)#將訓(xùn)練樣本輸入到變量中
? ? ? ? num_train_steps = int(
? ? ? ? ? ? len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)#訓(xùn)練執(zhí)行總批次數(shù)為樣本長度/訓(xùn)練總批次*訓(xùn)練總次數(shù)
? ? ? ? num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)#上面數(shù)值*進(jìn)行線性學(xué)習(xí)率熱身訓(xùn)練的比例屎媳。
? ? ? ? data_config['num_train_steps'] = num_train_steps#數(shù)據(jù)參數(shù)設(shè)定1
? ? ? ? data_config['num_warmup_steps'] = num_warmup_steps#數(shù)據(jù)參數(shù)設(shè)定2
? ? ? ? data_config['num_train_size'] = len(train_examples)#數(shù)據(jù)參數(shù)設(shè)定3(數(shù)據(jù)長度)
? ? else:
? ? ? ? num_train_steps = int(data_config['num_train_steps'])#直接調(diào)用1
? ? ? ? num_warmup_steps = int(data_config['num_warmup_steps'])#直接調(diào)用2
? ? # 返回的model_dn 是一個函數(shù)夺溢,其定義了模型,訓(xùn)練烛谊,評測方法风响,并且使用鉤子參數(shù),加載了BERT模型的參數(shù)進(jìn)行了自己模型的參數(shù)初始化過程
? ? # tf 新的架構(gòu)方法丹禀,通過定義model_fn 函數(shù)状勤,定義模型,然后通過EstimatorAPI進(jìn)行模型的其他工作湃崩,Es就可以控制模型的訓(xùn)練荧降,預(yù)測,評估工作等攒读。
model_fn = model_fn_builder(
? ? bert_config=bert_config,#從bert文件中獲得
? ? num_labels=len(label_list) + 1,#標(biāo)簽數(shù)量
? ? init_checkpoint=FLAGS.init_checkpoint,#r'D:\bert\chinese_L-12_H-768_A-12\bert_model.ckpt?"初始檢查點(通常來自預(yù)先訓(xùn)練的bert模型)."
? ? learning_rate=FLAGS.learning_rate,#學(xué)習(xí)率?5e-5,
? ? num_train_steps=num_train_steps,#總批次
? ? num_warmup_steps=num_warmup_steps,#warmup數(shù)
#warmup就是先采用小的學(xué)習(xí)率(0.01)進(jìn)行訓(xùn)練,訓(xùn)練了400iterations之后將學(xué)習(xí)率調(diào)整至0.1開始正式訓(xùn)練
? ? use_tpu=FLAGS.use_tpu,#none
? ? use_one_hot_embeddings=FLAGS.use_tpu)#none
print(model_fn)
estimator = tf.contrib.tpu.TPUEstimator(#定義評估器
? ? use_tpu=FLAGS.use_tpu,#none
? ? model_fn=model_fn,#將上面定義的model加入
? ? config=run_config,#將上面定義的runconfig參數(shù)加入
? ? train_batch_size=FLAGS.train_batch_size,#訓(xùn)練批次 64
? ? eval_batch_size=FLAGS.eval_batch_size,#評估批次 8
? ? predict_batch_size=FLAGS.predict_batch_size)# 預(yù)測批次 8
train_file =r'C:\Users\dell\Desktop\Name-Entity-Recognition-master\BERT-BiLSTM-CRF-NER\train.tf_record'
filed_based_convert_examples_to_features(
? ? train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file)#將數(shù)據(jù)轉(zhuǎn)化為TF_Record 結(jié)構(gòu)辛友,作為模型數(shù)據(jù)輸入:樣本薄扁,標(biāo)簽,最#大長度废累,tokenizer邓梅,數(shù)據(jù)
num_train_size = num_train_size = int(data_config['num_train_size'])
tf.logging.info("***** Running training *****")
tf.logging.info("? Num examples = %d", num_train_size)#20864
tf.logging.info("? Batch size = %d", FLAGS.train_batch_size)#64
tf.logging.info("? Num steps = %d", num_train_steps)#978
train_input_fn = file_based_input_fn_builder(
? ? input_file=train_file,#訓(xùn)練文件
? ? seq_length=FLAGS.max_seq_length,#最大序列長度 128
? ? is_training=True,#確定訓(xùn)練
? ? drop_remainder=True)#沒查到。邑滨。日缨。
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)#進(jìn)行訓(xùn)練
if FLAGS.do_eval:#進(jìn)行評估
? ? if data_config.get('eval.tf_record_path', '') == '':#如果字典中沒有評估路徑
? ? ? ? eval_examples = processor.get_dev_examples(FLAGS.data_dir)#讀到data_dir的dev文件
? ? ? ? eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")#獲得輸出位置的eval.tf_record文件
? ? ? ? filed_based_convert_examples_to_features(
? ? ? ? ? ? eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)#將評估文件轉(zhuǎn)換
? ? ? ? data_config['eval.tf_record_path'] = eval_file#將評估文件加入數(shù)據(jù)
? ? ? ? data_config['num_eval_size'] = len(eval_examples)#將評估文件長度加入數(shù)據(jù)
? ? else:
? ? ? ? eval_file = data_config['eval.tf_record_path']#將評估數(shù)據(jù)文件讀出
? ? ? ? # 打印驗證集數(shù)據(jù)信息
? ? num_eval_size = data_config.get('num_eval_size', 0)#將評估文件長度讀出
? ? tf.logging.info("***** Running evaluation *****")
? ? tf.logging.info("? Num examples = %d", num_eval_size)#2318
? ? tf.logging.info("? Batch size = %d", FLAGS.eval_batch_size)#8
? ? eval_steps = None
? ? if FLAGS.use_tpu:#none
? ? ? ? eval_steps = int(num_eval_size / FLAGS.eval_batch_size)#不管
? ? eval_drop_remainder = True if FLAGS.use_tpu else False#false
? ? eval_input_fn = file_based_input_fn_builder(
? ? ? ? input_file=eval_file,#評估文件
? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列長度
? ? ? ? is_training=False,#不訓(xùn)練
? ? ? ? drop_remainder=eval_drop_remainder)#none
? ? result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)#step=none(這里報錯)
? ? output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")#輸出文件
? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:
? ? ? ? tf.logging.info("***** Eval results *****")
? ? ? ? for key in sorted(result.keys()):
? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(result[key]))#報出文件
? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(result[key])))#寫入文件
# 保存數(shù)據(jù)的配置文件,避免在以后的訓(xùn)練過程中多次讀取訓(xùn)練以及測試數(shù)據(jù)集掖看,消耗時間
if not os.path.exists(FLAGS.data_config_path):
? ? with codecs.open(FLAGS.data_config_path, 'a', encoding='utf-8') as fd:
? ? ? ? json.dump(data_config, fd)#把a作為data_config_path存入data_config
if FLAGS.do_predict:#開始預(yù)測
? ? token_path = os.path.join(FLAGS.output_dir, "token_test.txt")#導(dǎo)入測試集輸出位置
? ? if os.path.exists(token_path):#如果測試集存在
? ? ? ? os.remove(token_path)#刪了
? ? with codecs.open(os.path.join(FLAGS.output_dir, 'label2id.pkl'), 'rb') as rf:#打開label2id的文件
? ? ? ? label2id = pickle.load(rf)
? ? ? ? id2label = {value: key for key, value in label2id.items()}#轉(zhuǎn)成字典
? ? predict_examples = processor.get_test_examples(FLAGS.data_dir)#得到test文件
? ? predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")#得到預(yù)測的tf_record文件
? ? filed_based_convert_examples_to_features(predict_examples, label_list,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? FLAGS.max_seq_length, tokenizer,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? predict_file, mode="test")#建立測試的tf_record文件
? ? tf.logging.info("***** Running prediction*****")
? ? tf.logging.info("? Num examples = %d", len(predict_examples))#4636
? ? tf.logging.info("? Batch size = %d", FLAGS.predict_batch_size)#8
? ? if FLAGS.use_tpu:
? ? ? ? ? ? # Warning: According to tpu_estimator.py Prediction on TPU is an
? ? ? ? ? ? # experimental feature and hence not supported here
? ? ? ? ? ? raise ValueError("Prediction in TPU not supported")
? ? predict_drop_remainder = True if FLAGS.use_tpu else False#false
? ? predict_input_fn = file_based_input_fn_builder(
? ? ? ? input_file=predict_file,#輸入文件
? ? ? ? seq_length=FLAGS.max_seq_length,#最大序列
? ? ? ? is_training=False,#不訓(xùn)練
? ? ? ? drop_remainder=predict_drop_remainder)#none
? ? predicted_result = estimator.evaluate(input_fn=predict_input_fn)#報錯匣距。。哎壳。
? ? output_eval_file = os.path.join(FLAGS.output_dir, "predicted_results.txt")#輸出預(yù)測結(jié)果
? ? with codecs.open(output_eval_file, "w", encoding='utf-8') as writer:
? ? ? ? tf.logging.info("***** Predict results *****")
? ? ? ? for key in sorted(predicted_result.keys()):
? ? ? ? ? ? tf.logging.info("? %s = %s", key, str(predicted_result[key]))
? ? ? ? ? ? writer.write("%s = %s\n" % (key, str(predicted_result[key])))#寫入文件
? ? result = estimator.predict(input_fn=predict_input_fn)#預(yù)測
? ? output_predict_file = os.path.join(FLAGS.output_dir, "label_test.txt")#輸出文件
? ? def result_to_pair(writer):#這里是寫入函數(shù)
? ? ? ? for predict_line, prediction in zip(predict_examples, result):
? ? ? ? ? ? idx = 0
? ? ? ? ? ? line = ''
? ? ? ? ? ? line_token = str(predict_line.text).split(' ')
? ? ? ? ? ? label_token = str(predict_line.label).split(' ')
? ? ? ? ? ? if len(line_token) != len(label_token):
? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)
? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)
? ? ? ? ? ? for id in prediction:
? ? ? ? ? ? ? ? if id == 0:
? ? ? ? ? ? ? ? ? ? continue
? ? ? ? ? ? ? ? curr_labels = id2label[id]
? ? ? ? ? ? ? ? if curr_labels in ['[CLS]', '[SEP]']:
? ? ? ? ? ? ? ? ? ? continue
? ? ? ? ? ? ? ? ? ? # 不知道為什么毅待,這里會出現(xiàn)idx out of range 的錯誤。归榕。尸红。do not know why here cache list out of range exception!
? ? ? ? ? ? ? ? try:
? ? ? ? ? ? ? ? ? ? line += line_token[idx] + ' ' + label_token[idx] + ' ' + curr_labels + '\n'
? ? ? ? ? ? ? ? except Exception as e:
? ? ? ? ? ? ? ? ? ? tf.logging.info(e)
? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.text)
? ? ? ? ? ? ? ? ? ? tf.logging.info(predict_line.label)
? ? ? ? ? ? ? ? ? ? line = ''
? ? ? ? ? ? ? ? ? ? break
? ? ? ? ? ? ? ? idx += 1
? ? ? ? ? ? writer.write(line + '\n')
? ? with codecs.open(output_predict_file, 'w', encoding='utf-8') as writer:
? ? ? ? result_to_pair(writer)#寫入文件
? ? from conlleval import return_report
? ? eval_result = return_report(output_predict_file)#百度找不到,猜測是得到評估結(jié)果的函數(shù)
? ? print(eval_result)