tensorflow有兩種數(shù)據(jù)輸入方法,比較簡(jiǎn)單的一種是使用feed_dict囊扳,這種方法在畫graph的時(shí)候使用placeholder來站位瓶堕,在真正run的時(shí)候通過feed字典把真實(shí)的輸入傳進(jìn)去雕薪。比較簡(jiǎn)單不再介紹。
比較惱火的是第二種方法拥刻,直接從文件中讀取數(shù)據(jù)(其實(shí)第一種也可以我們自己從文件中讀出來之后使用feed_dict傳進(jìn)去,但方法二tf提供很完善的一套類和函數(shù)形成一個(gè)類似pipeline一樣的讀取線):
1.使用tf.train.string_input_producer函數(shù)把我們需要的全部文件打包為一個(gè)tf內(nèi)部的queue類型父泳,之后tf開文件就從這個(gè)queue中取目錄了般哼,要注意一點(diǎn)的是這個(gè)函數(shù)的shuffle參數(shù)默認(rèn)是True,也就是你傳給他文件順序是1234惠窄,但是到時(shí)候讀就不一定了蒸眠,我一開始每次跑訓(xùn)練第一次迭代的樣本都不一樣,還納悶了好久杆融,就是這個(gè)原因楞卡。
files_in = ["./data/data_batch%d.bin" % i for i in range(1, 6)]
files = tf.train.string_input_producer(files_in)
2.搞一個(gè)reader,不同reader對(duì)應(yīng)不同的文件結(jié)構(gòu)脾歇,比如度bin文件tf.FixedLengthRecordReader就比較好蒋腮,因?yàn)槊看巫x等長的一段數(shù)據(jù)。如果要讀什么別的結(jié)構(gòu)也有相應(yīng)的reader藕各。
reader = tf.FixedLengthRecordReader(record_bytes=1+32*32*3)
3.用reader的read方法池摧,這個(gè)方法需要一個(gè)IO類型的參數(shù),就是我們上邊string_input_producer輸出的那個(gè)queue了激况,reader從這個(gè)queue中取一個(gè)文件目錄作彤,然后打開它經(jīng)行一次讀取膘魄,reader的返回是一個(gè)tensor(這一點(diǎn)很重要,我們現(xiàn)在寫的這些讀取代碼并不是真的在讀數(shù)據(jù)宦棺,還是在畫graph瓣距,和定義神經(jīng)網(wǎng)絡(luò)是一樣的,這時(shí)候的操作在run之前都不會(huì)執(zhí)行代咸,這個(gè)返回的tensor也沒有值蹈丸,他僅僅代表graph中的一個(gè)結(jié)點(diǎn))。
key, value = reader.read(files)
4.對(duì)這個(gè)tensor做些數(shù)據(jù)與處理呐芥,比如CIFAR1-10中l(wèi)abel和image數(shù)據(jù)是糅在一起的逻杖,這里用slice把他們切開,切成兩個(gè)tensor(注意這個(gè)兩個(gè)tensor是對(duì)應(yīng)的思瘟,一個(gè)image對(duì)一個(gè)label荸百,對(duì)叉了后便訓(xùn)練就完了),然后對(duì)image的tensor做data augmentation滨攻。
data = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(data, [0], [1]), tf.int64)
raw_image = tf.reshape(tf.slice(data, [1], [32*32*3]), [3, 32, 32])
image = tf.cast(tf.transpose(raw_image, [1, 2, 0]), tf.float32)
lr_image = tf.image.random_flip_left_right(image)
br_image = tf.image.random_brightness(lr_image, max_delta=63)
rc_image = tf.image.random_contrast(br_image, lower=0.2, upper=1.8)
std_image = tf.image.per_image_standardization(rc_image)
5.這時(shí)候可以發(fā)現(xiàn)够话,這個(gè)tensor代表的是一個(gè)樣本([高寬管道]),但是訓(xùn)練網(wǎng)絡(luò)的時(shí)候的輸入一般都是一推樣本([樣本數(shù)高寬*管道])光绕,我們就要用tf.train.batch或者tf.train.shuffle_batch這個(gè)函數(shù)把一個(gè)一個(gè)小樣本的tensor打包成一個(gè)高一維度的樣本batch女嘲,這些函數(shù)的輸入是單個(gè)樣本,輸出就是4D的樣本batch了诞帐,其內(nèi)部原理似乎是創(chuàng)建了一個(gè)queue欣尼,然后不斷調(diào)用你的單樣本tensor獲得樣本,直到queue里邊有足夠的樣本停蕉,然后一次返回一堆樣本愕鼓,組成樣本batch。
images, labels = tf.train.batch([std_image, label],
batch_size=100,
num_threads=16,
capacity=int(50000* 0.4 + 3 * batch_size))
5.事實(shí)上一直到上一部的images這個(gè)tensor慧起,都還沒有真實(shí)的數(shù)據(jù)在里邊菇晃,我們必須用Session run一下這個(gè)4D的tensor,才會(huì)真的有數(shù)據(jù)出來蚓挤。這個(gè)原理就和我們定義好的神經(jīng)網(wǎng)絡(luò)run一下出結(jié)果一樣谋旦,你一run這個(gè)4D tensor,他就會(huì)順著自己的operator找自己依賴的其他tensor屈尼,一路最后找到最開始reader那里册着。
除了上邊講的原理,其中還要注意幾點(diǎn)
1.tf.train.start_queue_runners(sess=sess)這一步一定要運(yùn)行脾歧,且其位置要在定義好讀取graph之后甲捏,在真正run之前,其作用是把queue里邊的內(nèi)容初始化鞭执,不跑這句一開始string_input_producer那里就沒用司顿,整個(gè)讀取流水線都沒用了芒粹。
training_images = tf.train.batch(XXXXXXXXXXXXXXX)
tf.train.start_queue_runners(sess=self.sess)
real_images = sess.run(training_images)
2.image和label一定要一起run,要記清楚我們的image和label是在一張graph里邊的大溜,跑一次那個(gè)graph化漆,這兩個(gè)tensor都會(huì)出結(jié)果,且同一次跑出來的image和label才是對(duì)應(yīng)的钦奋,如果你run兩次座云,第一次為了拿image第二次為了拿label,那整個(gè)就叉了付材,因?yàn)榈谝淮闻艹鰜淼?到100號(hào)image和0到100號(hào)label朦拖,第二次跑出來第100到200的image和第100到200的label,你拿到了0100的image和100200的label厌衔,整個(gè)樣本分類全不對(duì)璧帝,最后網(wǎng)絡(luò)肯定跑不出結(jié)果。
training_images, training_labels = read_image()
tf.train.start_queue_runners(sess=self.sess)
real_images = sess.run(training_images) # 讀出來是真的圖片富寿,但是和label對(duì)不上
real_labels = sess.run(training_labels) # 讀出來是真的label睬隶,但是和image對(duì)不上
# 正確調(diào)用方法,通過跑一次graph页徐,將成套的label和image讀出來
real_images, real_labels = sess.run([training_images, training_labels])
因?yàn)椴欢@個(gè)道理的up主跑了一下午正確率還是10%理疙。。泞坦。。(10類別分類10%正確率不就是亂猜嗎)