Target
- 二分類任務(wù)樣本制作成標(biāo)準(zhǔn)TF格式(TFRecords)
- 讀入TFRecords并展示圖片
- 高效多線程讀入TFRecords
Introduction
tensorflow/core/example/example.proto 中有詳細(xì)的例子說(shuō)明.
message Example {
Features features = 1;
};
簡(jiǎn)要說(shuō)就是每個(gè)樣本變成了一個(gè)key-value的字典形式, key為字符串,value可以是字符串, 整型, 浮點(diǎn)型. 不同類型的標(biāo)簽在TF中會(huì)非常容易的實(shí)現(xiàn). 但是在Caffe中(-.-)..
Value的格式(注:Int為Int64)
- BytesList
- FloatList
- Int64List
tensorflow/core/example/feature.proto 中定義
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
[個(gè)人理解]
說(shuō)明一下, 一條數(shù)據(jù)或一個(gè)樣本其實(shí)就是一個(gè)Example(其實(shí)還有SequenceExample,本文只說(shuō)明Example), 數(shù)據(jù)中會(huì)存在許多的屬性, 所有的屬性信息都存儲(chǔ)在Features類中, 并且每個(gè)屬性由Key-Value來(lái)實(shí)現(xiàn) map<string, Feature> feature. 所有的Key都是字符型, 但是數(shù)據(jù)可以是bytes,float或int64類型.
官方例子:
//features {
// feature {
// key: "age"
// value { float_list {
// value: 29.0
// }}
// }
// feature {
// key: "movie"
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }}
// }
//}
了解了存儲(chǔ)格式代碼就比較好弄了
制作TFRecords
#_*_ coding:utf-8 _*_
import os
import numpy as np
import tensorflow as tf
import cv2
'''
Example : fileName
Train/1.jpg 0
Train/2.jpg 1
Train/3.jpg 1
Train/4.jpg 1
'''
# 因?yàn)閯?chuàng)建Feature需要list
def toList(value):
if type(value) == list:
return value
else:
return [value]
# 創(chuàng)建不用類型的Feature數(shù)據(jù)
def _int64_feature(value):
value = toList(value)
value = [int(x) for x in value]
return tf.train.Feature(int64_list=tf.train.Int64List(value = value))
def _float_feature(value):
value = toList(value)
value = [float(x) for x in value]
return tf.train.Feature(float_list=tf.train.FloatList(value = value))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value = toList(value)))
# Make TFRecords
def MakeTFRecord(fileName,tfrecords,imageRoot):
# 創(chuàng)建一個(gè)TFRecordWriter來(lái)進(jìn)行寫(xiě)入數(shù)據(jù)
writer = tf.python_io.TFRecordWriter(tfrecords)
fp = open(fileName,'r')
lines = fp.readlines()
for line in lines:
line = line.strip().split()
imagePath = os.path.join(imageRoot,line[0])
print imagePath
## PIL Image 通道順序RBG
#img = Image.open(imagePath)
#img_raw = img.tobytes()
## cv2 通道順序BGR, 本例采取灰度圖
img = cv2.imread(imagePath,0)
# 注意resize之后要賦值回img
img = cv2.resize(img,(128,128))
#cv2.imshow("d",img)
#cv2.waitKey(0)
# 把圖像數(shù)據(jù)變成二進(jìn)制,節(jié)省空間
img_raw = img.tostring()
# 創(chuàng)建一個(gè)Example
example = tf.train.Example(
# 創(chuàng)建一個(gè)Features
features=tf.train.Features(
# 填寫(xiě)不同類型的key-value
feature={
"img_raw":_bytes_feature(img_raw),
"label": _int64_feature(line[1:])
}
))
# 把example進(jìn)行序列化,Serializes the protocol message to a binary string
writer.write(example.SerializeToString())
writer.close()
讀取并顯示 TFRecords
其實(shí)上面的制作理解了,讀其實(shí)只是一個(gè)逆過(guò)程,一個(gè)解序列化的過(guò)程
# 創(chuàng)建tfrecords迭代器,每個(gè)樣本都是序列化的
serialized_ex_it = tf.python_io.tf_record_iterator(tfrecords)
for serialized_ex in serialized_ex_it:
# 創(chuàng)建Example對(duì)象
example = tf.train.Example()
# 進(jìn)行解序列化(注:解析的是Example對(duì)象)
example.ParseFromString(serialized_ex)
# 輸出所有信息, 如果想知道TFRecords中屬性可以輸出example
print example
# 取出正確的key并且正確的類型的value值,錯(cuò)一個(gè)都會(huì)取不出值
image = example.features.feature['img_raw'].bytes_list.value
label = example.features.feature['label'].int64_list.value
print image, label
雖然上述可以讀取數(shù)據(jù),但是每個(gè)樣本都需要解析一次,往往我們的數(shù)據(jù)都是結(jié)構(gòu)化的,能不能一次就讀入許多數(shù)據(jù),并且在訓(xùn)練的時(shí)候數(shù)據(jù)是需要反復(fù)輸入到訓(xùn)練集中的.
# Read TFRecords using queue structs
def ReadTFRecord(tfrecords):
# 可以把多個(gè)tfrecords排成一個(gè)queue,這樣可以方便的使用多個(gè)tfrecords文件
record_queue = tf.train.string_input_producer([tfrecords])
# 讀取TFRecords器
reader = tf.TFRecordReader()
# 一個(gè)數(shù)據(jù)一個(gè)數(shù)據(jù)的讀返回key-value值,都保存在serialized_ex中
# 注意: 這里面keys是序列化的副產(chǎn)物,命名為tfrecords+random(),表示唯一的ID,沒(méi)有作用,可以設(shè)置為_(kāi)
#keys, serialized_ex = reader.read(record_queue)
_, serialized_ex = reader.read(record_queue)
# 直接解析出features數(shù)據(jù),并且使用固定特征長(zhǎng)度,及每個(gè)Example中一定會(huì)存在一個(gè)image和一個(gè)label
# 并不是輸入的圖片大小不同就使用VarLenFeature.
features = tf.parse_single_example(serialized_ex,
features={
# 取出key為img_raw和label的數(shù)據(jù),尤其是int位數(shù)一定不能錯(cuò)!!!
'img_raw': tf.FixedLenFeature([],tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
# 注意定義的為int多少位就轉(zhuǎn)換成多少位,否則容易出錯(cuò)!!
label = tf.cast(features['label'], tf.int64)
return img, label
imgs,labels = ReadTFRecord(tfrecords)
sess = tf.Session()
# 多線程調(diào)節(jié)器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
# 輸出10個(gè)樣本
for i in range(10):
image,label = sess.run([imgs,labels])
print image.shape,'label:', label