????????本文主要描述如何使用 Google 開(kāi)源的目標(biāo)檢測(cè) API 來(lái)訓(xùn)練目標(biāo)檢測(cè)器,內(nèi)容包括:安裝 TensorFlow/Object Detection API 和使用 TensorFlow/Object Detection API 訓(xùn)練自己的目標(biāo)檢測(cè)器演训。
一弟孟、安裝 TensorFlow Object Detection API
????????Google 開(kāi)源的目標(biāo)檢測(cè)項(xiàng)目 object_detection 位于與 tensorflow 獨(dú)立的項(xiàng)目 models(獨(dú)立指的是:在安裝 tensorflow 的時(shí)候并沒(méi)有安裝 models 部分)內(nèi):models/research/object_detection。models 部分的 GitHub 主頁(yè)為:
https://github.com/tensorflow/models
????????要使用 models 部分內(nèi)的目標(biāo)檢測(cè)功能 object_detection样悟,需要用戶手動(dòng)安裝 object_detection拂募。下面為詳細(xì)的安裝步驟:
1. 安裝依賴項(xiàng) matplotlib庭猩,pillow,lxml 等
????????使用 pip/pip3 直接安裝:
$ sudo pip/pip3 install matplotlib pillow lxml
其中如果安裝 lxml 不成功陈症,可使用
$ sudo apt-get install python-lxml python3-lxml
安裝蔼水。
2. 安裝編譯工具
$ sudo apt install protobuf-compiler
$ sudo apt-get install python-tk
$ sudo apt-get install python3-tk
3. 克隆 TensorFlow models 項(xiàng)目
????????使用 git 克隆 models 部分到本地,在終端輸入指令:
$ git clone https://github.com/tensorflow/models.git
克隆完成后录肯,會(huì)在終端當(dāng)前目錄出現(xiàn) models 的文件夾趴腋。要使用 git(分布式版本控制系統(tǒng)),首先得安裝 git:$ sudo apt-get install git
论咏。
4. 使用 protoc 編譯
????????在 models/research 目錄下的終端執(zhí)行:
$ protoc object_detection/protos/*.proto --python_out=.
將 object_detection/protos/ 文件下的以 .proto 為后綴的文件編譯為 .py 文件輸出优炬。
5. 配置環(huán)境變量
????????在 .bashrc 文件中加入環(huán)境變量。首先打開(kāi) .bashrc 文件:
$ sudo gedit ~/.bashrc
然后在文件末尾加入新行:
export PYTHONPATH=$PYTHONPATH:/.../models/research:/.../models/research/slim
其中省略號(hào)所在的兩個(gè)目錄需要填寫為 models/research 文件夾厅贪、models/research/slim 文件夾的完整目錄蠢护。保存之后執(zhí)行如下指令:
$ source ~/.bashrc
讓改動(dòng)立即生效。
6. 測(cè)試是否安裝成功
????????在 models/research 文件下執(zhí)行:
$ python/python3 object_detection/builders/model_builder_test.py
如果返回 OK卦溢,表示安裝成功糊余。
二、訓(xùn)練 TensorFlow 目標(biāo)檢測(cè)器
????????成功安裝好 TensorFlow Object Detection API 之后单寂,就可以按照 models/research/object_detection 文件夾下的演示文件 object_detection_tutorial.ipynb 來(lái)查看 Google 自帶的目標(biāo)檢測(cè)的檢測(cè)效果。其中吐辙,Google 自己訓(xùn)練好后的目標(biāo)檢測(cè)器都放在:
可以自己下載這些模型宣决,一一查看檢測(cè)效果。以下昏苏,假設(shè)你把某些預(yù)訓(xùn)練模型下載好了尊沸,放在models/ research/ object_detection 的某個(gè)文件夾下,比如自定義文件夾 pretrained_models贤惯。
????????要訓(xùn)練自己的模型洼专,除了使用 Google 自帶的預(yù)訓(xùn)練模型之外,最關(guān)鍵的是需要準(zhǔn)備自己的訓(xùn)練數(shù)據(jù)孵构。
????????以下屁商,詳細(xì)列出訓(xùn)練過(guò)程(后續(xù)部分文章將詳細(xì)介紹一些目標(biāo)檢測(cè)算法):
1. 準(zhǔn)備標(biāo)注工具和文件格式轉(zhuǎn)化工具
????????圖像標(biāo)注可以使用標(biāo)注工具 labelImg,直接使用
$ sudo pip install labelImg
安裝(當(dāng)前好像只支持Python2.7)颈墅。另外蜡镶,在此之前,需要安裝它的依賴項(xiàng) pyqt4:
$ sudo apt-get install pyqt4-dev-tools
(另一依賴項(xiàng) lxml 前面已安裝)恤筛。要使用 labelImg官还,只需要在終端輸入 labelImg 即可。
????????為了方便后續(xù)數(shù)據(jù)格式轉(zhuǎn)化毒坛,還需要準(zhǔn)備兩個(gè)文件格式轉(zhuǎn)化工具:xml_to_csv.py 和 generate_tfrecord.py望伦,它們的代碼分別列舉如下(它們可以從資料 [1] 中 GitHub 項(xiàng)目源代碼鏈接中下載林说。其中為了方便一般化使用,我已經(jīng)修改 generate_tfrecord.py 的部分內(nèi)容使得可以自定義圖像路徑和輸入 .csv 文件屯伞、輸出 .record 文件路徑腿箩,以及 6 中的 xxx_label_map.pbtxt 文件路徑):
(1) xml_to_csv.py 文件源碼:
import os
import glob
import pandas as pd
import xml.etree.ElementTree as ET
def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
xml_df = pd.DataFrame(xml_list, columns=column_name)
return xml_df
def main():
image_path = os.path.join(os.getcwd(), 'annotations')
xml_df = xml_to_csv(image_path)
xml_df.to_csv('road_signs_labels.csv', index=None)
print('Successfully converted xml to csv.')
if __name__ == '__main__':
main()
(2) 修改后的 generate_tfrecord.py 文件源碼:
"""
Usage:
# From tensorflow/models/
# Create train data:
python/python3 generate_tfrecord.py --csv_input=your path to read train.csv
--images_input=your path to read images
--output_path=your path to write train.record
--label_map_path=your path to read xxx_label_map.pbtxt
# Create validation data:
python/python3 generate_tfrecord.py --csv_input=you path to read val.csv
--images_input=you path to read images
--output_path=you path to write val.record
--label_map_path=your path to read xxx_label_map.pbtxt
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from collections import namedtuple
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('images_input', '', 'Path to the images input')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', '', 'Path to label map proto')
FLAGS = flags.FLAGS
def split(df, group):
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group)
return [data(filename, gb.get_group(x)) for filename, x in
zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, label_map_dict, images_path):
with tf.gfile.GFile(os.path.join(
images_path, '{}'.format(group.filename)), 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
width, height = image.size
filename = group.filename.encode('utf8')
image_format = b'jpg'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
xmins.append(row['xmin'] / width)
xmaxs.append(row['xmax'] / width)
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['class'].encode('utf8'))
classes.append(label_map_dict[row['class']])
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
images_path = FLAGS.images_input
examples = pd.read_csv(FLAGS.csv_input)
grouped = split(examples, 'filename')
for group in grouped:
tf_example = create_tf_example(group, label_map_dict, images_path)
writer.write(tf_example.SerializeToString())
writer.close()
output_path = FLAGS.output_path
print('Successfully created the TFRecords: {}'.format(output_path))
if __name__ == '__main__':
tf.app.run()
generate_tfrecord.py 也可以由 models/research/object_detection/dataset_tools 文件夾內(nèi)的相關(guān) .py 文件修改而來(lái)。后續(xù)也會(huì)有文章介紹怎么將圖像轉(zhuǎn)化為 .record 文件愕掏,敬請(qǐng)期待度秘。
2. 創(chuàng)建工作目錄,收集圖片
????????在 Ubuntu 中新建項(xiàng)目文件夾饵撑,比如 xxx_detection(xxx 自取剑梳,下同),在該文件夾內(nèi)新建文件夾 annotations滑潘,data垢乙,images,training语卤。將所有收集到的圖片放在 images 文件夾內(nèi)追逮。
3. 標(biāo)注圖片生成 xml 文件
????????利用標(biāo)注工具 labelImg 對(duì)所有收集的圖片進(jìn)行標(biāo)注,即將要檢測(cè)的目標(biāo)用矩形框框出粹舵,填入對(duì)應(yīng)的目標(biāo)類別名稱钮孵,生成對(duì)應(yīng)的 xml 文件,放在 annotations 文件夾內(nèi)眼滤。
4. 將所有的 .xml 文件整合成 .csv 文件
????????執(zhí)行 xml_to_csv.py(放在 xxx_detection文件夾下)巴席,將所有的 xml 標(biāo)注文件匯合成一個(gè) csv 文件,再?gòu)脑?csv 文件中分出用于訓(xùn)練和驗(yàn)證的文件 train.csv 和 val.csv(分割比例自茸缧琛)漾唉,放入 data 文件夾。
5. 將 .csv 文件轉(zhuǎn)化成 TensorFlow 要求的 .TFrecord 文件
????????將 generate_tfrecord.py 文件放在 TensorFlow models/research/object_detection 文件夾下堰塌,在該文件夾目錄下的終端執(zhí)行:
$ python3 generate_tfrecord.py --csv_input=/home/.../data/train.csv
--images_input=/home/.../images
--output_path=/home/.../data/train.record
--label_map_path=/home/.../training/xxx_label_map.pbtxt
類似的赵刑,對(duì) val.csv 執(zhí)行相同操作,生成 val.record 文件场刑。(其中 xxx_label_map.pbtxt 文件見(jiàn)下面的 6)
6. 編寫 .pbtxt 文件
????????仿照 TensorFlow models/research/object_detection/data 文件夾下的 .pbtxt 文件編寫自己的 .pbtxt 文件:對(duì)每個(gè)要檢測(cè)的類別寫入
item {
id: k
name: ‘xxx’
}
其中 item 之間空一行般此,類標(biāo)號(hào)從 1 開(kāi)始,即 k >= 1摇邦。將 .pbtxt 文件命名為 xxx_label_map.pbtxt 并放入training 文件夾恤煞。
7. 配置 .config 文件
????????從 TensorFlow models/research/object_detection/samples/configs 文件夾內(nèi)選擇合適的一個(gè) .config 文件復(fù)制到項(xiàng)目工程的 training 文件夾內(nèi),將名稱改為與工程相關(guān)的 保留模型名 _xxx.config(其中保留模型名為原 .config 文件關(guān)于模型的命名字段施籍,建議命名時(shí)保留下來(lái)居扒,xxx 為與項(xiàng)目相關(guān)的自己命名字段),打開(kāi)文件作如下修改:
(1)修改模型參數(shù)
????????將 model {} 中的 num_classes 修改為工程要檢測(cè)的類別個(gè)數(shù)丑慎。另外喜喂,也可以修改訓(xùn)練參數(shù):
train_config: {} => num_steps: xxx => schedule {} => step = xxx
num_steps 表示將要訓(xùn)練的次數(shù)瓤摧,刪除這一行為不確定次數(shù)訓(xùn)練(隨時(shí)可用 Ctrl+C 中斷),后面的 step 表示學(xué)習(xí)率每過(guò) step 步后進(jìn)行衰減玉吁。這些參數(shù)由自己的經(jīng)驗(yàn)確定照弥,也可以使用默認(rèn)值。
????????其它參數(shù)一般不需要修改进副。
(2)修改文件路徑
????????將 .config 文件中所有的 ’PATH_TO_BE_CONFIGURED’ 文件路徑修改為相應(yīng)的 .ckpt(預(yù)訓(xùn)練模型文件路徑)这揣,.record,.pbtxt 文件所在路徑影斑。
????????將修改后的 保留模型名_xxx.config 文件放在 training 文件夾內(nèi)给赞。
8. 開(kāi)始本地訓(xùn)練目標(biāo)檢測(cè)器
????????在 TensorFlow models/research/object_detection 目錄下的終端執(zhí)行:
$ python3 model_main.py --model_dir=/home/.../training
--pipeline_config_path=/home/.../training/保留模型名_xxx.config
進(jìn)行模型訓(xùn)練,期間每隔一定時(shí)間會(huì)輸出若干文件到 training 文件夾矫户。在訓(xùn)練過(guò)程中可使用 Ctrl+C 任意時(shí)刻中斷訓(xùn)練片迅,之后再執(zhí)行上述代碼會(huì)從斷點(diǎn)之處繼續(xù)訓(xùn)練,而不是從頭開(kāi)始(除非把訓(xùn)練輸出文件全部刪除)皆辽。
9. 查看實(shí)時(shí)訓(xùn)練曲線
????????在任意目錄下執(zhí)行:
$ tensorboard --logdir=/home/.../training
打開(kāi)返回的 http 鏈接查看 Loss 等曲線的實(shí)時(shí)變化情況柑蛇。
10. 導(dǎo)出 .pb 文件用于推斷
????????模型訓(xùn)練完后,生成的 .ckpt 文件已經(jīng)可以調(diào)用進(jìn)行目標(biāo)檢測(cè)驱闷。也可以將 .ckpt 文件轉(zhuǎn)化為 .pb 文件用于推斷耻台。在 TensorFlow models/research/object_detection 目錄下的終端執(zhí)行:
$ python3 export_inference_graph.py --input_type image_tensor
--pipeline_config_path /home/.../training/pipeline.config
--trained_checkpoint_prefix /home/.../training/model.ckpt-200000
--output_directory /home/.../training/output_inference_graph
????????執(zhí)行上述代碼之后會(huì)在 /home/.../training 文件夾內(nèi)看到新的文件夾 output_inference_graph,里面存儲(chǔ)著訓(xùn)練好的最終模型空另,如直接調(diào)用的用于推斷的文件:frozen_inference_graph.pb粘我。其中命令中 model.ckpt-200000 表示訓(xùn)練 200000 生成的模型,實(shí)際執(zhí)行上述代碼時(shí)要修改為自己訓(xùn)練多少次后生成的模型痹换。其它路徑和文件(夾)名稱也由自己任意指定。
11. 調(diào)用訓(xùn)練好的模型進(jìn)行目標(biāo)檢測(cè)
????????調(diào)用 frozen_inference_graph.pb 進(jìn)行目標(biāo)檢測(cè)請(qǐng)參考 TensorFlow models/research/object_detection 文件夾下的 object_detection_tutorial.ipynb 都弹。但該文件只針對(duì)單張圖像娇豫,對(duì)多張圖像不友好,因?yàn)槊繖z測(cè)一張圖像都要重新打開(kāi)一個(gè)會(huì)話(語(yǔ)句 with tf.Session() as sess
每張圖像執(zhí)行一次)畅厢,而這是非常耗時(shí)的操作冯痢。可以改成如下的形式:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 4 15:05:09 2017
@author: shirhe-lyh
"""
import time
import cv2
import numpy as np
import tensorflow as tf
#--------------Model preparation----------------
# Path to frozen detection graph. This is the actual model that is used for
# the object detection.
PATH_TO_CKPT = 'path_to_your_frozen_inference_graph.pb'
# Load a (frozen) Tensorflow model into memory
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular
# object was detected.
gboxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
gscores = detection_graph.get_tensor_by_name('detection_scores:0')
gclasses = detection_graph.get_tensor_by_name('detection_classes:0')
gnum_detections = detection_graph.get_tensor_by_name('num_detections:0')
# TODO: Add class names showing in the image
def detect_image_objects(image, sess, detection_graph):
# Expand dimensions since the model expects images to have
# shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image, axis=0)
# Actual detection.
(boxes, scores, classes, num_detections) = sess.run(
[gboxes, gscores, gclasses, gnum_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
boxes = np.squeeze(boxes)
scores = np.squeeze(scores)
height, width = image.shape[:2]
for i in range(boxes.shape[0]):
if (scores is None or
scores[i] > 0.5):
ymin, xmin, ymax, xmax = boxes[i]
ymin = int(ymin * height)
ymax = int(ymax * height)
xmin = int(xmin * width)
xmax = int(xmax * width)
score = None if scores is None else scores[i]
font = cv2.FONT_HERSHEY_SIMPLEX
text_x = np.max((0, xmin - 10))
text_y = np.max((0, ymin - 10))
cv2.putText(image, 'Detection score: ' + str(score),
(text_x, text_y), font, 0.4, (0, 255, 0))
cv2.rectangle(image, (xmin, ymin), (xmax, ymax),
(0, 255, 0), 2)
return image
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
video_path = 'path_to_your_video'
capture = cv2.VideoCapture(video_path)
while capture.isOpened():
if cv2.waitKey(30) & 0xFF == ord('q'):
break
ret, frame = capture.read()
if not ret:
break
t_start = time.clock()
detect_image_objects(frame, sess, detection_graph)
t_end = time.clock()
print('detect time per frame: ', t_end - t_start)
cv2.imshow('detected', frame)
capture.release()
cv2.destroyAllWindows()
這樣改動(dòng)之后框杜,有好處也有壞處浦楣,好處是處理視頻或很多圖像時(shí)只生成一次會(huì)話節(jié)省時(shí)間,而且從原文件中去掉了語(yǔ)句:
sys.path.append("..")
from object_detection.utils import ops as utils_ops
from utils import label_map_util
from utils import visualization_utils as vis_util
使得在任意目錄下都可以執(zhí)行咪辱。壞處是:上述代碼沒(méi)有使用 label_map_util 和 vis_util 等這些 object_detection 伴隨的模塊振劳,使得檢測(cè)結(jié)果顯示的時(shí)候只能自己利用 OpenCV 來(lái)做,而存在一個(gè)較大的缺陷:不能顯示檢測(cè)出的目標(biāo)的類別名稱(待完善)油狂。
資料:
[1]目標(biāo)干脆面君:動(dòng)動(dòng)手历恐,用TensorFlow訓(xùn)練自己的目標(biāo)檢測(cè)模型寸癌,36kr
[2]利用TensorFlow Object Detection API訓(xùn)練自己的數(shù)據(jù)集,紅黑聯(lián)盟
[3]TensorFlow models/research/object_detection的GitHub文檔