技巧:使用freeze_graph將SavedModel生成的模型合并成一個pb文件
在訓(xùn)練mnist模型的時候,增加一個參數(shù)昨寞,可以在訓(xùn)練完畢后導(dǎo)出SavedModel
python mnist.py --export_dir /tmp/mnist_saved_model
訓(xùn)練完成后夜涕,會以以下文件結(jié)構(gòu)在/tmp/mnist_saved_model生成模型文件
── 1530159051
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
saved_model.pb
保存的是計算圖信息
variables
保存的是計算中的變量信息
我的目是想要在Android上運行測試這個模型拴事,但是從官方給的Demo中可以看到募逞,所加載的模型都是單個文件越败,暫時也沒看到加載SavedModel的接口
為了方便Android平臺使用,用freeze_graph
腳本將SavedModel
合并
查看這個腳本的源碼拆祈,在最開頭可以看到
Converts checkpoint variables into Const ops in a standalone GraphDef file.
This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.
It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.
這個腳本就是用來將變量轉(zhuǎn)換為常量合并到到計算圖中的
具體怎么使用呢恨闪?
腳本中給出的例子是這樣的:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
因為我電腦是完整安裝了tensorflow
的,所以可以直接使用freeze_graph
而不用通過bazel
去創(chuàng)建freeze_graph
freeze_graph --input_graph=saved_model.pb --input_checkpoint=variables/variables --output_graph=merge_graph.pb --output_node_names=softmax
執(zhí)行放坏,然后掛了
from ._conv import register_converters as _register_converters
Traceback (most recent call last):
File "/home/xxx/anaconda2/bin/freeze_graph", line 11, in <module>
sys.exit(run_main())
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py", line 379, in run_main
app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py", line 378, in <lambda>
my_main = lambda unused_args: main(unused_args, flags)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py", line 272, in main
flags.saved_model_tags, checkpoint_version)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py", line 231, in freeze_graph
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py", line 174, in _parse_input_graph_proto
text_format.Merge(f.read(), input_graph_def)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 536, in Merge
descriptor_pool=descriptor_pool)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 590, in MergeLines
return parser.MergeLines(lines, message)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 623, in MergeLines
self._ParseOrMerge(lines, message)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 638, in _ParseOrMerge
self._MergeField(tokenizer, message)
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 706, in _MergeField
name = tokenizer.ConsumeIdentifierOrNumber()
File "/home/xxx/anaconda2/lib/python2.7/site-packages/google/protobuf/text_format.py", line 1166, in ConsumeIdentifierOrNumber
raise self.ParseError('Expected identifier or number, got %s.' % result)
google.protobuf.text_format.ParseError: 1:1 : Expected identifier or number, got.
什么鬼錯咙咽,Expected identifier or number
字面意思應(yīng)該是傳入的文件少了點什么東西
再看看freeze_graph
腳本的源碼
在run_main()
函數(shù)里看到了一個參數(shù)
parser.add_argument(
"--input_saved_model_dir",
type=str,
default="",
help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
原來是用錯參數(shù)了,噗...
修改一下命令
freeze_graph \
--input_saved_model_dir=1530159051 \
--output_node_names=Softmax,ArgMax \
--output_graph=merge1_graph.pb
再執(zhí)行淤年,屁都沒冒一個钧敞,說明執(zhí)行成功了,沒報錯麸粮,哈哈
看下指定的輸出文件有沒有生成
ww3:~/source/TFmodels_out/mnist_saved_model$ ls -l
總用量 12804
drwxr-xr-x 3 ww3 ww3 4096 6月 28 12:10 1530159051
-rw-rw-r-- 1 ww3 ww3 13105521 6月 28 17:58 merge1_graph.pb
嗯...具體這個文件能不能直接用還有待研究