先上代碼
import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile
from tensorflow.python.lib.io import file_io
input_tensor_key = 'Placeholder:0'
def loadNpData(filename):
tensor_key_feed_dict = {}
#inputs = preprocess_inputs_arg_string(inputs_str)
data = np.load(file_io.FileIO(filename, mode='r'))
# When no key is specified for the input file.
# Check if npz file only contains a single numpy ndarray.
if isinstance(data, np.lib.npyio.NpzFile):
variable_name_list = data.files
if len(variable_name_list) != 1:
raise RuntimeError(
'Input file %s contains more than one ndarrays. Please specify '
'the name of ndarray to use.' % filename)
tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
else:
tensor_key_feed_dict[input_tensor_key] = data
return tensor_key_feed_dict
with tf.Session() as sess:
# 定義模型文件及樣本測(cè)試文件
model_filename = 'merge1_graph.pb'
example_png = 'examples.npy'
# 加載npy格式的圖片測(cè)試樣本數(shù)據(jù)
image_data = loadNpData(example_png)
#加載模型文件
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef();
graph_def.ParseFromString(f.read())
# 獲取輸入節(jié)點(diǎn)的tensor
inputs = sess.graph.get_tensor_by_name("Placeholder:0");
#打印輸入節(jié)點(diǎn)的信息
#print inputs
# 導(dǎo)入計(jì)算圖证杭,定義輸入節(jié)點(diǎn)及輸出節(jié)點(diǎn)
output = tf.import_graph_def(graph_def, input_map={'Placeholder:0':inputs}, return_elements=[ 'ArgMax:0','Softmax:0'])
# 打印輸出節(jié)點(diǎn)的信息
#print output
results = sess.run(output, feed_dict={inputs:image_data[input_tensor_key]})
print 'ArgMax result(預(yù)測(cè)結(jié)果對(duì)應(yīng)的標(biāo)簽值):'
print results[0]
print 'Softmax result(最后一層的輸出):'
print results[1]
# 輸出node詳細(xì)信息,此處默認(rèn)只打印第一個(gè)節(jié)點(diǎn)
for node in graph_def.node:
print node
break
運(yùn)行輸出
ArgMax result(預(yù)測(cè)結(jié)果對(duì)應(yīng)的標(biāo)簽值):
[3 3]
Softmax result(最后一層的輸出):
[[4.1668140e-12 9.0696268e-18 6.4261091e-13 9.9999940e-01 1.7161388e-30
5.4321697e-07 7.6357951e-09 6.3293229e-19 1.3812791e-13 1.5360580e-12]
[1.1472046e-05 3.3404951e-10 6.0365837e-09 9.9997592e-01 9.8635665e-15
5.7557719e-07 1.1977763e-05 1.6275100e-16 7.2288098e-10 5.0601763e-08]]
此處加載的關(guān)鍵在于tf.import_graph_def
函數(shù)的參數(shù)配置俭令,三個(gè)參數(shù)graph_def
input_map
return_elements
第一個(gè)參數(shù)是導(dǎo)入的圖
input_map
是指定輸入節(jié)點(diǎn)拂封,如果不指定换吧,后面run的時(shí)候會(huì)報(bào)錯(cuò) ==You must feed a value for placeholder tensor 'Placeholder'==
return_elements
是指定運(yùn)算后的輸出節(jié)點(diǎn)泌霍,此處就是我們想要得到的標(biāo)簽估計(jì)值 ArgMax
以及 最后一層節(jié)點(diǎn)輸出 Softmax
模型的測(cè)試參考 將Tensorflow模型導(dǎo)出為單個(gè)文件