So you want to create a new model!怀大!
在本節(jié)中砍的,我們將討論用于定義檢測模型的一些抽象拭嫁。如果您想定義一個新的模型體系結(jié)構(gòu)以進行檢測并在Tensorflow Detection API中使用它,那么本節(jié)還應(yīng)該作為需要編輯以使新模型正常工作的文件的高級指南伶唯。
DetectionModels(object_detection/core/model.py
)
為了使用我們提供的二進制文件進行訓(xùn)練骡送,評估和導(dǎo)出冈敛,Tensorflow Object Detection API下的所有模型(Faser RCNN待笑,Mask RCNN,SSD等)都使用DetectionModel
接口(請參閱完整定義object_detection/core/model.py
)。DetectionModel實現(xiàn)5個功能:
-
preprocess
:在圖片輸入到檢測器之前抓谴,需要對圖片進行必要的預(yù)處理(例如暮蹂,縮放/移位/整形)。 -
predict
:生成可以傳遞給損失函數(shù)或后處理函數(shù)(postprocess)的“原始圖片”預(yù)測張量癌压。 -
postprocess
:將預(yù)測(predict)輸出張量轉(zhuǎn)換為最終檢測的圖片椎侠。 -
loss
:針對提供的真是標簽(ground_truth)計算標量損失張量。 -
restore
:將檢查點加載到Tensorflow圖中措拇。
給定DetectionModel
訓(xùn)練時間我纪,我們通過以下函數(shù)序列傳遞每個圖像批次,以計算可通過SGD優(yōu)化的損失:
inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
在eval時間,我們通過以下函數(shù)序列傳遞每個圖像批次以生成一組檢測:
inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
一些規(guī)定:
-
DetectionModel
不應(yīng)該對輸入大小或?qū)捀弑茸鋈魏渭僭O(shè)(也就是可以對任意圖片進行檢測) - 它們負責(zé)進行必要的調(diào)整大小/重新整形(參見preprocess
函數(shù)的注釋 )浅悉。 - 輸出類始終是在
[0, num_classes)
整數(shù)范圍內(nèi)的數(shù)趟据,沒有預(yù)先假設(shè)背景類別。 - 檢測到的框?qū)⒈唤忉尀?
[y_min, x_min, y_max, x_max]
格式化并相對于圖像窗口標準化术健。 - 我們沒有具體假設(shè)對分數(shù)的任何概率解釋 - 僅僅進行了相對排序汹碱。因此,后處理功能的實現(xiàn)可以自由地輸出對數(shù)荞估,概率咳促,校準概率或其他任何東西。
定義新的Faster R-CNN or SSD Feature Extractor
在大多數(shù)情況下勘伺,不會從頭寫DetectionModel
- 一般是創(chuàng)建一個新的功能提取器跪腹,供其中一個SSD或Faster R-CNN 的meta-architectures.模型使用。(meta-architectures是DetectionModel
子的類)飞醉。
注意:為了使下面的討論有意義冲茸,建議首先熟悉Faster R-CNN 論文。
如果使用一種全新的網(wǎng)絡(luò)架構(gòu)(比如說缅帘,“InceptionV100”)進行分類轴术,并希望了解InceptionV100如何作為檢測的特征提取器(例如,使用Faster R-CNN)钦无。
要使用InceptionV100逗栽,我們必須定義一個新的 FasterRCNNFeatureExtractor
并將其FasterRCNNMetaArch
作為輸入傳遞給我們的構(gòu)造函數(shù)。
在object_detection/meta_architectures/faster_rcnn_meta_arch.py
失暂。分別定義了FasterRCNNFeatureExtractor
和FasterRCNNMetaArch
祭陷。
FasterRCNNFeatureExtractor
必須定義的幾個功能:
-
preprocess
:在輸入圖像上運行檢測器之前,運行對輸入值進行的任何預(yù)處理趣席。 -
_extract_proposal_features
:提取第一階段區(qū)域提議網(wǎng)絡(luò)(RPN)功能。 -
_extract_box_classifier_features
:提取第二階段Box分類器功能醇蝴。 -
restore_from_classification_checkpoint_fn
:將檢查點加載到Tensorflow圖中宣肚。
使用object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
舉一個例子。:
- 使用Slim Resnet-101分類檢查點的權(quán)重來初始化此特征提取器的權(quán)重 悠栓,在此檢查點模型內(nèi)部對圖像進行了預(yù)處理霉涨,通過從每個輸入圖像中減去通道平均值。因此惭适,需要實現(xiàn)預(yù)處理函數(shù)來重現(xiàn)相同的通道平均減法行為笙瑟。
- 在slim中定義的“完整”resnet分類網(wǎng)絡(luò)被分成兩部分 - 除last “resnet block”之外的所有部分都被傳入到
_extract_proposal_features
函數(shù)中,last “resnet block”傳入到_extract_box_classifier_features function
函數(shù)中癞志。一般情況下往枷,可能需要進行一些實驗來確定最佳層,以便將特征提取器“切割”為這兩個部分,以實現(xiàn)FasterRCNN错洁。
配置自己的模型參數(shù)
假設(shè)feature extractor不需要標準配置秉宿,理想情況下,希望能夠簡單地更改配置中的“feature_extractor.type”字段以指向新的功能提取器屯碴。為了讓我們的API知道如何理解這種新類型描睦,您首先必須使用模型構(gòu)建器(object_detection/builders/model_builder.py
)編寫新的feature extractor,其作用是從配置原型創(chuàng)建模型导而。
創(chuàng)建很簡單---只需添加一個指針忱叭,該指針指向您在object_detection/builders/model_builder.py
文件頂部的一個SSD或FasterRCNN特征提取器類映射中定義的新的Feature Extractor類 。建議添加一個測試今艺,object_detection/builders/model_builder_test.py
以確保解析新的proto將按預(yù)期工作韵丑。(在model_builder.py有個字典把自己的CNN模型添加進去就可以了)
把新模型做的更加性感一點
創(chuàng)建好模型之后,就可以使用新的模型洼滚!最終提示:
- 要節(jié)省調(diào)試時間埂息,請首先嘗試在本地運行配置文件(包括培訓(xùn)和評估)。
- 學(xué)習(xí)一定的學(xué)習(xí)率遥巴,以確定哪種學(xué)習(xí)率最適合新的模型千康。
- 一個小但通常很重要的細節(jié):可能會發(fā)現(xiàn)有必要禁用BN訓(xùn)練(即,從分類檢查點加載批處理規(guī)范參數(shù)铲掐,但在梯度下降期間不要更新它們)拾弃。