TensorFlow 自定義模型導(dǎo)出:將 .ckpt 格式轉(zhuǎn)化為 .pb 格式

????????本文承接上文 TensorFlow-slim 訓(xùn)練 CNN 分類模型(續(xù)),闡述通過(guò) tf.contrib.slim 的函數(shù) slim.learning.train 訓(xùn)練的模型搪哪,怎么通過(guò)人為的加入數(shù)據(jù)入口(即占位符)來(lái)克服無(wú)法用于圖像推斷的問(wèn)題猾愿。要解決這個(gè)問(wèn)題嗅榕,最簡(jiǎn)單和最省時(shí)的方法是模仿窥突。我們模仿的代碼是 TensorFlow 實(shí)現(xiàn)的目標(biāo)檢測(cè) API 中的文件 exporter.py,該文件的目的正是要將 TensorFlow-slim 訓(xùn)練的目標(biāo)檢測(cè)模型由 .ckpt 格式轉(zhuǎn)化為.pb 格式今布,而且其代碼中人為添加占位符的操作也正是我們需求的经备。坦白的說(shuō)拭抬,我會(huì)用 TensorFlow 的 tf.contrib.slim 模塊來(lái)構(gòu)建和訓(xùn)練模型正是受 TensorFlow models 項(xiàng)目的影響,當(dāng)時(shí)我需要訓(xùn)練目標(biāo)檢測(cè)器侵蒙,因此變配置了 models 這個(gè)子項(xiàng)目造虎,并且從頭到尾的閱讀了其中 object_detection 中的 Faster RCNN 的源代碼,切實(shí)感受到了 slim 模塊的簡(jiǎn)便和高效(學(xué)習(xí) TensorFlow 最好的辦法除了查閱文檔之外纷闺,便是看 models 中各種項(xiàng)目的源代碼)算凿。

????????言歸正傳,現(xiàn)在我們回到主題犁功,怎么加入占位符氓轰,將前一篇文章訓(xùn)練的 CNN 分類器用于圖像分類。這個(gè)問(wèn)題在我們知道通過(guò)模仿 exporter.py 就可以解決它的時(shí)候浸卦,就變得異常簡(jiǎn)單了署鸡。我們先來(lái)理順一下解決這個(gè)問(wèn)題的邏輯:

1.定義數(shù)據(jù)入口,即定義占位符 inputs = tf.placeholder(···)镐躲;
2.將模型作用于占位符储玫,得到數(shù)據(jù)出口侍筛,即分類結(jié)果萤皂;
3.將訓(xùn)練文件從 .ckpt 格式轉(zhuǎn)化為 .pb 格式。

按照這個(gè)邏輯順序匣椰,下面我們?cè)敿?xì)的來(lái)看一下自定義模型導(dǎo)出裆熙,即模型格式轉(zhuǎn)化的代碼(命名為 exporter.py,如果沒有特別說(shuō)明禽笑,exporter.py 指的都是我們修改 TensorFlow 目標(biāo)檢測(cè)中的 exporter.py 后的自定義文件):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""

"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""

import logging
import os
import tempfile
import tensorflow as tf

from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib

slim = tf.contrib.slim


# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
    input_graph_def,
    input_saver_def,
    input_checkpoint,
    output_node_names,
    restore_op_name,
    filename_tensor_name,
    clear_devices,
    initializer_nodes,
    variable_names_blacklist=''):
    """Converts all variables in a graph and checkpoint into constants."""
    del restore_op_name, filename_tensor_name  # Unused by updated loading code.
    
    # 'input_checkpoint' may be a prefix if we're using Saver V2 format
    if not saver_lib.checkpoint_exists(input_checkpoint):
        raise ValueError(
            "Input checkpoint ' + input_checkpoint + ' does not exist!")
        
    if not output_node_names:
        raise ValueError(
            'You must supply the name of a node to --output_node_names.')
        
    # Remove all the explicit device specifications for this node. This helps
    # to make the graph more portable.
    if clear_devices:
        for node in input_graph_def.node:
            node.device = ''
    
    with tf.Graph().as_default():
        tf.import_graph_def(input_graph_def, name='')
        config = tf.ConfigProto(graph_options=tf.GraphOptions())
        with session.Session(config=config) as sess:
            if input_saver_def:
                saver = saver_lib.Saver(saver_def=input_saver_def)
                saver.restore(sess, input_checkpoint)
            else:
                var_list = {}
                reader = pywrap_tensorflow.NewCheckpointReader(
                    input_checkpoint)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    try:
                        tensor = sess.graph.get_tensor_by_name(key + ':0')
                    except KeyError:
                        # This tensor doesn't exist in the graph (for example
                        # it's 'global_step' or a similar housekeeping element)
                        # so skip it.
                        continue
                    var_list[key] = tensor
                saver = saver_lib.Saver(var_list=var_list)
                saver.restore(sess, input_checkpoint)
                if initializer_nodes:
                    sess.run(initializer_nodes)
            
            variable_names_blacklist = (variable_names_blacklist.split(',') if
                                        variable_names_blacklist else None)
            output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                output_node_names.split(','),
                variable_names_blacklist=variable_names_blacklist)
    return output_graph_def


def replace_variable_values_with_moving_averages(graph,
                                                 current_checkpoint_file,
                                                 new_checkpoint_file):
    """Replaces variable values in the checkpoint with their moving averages.
    
    If the current checkpoint has shadow variables maintaining moving averages
    of the variables defined in the graph, this function generates a new
    checkpoint where the variables contain the values of their moving averages.
    
    Args:
        graph: A tf.Graph object.
        current_checkpoint_file: A checkpoint both original variables and
            their moving averages.
        new_checkpoint_file: File path to write a new checkpoint.
    """
    with graph.as_default():
        variable_averages = tf.train.ExponentialMovingAverage(0.0)
        ema_variables_to_restore = variable_averages.variables_to_restore()
        with tf.Session() as sess:
            read_saver = tf.train.Saver(ema_variables_to_restore)
            read_saver.restore(sess, current_checkpoint_file)
            write_saver = tf.train.Saver()
            write_saver.save(sess, new_checkpoint_file)


def _image_tensor_input_placeholder(input_shape=None):
    """Returns input placeholder and a 4-D uint8 image tensor."""
    if input_shape is None:
        input_shape = (None, None, None, 3)
    input_tensor = tf.placeholder(
        dtype=tf.uint8, shape=input_shape, name='image_tensor')
    return input_tensor, input_tensor


def _encoded_image_string_tensor_input_placeholder():
    """Returns input that accepts a batch of PNG or JPEG strings.
    
    Returns:
        A tuple of input placeholder and the output decoded images.
    """
    batch_image_str_placeholder = tf.placeholder(
        dtype=tf.string,
        shape=[None],
        name='encoded_image_string_tensor')
    def decode(encoded_image_string_tensor):
        image_tensor = tf.image.decode_image(encoded_image_string_tensor,
                                             channels=3)
        image_tensor.set_shape((None, None, 3))
        return image_tensor
    return (batch_image_str_placeholder,
            tf.map_fn(
                decode,
                elems=batch_image_str_placeholder,
                dtype=tf.uint8,
                parallel_iterations=32,
                back_prop=False))


input_placeholder_fn_map = {
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
        _encoded_image_string_tensor_input_placeholder,
#    'tf_example': _tf_example_input_placeholder,
    }


def _add_output_tensor_nodes(postprocessed_tensors,
                             output_collection_name='inference_op'):
    """Adds output nodes.
    
    Adjust according to specified implementations.
    
    Adds the following nodes for output tensors:
        * classes: A float32 tensor of shape [batch_size] containing class
            predictions.
    
    Args:
        postprocessed_tensors: A dictionary containing the following fields:
            'classes': [batch_size].
        output_collection_name: Name of collection to add output tensors to.
        
    Returns:
        A tensor dict containing the added output tensor nodes.
    """
    outputs = {}
    classes = postprocessed_tensors.get('classes') # Assume containing 'classes'
    outputs['classes'] = tf.identity(classes, name='classes')
    for output_key in outputs:
        tf.add_to_collection(output_collection_name, outputs[output_key])
    return outputs


def write_frozen_graph(frozen_graph_path, frozen_graph_def):
    """Writes frozen graph to disk.
    
    Args:
        frozen_graph_path: Path to write inference graph.
        frozen_graph_def: tf.GraphDef holding frozen graph.
    """
    with gfile.GFile(frozen_graph_path, 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())
    logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
    
    
def write_saved_model(saved_model_path,
                      frozen_graph_def,
                      inputs,
                      outputs):
    """Writes SavedModel to disk.
    
    If checkpoint_path is not None bakes the weights into the graph thereby
    eliminating the need of checkpoint files during inference. If the model
    was trained with moving averages, setting use_moving_averages to True
    restores the moving averages, otherwise the original set of variables
    is restored.
    
    Args:
        saved_model_path: Path to write SavedModel.
        frozen_graph_def: tf.GraphDef holding frozen graph.
        inputs: The input image tensor.
        outputs: A tensor dictionary containing the outputs of a slim model.
    """
    with tf.Graph().as_default():
        with session.Session() as sess:
            tf.import_graph_def(frozen_graph_def, name='')
            
            builder = tf.saved_model.builder.SavedModelBuilder(
                saved_model_path)
            
            tensor_info_inputs = {
                'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
            tensor_info_outputs = {}
            for k, v in outputs.items():
                tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(
                    v)
                
            detection_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs=tensor_info_inputs,
                    outputs=tensor_info_outputs,
                    method_name=signature_constants.PREDICT_METHOD_NAME))
            
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        detection_signature,
                        },
            )
            builder.save()


def write_graph_and_checkpoint(inference_graph_def,
                               model_path,
                               input_saver_def,
                               trained_checkpoint_prefix):
    """Writes the graph and the checkpoint into disk."""
    for node in inference_graph_def.node:
        node.device = ''
    with tf.Graph().as_default():
        tf.import_graph_def(inference_graph_def, name='')
        with session.Session() as sess:
            saver = saver_lib.Saver(saver_def=input_saver_def,
                                    save_relative_paths=True)
            saver.restore(sess, trained_checkpoint_prefix)
            saver.save(sess, model_path)


def _get_outputs_from_inputs(input_tensors, model, 
                             output_collection_name):
    inputs = tf.to_float(input_tensors)
    preprocessed_inputs = model.preprocess(inputs)
    output_tensors = model.predict(preprocessed_inputs)
    postprocessed_tensors = model.postprocess(output_tensors)
    return _add_output_tensor_nodes(postprocessed_tensors,
                                    output_collection_name)
    
    
def _build_model_graph(input_type, model, input_shape, 
                           output_collection_name, graph_hook_fn):
    """Build the desired graph."""
    if input_type not in input_placeholder_fn_map:
        raise ValueError('Unknown input type: {}'.format(input_type))
    placeholder_args = {}
    if input_shape is not None:
        if input_type != 'image_tensor':
            raise ValueError("Can only specify input shape for 'image_tensor' "
                             'inputs.')
        placeholder_args['input_shape'] = input_shape
    placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type](
        **placeholder_args)
    outputs = _get_outputs_from_inputs(
        input_tensors=input_tensors,
        model=model,
        output_collection_name=output_collection_name)
    
    # Add global step to the graph
    slim.get_or_create_global_step()
    
    if graph_hook_fn: graph_hook_fn()
    
    return outputs, placeholder_tensor


def export_inference_graph(input_type,
                           model,
                           trained_checkpoint_prefix,
                           output_directory,
                           input_shape=None,
                           use_moving_averages=None,
                           output_collection_name='inference_op',
                           additional_output_tensor_names=None,
                           graph_hook_fn=None):
    """Exports inference graph for the desired graph.
    
    Args:
        input_type: Type of input for the graph. Can be one of ['image_tensor',
            'encoded_image_string_tensor', 'tf_example']. In this file, 
            input_type must be 'image_tensor'.
        model: A model defined by model.py.
        trained_checkpoint_prefix: Path to the trained checkpoint file.
        output_directory: Path to write outputs.
        input_shape: Sets a fixed shape for an 'image_tensor' input. If not
            specified, will default to [None, None, None, 3].
        use_moving_averages: A boolean indicating whether the 
            tf.train.ExponentialMovingAverage should be used or not.
        output_collection_name: Name of collection to add output tensors to.
            If None, does not add output tensors to a collection.
        additional_output_tensor_names: List of additional output tensors to
            include in the frozen graph.
    """
    tf.gfile.MakeDirs(output_directory)
    frozen_graph_path = os.path.join(output_directory,
                                     'frozen_inference_graph.pb')
    saved_model_path = os.path.join(output_directory, 'saved_model')
    model_path = os.path.join(output_directory, 'model.ckpt')
    
    outputs, placeholder_tensor = _build_model_graph(
        input_type=input_type,
        model=model,
        input_shape=input_shape,
        output_collection_name=output_collection_name,
        graph_hook_fn=graph_hook_fn)
    
    saver_kwargs = {}
    if use_moving_averages:
        # This check is to be compatible with both version of SaverDef.
        if os.path.isfile(trained_checkpoint_prefix):
            saver_kwargs['write_version'] = saver_pb2.SaverDef.V1
            temp_checkpoint_prefix = tempfile.NamedTemporaryFile().name
        else:
            temp_checkpoint_prefix = tempfile.mkdtemp()
        replace_variable_values_with_moving_averages(
            tf.get_default_graph(), trained_checkpoint_prefix,
            temp_checkpoint_prefix)
        checkpoint_to_use = temp_checkpoint_prefix
    else:
        checkpoint_to_use = trained_checkpoint_prefix
    
    saver = tf.train.Saver(**saver_kwargs)
    input_saver_def = saver.as_saver_def()
    
    write_graph_and_checkpoint(
        inference_graph_def=tf.get_default_graph().as_graph_def(),
        model_path=model_path,
        input_saver_def=input_saver_def,
        trained_checkpoint_prefix=checkpoint_to_use)
    
    if additional_output_tensor_names is not None:
        output_node_names = ','.join(outputs.keys()+
                                     additional_output_tensor_names)
    else:
        output_node_names = ','.join(outputs.keys())
        
    frozen_graph_def = freeze_graph_with_def_protos(
        input_graph_def=tf.get_default_graph().as_graph_def(),
        input_saver_def=input_saver_def,
        input_checkpoint=checkpoint_to_use,
        output_node_names=output_node_names,
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        clear_devices=True,
        initializer_nodes='')
    write_frozen_graph(frozen_graph_path, frozen_graph_def)
    write_saved_model(saved_model_path, frozen_graph_def,
                      placeholder_tensor, outputs)

首先看定義占位符的函數(shù) _image_tensor_input_placeholder_encoded_image_string_tensor_input_placeholder 入录,重點(diǎn)關(guān)注前一個(gè)函數(shù),因?yàn)樗妮斎霝橐粋€(gè)批量圖像組成的 4 維張量(正是我們需要的)佳镜,這個(gè)函數(shù)僅僅定義了一個(gè)圖像占位符 input_tensor

input_tensor = tf.placeholder(dtype=tf.uint8, shape=input_shape, name='image_tensor')

簡(jiǎn)單至極僚稿。接下來(lái)看 _build_model_graph 函數(shù),這個(gè)函數(shù)將數(shù)據(jù)輸入 input_tensor (第一個(gè)參數(shù))通過(guò)模型 model (第二個(gè)參數(shù))作用的結(jié)果 outputs 返回蟀伸。其中引用的函數(shù) _get_outputs_from_inputs蚀同,顧名思義,由輸入數(shù)據(jù)得到分類結(jié)果啊掏。它又引用了函數(shù) _add_output_tensor_nodes蠢络,這個(gè)函數(shù)比較重要,因?yàn)樗x了數(shù)據(jù)輸出結(jié)點(diǎn)

outputs['classes'] = tf.identity(classes, name='classes')

以上這些便是這個(gè)自定義文件 exporter.py 的精華迟蜜,因?yàn)樗鼘?shí)現(xiàn)了數(shù)據(jù)入口(name='image_tensor')和出口(name='classes')結(jié)點(diǎn)的定義刹孔。另一方面,這個(gè)自定義文件 exporter.py 可以作為模型導(dǎo)出的通用文件娜睛,而針對(duì)每一個(gè)特定的模型我們只需要修改與參數(shù) model(表示某個(gè)特定模型) 相關(guān)的函數(shù)即可髓霞,而所有這些函數(shù)就是以上列出的函數(shù)卦睹。

????????為了描述的完整性,也來(lái)看一看剩下的不需要修改的函數(shù)方库。我們從主函數(shù) export_inference_graph 開始分预,它是實(shí)際被調(diào)用的函數(shù)。它首先創(chuàng)建了用于保存輸出文件的文件夾薪捍,然后根據(jù)參數(shù) model 創(chuàng)建了模型數(shù)據(jù)入口和出口笼痹,接下來(lái)的 if 語(yǔ)句是說(shuō),如果使用移動(dòng)平均酪穿,則將原始 graph 中的變量用它的移動(dòng)平均值來(lái)替換(函數(shù) replace_variable_values_with_moving_averages)凳干。再下來(lái)的 write_graph_and_checkpoint 函數(shù)相當(dāng)于將上一篇文章的訓(xùn)練輸出文件復(fù)制到當(dāng)前指定的輸出路徑 output_directory,最后的函數(shù) freeze_graph_with_def_protosgraph 中的變量變成常量被济,然后通過(guò)函數(shù) write_frozen_graph 和函數(shù) write_saved_model 寫出到輸出路徑救赐。

????????最后來(lái)解釋一下函數(shù)

export_inference_graph(input_type,
                       model,
                       trained_checkpoint_prefix,
                       output_directory,
                       input_shape=None,
                       use_moving_averages=None,
                       output_collection_name='inference_op',
                       additional_output_tensor_names=None,
                       graph_hook_fn=None)

的各個(gè)參數(shù):1.input_type,指的是輸入數(shù)據(jù)的類型只磷,exporter.py 指定了只能從以下的字典中

input_placeholder_fn_map = {
    'image_tensor': _image_tensor_input_placeholder,
    'encoded_image_string_tensor':
        _encoded_image_string_tensor_input_placeholder,
#    'tf_example': _tf_example_input_placeholder,
    }

選出其中一種经磅,一般我們選擇圖像作為輸入,即 image_tensor钮追;2.model坪稽,指的是自己構(gòu)建的模型,是一個(gè)類對(duì)象畦韭,如上一篇文章定義的 Model 類的一個(gè)實(shí)例:

cls_model = model.Model(is_training=False, num_classes=10)

3.trained_checkpoint_prefix弯予,指定要導(dǎo)出的 .ckpt 文件路徑;4.output_directory刊棕,指定導(dǎo)出文件的存儲(chǔ)路徑(是一個(gè)文件夾)炭晒;5.input_shape,輸入數(shù)據(jù)的形狀甥角,缺省時(shí)為 [None, None, None, 3]网严;6.use_moving_average,是否使用移動(dòng)平均嗤无;7.output_collection_name震束,輸出的 collection 名,直接使用默認(rèn)名翁巍,不需要修改驴一;8.additional_output_tensor_names,指定額外的輸出張量名灶壶;9.graph_hook_fn肝断,意義不明,暫時(shí)不知道它的表示意義。

????????實(shí)際調(diào)用的時(shí)候胸懈,我們一般只需要指定前四個(gè)參數(shù)担扑,如(命名為 export_inference_graph.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:46:16 2018
@author: shirhe-lyh
"""

"""Tool to export a model for inference.
Outputs inference graph, asscociated checkpoint files, a frozen inference
graph and a SavedModel (https://tensorflow.github.io/serving_basic.html).
The inference graph contains one of three input nodes depending on the user
specified option.
    * 'image_tensor': Accepts a uint8 4-D tensor of shape [None, None, None, 3]
    * 'encoded_image_string_tensor': Accepts a 1-D string tensor of shape 
        [None] containg encoded PNG or JPEG images.
    * 'tf_example': Accepts a 1-D string tensor of shape [None] containing
        serialized TFExample protos.
        
and the following output nodes returned by the model.postprocess(..):
    * 'classes': Outputs float32 tensors of the form [batch_size] containing
        the classes for the predictions.
        
Example Usage:
---------------
python/python3 export_inference_graph \
    --input_type image_tensor \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory
    
The exported output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
    - model.ckpt.data-00000-of-00001
    - model.ckpt.info
    - model.ckpt.meta
    - frozen_inference_graph.pb
    + saved_model (a directory)
"""
import tensorflow as tf

import exporter
import model

slim = tf.contrib.slim
flags = tf.app.flags

flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can '
                    "be one of ['image_tensor', 'encoded_image_string_tensor'"
                    ", 'tf_example']")
flags.DEFINE_string('input_shape', None, "If input_type is 'image_tensor', "
                    "this can be explicitly set the shape of this input "
                    "to a fixed size. The dimensions are to be provided as a "
                    "comma-seperated list of integers. A value of -1 can be "
                    "used for unknown dimensions. If not specified, for an "
                    "'image_tensor', the default shape will be partially "
                    "specified as '[None, None, None, 3]'.")
flags.DEFINE_string('trained_checkpoint_prefix', None,
                    'Path to trained checkpoint, typically of the form '
                    'path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGS


def main(_):
    cls_model = model.Model(is_training=False, num_classes=10)
    if FLAGS.input_shape:
        input_shape = [
            int(dim) if dim != -1 else None 
            for dim in FLAGS.input_shape.split(',')
        ]
    else:
        input_shape = [None, 28, 28, 3]
    exporter.export_inference_graph(FLAGS.input_type,
                                    cls_model,
                                    FLAGS.trained_checkpoint_prefix,
                                    FLAGS.output_directory,
                                    input_shape)
    

if __name__ == '__main__':
    tf.app.run()

在終端運(yùn)行命令:

python3 export_inference_graph.py \
    --trained_checkpoint_prefix path/to/.ckpt-xxxx \
    --output_directory path/to/output/directory

很快會(huì)在 output_directory 指定的文件夾中生成一系列文件,其中的 frozen_inference_graph.pb 便是我們需要的最終用于推斷的文件趣钱。至于如何讀取 .pb 文件用于推斷涌献,則可以訪問(wèn)這個(gè)系列的文章 TensorFlow 模型保存與恢復(fù) 的第二部分。為了方便閱讀首有,我們承接上一篇文章燕垃,使用如下代碼來(lái)對(duì)訓(xùn)練的模型進(jìn)行驗(yàn)證:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 14:02:05 2018
@author: shirhe-lyh
"""

"""Evaluate the trained CNN model.
Example Usage:
---------------
python3 evaluate.py \
    --frozen_graph_path: Path to model frozen graph.
"""

import numpy as np
import tensorflow as tf

from captcha.image import ImageCaptcha

flags = tf.app.flags
flags.DEFINE_string('frozen_graph_path', None, 'Path to model frozen graph.')
FLAGS = flags.FLAGS


def generate_captcha(text='1'):
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


def main(_):
    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(FLAGS.frozen_graph_path, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    
    with model_graph.as_default():
        with tf.Session(graph=model_graph) as sess:
            inputs = model_graph.get_tensor_by_name('image_tensor:0')
            classes = model_graph.get_tensor_by_name('classes:0')
            for i in range(10):
                label = np.random.randint(0, 10)
                image = generate_captcha(str(label))
                image_np = np.expand_dims(image, axis=0)
                predicted_label = sess.run(classes, 
                                           feed_dict={inputs: image_np})
                print(predicted_label, ' vs ', label)
            
            
if __name__ == '__main__':
    tf.app.run()

簡(jiǎn)單運(yùn)行:

python3 evaluate.py --frozen_graph_path path/to/frozen_inference_graph.pb

可以看到驗(yàn)證結(jié)果。

????????本文(及前文)的所有代碼都在 github: slim_cnn_test井联,歡迎訪問(wèn)并下載卜壕。

預(yù)告:下一篇文章將介紹 TensorFlow 如何使用預(yù)訓(xùn)練文件來(lái)精調(diào)分類模型。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末烙常,一起剝皮案震驚了整個(gè)濱河市轴捎,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌蚕脏,老刑警劉巖侦副,帶你破解...
    沈念sama閱讀 218,607評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異驼鞭,居然都是意外死亡秦驯,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,239評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門终议,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)汇竭,“玉大人葱蝗,你說(shuō)我怎么就攤上這事穴张。” “怎么了两曼?”我有些...
    開封第一講書人閱讀 164,960評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵皂甘,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我悼凑,道長(zhǎng)偿枕,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,750評(píng)論 1 294
  • 正文 為了忘掉前任户辫,我火速辦了婚禮渐夸,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘渔欢。我一直安慰自己墓塌,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,764評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著苫幢,像睡著了一般访诱。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上韩肝,一...
    開封第一講書人閱讀 51,604評(píng)論 1 305
  • 那天触菜,我揣著相機(jī)與錄音,去河邊找鬼哀峻。 笑死涡相,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的剩蟀。 我是一名探鬼主播漾峡,決...
    沈念sama閱讀 40,347評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼喻旷!你這毒婦竟也來(lái)了生逸?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,253評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤且预,失蹤者是張志新(化名)和其女友劉穎槽袄,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體锋谐,經(jīng)...
    沈念sama閱讀 45,702評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡遍尺,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,893評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了涮拗。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片乾戏。...
    茶點(diǎn)故事閱讀 40,015評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖三热,靈堂內(nèi)的尸體忽然破棺而出鼓择,到底是詐尸還是另有隱情,我是刑警寧澤就漾,帶...
    沈念sama閱讀 35,734評(píng)論 5 346
  • 正文 年R本政府宣布呐能,位于F島的核電站,受9級(jí)特大地震影響抑堡,放射性物質(zhì)發(fā)生泄漏摆出。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,352評(píng)論 3 330
  • 文/蒙蒙 一首妖、第九天 我趴在偏房一處隱蔽的房頂上張望偎漫。 院中可真熱鬧,春花似錦有缆、人聲如沸象踊。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,934評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)通危。三九已至铸豁,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間菊碟,已是汗流浹背节芥。 一陣腳步聲響...
    開封第一講書人閱讀 33,052評(píng)論 1 270
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留逆害,地道東北人头镊。 一個(gè)月前我還...
    沈念sama閱讀 48,216評(píng)論 3 371
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像魄幕,于是被迫代替她去往敵國(guó)和親相艇。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,969評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容

  • 夜深了纯陨,卻并未寂靜坛芽,客廳里的麻將聲聲刺激著我的神經(jīng),不好的性格翼抠,讓我但凡有一點(diǎn)的吵鬧聲就不能入睡咙轩,盡管精神早已疲憊...
    余夢(mèng)人生閱讀 506評(píng)論 0 0
  • 除了做好本質(zhì)工作,在工作之外偎肃,你都做過(guò)那些努力煞烫?可是在大數(shù)的時(shí)候,我們只要下班了软棺,那就是解放了红竭,誰(shuí)還把自己...
    不二豆閱讀 461評(píng)論 2 2