隊(duì)列
隊(duì)列(queue)本身也是圖中的一個(gè)節(jié)點(diǎn)衡奥,是一種有狀態(tài)的節(jié)點(diǎn),其他節(jié)點(diǎn)锦庸,如入隊(duì)節(jié)點(diǎn)(enqueue)和出隊(duì)節(jié)點(diǎn)(dequeue)敲霍,可以修改它的內(nèi)容杰标。例如,入隊(duì)節(jié)點(diǎn)可以把新元素插到隊(duì)列末尾猪半,出隊(duì)節(jié)點(diǎn)可以把隊(duì)列前面的元素刪除兔朦。TensorFlow 中主要有兩種隊(duì)列偷线,即 FIFOQueue 和 RandomShuffleQueue。
- FIFOQueue 創(chuàng)建一個(gè)先入先出隊(duì)列沽甥。例如声邦,我們?cè)谟?xùn)練一些語(yǔ)音、文字樣本時(shí)摆舟,使用循環(huán)神經(jīng)網(wǎng)絡(luò)的網(wǎng)絡(luò)結(jié)構(gòu)亥曹,希望讀入的訓(xùn)練樣本是有序的,就要用 FIFOQueue恨诱。
- RandomShuffleQueue 創(chuàng)建一個(gè)隨機(jī)隊(duì)列媳瞪,在出隊(duì)列時(shí),是以隨機(jī)的順序產(chǎn)生元素的照宝。例如蛇受,我們?cè)谟?xùn)練一些圖像樣本時(shí),使用 CNN 的網(wǎng)絡(luò)結(jié)構(gòu)厕鹃,希望可以無(wú)序地讀入訓(xùn)練樣本兢仰,就要用RandomShuffleQueue,每次隨機(jī)產(chǎn)生一個(gè)訓(xùn)練樣本剂碴。RandomShuffleQueue 在 TensorFlow 使用異步計(jì)算時(shí)非常重要把将。因?yàn)?TensorFlow 的會(huì)話是支持多線程的,我們可以在主線程里執(zhí)行訓(xùn)練操作忆矛,使用 RandomShuffleQueue 作為訓(xùn)練輸入察蹲,開多個(gè)線程來(lái)準(zhǔn)備訓(xùn)練樣本,將樣本壓入隊(duì)列后催训,主線程會(huì)從隊(duì)列中每次取出 mini-batch 的樣本進(jìn)行訓(xùn)練洽议。
# FIFOQueue 先進(jìn)先出
q=tf.FIFOQueue(3,"float")
init=q.enqueue_many(([0.1,0.2,0.3],))
x=q.dequeue()
y=x+1
q_inc=q.enqueue([y])
with tf.Session() as sess:
sess.run(init)
for i in range(2):
sess.run(q_inc)
quelen=sess.run(q.size())
for i in range(quelen):
print(sess.run(q.dequeue()))
#RandomShuffleQueue
q=tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
with tf.Session() as sess:
for i in range(0,10):
sess.run(q.enqueue(i))
for i in range(0,8):
print(sess.run(q.dequeue()))
#3.0
#4.0
#8.0
#2.0
#0.0
#1.0
#7.0
#9.0
注意到RandomShuffleQueue的參數(shù)有容量和最小長(zhǎng)度。當(dāng)隊(duì)列長(zhǎng)度等于最小值瞳腌,執(zhí)行出隊(duì)操作以及隊(duì)列長(zhǎng)度等于最大值绞铃,執(zhí)行入隊(duì)操作時(shí)镜雨,會(huì)有阻斷情況發(fā)生嫂侍。只有當(dāng)隊(duì)列滿足要求后,才能繼續(xù)執(zhí)行荚坞√舫瑁可以通過設(shè)置繪畫在運(yùn)行時(shí)的等待時(shí)間來(lái)解除阻斷。
q=tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
with tf.Session() as sess:
for i in range(0,10):
sess.run(q.enqueue(i))
for i in range(0,10):
run_options = tf.RunOptions(timeout_in_ms = 10000) # 等待 10 秒
try:
print(sess.run(q.dequeue(), options=run_options))
except tf.errors.DeadlineExceededError:
print('out of range')
break;
如果入隊(duì)操作是在主線程中進(jìn)行颓影,那么當(dāng)入隊(duì)產(chǎn)生阻斷時(shí)各淀,會(huì)影響后續(xù)的讀數(shù)據(jù)以及訓(xùn)練操作。會(huì)話中可以運(yùn)行多個(gè)線程诡挂,我們使用線程管理器 QueueRunner 創(chuàng)建一系列的新線程進(jìn)行入隊(duì)操作碎浇,讓主線程繼續(xù)使用數(shù)據(jù)临谱,即訓(xùn)練網(wǎng)絡(luò)和讀取數(shù)據(jù)是異步的,主線程在訓(xùn)練網(wǎng)絡(luò)奴璃,另一個(gè)線程在將數(shù)據(jù)從硬盤讀入內(nèi)存悉默。
隊(duì)列管理器
q=tf.FIFOQueue(1000,'float')
counter=tf.Variable(0.0)
incre_op=tf.assign_add(counter,tf.constant(1.0))
enqueue_op=q.enqueue(counter)
qr=tf.train.QueueRunner(q,enqueue_ops=[incre_op,enqueue_op]*1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#啟動(dòng)入隊(duì)操作
enqueue_threads=qr.create_threads(sess,start=True)
#主線程是取數(shù)據(jù)操作
for i in range(10):
print(sess.run(q.dequeue()))
#3.0
#11.0
#69.0
#73.0
#90.0
#97.0
#107.0
#226.0
#266.0
#274.0
輸出的隊(duì)列也不是我們期待的自然數(shù)列,并且線程被阻斷苟穆。這是因?yàn)榧?1 操作和入隊(duì)操作不同步抄课,可能加 1 操作執(zhí)行了很多次之后,才會(huì)進(jìn)行一次入隊(duì)操作雳旅。
上述代碼最后報(bào)異常跟磨,然后會(huì)話自動(dòng)關(guān)閉。入隊(duì)線程自顧自地執(zhí)行攒盈,在需要的出隊(duì)操作完成之后抵拘,程序沒法結(jié)束,一直到超過隊(duì)列的容量之后型豁,會(huì)導(dǎo)致cancalledError仑濒。因此需要協(xié)調(diào)器來(lái)管理線程。
協(xié)調(diào)器coordinator
可以解決上面的入隊(duì)線程不受控制的情況偷遗。
q=tf.FIFOQueue(1000,'float')
counter=tf.Variable(0.0)
incre_op=tf.assign_add(counter,tf.constant(1.0))
enqueue_op=q.enqueue(counter)
qr=tf.train.QueueRunner(q,enqueue_ops=[incre_op,enqueue_op]*1)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
coord=tf.train.Coordinator() #協(xié)調(diào)器墩瞳,協(xié)調(diào)線程間的關(guān)系可以視為一種信號(hào)量,用來(lái)做同步
enqueue_threads = qr.create_threads(sess, coord = coord,start=True)
for i in range(0,10):
print(sess.run(q.dequeue()))
coord.request_stop()# 通知其他線程關(guān)閉
coord.join(enqueue_threads)# join 操作等待其他線程結(jié)束氏豌,其他所有線程關(guān)閉之后喉酌,這一函數(shù)才能返回
#3.0
#20.0
#304.0
#1118.0
#1164.0
#1242.0
#1311.0
#1320.0
#1381.0
#1387.0
這個(gè)很奇怪,并沒有按照書上說的關(guān)閉線程之后再執(zhí)行出隊(duì)操作泵喘,就會(huì)拋出 tf.errors.OutOfRange 錯(cuò)誤泪电。而且
print(sess.run(q.size()))
#每次都不一樣,這次是170
for i in range(0,10):
print(sess.run(q.dequeue()))
#在此實(shí)行上面的代碼纪铺,得到10個(gè)1387