在TensorFlow中讀數據一般有三種方法:
- 使用placeholder讀內存中的數據
- 使用queue讀硬盤中的數據
- 使用Dataset讀內存?zhèn)€硬盤中的數據
基本概率
由于第三種方法在語法上更簡潔,因此本文主要介紹第三種方法。官方給出的Dataset API類圖:
其中終于重要的兩個基礎類:Dateset和Iterator。
Dateset是具有相同類型的“元素”的有序表啤它,元素可以是向量、字符串驼鞭、圖片等缨历。
從內存中創(chuàng)建Dataset
以數字元素為例:
從Dataset中實例化一個Iterator,然后對Iterator進行迭代讹堤。
iterator = dataset.make_one_shot_iterator()
從dataset中實例化一個iterator,是“one shot iterator”厨疙,即只能從頭到尾讀取一次洲守。
one_element = iterator.get_next()
從iterator中取出一個元素, one_element是一個tensor,因此需要調用sess.run(one_element)取出值梗醇。
如果元素被讀取完了知允,再sess.run(one_element)會拋出tf.errors.OutOfRangeError異常。解決方法:使用 dataset.repeat()
更復雜的輸入形式叙谨,例如温鸽,在圖像識別的應用中,一個元素可以使{“image”:image_tensor手负, “l(fā)abel”:lable_tensor}
dataset = tf.data.Dataset.from_tensor_slices(
{
"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
"b": np.random.uniform(size=(5, 2))
}
)
最終dataset中的一個元素為{"a": 1.0, "b": [0.9, 0.1]}的形式涤垫。
或者
dataset = tf.data.Dataset.from_tensor_slices(
(np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2)))
)
對Dataset中的元素做變換:Transformation
一個Dataset通過Transformation變成一個新的Dataset。常用的操作有:
- map
- batch
- shuffle
- repeat
下面分別來介紹以上幾個操作竟终。
(1)map
map接收一個函數蝠猬,dataset中的每個元素都可以作為這個函數的輸入,并將函數的返回值作為新的dataset统捶,例如:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch
將多個元素組合成batch榆芦,例如:
dataset = dataset.batch(32)
(3)shuffle
打亂dataset中的元素,參數buffersize表示打亂時buffer的大小喘鸟。
dataset = dataset.shuffle(buffer_size=10000)
(4)repeat
將整個序列重復多次匆绣,只用用來處理epoch。如果直接調用repeat()的話什黑,生成的序列就會無限重復下去崎淳,沒有結束,因此也不會拋出兑凿。tf.errors.OutOfRangeError異常:
dataset = dataset.repeat(5)
例子:讀磁盤圖片與對應的label
讀入磁盤中的圖片和圖片相應的label凯力,并將其打亂茵瘾,組成batch_size=32的訓練樣本礼华。在訓練時重復10個epoch。
# 函數的功能時將filename對應的圖片文件讀進來拗秘,并縮放到統(tǒng)一的大小
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# 圖片文件的列表
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# label[i]就是圖片filenames[i]的label
labels = tf.constant([0, 37, ...])
# 此時dataset中的一個元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
# 此時dataset中的一個元素是(image_resized, label)
dataset = dataset.map(_parse_function)
# 此時dataset中的一個元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffersize=1000).batch(32).repeat(10)
# 此時dataset中的一個元素是(image_resized_batch, label_batch)
# image_resized_batch的形狀為(32, 28, 28, 3)圣絮, label_batch的形狀為(32, )