在上一章節(jié)訓(xùn)練一個(gè)物體檢測(cè)器,
我們將訓(xùn)練后的模型導(dǎo)出成了pb文件尸变,用在PC側(cè)tensorflow物體監(jiān)測(cè);本章節(jié)减俏,我們嘗試在Android手機(jī)上轉(zhuǎn)化我們訓(xùn)練好的模型召烂,供手機(jī)端tensorflow-lite使用
生成pb和pbtxt文件
#~/tensorflow/models/research/object_detection$
python export_tflite_ssd_graph.py \
--pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \
--trained_checkpoint_prefix=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/model.ckpt-62236 \
--output_directory=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train \
--add_postprocessing_op=true
輸出結(jié)果為:
(base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep tflite_graph
tflite_graph.pb
tflite_graph.pbtxt
pb文件轉(zhuǎn)化成tflite
運(yùn)行以下命令:
#~/anaconda2/lib/python2.7/site-packages/tensorflow/lite
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops
會(huì)提示以下錯(cuò)誤:
F tensorflow/lite/toco/tooling_util.cc:1709] Array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_0/Relu6, which is an input to the DepthwiseConv operator producing the output array FeatureExtractor/MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6, is lacking min/max data, which is necessary for quantization. If accuracy matters, either target a non-quantized output format, or run quantized training with your model from a floating point checkpoint to change the input graph to contain min/max information. If you don't care about accuracy, you can pass --default_ranges_min= and --default_ranges_max= for easy experimentation.
Aborted (core dumped)
錯(cuò)誤解決方案一:
使用非量化的轉(zhuǎn)換,需要將inference_type=QUANTIZED_UINT8 改為—inference_type=FLOAT并添加:
--default_ranges_min
--default_ranges_max
Quantized模型里面的權(quán)重參數(shù)用1個(gè)字節(jié)的uint8類型表示娃承,模型大小是Float版本的四分之一奏夫;后續(xù)我們?cè)僦v解怎么生成Quantized的模型文件
最后運(yùn)行以下命令:
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops \
--default_ranges_min=0\
--default_ranges_max=6
生成detect.tflite文件:
base) jiadongfeng@jiadongfeng:~/tensorflow/dataset/raccoon_dataset/jdf_train$ ls | grep detect
detect.tflite
生成的文件達(dá)到22MB,而原生的支持80個(gè)物種監(jiān)測(cè)的tflite文件(Quantized類型)相機(jī)集成物體監(jiān)測(cè)历筝,僅僅為3MB左右酗昼;
解決方案二:
使用量化轉(zhuǎn)換,將inference_type和input_data_type設(shè)置為QUANTIZED_UINT8 梳猪;
參數(shù)default_ranges_min和default_ranges_max也需要設(shè)置
toco \
--graph_def_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/tflite_graph.pb \
--output_file=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--input_data_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_dev_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops \
--default_ranges_min=0\
--default_ranges_max=6
此種方案生成的tflite文件麻削,比非量化模式減少了四倍,精度稍有下降
tflite集成到相機(jī)中
- 首先春弥,相機(jī)集成流程參考相機(jī)集成物體監(jiān)測(cè)呛哟,我們這里直接替換tflite文件,修改對(duì)應(yīng)的label_map文件即可
- 然后匿沛,原生的監(jiān)測(cè)模型是量化后的模型扫责,而我們的是float類型的模型;所以需要將TF_OD_API_IS_QUANTIZED 改為false
private static final boolean TF_OD_API_IS_QUANTIZED = true;
detector =
TFLiteObjectDetectionAPIModel.create(
cameraActivity.getAssets(),
TF_OD_API_MODEL_FILE,
TF_OD_API_LABELS_FILE,
TF_OD_API_INPUT_SIZE,
TF_OD_API_IS_QUANTIZED);
- 最后俺祠,集成后的監(jiān)測(cè)效果圖為: