????????上一篇文章 TensorFlow 訓(xùn)練 CNN 分類(lèi)器 中說(shuō)明了訓(xùn)練簡(jiǎn)單 CNN 模型的整個(gè)過(guò)程栏渺,并在訓(xùn)練結(jié)束后使用 .save
函數(shù)來(lái)保存訓(xùn)練的結(jié)果曙强,其后通過(guò)使用 tf.train.import_meta_graph
和 .restore
函數(shù)來(lái)導(dǎo)入模型進(jìn)行推斷腕扶。本文承接上文,對(duì)模型保存與恢復(fù)做一個(gè)總結(jié)。
????????總的來(lái)說(shuō)懒鉴,模型在保存和恢復(fù)時(shí)最重要的是留下數(shù)據(jù)接口犬性,方便使用時(shí)傳入數(shù)據(jù)和獲取結(jié)果瞻离。TensorFlow 中常用的模型保存格式為 .ckpt 和 .pb,下面分別進(jìn)行詳細(xì)說(shuō)明仔夺。
一琐脏、ckpt 格式模型保存與恢復(fù)
????????.ckpt 格式保存與恢復(fù)都很簡(jiǎn)單,具體可參考 TensorFlow 訓(xùn)練 CNN 分類(lèi)器缸兔。
1. ckpt 格式模型保存
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction') <-- 出口(僅作為例子日裙,下同)
···
saver = tf.train.Saver()
···
with tf.Session() as sess:
··· <-- 訓(xùn)練過(guò)程
saver.save(sess, './xxx/xxx.ckpt') <-- 模型保存
????????如上述代碼所示,假設(shè)你定義了一個(gè) TensorFlow 模型惰蜜,數(shù)據(jù)入口由占位符 inputs
給定昂拂,結(jié)果出口由張量 prediction
給定。通過(guò)語(yǔ)句 saver = tf.train.Saver()
定義了模型保存的一個(gè)實(shí)例對(duì)象 saver
抛猖,當(dāng)模型訓(xùn)練結(jié)束之后只需要簡(jiǎn)單的一條語(yǔ)句:
saver.save(sess, path_to_model.ckpt)
就把訓(xùn)練結(jié)果保存到了指定的路徑格侯。
????????以上代碼之所以把變量 inputs
和 prediction
單獨(dú)列出,一方面是因?yàn)樗鼈兪悄P?Graph 的起點(diǎn)和終點(diǎn)(戲稱(chēng)為數(shù)據(jù)入口财著、出口)联四,另一方面的原因是它們被特別的指定了名稱(chēng),因而在模型恢復(fù)時(shí)可以通過(guò)它們的名稱(chēng)而得到 Graph 中對(duì)應(yīng)的節(jié)點(diǎn)撑教。
2. ckpt 格式模型恢復(fù)
????????當(dāng)你需要導(dǎo)入模型進(jìn)行推斷時(shí)朝墩,只需要通過(guò)張量名獲取數(shù)據(jù)入口和出口,然后傳入數(shù)據(jù)即可:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
prediction = tf.get_default_graph().get_tensor_by_name('prediction:0')
pred = sess.run(prediction, feed_dict={inputs: xxx}
????????保存為 .ckpt 模型的一個(gè)好處是伟姐,當(dāng)需要繼續(xù)訓(xùn)練時(shí)收苏,只需要將訓(xùn)練過(guò)的模型結(jié)果導(dǎo)入,然后在這個(gè)基礎(chǔ)上再繼續(xù)訓(xùn)練愤兵。而下面的 .pb 格式則不能繼續(xù)訓(xùn)練鹿霸,因?yàn)檫@種格式保存的模型參數(shù)都已經(jīng)轉(zhuǎn)化為了常量(而不再是變量)。
二秆乳、pb 格式模型保存與恢復(fù)
????????.pb 格式模型保存與恢復(fù)相比于前面的 .ckpt 格式而言要稍微麻煩一點(diǎn)懦鼠,但使用更靈活,特別是模型恢復(fù),因?yàn)樗梢悦撾x會(huì)話(Session)而存在葛闷,便于部署憋槐。
1. pb 格式模型保存
????????與 .ckpt 格式模型保存類(lèi)似,首先定義數(shù)據(jù)入口淑趾、出口:
from tensorflow.python.framework import graph_util
···
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs')
···
prediction = tf.nn.softmax(logits, name='prediction')
···
with tf.Session() as sess:
··· <-- 訓(xùn)練過(guò)程
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 參數(shù):output_node_names阳仔,輸出節(jié)點(diǎn)名
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
然后通過(guò)函數(shù) graph_util.convert_variables_to_constants
將模型固話,使得所有變量轉(zhuǎn)化為常量扣泊,之后寫(xiě)入到指定的路徑完成模型保存過(guò)程近范。
2. pb 格式模型恢復(fù)
????????.pb 格式模型恢復(fù)自由度較大,不需要在會(huì)話里進(jìn)行操作延蟹,可以獨(dú)立存在:
import os
def load_model(path_to_model.pb):
if not os.path.exists(path_to_model.pb):
raise ValueError("'path_to_model.pb' is not exist.")
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model.pb, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return model_graph
模型導(dǎo)入之后评矩,便可以獲取數(shù)據(jù)入口和出口,然后進(jìn)行推斷:
model_graph = load_model('./xxx/xxx.pb')
inputs = model_graph.get_tensor_by_name('inputs:0')
prediction = model_graph.get_tensor_by_name('prediction:0')
with model_graph.as_default():
with tf.Session(graph=model_graph) as sess:
···
pred = sess.run(prediction, feed_dict={inputs: xxx}
三阱飘、ckpt 格式轉(zhuǎn) pb 格式
????????一般情況下斥杜,為了便于從斷點(diǎn)之處繼續(xù)訓(xùn)練,模型通常保存為 .ckpt 格式沥匈,而一旦對(duì)訓(xùn)練結(jié)果很滿(mǎn)意之后則可能需要將 .ckpt 格式轉(zhuǎn)化為 .pb 格式蔗喂。轉(zhuǎn)化方法很簡(jiǎn)單,只需要綜合前面的一高帖、二兩步即可:
from tensorflow.python.framework import graph_util
with tf.Session() as sess:
# Load .ckpt file
saver = tf.train.import_meta_graph('./xxx/xxx.ckpt.meta')
saver.restore(sess, './xxx/xxx.ckpt')
# Save as .pb file
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess,
graph_def,
['prediction'] <-- 輸出節(jié)點(diǎn)名缰儿,以實(shí)際情況為準(zhǔn)
)
with tf.gfile.GFile('./xxx/xxx.pb', 'wb') as fid:
serialized_graph = output_graph_def.SerializeToString()
fid.write(serialized_graph)
????????預(yù)告:下一篇文章將簡(jiǎn)單介紹 tensorflow.contrib.slim
的應(yīng)用,敬請(qǐng)關(guān)注散址!