參考鏈接:http://www.reibang.com/p/091415b114e2
https://cloud.tencent.com/developer/ask/188650
由于arm nn官方提供的mnist-tf例程中提供的模型類型是prototxt或者pb文件炸庞,所以這里需要把tensorflow保存的ckpt文件轉(zhuǎn)換成pb文件
tensorflow訓(xùn)練生成的ckpt文件包含4個,分別是
1. checkpoint文件粗梭,記錄了最新的檢查點文件
2. model.data文件,是saver.save(sess)保存的結(jié)果帮哈,記錄了所有變量的值
3. model.index文件镶骗,暫不明確,待查愉择》朔玻恢復(fù)模型不必須用到
4. model.meta文件膊畴,保存了計算圖的結(jié)構(gòu),沒有變量的值
轉(zhuǎn)換方法
使用freeze_graph(見第一個參考鏈接病游,經(jīng)過測試發(fā)現(xiàn)對于很小的模型lenet5可以成功唇跨,但是對于較大的模型稠通,比如這里用到的一個400MB左右的網(wǎng)絡(luò),經(jīng)過測試买猖,會把16GB的內(nèi)存消耗干凈改橘,轉(zhuǎn)換失敗=_=)
使用convert_variables_to_constants
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
sess = tf.Session()
saver = tf.train.import_meta_graph("meta文件目錄")
saver.restore(sess, tf.train.latest_checkpoint("checkpoint文件所在目錄"))
graph = tf.get_default_graph()
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['輸出tensor名字'])
with tf.gfile.FastGFile('pb文件保存目錄', mode='wb') as f:
f.write(output_graph_def.SerializeToString())