我們已經(jīng)訓(xùn)練出了rpn網(wǎng)絡(luò)凤跑,下面利用訓(xùn)練好的rpn網(wǎng)絡(luò)來(lái)生成proposals乍迄。
下面來(lái)看一下rpn_generate函數(shù):
首先設(shè)置參數(shù):
然后得到一個(gè)pascal_voc類:imdb = get_imdb(imdb_name)
加載訓(xùn)練的rpn網(wǎng)絡(luò):rpn_net = caffe.Net(rpn_test_prototxt, rpn_model_path, caffe.TEST)
然后得到生成的proposals(候選區(qū)域苗傅,最多2000個(gè)):rpn_proposals = imdb_proposals(rpn_net, imdb)休吠,得到的rpn_proposals 是一個(gè)列表语稠,列表中的每個(gè)元素是每個(gè)圖片的rpn_proposals,酷誓,而且rpn_proposals是一個(gè)len(keep)行4列的矩陣,其中l(wèi)en(keep)的最大值為2000慈参。
好了呛牲,下面看一下imdb_proposals函數(shù)的具體結(jié)構(gòu):
這里可以看到imdb_boxes是一個(gè)列表,列表中元素初始化為空驮配。imdb_boxes也是返回值娘扩。在這個(gè)函數(shù)中,首先用cv2.imread讀取圖片數(shù)據(jù)壮锻,然后用im_proposals函數(shù)來(lái)得到proposals和對(duì)應(yīng)的前景得分琐旁。
具體來(lái)看一下im_proposals函數(shù):
首先獲取網(wǎng)絡(luò)的輸入數(shù)據(jù),_get_image_blob函數(shù)可以將imread讀取的圖片數(shù)據(jù)轉(zhuǎn)化成blob需要的格式猜绣,這個(gè)函數(shù)返回兩個(gè)值:blob(4維矩陣灰殴,當(dāng)然這里的batch=1,通道數(shù):3)和?im_info(1行3列的矩陣:[縮放后的圖片高度掰邢、縮放后的圖片寬度牺陶、縮放比例])伟阔。
然后把這兩個(gè)變量輸入到net網(wǎng)絡(luò),net網(wǎng)絡(luò)是什么呢掰伸,從rpn_generate函數(shù)中皱炉,可以看到:net對(duì)應(yīng)的prototxt文件為:rpn_test_prototxt,我們進(jìn)入到rpn_test_prototxt文件里面狮鸭,可以看到這個(gè)網(wǎng)絡(luò)的結(jié)果和我們前面訓(xùn)練的rpn網(wǎng)絡(luò)基本上是一致的合搅,只是把最后的loss部分改成了proposal layer,我們只需要看最后一層proposal layer的forward的結(jié)果就可以了歧蕉。
layer {
name: 'proposal'
type: 'Python'
bottom: 'rpn_cls_prob_reshape'
bottom: 'rpn_bbox_pred'
bottom: 'im_info'
top: 'rois'? ? # len(keep)行5列的矩陣灾部,第1列元素:0,其余4列:proposals的左上角和右下角的坐標(biāo)
top: 'scores'? # len(keep)行1列的矩陣惯退,矩陣元素為:前景得分
python_param {
module: 'rpn.proposal_layer'
layer: 'ProposalLayer'
param_str: "'feat_stride': 16"
}
}
先看一下輸入:
bottom: 'rpn_cls_prob_reshape' ?:batch * 18 * height * width ?( 1 * 18 * 14 *14)
bottom: 'rpn_bbox_pred' ?:batch * 36* height * width ?( 1 * 36* 14 *14)
bottom: 'im_info' ?:1行3列的矩陣:[縮放后的圖片高度赌髓、縮放后的圖片寬度、縮放比例]
下面蒸痹,我們到對(duì)應(yīng)的rpn.proposal_layer里面去看一下:
首先春弥,是設(shè)置一些參數(shù),然后得到一些基本數(shù)據(jù):
scores:前景得分 ?( 1 * 9 * 14 * 14)
bbox_deltas:anchor的偏移量叠荠,即:tx, ty, tw, th ???( 1 * 36* 14 *14 )
im_info:[縮放后的圖片高度匿沛、縮放后的圖片寬度、縮放比例]
height, width = scores.shape[-2:] 榛鼎,得到feature的高度和寬度逃呼,然后利用height, width來(lái)生成所有的anchor,之后對(duì)anchor進(jìn)行reshape:
anchors = anchors.reshape((K * A,4)) ? ? ? ? ? # 生成所有的anchor者娱,K*A個(gè)抡笼,K=height *?width ,A=9
然后黄鳍,把bbox_deltas 和?scores 都reshape到同樣的形式:
bbox_deltas = bbox_deltas.transpose((0,2,3,1)).reshape((-1,4))
scores = scores.transpose((0,2,3,1)).reshape((-1,1))
好了推姻,到重點(diǎn)部分了,由anchor來(lái)生成proposals:
proposals = bbox_transform_inv(anchors, bbox_deltas) :bbox_transform_inv函數(shù)很簡(jiǎn)單框沟,就是根據(jù)anchors和anchor的偏移量(tx, ty, tw, th)來(lái)生成?proposals 藏古。
接下來(lái)對(duì)proposals 進(jìn)行一系列的過(guò)濾操作,過(guò)濾之后進(jìn)行NMS操作:具體的流程是忍燥,先對(duì)前景得分scores進(jìn)行從大到小的排序拧晕,然后把排序的結(jié)果做NMS:keep = nms(np.hstack((proposals, scores)), nms_thresh),這里得到的keep是一個(gè)列表梅垄,列表的元素是進(jìn)行NMS操作之后厂捞,剩余的proposals的索引,然后根據(jù)keep索引,取出剩余的proposals 和 對(duì)應(yīng)的scores: ? ? ? ? ? ?proposals = proposals[keep, :] ? ? ? ? ? ? ? ? ? ?scores = scores[keep]靡馁。
最后欲鹏,把proposals增加1列,最前面增加1列的0臭墨,然后把proposals 和??scores 輸出貌虾。
top: 'rois'? ? # len(keep)行5列的矩陣,第1列元素:0裙犹,其余4列:proposals的左上角和右下角的坐標(biāo)
top: 'scores'? # len(keep)行1列的矩陣,矩陣元素為:前景得分
到這里衔憨,已經(jīng)得到了net的forward的結(jié)果叶圃,下面回到im_proposals函數(shù)。
將boxes = blobs_out['rois'][:,1:].copy() / scale践图,得到的boxes是把proposals對(duì)應(yīng)的原圖的結(jié)果(proposals是在縮放后的圖片中得到的)掺冠。然后輸出boxes和scores。
接著返回到imdb_proposals函數(shù):
imdb_boxes[i], scores = im_proposals(net, im)码党,從這里可以看出imdb_boxes[i]就是我們得到的box( len(keep)行4列的矩陣 )德崭,而且imdb_boxes[i]矩陣,就是im_proposals中的box揖盘。最后眉厨,返回imdb_boxes(列表,imdb中所有圖片的proposals)兽狭。
下面憾股,返回rpn_generate函數(shù),我們最終得到了imdb_name中所有圖片的rpn_proposals(列表)箕慧。
最后把rpn_proposals列表寫(xiě)入到rpn_proposals_path文件中服球。然后把rpn_proposals_path文件以字典的形式推入到子進(jìn)程的隊(duì)列中:queue.put({'proposal_path': rpn_proposals_path})。
經(jīng)過(guò)以上步驟颠焦,我們就創(chuàng)建了一個(gè)子進(jìn)程p:p = mp.Process(target=rpn_generate,kwargs=mp_kwargs)
然后啟動(dòng)子進(jìn)程:p.start()
從進(jìn)程中拿出rpn_proposals_path文件斩熊,得到rpn_proposals:rpn_stage1_out['proposal_path'] = mp_queue.get()['proposal_path']
等待進(jìn)程結(jié)束:p.join()