通過tf.train.Saver類實現(xiàn)保存模型
import tensorflow as tf
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(1.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt")
通過tf.train.Saver類實現(xiàn)恢復(fù)模型
import tensorflow as tf
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(2.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, './Model/model.ckpt')
print(sess.run(c)) #[2.]
通過tf.train.Saver類實現(xiàn)恢復(fù)模型,支持在加載時給變量重命名
import tensorflow
d = tf.Variable(tf.constant(0.0, shape = [1]), name = 'dd')
e = tf.Variable(tf.constant(0.0, shape = [1]), name = 'ee')
c = tf.multiply(d, e)
saver = tf.train.Saver({'a':d, 'b':e})
sess = tf.Session()
saver.restore(sess, './Model/model.ckpt')
print(sess.run(c)) #[2.]
通過tf.train.Saver類實現(xiàn)恢復(fù)部分模型,比如不要Resnet50最后一層fc
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude = ['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)
sess = tf.Session()
saver.restore(sess, './Model/Resnet.ckpt')
通過tf.train.import_meta_graph()恢復(fù)模型,無需重復(fù)定義計算圖。
import tensorflow as tf
saver = tf.train.import_meta_graph("Model/model.meta")
sess = tf.Session()
saver.restore(sess, "./Model/model.ckpt")
graph = tf.get_default_graph()
c = graph.get_tensor_by_name("c:0")
print(sess.run(c)) #[2.]
還能在之前模型的基礎(chǔ)上增加自己的計算圖層蜗侈。們用meta圖導(dǎo)入了一個預(yù)訓(xùn)練的TextCNN網(wǎng)絡(luò)涣澡,然后將最后一層的輸出個數(shù)改成2用于微調(diào)新的數(shù)據(jù)
import tensorflow as tf
saver = tf.train.import_meta_graph("Model/TextCNN.meta")
sess = tf.Session()
saver.restore(sess, "./Model/model.ckpt")
graph = tf.get_default_graph()
tf_x = graph.get_tensor_by_name('input_x:0')
tf_y = graph.get_tensor_by_name('input_y:0')
dense = graph.get_tensor_by_name('dense:0')
dense = tf.stop_gradient(dense) #因為只想訓(xùn)練最后一層喇肋,所以在這里要停止梯度后向傳播
logist = tf.layers.dense(dense, 2)
pred = tf.argmax(tf.nn.softmax(logist),1)
通過convert_variables_to_constants函數(shù)將計算圖中的變量及其取值通過常量的方式保存于pb文件中
import tensorflow as tf
from tensorflow.python.framework import graph_util
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(1.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['multiply'])
with tf.gfile.GFile('Model/model.pb', 'wb') as f:
f.write(output_grapg_def.SerializeToString())
從pb文件中恢復(fù)模型凰棉,并實現(xiàn)預(yù)測
import tensorflow as tf
from tensorflow.python.platform import gfile
tf_x = tf.placeholder(tf.float32, shape = [None, None], name = 'x')
sess = tf.Session()
with gfile.FastGFile('Model/model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def, input_map = {'x:0': tf_x}, return_elements=['pred:0'])
pred = sess.run(output, feed_dict = {tf_x: x})