1.寫(xiě)在前面
最近組里有個(gè)項(xiàng)目與目標(biāo)識(shí)別有關(guān)侨嘀,去網(wǎng)上找了一下妨猩,發(fā)現(xiàn)目前SOTA的目標(biāo)識(shí)別算法基本都是one-stage的燕差,比如SSD遭笋、DSSD、RetinaNet徒探、YOLO等瓦呼,但是速度上YOLO是最快的。而且看了下YOLO主頁(yè)测暗,作者的風(fēng)格我很喜歡央串。所以仔細(xì)研究了一下。本文的內(nèi)容基于GluonCV碗啄、OpenCV和YoloV3质和,運(yùn)行平臺(tái)為Ubuntu16.04版本。ps:因?yàn)榻M里采購(gòu)的服務(wù)器還沒(méi)到挫掏,目前只能在我自己筆記本的虛擬機(jī)上跑侦另,而虛擬機(jī)的顯卡是模擬出來(lái)的,無(wú)法安裝CUDA和CUDNN(這個(gè)坑也是我安裝CUDA遇到了各種坑后發(fā)現(xiàn)的),各位有條件的還是使用CUDA+CUDNN環(huán)境褒傅,速度會(huì)快不少弃锐。
2.環(huán)境搭建
2.1 GluonCV
GuonCV是一個(gè)計(jì)算機(jī)視覺(jué)深度學(xué)習(xí)的工具箱,功能非常強(qiáng)大殿托,包含了圖像分類(lèi)霹菊,目標(biāo)識(shí)別,語(yǔ)義分割支竹,實(shí)例分割等旋廷。GluonCV的安裝在他們主頁(yè)上面有介紹,安裝很簡(jiǎn)單礼搁,python2和python3都可以饶碘,但是你的pip版本要大于9.0,同時(shí)還要安裝一個(gè)mxnet框架馒吴。同時(shí)他們主頁(yè)還提供了一些簡(jiǎn)單的demo教你使用扎运,還可以查詢(xún)API的源代碼。
2.2 OpenCV
OpenCV是一個(gè)用于圖像處理饮戳、分析豪治、機(jī)器視覺(jué)方面的開(kāi)源函數(shù)庫(kù). 無(wú)論你是做科學(xué)研究,還是商業(yè)應(yīng)用扯罐,OpenCV都可以作為你理想的工具庫(kù)负拟,因?yàn)椋瑢?duì)于這兩者歹河,它完全是免費(fèi)的掩浙。該庫(kù)采用C及C++語(yǔ)言編寫(xiě),可以在windows, linux, mac OSX系統(tǒng)上面運(yùn)行秸歧。該庫(kù)的所有代碼都經(jīng)過(guò)優(yōu)化涣脚,計(jì)算效率很高,因?yàn)榱让#鼘?zhuān)注于設(shè)計(jì)成為一種用于實(shí)時(shí)系統(tǒng)的開(kāi)源庫(kù)。opencv采用C語(yǔ)言進(jìn)行優(yōu)化矾麻,而且纱耻,在多核機(jī)器上面,其運(yùn)行速度會(huì)更快险耀。它的一個(gè)目標(biāo)是提供友好的機(jī)器視覺(jué)接口函數(shù)弄喘,從而使得復(fù)雜的機(jī)器視覺(jué)產(chǎn)品可以加速面世。該庫(kù)包含了橫跨工業(yè)產(chǎn)品檢測(cè)甩牺、醫(yī)學(xué)圖像處理蘑志、安防、用戶(hù)界面、攝像頭標(biāo)定急但、三維成像澎媒、機(jī)器視覺(jué)等領(lǐng)域的超過(guò)500個(gè)接口函數(shù)。
OpenCV安裝很簡(jiǎn)單波桩,直接pip install opencv-python
即可戒努。你也可以使用源代碼安裝,官網(wǎng)的下載速度很痛苦镐躲,我給個(gè)OpenCV3.4.7版本的鏈接,需要的朋友可以自却⒚怠:
https://pan.baidu.com/s/1Zts9WR7VtH-2L0e9fIaNHw
提取碼:498k
源碼的安裝教程網(wǎng)上很多,我貼一個(gè)別人https://jingyan.baidu.com/article/a3761b2be162951576f9aace.html萤皂,需要安裝cmake工具撒穷,沒(méi)有安裝的直接apt install cmake
就可以了。
2.3 YoloV3
YoloV3在他們主頁(yè)有很詳細(xì)的教程(基于darknet)裆熙,有興趣可以去看下他們的論文端礼,寫(xiě)的很有趣,傳統(tǒng)的識(shí)別方法是當(dāng)做一個(gè)分類(lèi)問(wèn)題弛车,而作者當(dāng)做一個(gè)回歸問(wèn)題來(lái)處理齐媒,同時(shí)并不像傳統(tǒng)算法那樣需要很多滑動(dòng)窗口,他是end to end直接輸出結(jié)果纷跛,這也是他們的名字YOLO(you only look once)的由來(lái)喻括。同時(shí)推薦新手使用darknet,他是一個(gè)很輕量級(jí)的框架贫奠,但是內(nèi)容很多唬血,且易于上手。
3.代碼
代碼主要分為三個(gè)模塊唤崭,utils模塊拷恨,detection模塊和main模塊。
3.1 utils模塊
utils模塊包括data_preset.py谢肾,yolov3.py腕侄,bbox.py等文件
[圖片上傳失敗...(image-4f99d3-1569487002238)]
3.2 detection模塊
detection模塊包括model,mobilefacedetnet.py等文件
[圖片上傳失敗...(image-7077ae-1569487002238)]
3.3 main模塊
main模塊包括cap.py函數(shù)芦疏,其實(shí)就是執(zhí)行函數(shù)冕杠。使用python3 cap.py
執(zhí)行就行。ps:我設(shè)置了一些命令行參數(shù)酸茴,比如--video
選擇本地視頻分预,--camera
選擇攝像頭,--gpu
選擇是否使用GPU薪捍。大家可以使用python3 cap.py -h
查看使用方法,比如
[圖片上傳失敗...(image-4ec4d7-1569487002238)]
cap.py代碼如下:
from mxnet import nd
import gluoncv as gcv
from mxnet.gluon.nn import BatchNorm
from gluoncv.data.transforms import presets
from matplotlib import pyplot as plt
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Detection/utils/')
from data_presets import data_trans
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Detection/')
from mobilefacedetnet import mobilefacedetnet_v2
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + os.sep + '../MobileFace_Tracking/')
from mobileface_sort_v1 import Sort
def parse_args():
parser = argparse.ArgumentParser(description='Test with YOLO networks.')
parser.add_argument('--model', type=str,
default='../MobileFace_Detection/model/mobilefacedet_v2_gluoncv.params',
help='Pretrained model path.')
parser.add_argument('--video', type=str, default='friends1.mp4',
help='Test video path.')
parser.add_argument('--camera', type=int, default=None,
help='Camera select')
parser.add_argument('--gpus', type=str, default='',
help='Default is cpu , you can specify 1,3 for example with GPUs.')
parser.add_argument('--pretrained', type=str, default='True',
help='Load weights from previously saved parameters.')
parser.add_argument('--thresh', type=float, default=0.5,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--sort_max_age', type=int, default=10,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--sort_min_hits', type=int, default=3,
help='Threshold of object score when visualize the bboxes.')
parser.add_argument('--output', type=str,
default='./tracking_result/result_friends1_tracking.avi',
help='Output video path and name.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# context list
ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
ctx = [mx.cpu()] if not ctx else ctx
net = mobilefacedetnet_v2(args.model)
net.set_nms(0.45, 200)
net.collect_params().reset_ctx(ctx = ctx)
mot_tracker = Sort(args.sort_max_age, args.sort_min_hits)
img_short = 256
colors = np.random.rand(32, 3) * 255
winName = 'MobileFace for face detection and tracking'
cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
if args.camera == None:
cap = cv2.VideoCapture(args.video)
else:
cap = cv2.VideoCapture(args.camera)
output_video = args.output
# video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc('M','J','P','G'), 30, (round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
video_writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc('M','J','P','G'), 30, (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
# while(cap.isOpened()):
while cv2.waitKey(1) < 0:
ret, frame = cap.read()
if not ret:
print("Done processing !!!")
print("Output file is stored as ", output_video)
cv2.waitKey(3000)
break
dets = []
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_nd = nd.array(frame_rgb)
x, img = data_trans(frame_nd, short=img_short)
x = x.as_in_context(ctx[0])
# ids, scores, bboxes = [xx[0].asnumpy() for xx in net(x)]
tic = time.time()
result = net(x)
toc = time.time() - tic
#print('Detection inference time:%fms' % (toc*1000))
ids, scores, bboxes = [xx[0].asnumpy() for xx in result]
h, w, c = frame.shape
scale = float(img_short) / float(min(h, w))
for i, bbox in enumerate(bboxes):
if scores[i]< args.thresh:
continue
xmin, ymin, xmax, ymax = [int(x/scale) for x in bbox]
# result = [xmin, ymin, xmax, ymax, ids[i], scores[i]]
result = [xmin, ymin, xmax, ymax, ids[i]]
dets.append(result)
dets = np.array(dets)
tic = time.time()
trackers = mot_tracker.update(dets)
toc = time.time() - tic
#print('Tracking time:%fms' % (toc*1000))
for d in trackers:
color = (int(colors[int(d[4]) % 32, 0]), int(colors[int(d[4]) % 32,1]), int(colors[int(d[4]) % 32, 2]))
cv2.rectangle(frame, (int(d[0]), int(d[1])), (int(d[2]), int(d[3])), color, 3)
# cv2.putText(frame, str('%s%0.2f' % (net.classes[int(d[4])], d[5])),
# (d[0], d[1] - 5), cv2.FONT_HERSHEY_COMPLEX , 0.8, color, 2)
cv2.putText(frame, str('%s%d' % ('face', d[4])),
(int(d[0]), int(d[1]) - 5), cv2.FONT_HERSHEY_COMPLEX , 0.8, color, 2)
video_writer.write(frame.astype(np.uint8))
cv2.imshow(winName, frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
warnings.filterwarnings("ignore")
main()
4.后續(xù)
項(xiàng)目我會(huì)放到我的GitHub上笼痹,更新了會(huì)告訴大家配喳,如果有想要的可以聯(lián)系我maplect@sina.com,我看到會(huì)發(fā)給你凳干。