繼續(xù)前面的三篇文章Part-1之拨,Part-2茉继,Part-3,這一篇我們來完善存儲和恢復(fù)機(jī)制蚀乔。
把計算圖保存到文件save_graph_to_file
下面是增加的代碼烁竭,先不要運行,稍后一起測試:
#將圖保存到文件,必要時創(chuàng)建允許的量子化
def save_graph_to_file(graph, graph_file_name, module_spec, class_count):
sess, _, _, _, _, _ = build_eval_session(module_spec, class_count)
graph = sess.graph
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), ['final_tensor_name'])
with tf.gfile.FastGFile(graph_file_name, 'wb') as f:
f.write(output_graph_def.SerializeToString())
保存評估模型export_model
注意每次使用前必須把舊的saved_model文件夾刪除或改名吉挣。
#導(dǎo)出評估eval圖的模型pd文件用于提供服務(wù)
saved_model_dir=os.path.join(dir_path,'saved_model'+str(datetime.now()))
def export_model(module_spec, class_count):
sess, in_image, _, _, _, _ = build_eval_session(module_spec, class_count)
graph = sess.graph
with graph.as_default():
#輸入輸出點
inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
out_classes = sess.graph.get_tensor_by_name('final_tensor_name:0')
outputs = {
'prediction': tf.saved_model.utils.build_tensor_info(out_classes)
}
#創(chuàng)建簽名
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
#初始化
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
#保存saved_model
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:signature
},
legacy_init_op=legacy_init_op)
builder.save()
改進(jìn)最終再訓(xùn)練函數(shù)
最后我們把run_final_retrain函數(shù)改進(jìn)一下派撕,增加評估eval和保存、導(dǎo)出功能睬魂。
注意每次使用前必須把舊的saved_model文件夾刪除或改名终吼。
以下是修改后的代碼(這里參數(shù)train_steps=10,所以得到的模型精度也非常糟糕氯哮。如果您的計算機(jī)允許际跪,官方默認(rèn)是4000,請量力而為):
#保存概要和checkpoint路徑設(shè)置
CHECKPOINT_NAME = os.path.join(dir_path,'checkpoints/retrain')
summaries_dir=os.path.join(dir_path,'summaries/train')
ensure_dir_exists(os.path.join(dir_path,'output'))
saved_model_path=os.path.join(dir_path,'output/out_graph.pd')
output_label_path=os.path.join(dir_path,'output/labels.txt')
#執(zhí)行訓(xùn)練兵保存checkpoint的函數(shù)
def run_final_retrain(train_steps=10,
eval_step_interval=5,
do_distort=True):
module_spec = hub.load_module_spec(HUB_MODULE)
#創(chuàng)建圖并獲取相關(guān)的張量入口
graph, bottleneck_tensor, resized_image_tensor, wants_quantization = (
create_module_graph(module_spec))
with graph.as_default():
#添加訓(xùn)練相關(guān)的張量和操作節(jié)點入口
(train_step, cross_entropy, bottleneck_input,ground_truth_input,
final_tensor) = add_final_retrain_ops(5, 'final_tensor_name',
bottleneck_tensor,wants_quantization,True)
with tf.Session(graph=graph) as sess:
init = tf.global_variables_initializer()
sess.run(init)
#添加圖片解碼相關(guān)的張量入口操作
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
#讀取圖片的bottleneck數(shù)據(jù)
if do_distort:
distorted_jpeg_data_tensor,distorted_image_tensor=add_input_distortions(module_spec,True,50,50,50)
else:
cache_bottlenecks(sess,
jpeg_data_tensor,decoded_image_tensor,
resized_image_tensor,bottleneck_tensor)
#創(chuàng)建評估新層精度的操作
evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
#記錄概要信息與保存
train_saver = tf.train.Saver()
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(summaries_dir+'/retrain',sess.graph)
validation_writer = tf.summary.FileWriter(summaries_dir + '/validation')
#開始運作!
for i in range(train_steps):
#獲取圖片bottleneck數(shù)據(jù)
if do_distort:
(train_bottlenecks,train_ground_truth) = get_random_distorted_bottlenecks(
sess,BATCH_SIZE, 'training',
distorted_jpeg_data_tensor,distorted_image_tensor,
resized_image_tensor, bottleneck_tensor)
else:
(train_bottlenecks,train_ground_truth, _) = get_random_cached_bottlenecks(
sess, BATCH_SIZE, 'training',
jpeg_data_tensor,decoded_image_tensor,
resized_image_tensor, bottleneck_tensor)
#啟動訓(xùn)練
train_summary, _ = sess.run(
[merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
train_writer.add_summary(train_summary, i)
#間隔性啟動評估
is_last_step = (i + 1 == train_steps)
if (i % eval_step_interval) == 0 or is_last_step:
train_accuracy, cross_entropy_value = sess.run(
[evaluation_step, cross_entropy],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %(datetime.now(), i, train_accuracy * 100))
tf.logging.info('%s: Step %d: Cross entropy = %f' %(datetime.now(), i, cross_entropy_value))
#使用不同的bottleneck數(shù)據(jù)進(jìn)行評估
validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, 10, 'validation',
jpeg_data_tensor,decoded_image_tensor,
resized_image_tensor, bottleneck_tensor))
#啟動評估姆打!
validation_summary, validation_accuracy = sess.run(
[merged, evaluation_step],
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i)
tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %(datetime.now(), i, validation_accuracy * 100,len(validation_bottlenecks)))
#間隔保存中介媒體文件良姆,為訓(xùn)練保存checkpoint
if (i % eval_step_interval == 0 and i > 0):
train_saver.save(sess, CHECKPOINT_NAME)
intermediate_file_name = (os.path.join(dir_path + 'intermediate') + str(i) + '.pb')
tf.logging.info('Save intermediate result to : '+intermediate_file_name)
save_graph_to_file(graph, intermediate_file_name, module_spec,5)
#保存模型
train_saver.save(sess, CHECKPOINT_NAME)
#執(zhí)行最終評估
run_final_eval(sess, module_spec, 5,
jpeg_data_tensor, decoded_image_tensor,
resized_image_tensor,bottleneck_tensor)
tf.logging.info('Save final result to : ' + saved_model_path)
if wants_quantization:
tf.logging.info('The model is instrumented for quantization with TF-Lite')
save_graph_to_file(graph, saved_model_path, module_spec, 5)
with tf.gfile.FastGFile(output_label_path, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
export_model(module_spec, 5)
案例小結(jié)
這個案例來自Tensorflow官方教程,之前兩個相對都比較簡單幔戏,代碼量只有100行左右玛追,這個案例官方原代碼突然有1300行之多,大有才學(xué)了十以內(nèi)加減法然后就講微積分方程的感覺闲延。
這里整個案例去掉了很多官方代碼中我認(rèn)為無關(guān)緊要的部分痊剖,仍然有600多行,如果有時間我還會在整理這個案例垒玲,希望能只保留關(guān)鍵流程代碼陆馁,兩三百行不能再多了。
已經(jīng)讀到這里的用戶實屬難得侍匙,如果遇到困難氮惯,請從百度網(wǎng)盤下載(密碼:lzjg)直接下載final.py文件使用。請注意文件讀寫權(quán)限想暗,每次運行前請刪除saved_model文件夾。
探索人工智能的新邊界
如果您發(fā)現(xiàn)文章錯誤帘不,請不吝留言指正说莫;
如果您覺得有用,請點喜歡寞焙;
如果您覺得很有用储狭,感謝轉(zhuǎn)發(fā)~
END