加載數據
TensorFlow 作為符號編程框架妖泄,需要先構建數據流圖丹喻,再讀取數據,隨后進行模型訓練螃征。
- 預加載數據(preloaded data):在 TensorFlow 圖中定義常量或變量來保存所有數據。這種方式的缺點在于略步,將數據直接嵌在數據流圖中描扯,當訓練數據較大時,很消耗內存趟薄。
- 填充數據(feeding): 使用 sess.run()中的 feed_dict 參數绽诚,將 Python 產生的數據填充給后端。Python 產生數據竟趾,再把數據填充后端憔购。填充的方式也有數據量大、消耗內存等缺點岔帽。
- 從文件讀取數據(reading from file):從文件中直接讀取玫鸟,讓隊列管理器從文件中讀取數據。這是最推薦的方式犀勒,讓 TensorFlow 自己從文件中讀取數據屎飘,并解碼成可使用的樣本集。
import tensorflow as tf
# 第二種方式:填充數據
a1 = tf.placeholder(tf.int16)
a2 = tf.placeholder(tf.int16)
b = tf.add(x1, x2)
# 用 Python 產生數據
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打開一個會話贾费,將數據填充給后端
with tf.Session() as sess:
print sess.run(b, feed_dict={a1: li1, a2: li2})
TFRecords 是一種二進制文件钦购,能更好地利用內存,更方便地復制和移動褂萧,并且不需要單獨的標記文件押桃。
從文件讀取數據分為如下兩個步驟:
(1)把樣本數據寫入 TFRecords 二進制文件;
(2)再從隊列中讀取导犹。
把樣本數據寫入 TFRecords 二進制文件
- 將數據填入到 tf.train.Example 的協議緩沖區(qū)(protocolbuffer)中
example=tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[i].tolist)),
'image_raw': _bytes_feature(image_raw)
}))
- 將協議緩沖區(qū)序列化為一個字符串唱凯,通過 tf.python_io.TFRecordWriter 寫入 TFRecords文件
#定義一個writer
filename=os.path.join(os.getcwd(),name+'.tfrecords')
writer= tf.python_io.TFRecordWriter(filename)
......
#對于for i in range(num_example)中的每個example,寫入文件
writer.write(example.SerializerToString())
- 最后關閉writer
writer.close()
從隊列中讀取
一旦生成了 TFRecords 文件谎痢,接下來就可以使用隊列讀取數據了磕昼。主要分為 3 步:
(1)創(chuàng)建張量,從二進制文件讀取一個樣本节猿;
(2)創(chuàng)建張量票从,從二進制文件隨機讀取一個 mini-batch;
(3)把每一批張量傳入網絡作為輸入節(jié)點滨嘱。