上一篇是網(wǎng)絡(luò)模型的加載昙衅,這一篇是輸入模型的加載,之后還有訓(xùn)練模型的加載团滥。
輸入模型的加載的開(kāi)始是train.py文件中的
create_input_dict_fn = functools.partial(input_reader_builder.build, input_config)
那就進(jìn)入input_reader_builder.build看一看颓影。
parallel_reader = tf.contrib.slim.parallel_reader
def build(input_reader_config):
#判斷類型輸入的類型是否為input_reader_pb2.InputReader)
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
#只接受輸入類型為tf_record_input_reader的輸入
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
#獲取數(shù)據(jù)集存放位置
config = input_reader_config.tf_record_input_reader
_, string_tensor = parallel_reader.parallel_read(
config.input_path,
reader_class=tf.TFRecordReader,
num_epochs=(input_reader_config.num_epochs
if input_reader_config.num_epochs else None),
num_readers=input_reader_config.num_readers,
shuffle=input_reader_config.shuffle,
dtypes=[tf.string, tf.string],
capacity=input_reader_config.queue_capacity,
min_after_dequeue=input_reader_config.min_after_dequeue)
return tf_example_decoder.TfExampleDecoder().decode(string_tensor)
raise ValueError('Unsupported input_reader_config.')
可以看出核心是使用了tf.contrib.slim.parallel_reader這個(gè)庫(kù)中的函數(shù)《芰郏看看這個(gè)函數(shù)的說(shuō)明犬性。
ef parallel_read(data_sources,
reader_class,
num_epochs=None,
num_readers=4,
reader_kwargs=None,
shuffle=True,
dtypes=None,
capacity=256,
min_after_dequeue=128,
seed=None,
scope=None):
"""
#從原始的數(shù)據(jù)文件使用多個(gè)reader獲取多個(gè)record。
#并行的使用ParallelReader從多個(gè)文件讀取數(shù)據(jù)
#多個(gè)readers是根據(jù) `reader_class` 和 `reader_kwargs'進(jìn)行創(chuàng)建的腾仅。
#如果shuffle為true乒裆,則common_queue將會(huì)是一個(gè)RandomShuffleQueue ,否則就是一個(gè)FIFOQueue.
參數(shù)說(shuō)明
data_sources: 一系列的文件位置比如: /path/to/train@128, /path/to/train* or /tmp/.../train*
reader_class: 一個(gè)繼承了io_ops.ReaderBase 的子類比如 TFRecordReader
num_epochs: 間隔多少次從數(shù)據(jù)源讀取一次文件攒砖,如果沒(méi)有給缸兔,就一直讀取
num_readers: 一個(gè)整數(shù)日裙,表示創(chuàng)建多少個(gè)數(shù)據(jù)讀取器。
reader_kwargs: 一個(gè)可選的字典惰蜜,表示of kwargs for the reader.
shuffle: 是否進(jìn)行數(shù)據(jù)的打亂操作昂拂。
dtypes: 一個(gè)類型的列表,dtypes的長(zhǎng)度一定等于每一個(gè)記錄中元素的長(zhǎng)度抛猖。如果為None格侯,則為[tf.string, tf.string] for (key, value).
capacity: 整數(shù),表示common_queue中需要包含多少數(shù)據(jù).
min_after_dequeue: 一個(gè)整數(shù)财著,在出隊(duì)后common_queue中最少的數(shù)據(jù)記錄的量联四,和打亂有關(guān)。
seed:RandomShuffleQueue所需的隨機(jī)種子.
scope: Optional name scope for the ops.
Returns:
key, value: a tuple of keys and values from the data_source.
"""
當(dāng)然讀取數(shù)據(jù)的最后一句話就是對(duì)獲取到的信息進(jìn)行解析撑教。
return tf_example_decoder.TfExampleDecoder().decode(string_tensor)
tf_example_decoder是一個(gè)用于解析包含了序列化后的tensorflow.Exampleprotos的解析器朝墩。
def decode(self, tf_example_string_tensor):
# 解析序列化后的tensroflow example并返回一個(gè)tensor的dict
# 傳入?yún)?shù):一個(gè)序列化后的tensorflow example proto對(duì)象
# 傳出對(duì)象: 返回的tensor的dict包含如下內(nèi)容:
# fields.InputDataFields.image - 一個(gè)三維類型為uint8的tensor,其大小為[None, None, 3]表示的是圖片
# fields.InputDataFields.source_id - 一個(gè)string類型的tensor包含的是圖片的id
# fields.InputDataFields.key - 一個(gè)string類型的tensor伟姐,是圖片的hd5碼
# fields.InputDataFields.filename - 一個(gè)string類型的tensor收苏,包含了數(shù)據(jù)庫(kù)的名稱
# fields.InputDataFields.groundtruth_boxes - 二維的float32的 tensor格式為
# [None, 4]包含box的四個(gè)頂點(diǎn)信息.
# fields.InputDataFields.groundtruth_classes - 1維的 int64型 tensor格式為shape
# [None]包含box所對(duì)應(yīng)的object類型
# fields.InputDataFields.groundtruth_area - 1維的 float32 類型的tensor格式為
# [None] 包含了物品的像素掩膜信息。
# fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
# [None] indicating if the boxes enclose a crowd.
# fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
# [None] indicating if the boxes represent `difficult` instances.
# fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
# shape [None, None, None] containing instance masks.
# fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
# of shape [None] containing classes for the instance masks.
serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
#構(gòu)建解析器
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers)
keys = decoder.list_items()
#解析
tensors = decoder.decode(serialized_example, items=keys)
tensor_dict = dict(zip(keys, tensors))
is_crowd = fields.InputDataFields.groundtruth_is_crowd
tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
return tensor_dict
數(shù)據(jù)已經(jīng)獲取愤兵,接下來(lái)就是solver了鹿霸。