Tensorflow TFRecords及多線程訓(xùn)練介紹 ——詳細

也可移步my github查看

先修知識——protocol buffer

TF框架中多處使用了protocol buffer辨液,protocol buffer全稱Google Protocol Buffer虐急,簡稱Protobuf,是一種結(jié)構(gòu)化數(shù)據(jù)存儲格式滔迈,類似于常見的Json和xml止吁,而且這種格式經(jīng)過編譯可以生成對應(yīng)C++或Java或Python類的形式被辑,即可以用編程語言讀取或修改數(shù)據(jù),不僅如此敬惦,還可以進一步將定義的結(jié)構(gòu)化數(shù)據(jù)進行序列化盼理,轉(zhuǎn)化成二進制數(shù)據(jù)存下來或發(fā)送出去,非常適合做數(shù)據(jù)存儲或 RPC 數(shù)據(jù)交換格式俄删。更具體的介紹可以參考網(wǎng)上比較推薦的文章:Google Protocol Buffer 的使用和原理宏怔。其實TensorFlow計算圖思想的實現(xiàn)也是基于protocol buffer的,感興趣的可以看一下畴椰,本文主要介紹TFRecords臊诊,TFRecords是TF官方推薦使用的數(shù)據(jù)存儲形式,也是使用了protocol buffer斜脂,下面結(jié)合TFRecords詳細介紹其使用方法和原理抓艳。

protocol buffer的使用

參考Google Protocol Buffer 的使用和原理可以發(fā)現(xiàn),要得到本地存儲的序列化數(shù)據(jù)帚戳,需要先定義.proto 文件玷或,再編譯成編程語言描述的類,然后實例化該類(該類也已自動生成setter getter修改類和序列化類等方法)片任,并序列化保存到本地或進行傳輸偏友。TFRecords的思想也是將數(shù)據(jù)集中的數(shù)據(jù)以結(jié)構(gòu)化的形式存到.proto中,然后序列化存儲到本地对供,方便使用時讀取并還原數(shù)據(jù)约谈,只不過TF又對這個過程進行了一點封裝,看起來和protocol buffer原始的使用方式略有差別犁钟。

protocol buffer中需要先將數(shù)據(jù)以結(jié)構(gòu)化文件.proto的格式展現(xiàn)棱诱,然后可以編譯成C++ Java 或python類進行后續(xù)操作,在TFRecords的應(yīng)用中tf.train.Example類就是扮演了這一角色涝动,TF中它的原始.proto文件定義在tensorflow/core/example/example.proto中,如下代碼片:

message Example {
  Features features = 1;
};

可以看到Example類中封裝的數(shù)據(jù)應(yīng)該是features,是Features類型的迈勋,而Features在python代碼中就對應(yīng)了tf.train.Features類,其原始.proto文件定義在tensorflow/core/example/feature.proto中,如下代碼片:

message Features {
  // Map from feature name to feature.
  map<string, Feature> feature = 1;
};

可以看到醋粟,Features中的數(shù)據(jù)又是feature(注意沒有s)靡菇,而feature屬性的類型是map<string, Feature>類型,string不必說了米愿,關(guān)鍵是Feature類型厦凤,和Features一樣,Feature對應(yīng)tf.train.Feature類育苟,其原始.proto文件也定義在tensorflow/core/example/feature.proto中较鼓,如下代碼片:

message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1; # bytes_list float_list int64_list也是和之前一樣,對應(yīng)一個類
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

將數(shù)據(jù)集轉(zhuǎn)化成TFRecords形式

TFRecords的定義過程就是使用了剛介紹的幾個類:tf.train.Exampletf.train.Features博烂,tf.train.Feature香椎,知道了這幾個類的定義以及它們的嵌套關(guān)系,再去理解TFRecords的產(chǎn)生就容易多了禽篱。
首先畜伐,使用tf.train.Example來封裝我們的數(shù)據(jù),然后使用tf.python_io.TFRecordWriter來寫入磁盤躺率,其中幾個類的的嵌套方式和上述一致玛界,見如下代碼:

#本段代碼來自[TensorFlow高效讀取數(shù)據(jù)](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)

import os
import tensorflow as tf 
from PIL import Image
cwd = os.getcwd()
'''
此處我加載的數(shù)據(jù)目錄如下:
0 -- img1.jpg
     img2.jpg
     img3.jpg
     ...
1 -- img1.jpg
     img2.jpg
     ...
2 -- ...
...
'''
# 先定義writer對象,writer負責(zé)將得到的記錄寫入TFRecords文件悼吱,此處為train.tfrecords文件
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + "/"
  # 一張一張的寫入TFRecords文件
  for img_name in os.listdir(class_path):
    img_path = class_path + img_name
    img = Image.open(img_path)
    img = img.resize((224, 224)) #對圖片做一些預(yù)處理操作
    img_raw = img.tobytes()     #將圖片轉(zhuǎn)化為原生bytes
    # 封裝僅Example對象中
    example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
    writer.write(example.SerializeToString())  #序列化為字符串并寫入磁盤
writer.close()

讀取數(shù)據(jù)

以上存儲數(shù)據(jù)時慎框,Example調(diào)用SerializeToString()方法將自己序列化并由writer = tf.python_io.TFRecordWriter("train.tfrecords")對象保存,最終是將所有的圖片文件和label保存到同一個tfrecords文件train.tfrecords中了舆绎。讀取數(shù)據(jù)則以上過程的逆,先獲取序列化數(shù)據(jù)们颜,再解析:由tf.python_io.tf_record_iterator("train.tfrecords")方法(注意這個是方法)返回所有本地序列化文件迭代器吕朵,然后由Example調(diào)用ParseFromString()方法解析,代碼如下:

for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
  # 本段代碼來自[TensorFlow高效讀取數(shù)據(jù)](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)
  example = tf.train.Example()
  # 進行解析
  example.ParseFromString(serialized_example)
  # 逐個讀取example對象里封裝的東西
  image = example.features.feature['image'].bytes_list.value
  label = example.features.feature['label'].int64_list.value
  # 可以做一些預(yù)處理之類的
  print image, labe

這是最基本的數(shù)據(jù)讀取方式窥突,tf.python_io.tf_record_iterator方法每次解析一個.tfrecords文件努溃。而在實際應(yīng)用中,由于數(shù)據(jù)集往往很大阻问,所以往往將數(shù)據(jù)分開保存至多個tfrecords文件中梧税,在這種情況下,TF提供了其他的接口進行讀取称近,所以正常情況下我們可能不會使用上述的數(shù)據(jù)讀取方式第队,以下才是重點,但必須強調(diào)的是整體的思想是一致的刨秆,都是先獲取序列化文件凳谦,然后解析,只是接口函數(shù)稍有不同衡未。

TF的多線程訓(xùn)練是TF框架重新設(shè)計的尸执,不是簡單地使用python語言多線程來搞得,很多時候TF多線程是和TFRecords配套使用的缓醋,下面介紹的數(shù)據(jù)讀取方法也是多線程訓(xùn)練的數(shù)據(jù)讀取方式如失。十圖詳解tensorflow數(shù)據(jù)讀取機制這篇文章深入淺出>的介紹了TF多線程讀取數(shù)據(jù)和訓(xùn)練的原理,多線程這一塊接口多送粱,也比較難以理解褪贵,下面僅從使用的角度出發(fā)談?wù)勎覀€人的理解,不詳細追究里面的實現(xiàn)原理抗俄。

假設(shè)我們按照上述方式將數(shù)據(jù)保存到了兩個tfrecords文件中竭鞍,分別為'1.tfrecords'和'2.tfrecords'板惑,保存在DATA_ROOT路徑中,那么我們分幾步讀取數(shù)據(jù)偎快,參考如下代碼:

    1. 讀取tfrecords文件名到隊列中冯乘,使用tf.train.string_input_producer函數(shù),該函數(shù)可以接收一個文件名列表晒夹,并自動返回一個對應(yīng)的文件名隊列filename_queue裆馒,之所以用隊列是為了后續(xù)多線程考慮(隊列和多線程經(jīng)常搭配使用)
    1. 實例化tf.TFRecordReader()類生成reader對象,接收filename_queue參數(shù)丐怯,并讀取該隊列中文件名對應(yīng)的文件喷好,得到serialized_example(讀到的就是.tfrecords序列化文件)
    1. 解析,注意這里的解析不是用的Example對象里的函數(shù)读跷,而是tf.parse_single_example函數(shù)梗搅,該函數(shù)能從serialized_example中解析出一條數(shù)據(jù),當(dāng)然也可以用tf.parse_example解析多條數(shù)據(jù)效览,此處暫不贅述无切。這里tf.parse_single_example函數(shù)傳入?yún)?shù)serialized_examplefeatures,其中features是字典的形式丐枉,指定每個key的解析方式哆键,比如image_raw使用tf.FixedLenFeature方法解析,這種解析方式返回一個Tensor瘦锹,大多數(shù)解析方式也都是這種籍嘹,另一種是tf.VarLenFeature方法,返回SparseTensor弯院,用于處理稀疏數(shù)據(jù)辱士,不贅述。這里還要注意必須告訴解析函數(shù)以何種數(shù)據(jù)類型解析听绳,這必須與生成TFRecords文件時指定的數(shù)據(jù)類型一致识补。最后返回features是一個字典,里面存放了每一項的解析結(jié)果辫红。
    1. 最后只要讀出features中的數(shù)據(jù)即可凭涂。比如,features['label'],features['pixels']贴妻。但要注意的是切油,此時的image_raw依然是字符串類型的(可以看寫入代碼中的image_raw),需要進一步還原成像素數(shù)組名惩,用TF提供的函數(shù)tf.decode_raw來搞定images = tf.decode_raw(features['image_raw'],tf.uint8)澎胡。

至此,就定義好了完成一次數(shù)據(jù)讀取的代碼,有了它攻谁,下面的訓(xùn)練時的多線程方法就有了數(shù)據(jù)來源稚伍,下節(jié)討論。

# 讀取文件戚宦。
filename_queue = tf.train.string_input_producer(["Records/output.tfrecords"])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)

# 解析讀取的樣例个曙。
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'pixels':tf.FixedLenFeature([],tf.int64),
        'label':tf.FixedLenFeature([],tf.int64)
    })

images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32) #需要用tf.cast做一個類型轉(zhuǎn)換
pixels = tf.cast(features['pixels'],tf.int32)

# 下面的代碼下節(jié)討論
sess = tf.Session()

# 啟動多線程處理輸入數(shù)據(jù)。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)

for i in range(10):
    image, label, pixel = sess.run([images, labels, pixels])

TF多線程機制

假設(shè)已將數(shù)據(jù)集文件轉(zhuǎn)換成了TFRecords格式受楼,共兩個文件垦搬,每個文件中存儲兩條數(shù)據(jù),兩個文件如下艳汽,下面用多線程的方式讀取并訓(xùn)練猴贰,分為以下幾個步驟:

/patah/to/data.tfrecords-00000-of-00002
/patah/to/data.tfrecords-00001-of-00002
    1. 獲取TFRecords文件隊列。TF提供了tf.train.match_filenames_once函數(shù)幫助獲取所有滿足條件的TFRecords文件河狐,tf.train.match_filenames_once函數(shù)參數(shù)為正則表達式米绕,返回匹配上的所有文件名集合變量。當(dāng)然馋艺,也可以選擇不用該函數(shù)栅干,用純python也可以匹配,python的話最終返回一個list類型即可丈钙,但正規(guī)起見非驮,還是推薦使用TF提供的方法交汤。然后tf.train.string_input_producer函數(shù)依此生成文件名隊列filename_queue雏赦。
files = tf.train.match_filenames_once("/patah/to/data.tfrecords-*") # 
filename_queue = tf.train.string_input_producer(files, shuffle=False)
    1. 解析TFRecords文件中的數(shù)據(jù),和上面一樣芙扎,不贅述星岗。
# 讀取文件。
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)

# 解析讀取的樣例戒洼。
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'pixels':tf.FixedLenFeature([],tf.int64),
        'label':tf.FixedLenFeature([],tf.int64)
    })

decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)
#pixels = tf.cast(features['pixels'],tf.int32)
# 最后只要labels和images
labels = tf.cast(features['label'],tf.int32)
images = tf.reshape(retyped_images, [784])
  • 3)將讀取到的數(shù)據(jù)打包為batch俏橘。上一段代碼得到了labelsimages,這是一條數(shù)據(jù)圈浇,訓(xùn)練一次需要一個batch的數(shù)據(jù)寥掐,怎么搞?難道將上述代碼用for循環(huán)反復(fù)執(zhí)行batch_size次磷蜀?這樣做未嘗不可召耘,但效率很低,TF提供了tf.train.shuffle_batch函數(shù)褐隆,上述解析代碼只要提供一次污它,然后將labelsimages作為tf.train.shuffle_batch函數(shù)的參數(shù),tf.train.shuffle_batch就能自動獲取到一個batch的labelsimagestf.train.shuffle_batch函數(shù)獲取batch的過程需要生成一個隊列(加入計算圖中)衫贬,然后一個一個入隊labelsimages德澈,然后出隊組合batch。關(guān)于里面參數(shù)的解釋固惯,batch_size就是batch的大小梆造,capacity指的是隊列的容量,比如capacity設(shè)為1缝呕,而batch_szie為3澳窑,那么組成一個batch的過程中,出隊的操作就會因為數(shù)據(jù)不足而頻繁地被阻塞來等待入隊加入數(shù)據(jù)供常,運行效率很低摊聋。相反,如果capacity被設(shè)置的很大栈暇,比如設(shè)為1000麻裁,而batch_size設(shè)置為3,那么入隊操作在空閑時就會頻繁入隊源祈,供過于求并非壞事煎源,糟糕的是這樣會占用很多內(nèi)存資源,而且沒有得到多少效率上的提升香缺。還有一點值得注意手销,當(dāng)使用tf.train.shuffle_batch時,為了使得shuffle效果好一點图张,出隊后隊列剩余元素必須得足夠多锋拖,因為太少的話也沒什么必要打亂了,因此tf.train.shuffle_batch函數(shù)要求提供min_after_dequeue參數(shù)來保證出隊后隊內(nèi)元素足夠多祸轮,這樣隊列就會等隊內(nèi)元素足夠多時才會出隊兽埃。顯而易見,capacity必須大于min_after_dequeue适袜。關(guān)于capacitymin_after_dequeue的設(shè)置柄错,參考《TensorFlow 實戰(zhàn)Google深度學(xué)習(xí)框架》,給出了設(shè)置capacity大小的一種比較科學(xué)的方式苦酱,min_after_dequeue根據(jù)數(shù)據(jù)集大小和batch_size綜合考慮售貌,而capacity則設(shè)置為capacity= min_after_dequeue+ 3*batch_size,在效率和資源占用之間取得平衡疫萤。組合batch_size的代碼如下:
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size

image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
                                                    batch_size=batch_size, 
                                                    capacity=capacity, 
                                                    min_after_dequeue=min_after_dequeue)
    1. 啟動多線程訓(xùn)練模型颂跨。訓(xùn)練過程和單線程的基本一致,唯一的區(qū)別就是多了一個tf.train.start_queue_runners函數(shù)给僵,這個函數(shù)中傳入?yún)?shù)sess,就可以做到多線程訓(xùn)練毫捣,具體地細節(jié)還不是很了解详拙,但照壺畫瓢應(yīng)該沒問題了,有空再深挖下蔓同。
# 前向傳播
y = inference(image_batch)
    
# 計算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=label_batch)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
    
# 計算最后的損失函數(shù)(加入正則化)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
regularaztion = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularaztion

# 優(yōu)化損失函數(shù)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    
# 初始化會話饶辙,并開始訓(xùn)練過程。
with tf.Session() as sess:
  # 初始化所有變量
  tf.global_variables_initializer().run()
   
  coord = tf.train.Coordinator()

  threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  # 循環(huán)的訓(xùn)練神經(jīng)網(wǎng)絡(luò)斑粱。
  for i in range(TRAINING_STEPS):
    if i % 1000 == 0:
      print("After %d training step(s), loss is %g " % (i, sess.run(loss)))              
    sess.run(train_step) 

    coord.request_stop()
    coord.join(threads

參考

TensorFlow高效讀取數(shù)據(jù)

Google Protocol Buffer 的使用和原理

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末弃揽,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子则北,更是在濱河造成了極大的恐慌矿微,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,968評論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件尚揣,死亡現(xiàn)場離奇詭異涌矢,居然都是意外死亡,警方通過查閱死者的電腦和手機快骗,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,601評論 2 382
  • 文/潘曉璐 我一進店門娜庇,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人方篮,你說我怎么就攤上這事名秀。” “怎么了藕溅?”我有些...
    開封第一講書人閱讀 153,220評論 0 344
  • 文/不壞的土叔 我叫張陵匕得,是天一觀的道長。 經(jīng)常有香客問我巾表,道長汁掠,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 55,416評論 1 279
  • 正文 為了忘掉前任攒发,我火速辦了婚禮调塌,結(jié)果婚禮上晋南,老公的妹妹穿的比我還像新娘惠猿。我一直安慰自己,他們只是感情好负间,可當(dāng)我...
    茶點故事閱讀 64,425評論 5 374
  • 文/花漫 我一把揭開白布偶妖。 她就那樣靜靜地躺著,像睡著了一般政溃。 火紅的嫁衣襯著肌膚如雪趾访。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,144評論 1 285
  • 那天董虱,我揣著相機與錄音扼鞋,去河邊找鬼申鱼。 笑死,一個胖子當(dāng)著我的面吹牛云头,可吹牛的內(nèi)容都是我干的捐友。 我是一名探鬼主播,決...
    沈念sama閱讀 38,432評論 3 401
  • 文/蒼蘭香墨 我猛地睜開眼溃槐,長吁一口氣:“原來是場噩夢啊……” “哼匣砖!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起昏滴,我...
    開封第一講書人閱讀 37,088評論 0 261
  • 序言:老撾萬榮一對情侶失蹤猴鲫,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后谣殊,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體拂共,經(jīng)...
    沈念sama閱讀 43,586評論 1 300
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 36,028評論 2 325
  • 正文 我和宋清朗相戀三年姻几,在試婚紗的時候發(fā)現(xiàn)自己被綠了匣缘。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 38,137評論 1 334
  • 序言:一個原本活蹦亂跳的男人離奇死亡鲜棠,死狀恐怖肌厨,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情豁陆,我是刑警寧澤柑爸,帶...
    沈念sama閱讀 33,783評論 4 324
  • 正文 年R本政府宣布,位于F島的核電站盒音,受9級特大地震影響表鳍,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜祥诽,卻給世界環(huán)境...
    茶點故事閱讀 39,343評論 3 307
  • 文/蒙蒙 一譬圣、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧雄坪,春花似錦厘熟、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,333評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至阔挠,卻和暖如春飘庄,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背购撼。 一陣腳步聲響...
    開封第一講書人閱讀 31,559評論 1 262
  • 我被黑心中介騙來泰國打工跪削, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留谴仙,地道東北人。 一個月前我還...
    沈念sama閱讀 45,595評論 2 355
  • 正文 我出身青樓碾盐,卻偏偏與公主長得像狞甚,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子廓旬,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 42,901評論 2 345

推薦閱讀更多精彩內(nèi)容