最近項(xiàng)目里需要一個小型的目標(biāo)檢測模型,SSD怪得、YOLO等一通模型調(diào)參試下來咱枉,直接調(diào)用TensorFlow object detect API居然效果最好,大廠的產(chǎn)品不得不服啊徒恋。使用mobilenet ssd v2模型蚕断,配置文件也未修改參數(shù),訓(xùn)練后的模型不光檢測效果不錯入挣,在CPU上的運(yùn)行時間也在70ms左右亿乳。之后將模型移植到安卓手機(jī)上(魅族MX4,老的不是一點(diǎn)點(diǎn)),卡頓明顯葛假;改用同事的華為障陶,在麒麟960上略微流暢了一些,但仍然不能達(dá)到實(shí)時檢測聊训。而且訓(xùn)練得到的pb模型居然有19M抱究,實(shí)在太大了,于是又探索了一波模型的壓縮和量化带斑。
模型壓縮
說到模型壓縮媳维,最簡單粗暴的方法當(dāng)然是減少卷積層數(shù)。在使用Tensorflow的API之前遏暴,我訓(xùn)練過一個SSD模型侄刽,檢測效果不錯,但耗時接近1s朋凉。為了提高檢測速度我果斷開始減少卷積層數(shù)州丹,并做了不同層數(shù)的對比試驗(yàn)。結(jié)果和原始的VGG16骨干相比杂彭,要么檢測效果相近墓毒,耗時也沒少多少,要么耗時大減亲怠,但漏檢率飆升所计。也就是在這個情況下,我轉(zhuǎn)投了mobilenet網(wǎng)絡(luò)团秽。
所以這次面臨模型壓縮時主胧, 我沒有再嘗試這個選項(xiàng)(當(dāng)然也有配置文件不支持刪減層數(shù),要刪就要去改slim里的源碼這個原因习勤。我一個前同事是中科院計(jì)算機(jī)博士踪栋,他的格言就是覺得源碼不好就別調(diào)用,自己寫图毕;要調(diào)用就盡量避免改源碼夷都,因?yàn)槟憧隙]有源碼寫得好)。這樣看下來予颤,就只能在配置文件的范圍內(nèi)自由發(fā)揮了囤官。
修改配置文件
首先,附上Tensorflow object detection API中支持的各大模型的配置文件地址:
models/research/object_detection/samples/configs at master · tensorflow/models · GitHub
這里面關(guān)于mobilenet_ssd_v2的有好幾個:
我使用的是最經(jīng)典的基于COCO數(shù)據(jù)集訓(xùn)練的配置文件蛤虐,也就是第一個党饮。圖里的最后一個也是基于COCO數(shù)據(jù)集的,不過是有量化的模型笆焰,這個文件我在后面也有用到劫谅。
打開配置文件,里面主要分成model、train和eval三塊捏检。在調(diào)用API訓(xùn)練自己的數(shù)據(jù)時荞驴,train和eval的數(shù)據(jù)當(dāng)然是要修改的:
回到model部分,在feature_extractor那里贯城,有一個depth_multiplier熊楼,這個參數(shù)作為一個因子與網(wǎng)絡(luò)中各層的channel數(shù)相乘,換言之能犯,depth_multiplier越小鲫骗,網(wǎng)絡(luò)中feature map的channel數(shù)越少,模型參數(shù)自然也就少了很多踩晶。depth_multiplier默認(rèn)為1执泰,在我的實(shí)驗(yàn)里改成了0.25,試就試一把大的渡蜻。
訓(xùn)練模型
之前depth_multiplier為1時术吝, 我訓(xùn)練是加載了預(yù)訓(xùn)練模型的,模型地址:
models/detection_model_zoo.md at master · tensorflow/models · GitHub
從圖中可以看出茸苇,mobilenet_v1的預(yù)訓(xùn)練模型中有一種0.75_depth的版本排苍,這就是depth_multiplier取0.75時在COCO數(shù)據(jù)集上訓(xùn)練出來的模型。對于mobilenet_v2学密,只提供了非量化版和量化版(個人覺得應(yīng)該0.25淘衙、0.5、0.75這幾個常用檔都提供一個腻暮,難道是官方不建議壓縮太多嗎彤守。。西壮。)
由于沒有對應(yīng)的預(yù)訓(xùn)練模型遗增,所以可以選擇加載或者不加載模型。
加載模型的話款青,開始訓(xùn)練后命令行會打印一大堆XXX?is available in checkpoint, but has an incompatible shape with model variable. This variable will not be initialized from the checkpoint. 不過這并不影響訓(xùn)練,忽略就可以了霍狰。
不加載的話抡草,就將配置文件里fine_tune_checkpoint的那兩行注釋掉。
進(jìn)入到object detection目錄蔗坯,運(yùn)行python object_detection/model_main.py? --pipeline_config_path=xxxxxxx/ssd_mobilenet_v2_coco.config? --model_dir=xxxxxxxx即可
PS:訓(xùn)練過程中是不會打印訓(xùn)練信息的康震,看命令行會以為電腦卡住了。宾濒。腿短。直到eval才會打印出信息
PPS:可以通過TensorBoard來監(jiān)聽訓(xùn)練過程,判斷訓(xùn)練是在正常進(jìn)行還是電腦真的卡住了(這種情況可能是因?yàn)閎atch size和輸入圖片大小太大。默認(rèn)是24和300*300橘忱,但也都可以改)
模型導(dǎo)出
訓(xùn)練完成之后赴魁,還是在object detection目錄下,運(yùn)行python export_inference_graph.py钝诚,必要的參數(shù)分別是輸入的ckpt的文件地址颖御,輸出的pb文件的文件夾以及配置文件地址。
在深度壓縮至0.25倍之后凝颇, 我的pb模型大小僅為2.2M潘拱,效果卓群。當(dāng)然網(wǎng)絡(luò)的縮減會帶來精度的損失拧略,我的AR和AP分別降了2個點(diǎn)和3個點(diǎn)芦岂。
模型移植
生成tflite模型
Tensorflow object detection API訓(xùn)練出的模型,講道理從ckpt轉(zhuǎn)成tflite只需要兩步:
第一步垫蛆,將ckpt轉(zhuǎn)成pb文件禽最,這次使用的是python?export_tflite_ssd_graph.py,操作難度不大月褥,會得到tflite_graph.pb和tflite_graph.pbtxt兩個文件弛随;
第二步,將pb轉(zhuǎn)為tflite文件宁赤,我搜到的方法大都是使用bazel編譯tensorflow/contirb/lite/toco下面的toca文件舀透,但我反復(fù)嘗試,報(bào)了多種錯誤决左,依舊沒有成功愕够。。佛猛。最后我在stackoverflow上搜到了一位小哥的回復(fù)惑芭,進(jìn)入tensorflow/contrib/lite/python目錄,運(yùn)行python tflite_convert.py继找,參數(shù)設(shè)置為
--graph_def_file=XXX/tflite_graph.pb 上一步生成的pb文件地址
--output_file=XXX/xxx.tflite 輸出的tflite文件地址
--input_arrays=normalized_input_image_tensor 輸入輸出的數(shù)組名稱對于mobilenet ssd是固定的遂跟,不用改?
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'
?--input_shape=1,XXX,XXX,3 輸入的圖片大小,需要與配置文件中一致
--allow_custom_ops
驗(yàn)證tflite模型
在將tflite模型放進(jìn)手機(jī)之前婴渡,我在python里加載tflite模型測試了一次幻锁,流程類似加載pb模型
第一步,導(dǎo)入模型
interpreter = tf.contrib.lite.Interpreter(model_path="compress_export/detect.tflite")
interpreter.allocate_tensors()
第二步边臼,獲得輸入和輸出的tensor
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
第三步哄尔,讀取輸入圖像,feed給輸入tensor
可以采用PIL或cv2將圖像讀入柠并,轉(zhuǎn)為numpy數(shù)組岭接,然后賦值給input_data
input_data = np.array(XXX)
interpreter.set_tensor(input_details[0]['index'], input_data)
第四步富拗,運(yùn)行模型
interpreter.invoke()
第五步, 獲得輸出
參考輸入tensor的表示方法鸣戴,目標(biāo)檢測的輸出有4個啃沪,具體的值可以通過output_details[0]['index']、output_details[1]['index']葵擎、output_details[2]['index']谅阿、output_details[3]['index']獲得
這里有一個我踩到的坑,驗(yàn)證tflite模型時酬滤,我采用了和加載pb模型完全相同的圖片預(yù)處理步驟签餐,輸出的結(jié)果完全不同。幾番檢查之后盯串,發(fā)現(xiàn)問題出在模型轉(zhuǎn)換時氯檐。運(yùn)行python tflite_convert.py時,輸入數(shù)組的名稱為normalized_input_image_tensor体捏,而我訓(xùn)練時采用的是未經(jīng)normalized的數(shù)組冠摄。所以在模型轉(zhuǎn)換時,tensorflow內(nèi)置了對input進(jìn)行normalized的步驟几缭。因此在調(diào)用tflite模型時河泳,同樣需要在圖像預(yù)處理中加入這一步。 nomlized的方法為除以128.0再減去1年栓,保證輸入的值在[-1,1)范圍內(nèi)拆挥。
參考
https://blog.csdn.net/qq_26535271/article/details/84930868
Tensorflow Convert pb file to TFLITE using python - Stack Overflow