背景
使用dataset進(jìn)行數(shù)據(jù)管道化處理時惩嘉,通常我們會加上batch(batch_size)來獲取批量樣本函匕。這里有個容易忽視的點(diǎn)癌瘾,batch本身還提供了一個參數(shù)drop_remaindar,用于標(biāo)示是否對于最后一個batch如果數(shù)據(jù)量達(dá)不到batch_size時保留還是拋棄本刽。本次的小坑就是由于這個參數(shù)導(dǎo)致的挪蹭。
案例
show me the code:
with tf.name_scope('input'):
dataset = tf.data.Dataset.from_tensor_slices(files).interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(10), cycle_length=num_preprocess_threads)
dataset = dataset.batch(batch_size, drop_remainder=True)
# dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
dataset = dataset.map(lambda x: _decode(x, type), num_parallel_calls=2)
dataset = dataset.shuffle(buffer_size=batch_size * 10)
dataset = dataset.prefetch(buffer_size=1000)
這是一段簡單的使用dataset來解析tfrecord的代碼亭饵,為了方便在創(chuàng)建dataset時,就將所有的數(shù)據(jù)集的batch_size設(shè)為了相同的值嚣潜。那就導(dǎo)致在數(shù)據(jù)消費(fèi)的時候冬骚,最后一個batch的數(shù)量達(dá)不到batch_size椅贱,所以這里我們將drop_remainder設(shè)為true懂算,運(yùn)行出錯。
經(jīng)過排查后發(fā)現(xiàn)庇麦,tf1.10之后的版本才支持這種方式计技,而之前的版本只能使用tf.contrib.data.batch_and_drop_remainder(batch_size)。
備注:小坑記錄下山橄,做留念