在使用TensorFlow構(gòu)建模型并進(jìn)行訓(xùn)練時瞳收,如何讀取數(shù)據(jù)并將數(shù)據(jù)恰當(dāng)?shù)厮瓦M(jìn)模型奏寨,是一個首先需要考慮的問題。以往通常所用的方法無外乎以下幾種:
1.建立placeholder倦始,然后使用feed_dict將數(shù)據(jù)feed進(jìn)placeholder進(jìn)行使用煌往。使用這種方法十分靈活,可以一下子將所有數(shù)據(jù)讀入內(nèi)存糠爬,然后分batch進(jìn)行feed寇荧;也可以建立一個Python的generator,一個batch一個batch的將數(shù)據(jù)讀入执隧,并將其feed進(jìn)placeholder揩抡。這種方法很直觀,用起來也比較方便靈活镀琉,但是這種方法的效率較低峦嗤,難以滿足高速計算的需求。
2.使用TensorFlow的QueueRunner屋摔,通過一系列的Tensor操作烁设,將磁盤上的數(shù)據(jù)分批次讀入并送入模型進(jìn)行使用。這種方法效率很高钓试,但因為其牽涉到Tensor操作装黑,不夠直觀,也不方便調(diào)試弓熏,所有有時候會顯得比較困難恋谭。使用這種方法時,常用的一些操作包括tf.TextLineReader挽鞠,tf.FixedLengthRecordReader以及tf.decode_raw等等疚颊。如果需要循環(huán),條件操作信认,還需要使用TensorFlow的tf.while_loop材义,tf.case等操作,更是難上加難嫁赏。
因此其掂,在這種情況下,TensorFlow在后續(xù)的更新中橄教,自1.x版本開始清寇,逐步開發(fā)引入了tf.data.Dataset模塊喘漏,使其數(shù)據(jù)讀入的操作變得更為方便,而支持多線程(進(jìn)程)的操作华烟,也在效率上獲得了一定程度的提高翩迈。本文就將使用tf.data.Dataset過程中的一些經(jīng)驗進(jìn)行總結(jié)記錄,以便備忘盔夜。
如我們所知负饲,在使用TensorFlow建立模型進(jìn)行訓(xùn)練的時候,可以很容易生成這樣的文件喂链,來表示數(shù)據(jù):
1. data/01.jpg,貓
2. data/05.jpg,狗
3. data/03.jpg,貓
4. data/04.jpg,狗
5. data/06.jpg,狗
6. data/02.jpg,貓
這種數(shù)據(jù)格式可以很方便地進(jìn)行各種操作返十,比如劃分?jǐn)?shù)據(jù)集、shuffle等等椭微。所以我們就以將這樣的數(shù)據(jù)通過tf.data.Dataset讀入進(jìn)行訓(xùn)練為例洞坑,來講述其用法。
具體來說蝇率,使用tf.data.Dataset讀取數(shù)據(jù)迟杂,本文講述這樣三種方法:
1.首先將數(shù)據(jù)讀入內(nèi)存,然后使用tf.data.Dataset構(gòu)建數(shù)據(jù)集
具體來說本慕,因為tf.data.Dataset.from_tensor_slices()函數(shù)會對tensor和numpy array的處理一視同仁排拷,所以該函數(shù)既可以使用tensor參數(shù),也可以直接使用numpy array作參數(shù)锅尘,使用numpy array作參數(shù)监氢,即是第1種方法。
如下所示:
1. images = ...
2. labels = ...
3. data = tf.data.Dataset.from_tensor_slices((images, labels))
4. data = data.batch(batch_size)
5. iterator = tf.data.Iterator.from_structure(data.output_types,
6. data.output_shapes)
7. init_op = iterator.make_initializer(data)
8. with tf.Session() as sess:
9. sess.run(init_op)
10. try:
11. images, labels = iterator.get_next()
12. except tf.errors.OutOfRangeError:
13. sess.run(init_op)
第1~2行藤违,首先浪腐,將數(shù)據(jù)images、labels讀入內(nèi)存顿乒;
第3~4行牛欢,使用讀入內(nèi)存的數(shù)據(jù)images、labels構(gòu)建Dataset淆游,并設(shè)置Dataset的batch大小隔盛;
第5行犹菱,基于此前構(gòu)建的Dataset的數(shù)據(jù)類型和結(jié)構(gòu),構(gòu)建一個iterator吮炕;
第6行腊脱,基于此前構(gòu)建的Dataset構(gòu)建一個初始化op。
隨后的操作龙亲,即是在TensorFlow的session里陕凹,首先進(jìn)行初始化操作悍抑,然后即可通過iterator的函數(shù)逐批獲得數(shù)據(jù),并進(jìn)行使用了杜耙。
需要注意的是搜骡,iterator中的元素取完之后,會拋出OutOfRangeError異常佑女,TensorFlow沒有對這個異常進(jìn)行處理记靡,我們需要對其進(jìn)行捕捉和處理。
本方法詳細(xì)代碼可參閱這里团驱。
2.使用tf.data.Dataset包裝一個generator讀入數(shù)據(jù)
1中方法雖然簡單摸吠,但其將數(shù)據(jù)一次讀入,在面對大數(shù)據(jù)集時會束手無策嚎花。因此寸痢,我們可以建立一個讀入數(shù)據(jù)的generator,然后使用tf.data.Dataset對其進(jìn)行包裝轉(zhuǎn)換紊选,即可實現(xiàn)逐batch讀入數(shù)據(jù)的目的啼止。如下:
1. def gen():
2. with open('train.csv') as f:
3. lines = [line.strip().split(',') for line in f.readlines()]
4. index = 0
5. while True:
6. image = cv2.imread(lines[index][0])
7. image = cv2.resize(image, (224, 224))
8. label = lines[index][1]
9. yield (image, label)
10. index += 1
11. if index == len(lines):
12. index = 0
15. batch_size = 2
16. data = tf.data.Dataset.from_generator(gen, (tf.float32, tf.int32),
17. (tf.TensorShape([224, 224, 3]), tf.TensorShape([])))
18. data = data.batch(batch_size)
19. iter = data.make_one_shot_iterator()
20. with tf.Session() as sess:
21. images, labels = iter.get_next()
如上,首先構(gòu)建一個generator:gen丛楚,然后使用tf.data.Dataset的from_generator函數(shù)族壳,通過指定數(shù)據(jù)類型,數(shù)據(jù)的shape等參數(shù)趣些,構(gòu)建一個Dataset仿荆,當(dāng)然,隨后也要指定一下batch_size坏平,最后使用make_one_shot_iterator()函數(shù)拢操,構(gòu)建一個iterator。
然后其使用方法即與前述相同了舶替,不過需要說明的是令境,這里是通過一個永無盡頭的generator構(gòu)建的Dataset,所以其可以一直取數(shù)據(jù)顾瞪,而不會出現(xiàn)1中所述的OutOfRange的問題舔庶。
本方法詳細(xì)代碼可參閱這里。
3.基于Tensor操作構(gòu)建Dataset
前述兩種方法陈醒,1中需要將數(shù)據(jù)一次全部讀入內(nèi)存惕橙,2中使用generator逐batch讀入數(shù)據(jù),雖然內(nèi)存占用得到了控制钉跷,但是其效率仍然不高弥鹦,讀取速度較慢。在第3種方法里爷辙,我們通過TensorFlow提供的tensor操作來讀取數(shù)據(jù)彬坏,并基于此朦促,構(gòu)建Dataset。
示例的代碼片段如下:
1. def _parse_function(filename, label):
2. image_string = tf.read_file(filename)
3. image_decoded = tf.image.decode_jpeg(image_string, channels=3)
4. image = tf.cast(image_decoded, tf.float32)
5. image = tf.image.resize_images(image, [224, 224])
6. return image, filename, label
8. images = tf.constant(image_names)
9. labels = tf.constant(labels)
10. images = tf.random_shuffle(images, seed=0)
11. labels = tf.random_shuffle(labels, seed=0)
12. data = tf.data.Dataset.from_tensor_slices((images, labels))
14. data = data.map(_parse_function, num_parallel_calls=4)
15. data = data.prefetch(buffer_size=batch_size * 10)
16. data = data.batch(batch_size)
18. iterator = tf.data.Iterator.from_structure(data.output_types,
19. data.output_shapes)
21. init_op = iterator.make_initializer(data)
22. with tf.Session() as sess:
23. sess.run(init_op)
24. try:
25. images, filenames, labels = iterator.get_next()
26. except tf.errors.OutOfRangeError:
27. sess.run(init_op)
首先讀入image names以及相應(yīng)的labels栓始,然后通過tf.constant構(gòu)建constant Tensor:images, labels务冕,并可選擇地對其進(jìn)行shuffle。
接著使用tf.data.Dataset.from_tensor_slices()函數(shù)基于images和labels構(gòu)建Dataset混滔。
然后使用map函數(shù)將函數(shù)應(yīng)用到該Dataset上洒疚,本例中,將解析圖像的函數(shù)_parse_function應(yīng)用到Dataset上坯屿,還指定了多線程并行操作的線程數(shù)油湖。
隨后指定prefetch的buffer_size,以及batch的大小领跛。
最后乏德,基于構(gòu)建的Dataset建立iterator,并定義iterator的初始化操作op吠昭,然后就可以按照正常的方式進(jìn)行使用了喊括。
需要注意的是,本方法構(gòu)建的Dataset也會有OutOfRange的異常出現(xiàn)矢棚,需要恰當(dāng)?shù)剡M(jìn)行捕捉并處理郑什。
本方法詳細(xì)代碼可參閱這里。