Preparing Inputs
代碼高能預(yù)警
Tensorflow Object Detection API 在讀取數(shù)據(jù)中使用了TFRecord文件格式。API提供了兩個(gè)示例腳本烈拒,(create_pascal_tf_record.py
和 create_pet_tf_record.py
)荆几。這里我們精讀一下代碼create_pascal_tf_record.py
。
掌握TFRocord讀取方法的可以跳級(jí)了行拢。
先看一下License
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #
這個(gè)腳本的主要用處是把PASCAL數(shù)據(jù)集轉(zhuǎn)換成TFRecord舟奠。
用法是
./create_pascal_tf_record --data_dir=/home/user/VOCdevkit \ --year=VOC2012 \ --output_path=/home/user/pascal.record
引入各種庫
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import os
from lxml import etree
import PIL.Image
import tensorflow as tf
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
到這一步沼瘫,程序都在引入各種各樣的庫晕鹊,沒有的裝就是了暴浦。
flags
flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
'(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
FLAGS = flags.FLAGS
Tensorflow 中的flags類似于argv歌焦,基本用法是flags.DEFINE_類型('參數(shù)名稱'独撇,'默認(rèn)值','參數(shù)描述')卵史。進(jìn)一步了解flags用法請(qǐng)移步tensorflow 學(xué)習(xí)(三)使用flags定義命令行參數(shù) 搜立。
dict_to_tf_example
SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
full_path = os.path.join(dataset_directory, img_path)
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
這段主要定義了一個(gè)函數(shù)dict_to_tf_example的函數(shù),用以將PASCAL數(shù)據(jù)集中的XML標(biāo)注文件轉(zhuǎn)換為tf.Example.
輸入?yún)?shù)為:
- data: 包含標(biāo)注信息的XML文件址晕。PASCAL數(shù)據(jù)集中顿锰,每張圖片的標(biāo)注信息存放于對(duì)應(yīng)的XML文件中。在main函數(shù)中刘陶,data是通過
dataset_util.recursive_parse_xml_to_dict
的方法將XML中信息導(dǎo)入為字典獲取的易核; - dataset_directory: 你懂得浪默;
- label_map_dict: 為每一個(gè)類別賦予一個(gè)id纳决;由默認(rèn)路徑下已有文本給出碰逸;
- ignore_difficult_instances: 是否忽略數(shù)據(jù)集中的difficult_instances。 保持默認(rèn)即可阔加;
- image_subdirectory: 包含Images的PASCAL數(shù)據(jù)集的子文件夾饵史,同樣保持默認(rèn)即可。
在得到圖片的絕對(duì)路徑后(full_path)胜榔,通過GFile實(shí)現(xiàn)對(duì)圖片的讀取胳喷,并用PIL打開成為我們喜聞樂見的[c,h,w]格式。
而后夭织,將data傳過來的信息轉(zhuǎn)化為規(guī)范化的格式(x/width,y/height)添加到列表中吭露。說到這里就不得不夸一下dataset_util.recursive_parse_xml_to_dict
這個(gè)配件了,from XML to dict尊惰,很方便的讲竿。
再然后定義了一個(gè)tf.train.Example 實(shí)例example弄屡,將獲得的信息全加進(jìn)去题禀,最后返回example。
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
if FLAGS.year not in YEARS:
raise ValueError('year must be in : {}'.format(YEARS))
data_dir = FLAGS.data_dir
years = ['VOC2007', 'VOC2012']
if FLAGS.year != 'merged':
years = [FLAGS.year]
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
'aeroplane_' + FLAGS.set + '.txt')
annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.app.run()
把example保存為TFRecord格式膀捷。