Tensorflow-花分類-圖像再訓(xùn)練-part-4-整理翻譯

繼續(xù)前面的三篇文章Part-1之拨,Part-2茉继,Part-3,這一篇我們來完善存儲和恢復(fù)機(jī)制蚀乔。


把計算圖保存到文件save_graph_to_file

下面是增加的代碼烁竭,先不要運行,稍后一起測試:

#將圖保存到文件,必要時創(chuàng)建允許的量子化    
def save_graph_to_file(graph, graph_file_name, module_spec, class_count):
    sess, _, _, _, _, _ = build_eval_session(module_spec, class_count)
    graph = sess.graph

    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), ['final_tensor_name'])

    with tf.gfile.FastGFile(graph_file_name, 'wb') as f:
        f.write(output_graph_def.SerializeToString())  

保存評估模型export_model

注意每次使用前必須把舊的saved_model文件夾刪除或改名吉挣。

#導(dǎo)出評估eval圖的模型pd文件用于提供服務(wù)
saved_model_dir=os.path.join(dir_path,'saved_model'+str(datetime.now()))
def export_model(module_spec, class_count):
    sess, in_image, _, _, _, _ = build_eval_session(module_spec, class_count)
    graph = sess.graph
    with graph.as_default():
        #輸入輸出點
        inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
        out_classes = sess.graph.get_tensor_by_name('final_tensor_name:0')
        outputs = {
            'prediction': tf.saved_model.utils.build_tensor_info(out_classes)
        }
        #創(chuàng)建簽名
        signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs,
            outputs=outputs,
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

        #初始化
        legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

        #保存saved_model
        builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:signature
            },
            legacy_init_op=legacy_init_op)
    builder.save()


改進(jìn)最終再訓(xùn)練函數(shù)

最后我們把run_final_retrain函數(shù)改進(jìn)一下派撕,增加評估eval和保存、導(dǎo)出功能睬魂。

注意每次使用前必須把舊的saved_model文件夾刪除或改名终吼。

以下是修改后的代碼(這里參數(shù)train_steps=10,所以得到的模型精度也非常糟糕氯哮。如果您的計算機(jī)允許际跪,官方默認(rèn)是4000,請量力而為):

#保存概要和checkpoint路徑設(shè)置
CHECKPOINT_NAME = os.path.join(dir_path,'checkpoints/retrain')
summaries_dir=os.path.join(dir_path,'summaries/train')
ensure_dir_exists(os.path.join(dir_path,'output')) 
saved_model_path=os.path.join(dir_path,'output/out_graph.pd')
output_label_path=os.path.join(dir_path,'output/labels.txt')

#執(zhí)行訓(xùn)練兵保存checkpoint的函數(shù)
def run_final_retrain(train_steps=10,
             eval_step_interval=5,
             do_distort=True):
    module_spec = hub.load_module_spec(HUB_MODULE)
    
    #創(chuàng)建圖并獲取相關(guān)的張量入口
    graph, bottleneck_tensor, resized_image_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with graph.as_default(): 
        #添加訓(xùn)練相關(guān)的張量和操作節(jié)點入口
        (train_step, cross_entropy, bottleneck_input,ground_truth_input,
         final_tensor) = add_final_retrain_ops(5, 'final_tensor_name', 
                                               bottleneck_tensor,wants_quantization,True)    

    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        
        #添加圖片解碼相關(guān)的張量入口操作
        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
        
        #讀取圖片的bottleneck數(shù)據(jù)
        if do_distort:
            distorted_jpeg_data_tensor,distorted_image_tensor=add_input_distortions(module_spec,True,50,50,50)
        else:
            cache_bottlenecks(sess, 
                              jpeg_data_tensor,decoded_image_tensor, 
                              resized_image_tensor,bottleneck_tensor)           
            
        #創(chuàng)建評估新層精度的操作
        evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
        
        #記錄概要信息與保存
        train_saver = tf.train.Saver()
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(summaries_dir+'/retrain',sess.graph)      
        validation_writer = tf.summary.FileWriter(summaries_dir + '/validation')        
        
        #開始運作!
        for i in range(train_steps):
            #獲取圖片bottleneck數(shù)據(jù)
            if do_distort:
                (train_bottlenecks,train_ground_truth) = get_random_distorted_bottlenecks(
                    sess,BATCH_SIZE, 'training',
                    distorted_jpeg_data_tensor,distorted_image_tensor, 
                    resized_image_tensor, bottleneck_tensor)
            else:
                (train_bottlenecks,train_ground_truth, _) = get_random_cached_bottlenecks(
                    sess, BATCH_SIZE, 'training',
                    jpeg_data_tensor,decoded_image_tensor, 
                    resized_image_tensor, bottleneck_tensor)
            
            #啟動訓(xùn)練
            train_summary, _ = sess.run(
                [merged, train_step],
                feed_dict={bottleneck_input: train_bottlenecks,
                           ground_truth_input: train_ground_truth})
            train_writer.add_summary(train_summary, i)
            
            #間隔性啟動評估
            is_last_step = (i + 1 == train_steps)
            if (i % eval_step_interval) == 0 or is_last_step:
                train_accuracy, cross_entropy_value = sess.run(
                    [evaluation_step, cross_entropy],
                    feed_dict={bottleneck_input: train_bottlenecks,
                               ground_truth_input: train_ground_truth})
                
                tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %(datetime.now(), i, train_accuracy * 100))
                tf.logging.info('%s: Step %d: Cross entropy = %f' %(datetime.now(), i, cross_entropy_value))
            
                #使用不同的bottleneck數(shù)據(jù)進(jìn)行評估
                validation_bottlenecks, validation_ground_truth, _ = (
                    get_random_cached_bottlenecks(
                        sess, 10, 'validation', 
                        jpeg_data_tensor,decoded_image_tensor, 
                        resized_image_tensor, bottleneck_tensor))
                #啟動評估姆打!
                validation_summary, validation_accuracy = sess.run(
                    [merged, evaluation_step],
                    feed_dict={bottleneck_input: validation_bottlenecks,
                               ground_truth_input: validation_ground_truth})
                
                validation_writer.add_summary(validation_summary, i)
                tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %(datetime.now(), i, validation_accuracy * 100,len(validation_bottlenecks)))            
            
        
            #間隔保存中介媒體文件良姆,為訓(xùn)練保存checkpoint
            if (i % eval_step_interval == 0 and i > 0):
                train_saver.save(sess, CHECKPOINT_NAME)
                intermediate_file_name = (os.path.join(dir_path + 'intermediate') + str(i) + '.pb')
                tf.logging.info('Save intermediate result to : '+intermediate_file_name)
                save_graph_to_file(graph, intermediate_file_name, module_spec,5)

        #保存模型    
        train_saver.save(sess, CHECKPOINT_NAME)        
        
        #執(zhí)行最終評估
        run_final_eval(sess, module_spec, 5,
                       jpeg_data_tensor, decoded_image_tensor,
                       resized_image_tensor,bottleneck_tensor)

        
        tf.logging.info('Save final result to : ' + saved_model_path)
        if wants_quantization:
            tf.logging.info('The model is instrumented for quantization with TF-Lite')
        save_graph_to_file(graph, saved_model_path, module_spec, 5)
        with tf.gfile.FastGFile(output_label_path, 'w') as f:
            f.write('\n'.join(image_lists.keys()) + '\n')
        export_model(module_spec, 5)

案例小結(jié)

這個案例來自Tensorflow官方教程,之前兩個相對都比較簡單幔戏,代碼量只有100行左右玛追,這個案例官方原代碼突然有1300行之多,大有才學(xué)了十以內(nèi)加減法然后就講微積分方程的感覺闲延。

這里整個案例去掉了很多官方代碼中我認(rèn)為無關(guān)緊要的部分痊剖,仍然有600多行,如果有時間我還會在整理這個案例垒玲,希望能只保留關(guān)鍵流程代碼陆馁,兩三百行不能再多了。

已經(jīng)讀到這里的用戶實屬難得侍匙,如果遇到困難氮惯,請從百度網(wǎng)盤下載(密碼:lzjg)直接下載final.py文件使用。請注意文件讀寫權(quán)限想暗,每次運行前請刪除saved_model文件夾。


探索人工智能的新邊界

如果您發(fā)現(xiàn)文章錯誤帘不,請不吝留言指正说莫;
如果您覺得有用,請點喜歡寞焙;
如果您覺得很有用储狭,感謝轉(zhuǎn)發(fā)~


END

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市捣郊,隨后出現(xiàn)的幾起案子辽狈,更是在濱河造成了極大的恐慌,老刑警劉巖呛牲,帶你破解...
    沈念sama閱讀 206,602評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件刮萌,死亡現(xiàn)場離奇詭異,居然都是意外死亡娘扩,警方通過查閱死者的電腦和手機(jī)着茸,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,442評論 2 382
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來琐旁,“玉大人涮阔,你說我怎么就攤上這事』遗梗” “怎么了敬特?”我有些...
    開封第一講書人閱讀 152,878評論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我伟阔,道長辣之,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,306評論 1 279
  • 正文 為了忘掉前任减俏,我火速辦了婚禮召烂,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘娃承。我一直安慰自己奏夫,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 64,330評論 5 373
  • 文/花漫 我一把揭開白布历筝。 她就那樣靜靜地躺著酗昼,像睡著了一般。 火紅的嫁衣襯著肌膚如雪梳猪。 梳的紋絲不亂的頭發(fā)上麻削,一...
    開封第一講書人閱讀 49,071評論 1 285
  • 那天,我揣著相機(jī)與錄音春弥,去河邊找鬼呛哟。 笑死,一個胖子當(dāng)著我的面吹牛匿沛,可吹牛的內(nèi)容都是我干的扫责。 我是一名探鬼主播,決...
    沈念sama閱讀 38,382評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼逃呼,長吁一口氣:“原來是場噩夢啊……” “哼鳖孤!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起抡笼,我...
    開封第一講書人閱讀 37,006評論 0 259
  • 序言:老撾萬榮一對情侶失蹤苏揣,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后推姻,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體平匈,經(jīng)...
    沈念sama閱讀 43,512評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 35,965評論 2 325
  • 正文 我和宋清朗相戀三年拾碌,在試婚紗的時候發(fā)現(xiàn)自己被綠了吐葱。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 38,094評論 1 333
  • 序言:一個原本活蹦亂跳的男人離奇死亡校翔,死狀恐怖弟跑,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情防症,我是刑警寧澤孟辑,帶...
    沈念sama閱讀 33,732評論 4 323
  • 正文 年R本政府宣布哎甲,位于F島的核電站,受9級特大地震影響饲嗽,放射性物質(zhì)發(fā)生泄漏炭玫。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 39,283評論 3 307
  • 文/蒙蒙 一貌虾、第九天 我趴在偏房一處隱蔽的房頂上張望吞加。 院中可真熱鬧,春花似錦尽狠、人聲如沸衔憨。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,286評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽践图。三九已至,卻和暖如春沉馆,著一層夾襖步出監(jiān)牢的瞬間码党,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,512評論 1 262
  • 我被黑心中介騙來泰國打工斥黑, 沒想到剛下飛機(jī)就差點兒被人妖公主榨干…… 1. 我叫王不留揖盘,地道東北人。 一個月前我還...
    沈念sama閱讀 45,536評論 2 354
  • 正文 我出身青樓锌奴,卻偏偏與公主長得像扣讼,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子缨叫,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,828評論 2 345

推薦閱讀更多精彩內(nèi)容