承接上文, 將數(shù)據(jù)存儲(chǔ)為tfrecords文件之后, 在從tfrecords文件中讀取數(shù)據(jù)訓(xùn)練模型蕉毯, 這里嘗試使用tf.data數(shù)據(jù)讀取方式。利用tf.data讀取數(shù)據(jù)能加快數(shù)據(jù)讀取過程思犁。
def read_and_decode(loader, handle, num_epochs=1):
""" read tfrecord format data"""
batch_size = int(loader.batch_size() / FLAGS.gpu_num)
feature_size = model_settings['fingerprint_size']
def parse_exmp(serialized_example):
features = tf.parse_single_example(serialized_example, features={
'feature': tf.VarLenFeature(tf.float32),
'label': tf.VarLenFeature(tf.int64),
'mask': tf.VarLenFeature(tf.int64),
'length': tf.FixedLenFeature((),tf.int64)
})
length = tf.cast(features['length'], tf.int32)
feature = tf.sparse_tensor_to_dense(features['feature'])
feature = tf.reshape(feature, [length, feature_size])
label = tf.sparse_tensor_to_dense(features['label'])
mask = tf.sparse_tensor_to_dense(features['mask'])
return feature, label, mask, length
filenames = ['./train_input/tfrecords_file/train_dataset_%d.tfrecords'%i for i in range(10)]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_exmp, num_parallel_calls=64)
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle(64).repeat(num_epochs).padded_batch(batch_size, padded_shapes=([None, feature_size],[None],[None],[]))
train_iterator = dataset.make_initializable_iterator()
iterator = tf.contrib.data.Iterator.from_string_handle(handle, \
dataset.output_types, dataset.output_shapes)
batch_data, batch_label, batch_mask, batch_length = iterator.get_next()
if FLAGS.ctc_loss == True:
return train_iterator,tf.transpose(batch_data, (1,0,2)), batch_label, batch_mask, batch_length
else:
return train_iterator,tf.transpose(batch_data, (1,0,2)), tf.transpose(batch_label, (1,0)), tf.transpose(batch_mask, (1,0)), batch_length