一升略、數(shù)據(jù)集準備
我這次用到的數(shù)據(jù)集來自這里:撲克牌數(shù)據(jù)集微王,下載下整個zip文件再從中復制出來屡限。劃重點:但是,我不建議你直接從他那下載炕倘,慢不說囚霸,后來我遇到了一個問題,就是會出現(xiàn)下面這樣的報錯:
FileNotFoundError: img file does not exist: /home/ymz/lsm/mmdetection/data/VOCdevkit/VOC2007/JPEGImages/IMG_2608.jpg
后來我發(fā)現(xiàn)確實數(shù)據(jù)集里面有IMG_2608.JPG激才,唯一的區(qū)別就是文件格式大寫了,后來證實確實mmcv.imread
讀不了大寫的.JPG额嘿,所以我這里把所有文件格式小寫之后的數(shù)據(jù)集鏈接放這里瘸恼,提取碼:vmsy
這個數(shù)據(jù)集并沒有收集全部的撲克牌類別,里面只有6類:nine,ten,jack,queen,king,ace册养。一共364張东帅,所以之后訓練不會耗時很久,基本20個epoch半個小時就完事了球拦。解壓之后的文件目錄是這樣的:
├── poker
│ ├── VOC2007
│ │ ├── Annotations
│ │ ├── JPEGImages
│ │ ├── ImageSets
│ │ │ ├── Main
│ │ │ │ ├── val.txt
│ │ │ │ ├── train.txt
二靠闭、mmdetection的安裝
mmdetection是一個基于pytorch的目標檢測框架,非常好用坎炼,支持模型也比較全愧膀,Github上目前star已有8k,而且commit也非骋ス猓活躍檩淋。這次就想熟悉一下怎么使用這個框架,故用了自己找的數(shù)據(jù)集跑一遍萄金。
安裝的話基本照著官網(wǎng)的說明文檔就行蟀悦,不過似乎最近也有一些小改動,這個https://mmdetection.readthedocs.io/en/latest/上面會更新慢一點氧敢。我在這里放一下全部整合的命令:
# 注意官方的Requirements
conda create -n open-mmlab python=3.7 -y
conda activate open-mmlab
# 安裝pytorch和torchvision自己來也行
conda install -c pytorch pytorch torchvision -y
# cython一定要安裝日戈,編譯需要
conda install cython -y
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .
# 官方建議創(chuàng)建軟連接,節(jié)省硬盤空間孙乖,在mmdetection目錄下運行下面的命令
mkdir data
ln -s $COCO_ROOT data
當然我們這里是自己VOC格式的數(shù)據(jù)集浙炼,最后一個軟鏈接就不能是上面最后一行,根據(jù)上面的數(shù)據(jù)集目錄結(jié)構(gòu)應該是:
mkdir data
cd data
ln -s /home/你的存放路徑/poker VOCdevkit
這樣就符合官方的推薦結(jié)構(gòu)了唯袄。
三鼓拧、修改相關(guān)文件
1. 修改class_names.py文件
修改mmdetection/mmdet/core/evaluation下的class_names.py中的voc_classes,將其改為要訓練的數(shù)據(jù)集的類別名稱越妈,否則測試的結(jié)果的名稱還會是aeroplane, bicycle, bird, boat,…這些季俩。改完后如圖:
2. 修改voc.py文件
修改mmdetection/mmdet/datasets/voc.py 下的類別,如果只有一個類梅掠,因為CLASSES是一個元組酌住,所以要加上一個逗號店归,否則將會報錯,改完后如圖:3. 修改配置文件
配置文件就是mmdetection/configs下一堆的名稱諸如cascade_rcnn_r50_fpn_1x.py的文件酪我,因為我們使用的是VOC格式消痛,這些默認是COCO格式(除了mmdetection/configs/pascal_voc文件夾下的幾個),所以我就挑了cascade_rcnn_r50_fpn_1x.py都哭,將它復制重命名為cascade_rcnn_r50_fpn_1x_poker.py秩伞,有下面幾個地方需要修改:
1、修改num_classes變量欺矫,就是背景類加上要分類的數(shù)量纱新,所以我們這里為7:
2、修改data settings部分穆趴,主要是了dataset_type脸爱、data_root、img_scale未妹、ann_file簿废、img_prefix變量的值:
最后的runtime settings也可以修改一下,比如total_epochs和workflow【[('train', 1)]表示只訓練络它,不驗證族檬;[('train', 2), ('val', 1)] 表示2個epoch訓練,1個epoch驗證】化戳,我將total_epochs設(shè)置成20导梆,所以學習率設(shè)置為step=[8, 15],checkpoint_config = dict(interval=2)迂烁,其他都保持默認看尼。
四、開始訓練
到現(xiàn)在就可以開始訓練了盟步,在mmdetection目錄下:
python tools/train.py configs/cascade_rcnn_r50_fpn_1x_poker.py
這樣就能成功訓練了藏斩,屏幕上會打印很多l(xiāng)og日志,當然訓練完成之后會在work_dirs目錄下出現(xiàn)如下圖的東西:有.log日志和.log.json却盘,還有每隔一定epoch(我這里是每隔2個epoch)保存模型狰域,為了方便后面的測試,還有最后的模型latest.pth黄橘。
五兆览、測試并計算mAP
1. 測試一張圖片的效果
我模仿demo/webcam_demo.py文件寫了試用于一張圖片的demo腳本image_demo.py:
import argparse
import torch
from mmdet.apis import inference_detector, init_detector, show_result
def parse_args():
parser = argparse.ArgumentParser(description='MMDetection image demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('imagepath', help='camera device id')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
parser.add_argument(
'--score-thr', type=float, default=0.5, help='bbox score threshold')
args = parser.parse_args()
return args
def main():
args = parse_args()
model = init_detector(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
result = inference_detector(model, args.imagepath)
show_result(
args.imagepath, result, model.CLASSES, score_thr=args.score_thr, wait_time=0)
if __name__ == '__main__':
main()
然后運行下面的命令:
python demo/image_demo.py configs/cascade_rcnn_r50_fpn_1x_poker.py work_dirs/cascade_rcnn_r50_fpn_1x_poker/latest.pth demo/poker_test.jpg
得到下面的結(jié)果,可以看到bbox框得非常tight塞关,分類也都正確了:
2. 計算mAP
計算mAP之前需要修改mmdetection/tools/voc_eval.py文件中的voc_eval函數(shù)抬探,改完后的圖:
然后通過下面命令產(chǎn)生poker_results.pkl文件:
python tools/test.py configs/cascade_rcnn_r50_fpn_1x_poker.py work_dirs/cascade_rcnn_r50_fpn_1x_poker/latest.pth --out poker_results.pkl
然后執(zhí)行如下命令,采用voc標準計算mAP:
python tools/voc_eval.py poker_results.pkl configs/cascade_rcnn_r50_fpn_1x_poker.py
便得到了下面的結(jié)果,可以看到mAP高達0.977小压,這當然因為撲克牌方方正正很容易檢測的緣故啦:
好线梗,算是玩了一下mmdetection吧。以后會常碰到它的~~