一種工程級(jí)方便的存取模型的方法,saved_model
通過(guò)存取一個(gè)簡(jiǎn)單的模型來(lái)作為示范
首先是模型定義
import tensorflow as tf
import numpy as np
W = tf.get_variable(name="demo", initializer=tf.ones([10, 32],dtype=tf.float32))
x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
y = tf.matmul(x, W)
y_ = np.ones(shape=[10, 32], dtype=np.float32) # 使用np來(lái)創(chuàng)造兩個(gè)label
cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=y_, name=None)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
這里定義了一個(gè)簡(jiǎn)單的矩陣乘枢泰, 然后我們來(lái)簡(jiǎn)單的訓(xùn)練幾步
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
feed_dict = {x: np.ones([10, 10])}
for i in range(100):
sess.run(train_op, feed_dict=feed_dict)
print(sess.run(y, feed_dict=feed_dict))
現(xiàn)在我們想把這個(gè)模型存儲(chǔ)起來(lái)危号,傳統(tǒng)的做法是用ckpt來(lái)做今野,現(xiàn)在tensorflow提供一種更強(qiáng)大簡(jiǎn)便的方法
首先構(gòu)建兩個(gè)字典干奢,inputs 和 outputs, 把要存入的變量放入字典中
其中 tf.saved_model.utils.build_tensor_info是把變量變成可緩存對(duì)象的函數(shù)
saved_model_dir = "save_model"
signature_key = 'test_signature'
input_key = 'input_x'
output_key = 'output'
# x 為輸入tensor
inputs = {input_key: tf.saved_model.utils.build_tensor_info(x)}
# y 為最終需要的輸出結(jié)果tensor
outputs = {output_key: tf.saved_model.utils.build_tensor_info(y)}
然后把兩個(gè)字典打包放入 signature 中
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=signature_key)
然后建立SavedModelBuilder另患,并以signature的形式添加要存儲(chǔ)的變量
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
sess=sess,
tags=['test_saved_model'],
signature_def_map={signature_key: signature},
clear_devices=True)
builder.save()
saved_model_dir 是要存模型的文件夾,可以是一個(gè)不存在的目錄名芳绩,save之后掀亥,包括圖結(jié)構(gòu),變量的內(nèi)容妥色,都會(huì)被存入到新創(chuàng)建的 saved_model_dir 目錄內(nèi)搪花,下圖就是存好的模型
下面我們來(lái)取出一個(gè)訓(xùn)練好的模型
用 tf.saved_model.loader.load 從 模型文件夾中取出模型
其中tags字段是['test_saved_model'], 與存模型時(shí)候指定的字段相同
把模型導(dǎo)入到session之后嘹害, 取出signature 就從signature中取出存入的變量了
saved_model_dir = "save_model"
signature_key = 'test_signature'
input_key = 'input_x'
output_key = 'output'
with tf.Session() as sess1:
meta_graph_def = tf.saved_model.loader.load(sess1, ['test_saved_model'], saved_model_dir)
signature = meta_graph_def.signature_def
x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name
print(x_tensor_name)
print(y_tensor_name)
x = sess1.graph.get_tensor_by_name(x_tensor_name)
y = sess1.graph.get_tensor_by_name(y_tensor_name)
feed_dict = {x: np.ones([1, 10])}
print(sess1.run(y, feed_dict=feed_dict))
我們看到撮竿,首先我們從signature 和 inputs/outputs都是一種字典的封裝,把tensor_name存入到了字典中
傳統(tǒng)的導(dǎo)入 需要用get_tensor_by_name 吼拥, 這樣就需要記錄tensor的name熟悉倚聚,很麻煩。
通過(guò)signature凿可,我們可以指定變量的別名惑折,方便存取。
另外枯跑,存模型和變量的時(shí)候惨驶,會(huì)把全部的模型圖存入,并不是只存我們指定幾個(gè)變量敛助,而signature只是方便我們存取想要使用的變量粗卜。
一個(gè)坑,使用tf.Session的時(shí)候纳击,切記默認(rèn)圖和指定圖的區(qū)別续扔。tf.Session()會(huì)導(dǎo)入默認(rèn)圖的結(jié)構(gòu), 而導(dǎo)入模型是需要依附于sess的圖焕数, 在默認(rèn)圖中導(dǎo)入模型纱昧,如果默認(rèn)圖定義了其他計(jì)算圖,會(huì)導(dǎo)致圖沖突堡赔,模型導(dǎo)不進(jìn)去识脆。!!灼捂!