TFRecord生成
一、為什么使用TFRecord?
正常情況下我們訓練文件夾經(jīng)常會生成 train, test 或者val文件夾苏潜,這些文件夾內(nèi)部往往會存著成千上萬的圖片或文本等文件,這些文件被散列存著变勇,這樣不僅占用磁盤空間恤左,并且再被一個個讀取的時候會非常慢,繁瑣搀绣。占用大量內(nèi)存空間(有的大型數(shù)據(jù)不足以一次性加載)飞袋。此時我們TFRecord格式的文件存儲形式會很合理的幫我們存儲數(shù)據(jù)。TFRecord內(nèi)部使用了“Protocol Buffer”二進制數(shù)據(jù)編碼方案链患,它只占用一個內(nèi)存塊巧鸭,只需要一次性加載一個二進制文件的方式即可,簡單麻捻,快速纲仍,尤其對大型訓練數(shù)據(jù)很友好。而且當我們的訓練數(shù)據(jù)量比較大的時候贸毕,可以將數(shù)據(jù)分成多個TFRecord文件郑叠,來提高處理效率。
二明棍、 生成TFRecord簡單實現(xiàn)方式
我們可以分成兩個部分來介紹如何生成TFRecord乡革,分別是TFRecord生成器以及樣本Example模塊。
- TFRecord生成器
writer = tf.python_io.TFRecordWriter(record_path)
writer.write(tf_example.SerializeToString())
writer.close()
這里面writer
就是我們TFrecord生成器击蹲。接著我們就可以通過writer.write(tf_example.SerializeToString())
來生成我們所要的tfrecord文件了署拟。這里需要注意的是我們TFRecord生成器在寫完文件后需要關閉writer.close()
。這里tf_example.SerializeToString()
是將Example中的map壓縮為二進制文件歌豺,更好的節(jié)省空間推穷。那么tf_example是如何生成的呢?那就是下面所要介紹的樣本Example模塊了类咧。
- Example模塊
首先們來看一下Example協(xié)議塊是什么樣子的馒铃。
message Example {
Features features = 1;
};
message Features {
map<string, Feature> feature = 1;
};
message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
我們可以看出上面的tf_example可以寫入的數(shù)據(jù)形式有三種蟹腾,分別是BytesList, FloatList以及Int64List的類型。那我們?nèi)绾螌懸粋€tf_example呢区宇?下面有一個簡單的例子娃殖。
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
tf_example = tf.train.Example(
features=tf.train.Features(feature={
'image/encoded': bytes_feature(encoded_jpg),
'image/format': bytes_feature('jpg'.encode()),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)}))
下面我們來好好從外部往內(nèi)部分解來解釋一下上面的內(nèi)容。
(1)tf.train.Example(features = None)
這里的features是tf.train.Features類型的特征實例议谷。
(2)tf.train.Features(feature = None)
這里的feature是以字典的形式存在炉爆,*key:要保存數(shù)據(jù)的名字 ?? value:要保存的數(shù)據(jù),但是格式必須符合tf.train.Feature實例要求卧晓。
三芬首、 生成TFRecord文件完整代碼實例
首先我們需要提供數(shù)據(jù)集
通過圖片文件夾我們可以知道這里面總共有七種分類圖片,類別的名稱就是每個文件夾名稱逼裆,每個類別文件夾存儲各自的對應類別的很多圖片郁稍。下面我們通過一下代碼(
generate_annotation_json.py
和generate_tfrecord.py
)生成train.record。
- generate_annotation_json.py
# -*- coding: utf-8 -*-
# @Time : 2018/11/22 22:12
# @Author : MaochengHu
# @Email : wojiaohumaocheng@gmail.com
# @File : generate_annotation_json.py
# @Software: PyCharm
import os
import json
def get_annotation_dict(input_folder_path, word2number_dict):
label_dict = {}
father_file_list = os.listdir(input_folder_path)
for father_file in father_file_list:
full_father_file = os.path.join(input_folder_path, father_file)
son_file_list = os.listdir(full_father_file)
for image_name in son_file_list:
label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file]
return label_dict
def save_json(label_dict, json_path):
with open(json_path, 'w') as json_path:
json.dump(label_dict, json_path)
print("label json file has been generated successfully!")
- generate_tfrecord.py
# -*- coding: utf-8 -*-
# @Time : 2018/11/23 0:09
# @Author : MaochengHu
# @Email : wojiaohumaocheng@gmail.com
# @File : generate_tfrecord.py
# @Software: PyCharm
import os
import tensorflow as tf
import io
from PIL import Image
from generate_annotation_json import get_annotation_dict
flags = tf.app.flags
flags.DEFINE_string('images_dir',
'/data2/raycloud/jingxiong_datasets/six_classes/images',
'Path to image(directory)')
flags.DEFINE_string('annotation_path',
'/data1/humaoc_file/classify/data/annotations/annotations.json',
'Path to annotation')
flags.DEFINE_string('record_path',
'/data1/humaoc_file/classify/data/train_tfrecord/train.record',
'Path to TFRecord')
FLAGS = flags.FLAGS
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def process_image_channels(image):
process_flag = False
# process the 4 channels .png
if image.mode == 'RGBA':
r, g, b, a = image.split()
image = Image.merge("RGB", (r,g,b))
process_flag = True
# process the channel image
elif image.mode != 'RGB':
image = image.convert("RGB")
process_flag = True
return image, process_flag
def process_image_reshape(image, resize):
width, height = image.size
if resize is not None:
if width > height:
width = int(width * resize / height)
height = resize
else:
width = resize
height = int(height * resize / width)
image = image.resize((width, height), Image.ANTIALIAS)
return image
def create_tf_example(image_path, label, resize=None):
with tf.gfile.GFile(image_path, 'rb') as fid:
encode_jpg = fid.read()
encode_jpg_io = io.BytesIO(encode_jpg)
image = Image.open(encode_jpg_io)
# process png pic with four channels
image, process_flag = process_image_channels(image)
# reshape image
image = process_image_reshape(image, resize)
if process_flag == True or resize is not None:
bytes_io = io.BytesIO()
image.save(bytes_io, format='JPEG')
encoded_jpg = bytes_io.getvalue()
width, height = image.size
tf_example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded': bytes_feature(encode_jpg),
'image/format': bytes_feature(b'jpg'),
'image/class/label': int64_feature(label),
'image/height': int64_feature(height),
'image/width': int64_feature(width)
}
))
return tf_example
def generate_tfrecord(annotation_dict, record_path, resize=None):
num_tf_example = 0
writer = tf.python_io.TFRecordWriter(record_path)
for image_path, label in annotation_dict.items():
if not tf.gfile.GFile(image_path):
print("{} does not exist".format(image_path))
tf_example = create_tf_example(image_path, label, resize)
writer.write(tf_example.SerializeToString())
num_tf_example += 1
if num_tf_example % 100 == 0:
print("Create %d TF_Example" % num_tf_example)
writer.close()
print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path))
def main(_):
word2number_dict = {
"combinations": 0,
"details": 1,
"sizes": 2,
"tags": 3,
"models": 4,
"tileds": 5,
"hangs": 6
}
images_dir = FLAGS.images_dir
#annotation_path = FLAGS.annotation_path
record_path = FLAGS.record_path
annotation_dict = get_annotation_dict(images_dir, word2number_dict)
generate_tfrecord(annotation_dict, record_path)
if __name__ == '__main__':
tf.app.run()
* 這里需要說明的是generate_annotation_json.py是為了得到圖片標注的label_dict胜宇。通過這個代碼塊可以獲得我們需要的圖片標注字典耀怜,key是圖片具體地址, value是圖片的類別桐愉,具體實例如下:
{
"/images/hangs/862e67a8-5bd9-41f1-8c6d-876a3cb270df.JPG": 6,
"/images/tags/adc264af-a76b-4477-9573-ac6c435decab.JPG": 3,
"/images/tags/fd231f5a-b42c-43ba-9e9d-4abfbaf38853.JPG": 3,
"/images/hangs/2e47d877-1954-40d6-bfa2-1b8e3952ebf9.jpg": 6,
"/images/tileds/a07beddc-4b39-4865-8ee2-017e6c257e92.png": 5,
"/images/models/642015c8-f29d-4930-b1a9-564f858c40e5.png": 4
}
- 如何運行代碼
(1)首先我們的文件夾構成形式是如下結構财破,其中images_root
是圖片根文件夾,combinations, details, sizes, tags, models, tileds, hangs
分別存放不同類別的圖片文件夾仅财。
-<images_root>
-<combinations>
-圖片.jpg
-<details>
-圖片.jpg
-<sizes>
-圖片.jpg
-<tags>
-圖片.jpg
-<models>
-圖片.jpg
-<tileds>
-圖片.jpg
-<hangs>
-圖片.jpg
(2)建立文件夾TFRecord
,并將generate_tfrecord.py
和generate_annotation_json.py
這兩個python文件放入文件夾內(nèi)狈究,需要注意的是我們需要將 generate_tfrecord.py
文件中字典word2number_dict換成自己的字典(即key是放不同類別的圖片文件夾名稱,value是對應的分類number)
word2number_dict = {
"combinations": 0,
"details": 1,
"sizes": 2,
"tags": 3,
"models": 4,
"tileds": 5,
"hangs": 6
}
(3)直接執(zhí)行代碼 python3/python2 ./TFRecord/generate_tfrecord.py --image_dir="images_root地址" --record_path="你想要保存record地址(.record文件全路徑)"
即可盏求。如下是一個實例:
python3 generate_tfrecord.py --image_dir /images/ --record_path /classify/data/train_tfrecord/train.record
TFRecord讀取
上面我們介紹了如何生成TFRecord,現(xiàn)在我們嘗試如何通過使用隊列讀取讀取我們的TFRecord亿眠。
讀取TFRecord可以通過tensorflow兩個個重要的函數(shù)實現(xiàn)碎罚,分別是tf.train.string_input_producer
和 tf.TFRecordReader
的tf.parse_single_example
解析器。如下圖
四纳像、 讀取TFRecord的簡單實現(xiàn)方式
解析TFRecord有兩種解析方式一種是利用tf.parse_single_example
, 另一種是通過tf.contrib.slim
(* 推薦使用)荆烈。
1. 第一種方式(tf.parse_single_example)解析步驟如下:
(1).第一步,我們將train.record
文件讀入到隊列中竟趾,如下所示:
filename_queue = tf.train.string_input_producer([tfrecords_filename])
(2) 第二步憔购,我們需要通過TFRecord將生成的隊列讀入
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
(3)第三步, 通過解析器tf.parse_single_example
將我們的example解析出來岔帽。
- 第二種方式(tf.contrib.slim)解析步驟如下:
(1) 第一步玫鸟, 我們要設置decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
, 其中key_to_features
這個字典需要和TFrecord文件中定義的字典項匹配,items_to_handlers
中的關鍵字可以是任意值犀勒,但是它的handler的初始化參數(shù)必須要來自于keys_to_features中的關鍵字屎飘。
(2) 第二步妥曲, 我們要設定dataset = slim.dataset.Dataset(params)
, 其中params包括:
a. data_source
: 為tfrecord文件地址
b. reader
: 一般設置為tf.TFRecordReader閱讀器
c. decoder
: 為第一步設置的decoder
d. num_samples
: 樣本數(shù)量
e. items_to_description
: 對樣本及標簽的描述
f. num_classes
: 分類的數(shù)量
(3) 第三步, 我們設置provider = slim.dataset_data_provider.DatasetDataProvider(params)
, 其中params包括 :
a. dataset
: 第二步驟我們生成的數(shù)據(jù)集
b. num_reader
: 并行閱讀器數(shù)量
c. shuffle
: 是否打亂
d. num_epochs
:每個數(shù)據(jù)源被讀取的次數(shù),如果設為None數(shù)據(jù)將會被無限循環(huán)的讀取
e. common_queue_capacity
:讀取數(shù)據(jù)隊列的容量钦购,默認為256
f. scope
:范圍
g. common_queue_min
:讀取數(shù)據(jù)隊列的最小容量檐盟。
(4) 第四步, 我們可以通過provider.get
得到我們需要的數(shù)據(jù)了押桃。
3. 對不同圖片大小的TFRecord讀取并resize成相同大小
reshape_same_size
函數(shù)來對圖片進行resize葵萎,這樣我們可以對我們的圖片進行batch操作了,因為有的神經(jīng)網(wǎng)絡訓練需要一個batch一個batch操作唱凯,不同大小的圖片在組成一個batch的時候會報錯陌宿,因此我們我通過后期處理可以更好的對圖片進行batch操作。
或者直接通過resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
即可波丰。
五壳坪、tf.contrib.slim模塊讀取TFrecord文件完整代碼實例
# -*- coding: utf-8 -*-
# @Time : 2018/12/1 11:06
# @Author : MaochengHu
# @Email : wojiaohumaocheng@gmail.com
# @File : read_tfrecord.py
# @Software: PyCharm
import os
import tensorflow as tf
flags = tf.app.flags
flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record', 'path to tfrecord file')
flags.DEFINE_integer('resize_height', 800, 'resize height of image')
flags.DEFINE_integer('resize_width', 800, 'resize width of image')
FLAG = flags.FLAGS
slim = tf.contrib.slim
def print_data(image, resized_image, label, height, width):
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
print("______________________image({})___________________".format(i))
print_image, print_resized_image, print_label, print_height, print_width = sess.run([image, resized_image, label, height, width])
print("resized_image shape is: ", print_resized_image.shape)
print("image shape is: ", print_image.shape)
print("image label is: ", print_label)
print("image height is: ", print_height)
print("image width is: ", print_width)
coord.request_stop()
coord.join(threads)
def reshape_same_size(image, output_height, output_width):
"""Resize images by fixed sides.
Args:
image: A 3-D image `Tensor`.
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
Returns:
resized_image: A 3-D tensor containing the resized image.
"""
output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
image = tf.expand_dims(image, 0)
resized_image = tf.image.resize_nearest_neighbor(
image, [output_height, output_width], align_corners=False)
resized_image = tf.squeeze(resized_image)
return resized_image
def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800):
keys_to_features = {
'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string,),
'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
'image/width': tf.FixedLenFeature([], tf.int64, default_value=0)
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
'width': slim.tfexample_decoder.Tensor('image/width', shape=[])
}
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
labels_to_names = None
items_to_descriptions = {
'image': 'An image with shape image_shape.',
'label': 'A single integer between 0 and 9.'}
dataset = slim.dataset.Dataset(
data_sources=tfrecord_path,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=num_samples,
items_to_descriptions=None,
num_classes=num_classes,
)
provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
num_readers=3,
shuffle=True,
common_queue_capacity=256,
common_queue_min=128,
seed=None)
image, label, height, width = provider.get(['image', 'label', 'height', 'width'])
resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))
return resized_image, label, image, height, width
def main():
resized_image, label, image, height, width = read_tfrecord(tfrecord_path=FLAG.tfrecord_path,
resize_height=FLAG.resize_height,
resize_width=FLAG.resize_width)
#resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width)
#resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
print_data(image, resized_image, label, height, width)
if __name__ == '__main__':
main()
代碼運行方式
python3 read_tfrecord.py --tfrecord_path /data1/humaoc_file/classify/data/train_tfrecord/train.record --resize_height 800 --resize_width 800
最終我們可以看到我們讀取文件的部分內(nèi)容:
______________________image(0)___________________
resized_image shape is: (800, 800, 3)
image shape is: (2000, 1333, 3)
image label is: 5
image height is: 2000
image width is: 1333
______________________image(1)___________________
resized_image shape is: (800, 800, 3)
image shape is: (667, 1000, 3)
image label is: 0
image height is: 667
image width is: 1000
______________________image(2)___________________
resized_image shape is: (800, 800, 3)
image shape is: (667, 1000, 3)
image label is: 3
image height is: 667
image width is: 1000
______________________image(3)___________________
resized_image shape is: (800, 800, 3)
image shape is: (800, 800, 3)
image label is: 5
image height is: 800
image width is: 800
______________________image(4)___________________
resized_image shape is: (800, 800, 3)
image shape is: (1424, 750, 3)
image label is: 0
image height is: 1424
image width is: 750
______________________image(5)___________________
resized_image shape is: (800, 800, 3)
image shape is: (1196, 1000, 3)
image label is: 6
image height is: 1196
image width is: 1000
______________________image(6)___________________
resized_image shape is: (800, 800, 3)
image shape is: (667, 1000, 3)
image label is: 5
image height is: 667
image width is: 1000
參考:
[1] TensorFlow 自定義生成 .record 文件
[2] TensorFlow基礎5:TFRecords文件的存儲與讀取講解及代碼實現(xiàn)
[3] Slim讀取TFrecord文件
[4] Tensorflow針對不定尺寸的圖片讀寫tfrecord文件總結