也可移步my github查看
先修知識——protocol buffer
TF框架中多處使用了protocol buffer辨液,protocol buffer全稱Google Protocol Buffer虐急,簡稱Protobuf,是一種結(jié)構(gòu)化數(shù)據(jù)存儲格式滔迈,類似于常見的Json和xml止吁,而且這種格式經(jīng)過編譯可以生成對應(yīng)C++或Java或Python類的形式被辑,即可以用編程語言讀取或修改數(shù)據(jù),不僅如此敬惦,還可以進一步將定義的結(jié)構(gòu)化數(shù)據(jù)進行序列化盼理,轉(zhuǎn)化成二進制數(shù)據(jù)存下來或發(fā)送出去,非常適合做數(shù)據(jù)存儲或 RPC 數(shù)據(jù)交換格式俄删。更具體的介紹可以參考網(wǎng)上比較推薦的文章:Google Protocol Buffer 的使用和原理宏怔。其實TensorFlow計算圖思想的實現(xiàn)也是基于protocol buffer的,感興趣的可以看一下畴椰,本文主要介紹TFRecords臊诊,TFRecords是TF官方推薦使用的數(shù)據(jù)存儲形式,也是使用了protocol buffer斜脂,下面結(jié)合TFRecords詳細介紹其使用方法和原理抓艳。
protocol buffer的使用
參考Google Protocol Buffer 的使用和原理可以發(fā)現(xiàn),要得到本地存儲的序列化數(shù)據(jù)帚戳,需要先定義.proto 文件玷或,再編譯成編程語言描述的類,然后實例化該類(該類也已自動生成setter getter修改類和序列化類等方法)片任,并序列化保存到本地或進行傳輸偏友。TFRecords的思想也是將數(shù)據(jù)集中的數(shù)據(jù)以結(jié)構(gòu)化的形式存到.proto中,然后序列化存儲到本地对供,方便使用時讀取并還原數(shù)據(jù)约谈,只不過TF又對這個過程進行了一點封裝,看起來和protocol buffer原始的使用方式略有差別犁钟。
protocol buffer中需要先將數(shù)據(jù)以結(jié)構(gòu)化文件.proto的格式展現(xiàn)棱诱,然后可以編譯成C++ Java 或python類進行后續(xù)操作,在TFRecords的應(yīng)用中tf.train.Example
類就是扮演了這一角色涝动,TF中它的原始.proto文件定義在tensorflow/core/example/example.proto
中,如下代碼片:
message Example {
Features features = 1;
};
可以看到Example類中封裝的數(shù)據(jù)應(yīng)該是features
,是Features
類型的迈勋,而Features
在python代碼中就對應(yīng)了tf.train.Features
類,其原始.proto文件定義在tensorflow/core/example/feature.proto
中,如下代碼片:
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
可以看到醋粟,Features
中的數(shù)據(jù)又是feature
(注意沒有s)靡菇,而feature
屬性的類型是map<string, Feature>
類型,string
不必說了米愿,關(guān)鍵是Feature
類型厦凤,和Features
一樣,Feature
對應(yīng)tf.train.Feature
類育苟,其原始.proto文件也定義在tensorflow/core/example/feature.proto
中较鼓,如下代碼片:
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1; # bytes_list float_list int64_list也是和之前一樣,對應(yīng)一個類
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
將數(shù)據(jù)集轉(zhuǎn)化成TFRecords形式
TFRecords的定義過程就是使用了剛介紹的幾個類:tf.train.Example
,tf.train.Features
博烂,tf.train.Feature
香椎,知道了這幾個類的定義以及它們的嵌套關(guān)系,再去理解TFRecords的產(chǎn)生就容易多了禽篱。
首先畜伐,使用tf.train.Example來封裝我們的數(shù)據(jù),然后使用tf.python_io.TFRecordWriter來寫入磁盤躺率,其中幾個類的的嵌套方式和上述一致玛界,見如下代碼:
#本段代碼來自[TensorFlow高效讀取數(shù)據(jù)](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)
import os
import tensorflow as tf
from PIL import Image
cwd = os.getcwd()
'''
此處我加載的數(shù)據(jù)目錄如下:
0 -- img1.jpg
img2.jpg
img3.jpg
...
1 -- img1.jpg
img2.jpg
...
2 -- ...
...
'''
# 先定義writer對象,writer負責(zé)將得到的記錄寫入TFRecords文件悼吱,此處為train.tfrecords文件
writer = tf.python_io.TFRecordWriter("train.tfrecords")
for index, name in enumerate(classes):
class_path = cwd + name + "/"
# 一張一張的寫入TFRecords文件
for img_name in os.listdir(class_path):
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((224, 224)) #對圖片做一些預(yù)處理操作
img_raw = img.tobytes() #將圖片轉(zhuǎn)化為原生bytes
# 封裝僅Example對象中
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString()) #序列化為字符串并寫入磁盤
writer.close()
讀取數(shù)據(jù)
以上存儲數(shù)據(jù)時慎框,Example
調(diào)用SerializeToString()
方法將自己序列化并由writer = tf.python_io.TFRecordWriter("train.tfrecords")
對象保存,最終是將所有的圖片文件和label保存到同一個tfrecords文件train.tfrecords
中了舆绎。讀取數(shù)據(jù)則以上過程的逆,先獲取序列化數(shù)據(jù)们颜,再解析:由tf.python_io.tf_record_iterator("train.tfrecords")
方法(注意這個是方法)返回所有本地序列化文件迭代器吕朵,然后由Example
調(diào)用ParseFromString()
方法解析,代碼如下:
for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
# 本段代碼來自[TensorFlow高效讀取數(shù)據(jù)](http://ycszen.github.io/2016/08/17/TensorFlow%E9%AB%98%E6%95%88%E8%AF%BB%E5%8F%96%E6%95%B0%E6%8D%AE/)
example = tf.train.Example()
# 進行解析
example.ParseFromString(serialized_example)
# 逐個讀取example對象里封裝的東西
image = example.features.feature['image'].bytes_list.value
label = example.features.feature['label'].int64_list.value
# 可以做一些預(yù)處理之類的
print image, labe
這是最基本的數(shù)據(jù)讀取方式窥突,tf.python_io.tf_record_iterator
方法每次解析一個.tfrecords
文件努溃。而在實際應(yīng)用中,由于數(shù)據(jù)集往往很大阻问,所以往往將數(shù)據(jù)分開保存至多個tfrecords
文件中梧税,在這種情況下,TF提供了其他的接口進行讀取称近,所以正常情況下我們可能不會使用上述的數(shù)據(jù)讀取方式第队,以下才是重點,但必須強調(diào)的是整體的思想是一致的刨秆,都是先獲取序列化文件凳谦,然后解析,只是接口函數(shù)稍有不同衡未。
TF的多線程訓(xùn)練是TF框架重新設(shè)計的尸执,不是簡單地使用python語言多線程來搞得,很多時候TF多線程是和TFRecords配套使用的缓醋,下面介紹的數(shù)據(jù)讀取方法也是多線程訓(xùn)練的數(shù)據(jù)讀取方式如失。十圖詳解tensorflow數(shù)據(jù)讀取機制這篇文章深入淺出>的介紹了TF多線程讀取數(shù)據(jù)和訓(xùn)練的原理,多線程這一塊接口多送粱,也比較難以理解褪贵,下面僅從使用的角度出發(fā)談?wù)勎覀€人的理解,不詳細追究里面的實現(xiàn)原理抗俄。
假設(shè)我們按照上述方式將數(shù)據(jù)保存到了兩個tfrecords
文件中竭鞍,分別為'1.tfrecords'和'2.tfrecords'板惑,保存在DATA_ROOT路徑中,那么我們分幾步讀取數(shù)據(jù)偎快,參考如下代碼:
- 讀取
tfrecords
文件名到隊列中冯乘,使用tf.train.string_input_producer
函數(shù),該函數(shù)可以接收一個文件名列表晒夹,并自動返回一個對應(yīng)的文件名隊列filename_queue
裆馒,之所以用隊列是為了后續(xù)多線程考慮(隊列和多線程經(jīng)常搭配使用)
- 讀取
- 實例化
tf.TFRecordReader()
類生成reader
對象,接收filename_queue
參數(shù)丐怯,并讀取該隊列中文件名對應(yīng)的文件喷好,得到serialized_example
(讀到的就是.tfrecords序列化文件)
- 實例化
- 解析,注意這里的解析不是用的
Example
對象里的函數(shù)读跷,而是tf.parse_single_example
函數(shù)梗搅,該函數(shù)能從serialized_example
中解析出一條數(shù)據(jù),當(dāng)然也可以用tf.parse_example
解析多條數(shù)據(jù)效览,此處暫不贅述无切。這里tf.parse_single_example
函數(shù)傳入?yún)?shù)serialized_example
和features
,其中features
是字典的形式丐枉,指定每個key的解析方式哆键,比如image_raw
使用tf.FixedLenFeature
方法解析,這種解析方式返回一個Tensor瘦锹,大多數(shù)解析方式也都是這種籍嘹,另一種是tf.VarLenFeature
方法,返回SparseTensor
弯院,用于處理稀疏數(shù)據(jù)辱士,不贅述。這里還要注意必須告訴解析函數(shù)以何種數(shù)據(jù)類型解析听绳,這必須與生成TFRecords
文件時指定的數(shù)據(jù)類型一致识补。最后返回features
是一個字典,里面存放了每一項的解析結(jié)果辫红。
- 解析,注意這里的解析不是用的
- 最后只要讀出
features
中的數(shù)據(jù)即可凭涂。比如,features['label']
,features['pixels']
贴妻。但要注意的是切油,此時的image_raw
依然是字符串類型的(可以看寫入代碼中的image_raw
),需要進一步還原成像素數(shù)組名惩,用TF提供的函數(shù)tf.decode_raw
來搞定images = tf.decode_raw(features['image_raw'],tf.uint8)
澎胡。
- 最后只要讀出
至此,就定義好了完成一次數(shù)據(jù)讀取的代碼,有了它攻谁,下面的訓(xùn)練時的多線程方法就有了數(shù)據(jù)來源稚伍,下節(jié)討論。
# 讀取文件戚宦。
filename_queue = tf.train.string_input_producer(["Records/output.tfrecords"])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
# 解析讀取的樣例个曙。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
})
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32) #需要用tf.cast做一個類型轉(zhuǎn)換
pixels = tf.cast(features['pixels'],tf.int32)
# 下面的代碼下節(jié)討論
sess = tf.Session()
# 啟動多線程處理輸入數(shù)據(jù)。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(10):
image, label, pixel = sess.run([images, labels, pixels])
TF多線程機制
假設(shè)已將數(shù)據(jù)集文件轉(zhuǎn)換成了TFRecords
格式受楼,共兩個文件垦搬,每個文件中存儲兩條數(shù)據(jù),兩個文件如下艳汽,下面用多線程的方式讀取并訓(xùn)練猴贰,分為以下幾個步驟:
/patah/to/data.tfrecords-00000-of-00002
/patah/to/data.tfrecords-00001-of-00002
- 獲取
TFRecords
文件隊列。TF提供了tf.train.match_filenames_once
函數(shù)幫助獲取所有滿足條件的TFRecords
文件河狐,tf.train.match_filenames_once
函數(shù)參數(shù)為正則表達式米绕,返回匹配上的所有文件名集合變量。當(dāng)然馋艺,也可以選擇不用該函數(shù)栅干,用純python也可以匹配,python的話最終返回一個list類型即可丈钙,但正規(guī)起見非驮,還是推薦使用TF提供的方法交汤。然后tf.train.string_input_producer
函數(shù)依此生成文件名隊列filename_queue
雏赦。
- 獲取
files = tf.train.match_filenames_once("/patah/to/data.tfrecords-*") #
filename_queue = tf.train.string_input_producer(files, shuffle=False)
- 解析
TFRecords
文件中的數(shù)據(jù),和上面一樣芙扎,不贅述星岗。
- 解析
# 讀取文件。
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
# 解析讀取的樣例戒洼。
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
})
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)
#pixels = tf.cast(features['pixels'],tf.int32)
# 最后只要labels和images
labels = tf.cast(features['label'],tf.int32)
images = tf.reshape(retyped_images, [784])
- 3)將讀取到的數(shù)據(jù)打包為batch俏橘。上一段代碼得到了
labels
和images
,這是一條數(shù)據(jù)圈浇,訓(xùn)練一次需要一個batch
的數(shù)據(jù)寥掐,怎么搞?難道將上述代碼用for
循環(huán)反復(fù)執(zhí)行batch_size
次磷蜀?這樣做未嘗不可召耘,但效率很低,TF提供了tf.train.shuffle_batch
函數(shù)褐隆,上述解析代碼只要提供一次污它,然后將labels
和images
作為tf.train.shuffle_batch
函數(shù)的參數(shù),tf.train.shuffle_batch
就能自動獲取到一個batch的labels
和images
。tf.train.shuffle_batch
函數(shù)獲取batch
的過程需要生成一個隊列(加入計算圖中)衫贬,然后一個一個入隊labels
和images
德澈,然后出隊組合batch。關(guān)于里面參數(shù)的解釋固惯,batch_size
就是batch
的大小梆造,capacity
指的是隊列的容量,比如capacity
設(shè)為1缝呕,而batch_szie
為3澳窑,那么組成一個batch
的過程中,出隊的操作就會因為數(shù)據(jù)不足而頻繁地被阻塞來等待入隊加入數(shù)據(jù)供常,運行效率很低摊聋。相反,如果capacity
被設(shè)置的很大栈暇,比如設(shè)為1000麻裁,而batch_size
設(shè)置為3,那么入隊操作在空閑時就會頻繁入隊源祈,供過于求并非壞事煎源,糟糕的是這樣會占用很多內(nèi)存資源,而且沒有得到多少效率上的提升香缺。還有一點值得注意手销,當(dāng)使用tf.train.shuffle_batch
時,為了使得shuffle
效果好一點图张,出隊后隊列剩余元素必須得足夠多锋拖,因為太少的話也沒什么必要打亂了,因此tf.train.shuffle_batch
函數(shù)要求提供min_after_dequeue
參數(shù)來保證出隊后隊內(nèi)元素足夠多祸轮,這樣隊列就會等隊內(nèi)元素足夠多時才會出隊兽埃。顯而易見,capacity
必須大于min_after_dequeue
适袜。關(guān)于capacity
和min_after_dequeue
的設(shè)置柄错,參考《TensorFlow 實戰(zhàn)Google深度學(xué)習(xí)框架》,給出了設(shè)置capacity
大小的一種比較科學(xué)的方式苦酱,min_after_dequeue
根據(jù)數(shù)據(jù)集大小和batch_size
綜合考慮售貌,而capacity
則設(shè)置為,在效率和資源占用之間取得平衡疫萤。組合batch_size
的代碼如下:
min_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_size
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
- 啟動多線程訓(xùn)練模型颂跨。訓(xùn)練過程和單線程的基本一致,唯一的區(qū)別就是多了一個
tf.train.start_queue_runners
函數(shù)给僵,這個函數(shù)中傳入?yún)?shù)sess
,就可以做到多線程訓(xùn)練毫捣,具體地細節(jié)還不是很了解详拙,但照壺畫瓢應(yīng)該沒問題了,有空再深挖下蔓同。
- 啟動多線程訓(xùn)練模型颂跨。訓(xùn)練過程和單線程的基本一致,唯一的區(qū)別就是多了一個
# 前向傳播
y = inference(image_batch)
# 計算交叉熵及其平均值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=label_batch)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 計算最后的損失函數(shù)(加入正則化)
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
regularaztion = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularaztion
# 優(yōu)化損失函數(shù)
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 初始化會話饶辙,并開始訓(xùn)練過程。
with tf.Session() as sess:
# 初始化所有變量
tf.global_variables_initializer().run()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 循環(huán)的訓(xùn)練神經(jīng)網(wǎng)絡(luò)斑粱。
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
print("After %d training step(s), loss is %g " % (i, sess.run(loss)))
sess.run(train_step)
coord.request_stop()
coord.join(threads