國外一位程序員分享了自己實現(xiàn)可愛的浣熊檢測器的經(jīng)歷
原文作者的github:
https://github.com/datitran/raccoon_dataset
(作者的數(shù)據(jù)集可以在這里下載)
我們先按照原作者步驟和已有的數(shù)據(jù)集芒填,將流程先走一遍纸俭;后續(xù)我們再訓練我們自己的檢測器
首先,準備數(shù)據(jù)集
數(shù)據(jù)集包含兩個部分,一個是源數(shù)據(jù)圖片庄萎,一個是源數(shù)據(jù)的標簽數(shù)據(jù)
我們這里直接使用原作者數(shù)據(jù):
https://github.com/datitran/raccoon_dataset
(base) jiadongfeng:~/tensorflow/dataset/raccoon$ git clone https://github.com/datitran/raccoon_dataset.git
Cloning into 'raccoon_dataset'...
remote: Enumerating objects: 652, done.
remote: Total 652 (delta 0), reused 0 (delta 0), pack-reused 652
Receiving objects: 100% (652/652), 48.01 MiB | 245.00 KiB/s, done.
Resolving deltas: 100% (415/415), done.
下載后的文件目錄如下:
-
images:圖片源文件目錄妖泄,包含200個浣熊圖片文件
第一張圖片為:
- annotations:標注文件目錄,里面包含了200圖片的文件標注
(base) jiadongfeng:~/tensorflow/dataset/raccoon_dataset/annotations$ ls
...
raccoon-100.xml raccoon-146.xml raccoon-191.xml raccoon-55.xml
raccoon-101.xml raccoon-147.xml raccoon-192.xml raccoon-56.xml
...
raccoon-144.xml raccoon-18.xml raccoon-53.xml raccoon-99.xml
raccoon-145.xml raccoon-190.xml raccoon-54.xml raccoon-9.xml
...
第一張圖片的標注信息為:
<annotation verified="yes">
<folder>images</folder><--! 圖片所在的文件夾 -->
<filename>raccoon-1.jpg</filename><--! 圖片名稱 -->
<--! 圖片路徑客税,需要轉(zhuǎn)成我們自己的路徑/tensorflow/dataset/raccoon_dataset/images/raccoon-1.jpg-->
<path>/Users/datitran/Desktop/raccoon/images/raccoon-1.jpg</path>
<source>
<database>Unknown</database>
</source>
<--! 圖片大小-->
<size>
<width>650</width>
<height>417</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<--! 描述圖片中浣熊的相關(guān)信息 -->
<object>
<name>raccoon</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>81</xmin>
<ymin>88</ymin>
<xmax>522</xmax>
<ymax>408</ymax>
</bndbox>
</object>
</annotation>
獲取TFRecord數(shù)據(jù)
接下來况褪,我們需要將圖片和標注文件轉(zhuǎn)化成對tensorflow支持比較好的TFRecord數(shù)據(jù),便于后續(xù)數(shù)據(jù)的處理
這里假設(shè)更耻,你已經(jīng)根據(jù)物體檢測1:安裝和驗證對象檢測API搭建好了相關(guān)環(huán)境
- 創(chuàng)建python文件
在~/tensorflow/models/research/object_detection/dataset_tools目錄下)找到create_pascal_tf_record.py 文件测垛,這個就是tensorflow提供的將pascal voc格式轉(zhuǎn)換為TFRecord格式的腳本,執(zhí)行如下腳本秧均,將其復制一份食侮,我們需要將它改成針對浣熊數(shù)據(jù)集的處理流程
cp tensorflow/models/research/object_detection/dataset_tools/create_pascal_tf_record.py /home/jdf/tensorflow/raccoon_dataset\create_raccoon_tf_record.py
修改后的python文件為:
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset\create_raccoon_tf_record.py
# -*- coding: utf-8 -*-
#create_raccoon_tf_record.py
r"""Convert raw PASCAL dataset to TFRecord for object_detection.
Example usage:
python create_raccoon_tf_record.py \
--data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
--set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt \
--annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
--year='VOC2007' \
--output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record \
--label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
--ignore_difficult_instances=False
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import os
import sys
from lxml import etree
import PIL.Image
import tensorflow as tf
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
import warnings
warnings.filterwarnings("ignore")
#default value
data_dir_default = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images'
set = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt'
annotations_dir = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations'
year='VOC2007'
output_path = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record'
label_map_path = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt'
ignore_difficult_instances = False
# tf 中定義了 tf.app.flags.FLAGS,用于接受從終端傳入的命令行參數(shù)
flags = tf.app.flags
# 定義一個用于接收 string 類型數(shù)值的變量,帶3個參數(shù)号涯,分別是變量名稱,默認值锯七,用法描述
flags.DEFINE_string('data_dir', data_dir_default, 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', set, 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('annotations_dir', annotations_dir,
'(Relative) path to annotations directory.')
flags.DEFINE_string('year', year, 'Desired challenge year.')
flags.DEFINE_string('output_path', output_path, 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', label_map_path,
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', ignore_difficult_instances, 'Whether to ignore '
'difficult instances')
FLAGS = flags.FLAGS
SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
data: dict holding PASCAL XML fields for a single image (obtained by
running dataset_util.recursiv img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])e_parse_xml_to_dict)
dataset_directory: Path to root directory holding PASCAL dataset
label_map_dict: A map from string label names to integers ids.
ignore_difficult_instances: Whether to skip difficult instances in the
dataset (default: False).
image_subdirectory: String specifying subdirectory within the
PASCAL dataset directory holding the actual image data.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
"""
file_name = data['filename']
img_path = os.path.join(dataset_directory, file_name)
full_path = os.path.join(dataset_directory, img_path)
#print("#start dict_to_tf_example:"+file_name)
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
if 'object' in data:
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
#print("##finish dict_to_tf_example:"+file_name)
return example
#main函數(shù) 函數(shù)入口
def main(_):
#if FLAGS.set not in SETS:
# raise ValueError('set must be in : {}'.format(SETS))
# if FLAGS.year not in YEARS:
# raise ValueError('year must be in : {}'.format(YEARS))
print("#start main...")
data_dir = FLAGS.data_dir
print("--data_dir [%s]"%data_dir)
years = ['--VOC2007', 'VOC2012']
if FLAGS.year != 'merged':
years = [FLAGS.year]
print("--years [%s]"%str(years))
writer = tf.io.TFRecordWriter(FLAGS.output_path)
print("--writer to [%s]"%str(FLAGS.output_path))
print("--label_map_path[%s]"%str(FLAGS.label_map_path))
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
print("--label_map_dict[%s]"%str(label_map_dict))
print("#start for years....")
for year in years:
print('#Reading from PASCAL %s dataset.'%year)
examples_path = FLAGS.set
print("--examples_path[%s]"%examples_path)
annotations_dir = FLAGS.annotations_dir
print("--annotations_dir[%s]"%annotations_dir)
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 10 == 0:
print('On image %d of %d' %(idx,len(examples_list)))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
#parse every xml file's annotation in 'raccoon_dataset/annotations'
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
#logging.info('dict_to_tf for %s', data['filename'])
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
print("#close writer %s"%FLAGS.output_path)
if __name__ == '__main__':
#run main function.
#if enter function not main()链快,for example,'test()'眉尸,please run 'tf.app.run(test)'
tf.app.run()
對上述py文件中域蜗,傳入的各個問價,我們看下主要文件的意義:
- raccoon_label_map.pbtxt
loccation:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt
框架需要我們定義好我們的類別ID與類別名稱的關(guān)系噪猾,通常用pbtxt格式文件保存, 內(nèi)容如下:
item {
id: 1
name: 'raccoon'
}
因為我們只有一個類別霉祸,所以這里就只需要定義1個item,若你有多個類別袱蜡,就需要多個item丝蹭,
注意, id從1開始,name的值要和標注文件里的類別name相同坪蚁,即你在圖像打標的時候標記的是raccoon奔穿,這里就要寫raccoon,不能寫"浣熊
- train_db.txt
location:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt
訓練集敏晤,這里我們從200條數(shù)據(jù)中選擇160條數(shù)據(jù)作為訓練集
raccoon-5
raccoon-12
raccoon-107
raccoon-116
raccoon-123
raccoon-70
raccoon-152
raccoon-63
raccoon-135
raccoon-161
raccoon-171
raccoon-118
raccoon-124
raccoon-169
raccoon-38
raccoon-98
raccoon-158
raccoon-93
raccoon-34
raccoon-69
raccoon-35
raccoon-146
raccoon-78
raccoon-19
raccoon-127
raccoon-66
raccoon-117
raccoon-62
raccoon-200
raccoon-122
raccoon-173
raccoon-33
raccoon-73
raccoon-77
raccoon-7
raccoon-191
raccoon-86
raccoon-180
raccoon-61
raccoon-60
raccoon-49
raccoon-32
raccoon-27
raccoon-197
raccoon-126
raccoon-189
raccoon-75
raccoon-156
raccoon-192
raccoon-57
raccoon-167
raccoon-45
raccoon-65
raccoon-82
raccoon-184
raccoon-3
raccoon-178
raccoon-30
raccoon-164
raccoon-67
raccoon-44
raccoon-166
raccoon-43
raccoon-168
raccoon-170
raccoon-132
raccoon-108
raccoon-101
raccoon-20
raccoon-2
raccoon-22
raccoon-11
raccoon-74
raccoon-176
raccoon-114
raccoon-14
raccoon-36
raccoon-129
raccoon-177
raccoon-141
raccoon-151
raccoon-94
raccoon-179
raccoon-130
raccoon-128
raccoon-193
raccoon-104
raccoon-8
raccoon-137
raccoon-76
raccoon-185
raccoon-26
raccoon-81
raccoon-190
raccoon-120
raccoon-175
raccoon-112
raccoon-90
raccoon-46
raccoon-91
raccoon-13
raccoon-119
raccoon-149
raccoon-50
raccoon-181
raccoon-162
raccoon-136
raccoon-53
raccoon-143
raccoon-48
raccoon-163
raccoon-125
raccoon-31
raccoon-188
raccoon-37
raccoon-154
raccoon-157
raccoon-195
raccoon-47
raccoon-97
raccoon-187
raccoon-80
raccoon-153
raccoon-139
raccoon-147
raccoon-25
raccoon-84
raccoon-174
raccoon-110
raccoon-59
raccoon-52
raccoon-99
raccoon-4
raccoon-92
raccoon-186
raccoon-1
raccoon-41
raccoon-71
raccoon-194
raccoon-10
raccoon-134
raccoon-140
raccoon-16
raccoon-142
raccoon-172
raccoon-24
raccoon-109
raccoon-89
raccoon-160
raccoon-111
raccoon-54
raccoon-15
raccoon-182
raccoon-18
raccoon-144
raccoon-138
raccoon-39
raccoon-6
raccoon-51
raccoon-103
生成以上數(shù)據(jù)的py文件為:
import os
import random
i = 0
pt="/home/jdf/tensorflow/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
if temp.endswith(".jpg"):
if i<160:
print (temp.replace('.jpg',''))
i = i+1
直接執(zhí)行贱田,即可獲得160條數(shù)據(jù)
- test_db.txt
location:/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt
預測集,從200數(shù)據(jù)重選擇另外的40條數(shù)據(jù)作為測試集
raccoon-23
raccoon-21
raccoon-85
raccoon-131
raccoon-29
raccoon-115
raccoon-183
raccoon-199
raccoon-72
raccoon-17
raccoon-83
raccoon-9
raccoon-56
raccoon-68
raccoon-87
raccoon-100
raccoon-79
raccoon-145
raccoon-64
raccoon-96
raccoon-196
raccoon-58
raccoon-105
raccoon-106
raccoon-148
raccoon-42
raccoon-55
raccoon-40
raccoon-155
raccoon-88
raccoon-165
raccoon-28
raccoon-102
raccoon-133
raccoon-113
raccoon-95
raccoon-121
raccoon-150
raccoon-159
raccoon-198
生成以上數(shù)據(jù)的py文件為:
import os
import random
i = 0
pt="/home/jdf/tensorflow/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
if temp.endswith(".jpg"):
if i>=160:
print (temp.replace('.jpg',''))
i = i+1
分割圖片集合的代碼為:
import os
import random
pt="/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images"
image_name=os.listdir(pt)
for temp in image_name:
if temp.endswith(".jpg"):
print (temp.replace('.jpg',''))
- 生成訓練集的TFRecord文件
執(zhí)行以下命令之前茵典,請確保你已經(jīng)編譯了protoc和設(shè)置了PYTHONPATH
cd ~/tensorflow/models/research/
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.p
否則會拋出以下異常:
ImportError: No module named object_detection.utils
運行命令湘换,生成tfrecord文件:
python create_raccoon_tf_record.py \
--data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
--set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.txt \
--annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
--year='VOC2007' \
--output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train.record \
--label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
--ignore_difficult_instances=False
data_dir:搜集的源數(shù)據(jù)集合
set:160訓練集合文件
annotations_dir:各個圖片數(shù)據(jù)的標注文件
year:數(shù)據(jù)集的年份,代碼中會進行校驗统阿,可以根據(jù)實際情況修改相關(guān)代碼彩倚,沒啥特殊意義
label_map_path:如上文講述過的,存儲要訓練的數(shù)據(jù)的類別信息
輸出的結(jié)果為:
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record
- 生成測試集的TFRecord文件
python create_raccoon_tf_record.py \
--data_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images \
--set=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.txt \
--annotations_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/annotations \
--year='VOC2007' \
--output_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record \
--label_map_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt \
--ignore_difficult_instances=False
各個參數(shù)的意義扶平,同訓練集參數(shù)
輸出的結(jié)果為:
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record
下載預訓練模型
在安裝和驗證對象檢測API章節(jié)帆离,我們已經(jīng)下載了物體檢測API,里面有的模型文件model.ckpt结澄;我們基于原生的ssd_mobilenet_v1作為預訓練模型哥谷,在此模型的基礎(chǔ)上進行遷徙學習訓練
base) jiadongfeng:~/tensorflow/models/research/object_detection/ssd_mobilenet_v1_coco_2017_11_17$ ls
checkpoint model.ckpt.index
frozen_inference_graph.pb model.ckpt.meta
model.ckpt.data-00000-of-00001 saved_model
創(chuàng)建配置文件
我們基于原有的配置文件進行改動
復制
object_dection/samples/configs/ssd_mobilenet_v1_coco.config
到
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config
# SSD with Mobilenet v1 configuration for MSCOCO Dataset.
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
model {
ssd {
num_classes: 1
box_coder {
faster_rcnn_box_coder {
y_scale: 10.0
x_scale: 10.0
height_scale: 5.0
width_scale: 5.0
}
}
matcher {
argmax_matcher {
matched_threshold: 0.5
unmatched_threshold: 0.5
ignore_thresholds: false
negatives_lower_than_unmatched: true
force_match_for_each_row: true
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
num_layers: 6
min_scale: 0.2
max_scale: 0.95
aspect_ratios: 1.0
aspect_ratios: 2.0
aspect_ratios: 0.5
aspect_ratios: 3.0
aspect_ratios: 0.3333
}
}
image_resizer {
fixed_shape_resizer {
height: 300
width: 300
}
}
box_predictor {
convolutional_box_predictor {
min_depth: 0
max_depth: 0
num_layers_before_predictor: 0
use_dropout: false
dropout_keep_probability: 0.8
kernel_size: 1
box_code_size: 4
apply_sigmoid_to_scores: false
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
}
feature_extractor {
type: 'ssd_mobilenet_v1'
min_depth: 16
depth_multiplier: 1.0
conv_hyperparams {
activation: RELU_6,
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
mean: 0.0
}
}
batch_norm {
train: true,
scale: true,
center: true,
decay: 0.9997,
epsilon: 0.001,
}
}
}
loss {
classification_loss {
weighted_sigmoid {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
hard_example_miner {
num_hard_examples: 3000
iou_threshold: 0.99
loss_type: CLASSIFICATION
max_negatives_per_positive: 3
min_negatives_per_image: 0
}
classification_weight: 1.0
localization_weight: 1.0
}
normalize_loss_by_num_matches: true
post_processing {
batch_non_max_suppression {
score_threshold: 1e-8
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 100
}
score_converter: SIGMOID
}
}
}
train_config: {
batch_size: 24
optimizer {
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
}
#修改點1:設(shè)置遷徙學習預訓練模型
fine_tune_checkpoint: "/home/jiadongfeng/tensorflow/models/research/object_detection/ssd_mobilenet_v1_coco_2017_11_17/model.ckpt"
from_detection_checkpoint: true
# Note: The below line limits the training process to 200K steps, which we
# empirically found to be sufficient enough to train the pets dataset. This
# effectively bypasses the learning rate schedule (the learning rate will
# never decay). Remove the below line to train indefinitely.
num_steps: 200000
data_augmentation_options {
random_horizontal_flip {
}
}
data_augmentation_options {
ssd_random_crop {
}
}
}
#修改點2:設(shè)置訓練集record和map文件
train_input_reader: {
tf_record_input_reader {
input_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/train_db.record"
}
label_map_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt"
}
#修改點3:設(shè)置測試集的圖片數(shù)量和驗證循環(huán)次數(shù)
eval_config: {
num_examples: 40
# Note: The below line limits the evaluation process to 10 evaluations.
# Remove the below line to evaluate indefinitely.
max_evals: 10
}
#修改點2:設(shè)置測試集record和map文件
eval_input_reader: {
tf_record_input_reader {
input_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/test_db.record"
}
label_map_path: "/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt"
shuffle: false
num_readers: 1
}
開始訓練
在object_detection路徑下,執(zhí)行下面的命令,開始訓練
python ./legacy/train.py --logtostderr \--pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \--train_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train
運行結(jié)果:
...
I0303 12:19:59.530555 140558114309504 learning.py:507] global step 10274: loss = 0.8229 (4.883 sec/step)
...
訓練需要大量時間麻献,訓練代碼總設(shè)置的迭代測試未200k们妥,每次跌點需要5秒所有,總共需要11.5天勉吻;
我在訓練到40k多次的時候主動結(jié)束了
INFO:tensorflow:global step 48267: loss = 1.3767 (8.383 sec/step)
I0308 16:24:15.022326 140503706822016 learning.py:507] global step 48267: loss = 1.3767 (8.383 sec/step)
INFO:tensorflow:global step 48268: loss = 0.7260 (11.070 sec/step)
I0308 16:24:26.094815 140503706822016 learning.py:507] global step 48268: loss = 0.7260 (11.070 sec/step)
用測試集評估訓練效果
python ./legacy/eval.py --logtostderr \ --pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \ --checkpoint_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train \ --eval_dir=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/eval
運行結(jié)果:
W0314 16:56:10.875488 139739513279872 deprecation_wrapper.py:119] From /home/jiadongfeng/tensorflow/models/research/object_detection/exporter.py:274: The name tf.saved_model.tag_constants.SERVING is deprecated. Please use tf.saved_model.SERVING instead.
INFO:tensorflow:No assets to save.
I0314 16:56:10.876045 139739513279872 builder_impl.py:636] No assets to save.
INFO:tensorflow:No assets to write.
I0314 16:56:10.876254 139739513279872 builder_impl.py:456] No assets to write.
INFO:tensorflow:SavedModel written to: /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/saved_model/saved_model.pb
I0314 16:56:11.369002 139739513279872 builder_impl.py:421] SavedModel written to: /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/saved_model/saved_model.pb
WARNING:tensorflow:From /home/jiadongfeng/tensorflow/models/research/object_detection/utils/config_util.py:180: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.
INFO:tensorflow:Writing pipeline config file to /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/pipeline.config
I0314 16:56:11.420886 139739513279872 config_util.py:182] Writing pipeline config file to /home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/pipeline.config
將檢查點文件導出為凍結(jié)的模型文件
python export_inference_graph.py \
--pipeline_config_path=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_config/ssd_mobilenet_v1_raccoon.config \
--trained_checkpoint_prefix=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/model.ckpt-48253 \
--output_directory=/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train
導出的結(jié)果為:
/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/frozen_inference_graph.pb
用模型進行浣熊的識別
修改物體檢測1:安裝和驗證對象檢測API中的驗證demo python文件tensorflow物體檢測API完整demo
....
# 執(zhí)行上文訓練生產(chǎn)的pb文件
PATH_TO_FROZEN_GRAPH = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_train/frozen_inference_graph.pb'
# 執(zhí)行圖像數(shù)據(jù)label_map文件
PATH_TO_LABELS = '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/jdf_data/raccoon_label_map.pbtxt'
...
###Detection
# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
#要預測的文件地址
TEST_IMAGE_PATHS = [ '/home/jiadongfeng/tensorflow/dataset/raccoon_dataset/images/raccoon-33.jpg' ]
...
完整的python文件參考:
浣熊預測demo
預測結(jié)果
報錯總結(jié):
- ImportError: No module named 'pycocotools'
解決方法是:在https://github.com/waleedka/coco上下載coco數(shù)據(jù)集的開源包到本地监婶,注意,使用terminal中的:
git clone https://github.com/waleedka/coco
所以我們建議直接從github上下載(Download ZIP),如果是在ssh遠程鏈接服務(wù)器操作的同學惑惶,建議先下載到本地煮盼,再scp上傳到服務(wù)器上。
接下來進入下載好的coco-master目錄中带污,進入pythonAPI:
PythonAPI$:pip install pycocotools
- limits.h:194:15: fatal error: limits.h: No such file or directory
sudo apt-get install build-essential //install build-essential(optional)
sudo apt-get update //install linux-headers
sudo apt-get install linux-headers-$(uname -r)
sudo apt-get update && sudo apt-get install build-essential linux-headers-$(uname -r)
- 運行python train的時候進程被殺
進入文件夾:
cd /var/log/
查看殺死的進程信息:
journalctl -xb | egrep -i 'killed process'
Feb 28 20:28:44 jiadongfeng-VirtualBox kernel: Killed process 1612 (train.py) total-vm:6352780kB, anon-rss:3783324kB, file-rss:0kB, shmem-rss:0kB
Feb 28 20:43:14 jiadongfeng-VirtualBox kernel: Killed process 1748 (train.py) total-vm:6257152kB, anon-rss:3692180kB, file-rss:0kB, shmem-rss:0kB
由于訓練過程需要大量的內(nèi)存僵控,因為內(nèi)存不足被殺;此時鱼冀,你需要提高你的虛擬機的可用內(nèi)存上限报破,我為此為電腦多擴展了8G內(nèi)存