Tensorflow 數(shù)據(jù)讀取

TF官網(wǎng)上給出了三種讀取數(shù)據(jù)的方式:

  1. Preloaded data: 預(yù)加載數(shù)據(jù)
  2. Feeding: Python 產(chǎn)生數(shù)據(jù)唆垃,再把數(shù)據(jù)喂給后端
  3. Reading from file:從文件中直接讀取
    (Ps: 此處參考博客 詳解TF數(shù)據(jù)讀取有三種方式(next_batch))
    (Pps: 文中的代碼均基于Python3.6版本)

TF的核心是用C++寫(xiě)的遵湖,運(yùn)行快叽躯,但是調(diào)用不靈活。結(jié)合Python和TF趟卸,將計(jì)算的核心算子和運(yùn)行框架用C++寫(xiě)携龟,然后以API的形式提供給Python調(diào)用崇堰。Python的主要工作是設(shè)計(jì)計(jì)算圖(模型及數(shù)據(jù))亿虽,將設(shè)計(jì)好的Graph提供給后端執(zhí)行。簡(jiǎn)而言之它呀,TF是Run螺男,Pyhton的角色是Design。

一. Preloaded Data

  • constant钟些,常量
  • variable烟号,初始化或者后面更新均可

這種數(shù)據(jù)讀取方式只適合小數(shù)據(jù),通常在程序中定義某固定值政恍,如循環(huán)次數(shù)等汪拥,而很少用來(lái)讀取訓(xùn)練數(shù)據(jù)。

import tensorflow as tf
# 設(shè)計(jì)Graph
a = tf.constant([1, 2, 3])
b = tf.Variable([1, 2, 4])
c = tf.add(a, b)


二. Feeding

Feeding的方式在設(shè)計(jì)Graph的時(shí)候留占位符篙耗,在真正Run的時(shí)候向占位符中傳遞數(shù)據(jù)迫筑,喂給后端訓(xùn)練。

#!/usr/bin/env python3
# _*_coding:utf-8 _*_

import tensorflow as tf
# 設(shè)計(jì)Graph
a = tf.placeholder(tf.int16) 
b = tf.placeholder(tf.int16)
c = tf.add(a, b)
# 用Python產(chǎn)生數(shù)據(jù)
li1 = [2, 3, 4] # li1:<type:'list'>: [2, 3, 4]
li2 = [4, 0, 1]
# 打開(kāi)一個(gè)session --> 喂數(shù)據(jù) --> 計(jì)算y
with tf.Session() as sess:
  print(sess.run(c, feed_dict={a: li1, b: li2})) # [6, 3, 5]

這里tf.placeholder代表占位符宗弯,先定一下變量a的類(lèi)型脯燃。在實(shí)際運(yùn)行的時(shí)候,通過(guò)feed_dict來(lái)指定a在計(jì)算中的實(shí)際值蒙保。

這種數(shù)據(jù)讀取方式非常靈活辕棚,而且易于理解,但是在讀取大數(shù)據(jù)時(shí)會(huì)非常吃力。



三. Read from file

官網(wǎng)上給出的例子是從csv等文件中讀取數(shù)據(jù)逝嚎,這里都會(huì)涉及到隊(duì)列的概念扁瓢, 我們首先簡(jiǎn)單介紹一下Queue讀取數(shù)據(jù)的原理,便于后面代碼的理解补君。(參考 Blog

讀取數(shù)據(jù)其實(shí)是為了后續(xù)的計(jì)算引几,以圖片為例,假設(shè)我們的硬盤(pán)中有一個(gè)圖片數(shù)據(jù)集0001.jpg挽铁,0002.jpg伟桅,0003.jpg……我們只需要把它們讀取到內(nèi)存中,然后提供給GPU或是CPU進(jìn)行計(jì)算就可以了叽掘。這聽(tīng)起來(lái)很容易楣铁,但事實(shí)遠(yuǎn)沒(méi)有那么簡(jiǎn)單。事實(shí)上够掠,我們必須要把數(shù)據(jù)先讀入后才能進(jìn)行計(jì)算民褂,假設(shè)讀入用時(shí)0.1s,計(jì)算用時(shí)0.9s疯潭,那么就意味著每過(guò)1s,GPU都會(huì)有0.1s無(wú)事可做面殖,這就大大降低了運(yùn)算的效率竖哩。

隊(duì)列的存在就是為了使計(jì)算的速度不完全受限于數(shù)據(jù)讀取的速度,保證有足夠多的數(shù)據(jù)喂給計(jì)算脊僚。如圖所示相叁,將數(shù)據(jù)的讀入和計(jì)算分別放在兩個(gè)線程中,讀入的數(shù)據(jù)保存為內(nèi)存中的一個(gè)隊(duì)列辽幌,負(fù)責(zé)計(jì)算的線程可以源源不斷地從內(nèi)存隊(duì)列中讀取數(shù)據(jù)增淹。這樣就解決了GPU因?yàn)镮O而空閑的問(wèn)題。

Tensorflow中在內(nèi)存隊(duì)列之前又添加了一個(gè)文件名隊(duì)列乌企,這是因?yàn)闄C(jī)器學(xué)習(xí)中一般會(huì)設(shè)定epoch虑润。對(duì)于一個(gè)數(shù)據(jù)集來(lái)說(shuō),運(yùn)行一個(gè)epoch就是將這個(gè)數(shù)據(jù)集中的樣本數(shù)據(jù)全部計(jì)算一遍加酵。如圖所示拳喻,當(dāng)數(shù)據(jù)集結(jié)束后可以做一個(gè)標(biāo)注,以此來(lái)告訴計(jì)算機(jī)這個(gè)epoch結(jié)束了猪腕。

文件名隊(duì)列冗澈,我們用tf.train.string_input_producer()函數(shù)創(chuàng)建文件名隊(duì)列。

tf.train.string_input_producer(
    string_tensor,     # 文件名列表
    num_epochs=None,   # epoch的個(gè)數(shù)陋葡,None代表無(wú)限循環(huán)
    shuffle=True,      # 一個(gè)epoch內(nèi)的樣本(文件)順序是否打亂
    seed=None,         # 當(dāng)shuffle=True時(shí)才用亚亲,應(yīng)該是指定一個(gè)打亂順序的入口
    capacity=32,       # 設(shè)置隊(duì)列的容量
    shared_name=None,
    name=None,
    cancel_op=None)

ps: 在Tensorflow中,內(nèi)存隊(duì)列不需要我們自己建立,后續(xù)只需要使用reader從文件名隊(duì)列中讀取數(shù)據(jù)就可以捌归。

tf.train.string_input_produecer()會(huì)將一個(gè)隱含的QueueRunner添加到全局圖中(類(lèi)似的操作還有tf.train.shuffle_batch()等)肛响。由于沒(méi)有顯式地返回QueueRunner()來(lái)調(diào)用create_threads()啟動(dòng)線程,這里使用了tf.train.start_queue_runners()方法直接啟動(dòng)tf.GraphKeys.QUEUE_RUNNERS集合中的所有隊(duì)列線程陨溅。

在我們使用tf.train.string_input_producer創(chuàng)建文件名隊(duì)列后终惑,整個(gè)系統(tǒng)其實(shí)還是處于“停滯狀態(tài)”的,也就是說(shuō)门扇,我們文件名并沒(méi)有真正被加入到隊(duì)列中(如下圖所示)雹有。此時(shí)如果我們開(kāi)始計(jì)算,因?yàn)閮?nèi)存隊(duì)列中什么也沒(méi)有臼寄,計(jì)算單元就會(huì)一直等待霸奕,導(dǎo)致整個(gè)系統(tǒng)被阻塞。

而使用tf.train.start_queue_runners()之后吉拳,才會(huì)啟動(dòng)填充隊(duì)列的線程质帅,這時(shí)系統(tǒng)就不再“停滯”。此后計(jì)算單元就可以拿到數(shù)據(jù)并進(jìn)行計(jì)算留攒,整個(gè)程序也就跑起來(lái)了煤惩,這就是函數(shù)tf.train.start_queue_runners的用處。

在讀取文件的整個(gè)過(guò)程中會(huì)涉及到:

  • 文件名隊(duì)列創(chuàng)建: tf.train.string_input_producer()
  • 文件閱讀器: tf.TFRecordReader()
  • 文件解析器:tf.parse_single_example() 或者decode_csv()
  • Batch_size:tf.train.shuffle_batch()
  • 填充進(jìn)程:tf.train.start_queue_runners()

下面我們用python生成數(shù)據(jù)炼邀,并將數(shù)據(jù)轉(zhuǎn)換成tfrecord格式魄揉,然后讀取tfrecord文件。在這過(guò)程中拭宁,我們會(huì)介紹幾種不同的從文件讀取數(shù)據(jù)的方法洛退。

生成數(shù)據(jù):

#!/usr/bin/env python3 
# _*_coding:utf-8 _*_

import os
import numpy as np
'''
二分類(lèi)問(wèn)題,樣本數(shù)據(jù)是形如1杰标,2兵怯,5,8腔剂,9(1*5)的隨機(jī)數(shù)媒区,對(duì)應(yīng)標(biāo)簽是0或1
arg:
    data_filename: 路徑下的文件名 'data/data_train.txt'
    size: 設(shè)定生成樣本數(shù)據(jù)的size=(10000, 5),其中10000是樣本個(gè)數(shù)桶蝎,5是單個(gè)樣本的特征驻仅。
'''
gene_data = 'data/data_train.txt'
size = (100000, 5)
def generate_data(gene_data, size):
    if not os.path.exists(gene_data):
        np.random.seed(9)
        x_data = np.random.randint(0, 10, size=size)
        # 這里設(shè)置標(biāo)簽值一半樣本是0,一半樣本是1
        y1_data = np.ones((size[0]//2, 1), int) # 這里需要注意python3和python2的區(qū)別登渣。
        y2_data = np.zeros((size[0]//2, 1), int) # python2用/得到整數(shù)噪服,python3要用//。否則會(huì)報(bào)錯(cuò)“'float' object cannot be interpreted as an integer”
        y_data = np.append(y1_data, y2_data)
        np.random.shuffle(y_data)

        # 將樣本和標(biāo)簽以1 2 3 6 8/1的形式來(lái)保存
        xy_data = str('')
        for xy_row in range(len(x_data)):
            x_str = str('')
            for xy_col in range(len(x_data[0])):
                if not xy_col == (len(x_data[0])-1):
                    x_str =x_str+str(x_data[xy_row, xy_col])+' '
                else:
                    x_str = x_str + str(x_data[xy_row, xy_col])
            y_str = str(y_data[xy_row])
            xy_data = xy_data+(x_str+'/'+y_str + '\n')
        #print(xy_data[1])

        # write to txt 保存成txt格式
        write_txt = open(gene_data, 'w')
        write_txt.write(xy_data)
        write_txt.close()
    return
# generate_data(gene_data=gene_data, size=size) # 取消注釋后可以直接生成數(shù)據(jù)

從txt文件中讀取數(shù)據(jù)胜茧,并轉(zhuǎn)換成TFrecord格式

tfrecord數(shù)據(jù)文件是一種將數(shù)據(jù)和標(biāo)簽統(tǒng)一存儲(chǔ)的二進(jìn)制文件粘优,能更好的利用內(nèi)存仇味,在tensorflow中快速的復(fù)制,移動(dòng)雹顺,讀取丹墨,存儲(chǔ)等。

TFRecord 文件中的數(shù)據(jù)是通過(guò) tf.train.Example() 以 Protocol Buffer(協(xié)議緩沖區(qū)) 的格式存儲(chǔ)嬉愧。Protocol Buffer是Google的一種數(shù)據(jù)交換的格式贩挣,他獨(dú)立于語(yǔ)言,獨(dú)立于平臺(tái)没酣,以二進(jìn)制的形式存在王财,能更好的利用內(nèi)存,方便復(fù)制和移動(dòng)裕便。
tf.train.Example()包含F(xiàn)eatures字段绒净,通過(guò)feature將數(shù)據(jù)和label進(jìn)行統(tǒng)一封裝, 然后將example協(xié)議內(nèi)存塊轉(zhuǎn)化為字符串偿衰。tf.train.Features()是字典結(jié)構(gòu)挂疆,包括字符串格式的key,可以自己定義key下翎。與key對(duì)應(yīng)的是value值缤言,這里需要注意的是,feature的value值只支持列表视事,可以是字符串(Byteslist)墨闲,浮點(diǎn)數(shù)列表(Floatlist)和整型數(shù)列表(int64list),所以郑口,在給value賦值時(shí)一定要注意類(lèi)型將數(shù)據(jù)轉(zhuǎn)換為這三種類(lèi)型的列表。

  • 類(lèi)型為標(biāo)量:如0盾鳞,1標(biāo)簽犬性,轉(zhuǎn)為列表。 tf.train.Int64List(value=[label])
  • 類(lèi)型為數(shù)組:sample = [1, 2, 3]腾仅,tf.train.Int64List(value=sample)
  • 類(lèi)型為矩陣:sample = [[1, 2, 3], [1, 2 ,3]]乒裆,
    兩種方式:
    轉(zhuǎn)成list類(lèi)型:將張量fatten成list(向量)
    轉(zhuǎn)成string類(lèi)型:將張量用.tostring()轉(zhuǎn)換成string類(lèi)型。
    同時(shí)要記得保存形狀信息推励,在讀取后恢復(fù)shape鹤耍。
'''
讀取txt中的數(shù)據(jù),并將數(shù)據(jù)保存成tfrecord文件
arg:
    txt_filename: 是txt保存的路徑+文件名 'data/data_train.txt'
    tfrecord_path:tfrecord文件將要保存的路徑及名稱 'data/test_data.tfrecord'
'''
def txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:讀取TXT數(shù)據(jù)验辞,并分割出樣本數(shù)據(jù)和標(biāo)簽
    file = open(txt_filename)
    for data_line in file.readlines(): # 每一行
        data_line = data_line.strip('\n') # 去掉換行符
        sample = []
        spls = data_line.split('/', 1)[0]# 樣本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]# 標(biāo)簽
        label = int(label)
        # print('sample:', sample, 'labels:', label)

        # 第三步: 建立feature字典稿黄,tf.train.Feature()對(duì)單一數(shù)據(jù)編碼成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解為將內(nèi)層多個(gè)feature的字典數(shù)據(jù)再編碼,集成為features
        features = tf.train.Features(feature = feature)
        # 第五步:將features數(shù)據(jù)封裝成特定的協(xié)議格式
        example = tf.train.Example(features=features)
        # 第六步:將example數(shù)據(jù)序列化為字符串
        Serialized = example.SerializeToString()
        # 第七步:將序列化的字符串?dāng)?shù)據(jù)寫(xiě)入?yún)f(xié)議緩沖區(qū)
        writer.write(Serialized)
    # 記得關(guān)閉writer和open file的操作
    writer.close()
    file.close()
    return
# txt_to_tfrecord(txt_filename=gene_data, tfrecord_path=tfrecord_path)

所以在上面的程序中我們涉及到了讀取txt文本數(shù)據(jù)跌造,并將數(shù)據(jù)寫(xiě)成tfrecord文件杆怕。在網(wǎng)絡(luò)訓(xùn)練過(guò)程中數(shù)據(jù)的讀取通常是對(duì)tfrecord文件的操作族购。

TF讀取tfrecord文件有兩種方式:一種是Queue方式,就是上面介紹的隊(duì)列陵珍,另外一種是用dataset來(lái)讀取寝杖。先介紹Queue讀取文件數(shù)據(jù)的方法

1. Queue方式

Queue讀取數(shù)據(jù)可以分為兩種:tf.parse_single_example()和tf.parse_example()

(1). tf.parse_single_example()讀取數(shù)據(jù)

tf.parse_single_example(
    serialized,  # 張量
    features,  # 對(duì)應(yīng)寫(xiě)入的features
    name=None,
    example_names=None)
'''
用tf.parse_single_example()讀取并解析tfrecord文件
args: 
      filename_queue: 文件名隊(duì)列
      shuffle_batch: 判斷在batch的時(shí)候是否要打亂順序
      if_enq_many: 設(shè)定batch中的參數(shù)enqueue_many,評(píng)估該參數(shù)的作用
'''
# 第一步: 建立文件名隊(duì)列互纯,可設(shè)置Epoch次數(shù)
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)

def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立閱讀器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根據(jù)寫(xiě)入時(shí)的格式建立相對(duì)應(yīng)的讀取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),# 如果不是標(biāo)量瑟幕,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析單個(gè)EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:對(duì)數(shù)據(jù)進(jìn)行后處理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)
    # 第六步:生成Batch數(shù)據(jù) generate batch
    if shuffle_batch:  # 打亂數(shù)據(jù)順序,隨機(jī)取樣
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                 batch_size=2,
                                                 capacity=200000,
                                                 min_after_dequeue=10000,
                                                 num_threads=1,
                                                 enqueue_many=if_enq_many)# 主要是為了評(píng)估enqueue_many的作用
    else:  # # 如果不打亂順序則用tf.train.batch(), 輸出隊(duì)列按順序組成Batch輸出
        sample_single, label_single = tf.train.batch([sample, label],
                                                batch_size=2,
                                                capacity=200000,
                                                min_after_dequeue=10000,
                                                num_threads=1,
                                                enqueue_many = if_enq_many)
    return sample_single, label_single
x1_samples, y1_labels = read_single(filename_queue=filename_queue, 
shuffle_batch=False, if_enq_many=False)
x2_samples, y2_labels = read_single(filename_queue=filename_queue, 
shuffle_batch=True, if_enq_many=False)
print(x1_samples, y1_labels) # 因?yàn)槭莟ensor留潦,這里還處于構(gòu)造tensorflow計(jì)算圖的過(guò)程只盹,輸出僅僅是shape等,不會(huì)是具體的數(shù)值愤兵。
# 如果想得到具體的數(shù)值鹿霸,必須建立session,是tensor在計(jì)算圖中流動(dòng)起來(lái)秆乳,也就是用session.run()的方式得到具體的數(shù)值懦鼠。
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=3)中num_epochs不為空的化,必須要初始化local變量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開(kāi)始進(jìn)入文件名隊(duì)列和內(nèi)存
    for i in range(1):
        # Queue + tf.parse_single_example()讀取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1) # 這里就可以得到tensor具體的數(shù)值
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2) # 這里就可以得到tensor具體的數(shù)值
    coord.request_stop()
    coord.join(threads)

Ps: 如果建立文件名tf.train.string_input_producer([tfrecord_path], num_epochs=3)時(shí)屹堰, 設(shè)置num_epochs為具體的值(不是None)肛冶。在初始化的時(shí)候必須對(duì)local_variables進(jìn)行初始化sess.run(tf.local_variables_initializer())。否則會(huì)報(bào)錯(cuò):
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 2, current size 0)

上面第六步batch前取到的是單個(gè)樣本數(shù)據(jù)扯键,在實(shí)際訓(xùn)練中通常用批量數(shù)據(jù)來(lái)更新參數(shù)睦袖,設(shè)置批量讀取數(shù)據(jù)的時(shí)候有按順序讀取數(shù)據(jù)的tf.train.batch()和打亂數(shù)據(jù)出列順序的tf.train.shuffle_batch()。假設(shè)文本中的數(shù)據(jù)如圖所示:

設(shè)置batch_size=2, shuffle_batch=True和False時(shí)的輸出分別為:

X11:  [[5. 6. 8. 6. 1.] [6. 4. 8. 1. 8.]] Y11:  [1. 1.] #用tf.train.batch()
X21:  [[0. 4. 3. 7. 8.] [5. 0. 2. 8. 7.]] Y21:  [0. 1.] # 用tf.train.shuffle_batch()

這里需要對(duì)tf.train.shuffle_batch()和tf.train.batch()的參數(shù)進(jìn)行說(shuō)明

tf.train.shuffle_batch(
    tensors,
    batch_size, # 設(shè)置batch_size的大小
    capacity,  # 設(shè)置隊(duì)列中最大的數(shù)據(jù)量荣刑,容量馅笙。一般要求capacity > min_after_dequeue + num_threads*batch_size
    min_after_dequeue, # 隊(duì)列中最小的數(shù)據(jù)量作為隨機(jī)取樣的緩沖區(qū)。越大厉亏,數(shù)據(jù)混合越充分董习,認(rèn)為采樣到的數(shù)據(jù)更具有隨機(jī)性。
    # 但是這個(gè)值設(shè)置太大在初始啟動(dòng)時(shí)爱只,需要給隊(duì)列喂足夠多的數(shù)據(jù)皿淋,啟動(dòng)慢,而且占用內(nèi)存恬试。
    num_threads=1, # 設(shè)置線程數(shù)
    seed=None,
    enqueue_many=False, # Whether each tensor in tensor_list is a single example. 在下面單獨(dú)說(shuō)明
    shapes=None,
    allow_smaller_final_batch=False, # (Optional) Boolean. If True, allow the final batch to be smaller if there are insufficient items left in the queue.
    shared_name=None,
    name=None)
tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None)  # 注意:這里沒(méi)有min_after_dequeue這個(gè)參數(shù)

讀取數(shù)據(jù)的目的是為了訓(xùn)練網(wǎng)絡(luò)窝趣,而使用Batch訓(xùn)練網(wǎng)絡(luò)的原因可以解釋為:

深度學(xué)習(xí)的優(yōu)化說(shuō)白了就是梯度下降。每次的參數(shù)更新有兩種方式训柴。

  • 第一種哑舒,遍歷全部數(shù)據(jù)集算一次損失函數(shù),然后算函數(shù)對(duì)各個(gè)參數(shù)的梯度畦粮,更新梯度散址。這種方法每更新一次參數(shù)都要把數(shù)據(jù)集里的所有樣本都看一遍乖阵,計(jì)算量開(kāi)銷(xiāo)大,計(jì)算速度慢预麸,不支持在線學(xué)習(xí)瞪浸,這稱為Batch gradient descent,批梯度下降吏祸。
  • 另一種对蒲,每看一個(gè)數(shù)據(jù)就算一下?lián)p失函數(shù),然后求梯度更新參數(shù)贡翘,這個(gè)稱為隨機(jī)梯度下降蹈矮,stochastic gradient descent。這個(gè)方法速度比較快鸣驱,但是收斂性能不太好泛鸟,可能在最優(yōu)點(diǎn)附近晃來(lái)晃去,hit不到最優(yōu)點(diǎn)踊东。兩次參數(shù)的更新也有可能互相抵消掉北滥,造成目標(biāo)函數(shù)震蕩的比較劇烈。
    為了克服兩種方法的缺點(diǎn)闸翅,現(xiàn)在一般采用的是一種折中手段再芋,mini-batch gradient decent,小批的梯度下降坚冀,這種方法把數(shù)據(jù)分為若干個(gè)批济赎,按批來(lái)更新參數(shù),這樣记某,一個(gè)批中的一組數(shù)據(jù)共同決定了本次梯度的方向司训,下降起來(lái)就不容易跑偏,減少了隨機(jī)性液南。另一方面因?yàn)榕臉颖緮?shù)與整個(gè)數(shù)據(jù)集相比小了很多豁遭,計(jì)算量也不是很大。

個(gè)人理解:大Batch_size一是會(huì)受限于計(jì)算機(jī)硬件贺拣,另一方面將會(huì)降低梯度下降的隨機(jī)性。 而小Batch_size收斂速度慢

這里用代碼對(duì)enqueue_many這個(gè)參數(shù)進(jìn)行理解

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np

tensor_list = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]

with tf.Session() as sess:
    x1 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=False)
    x2 = tf.train.batch(tensor_list, batch_size=3, enqueue_many=True)
    x3 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=False)
    x4 = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity = 1000, min_after_dequeue=100, num_threads=1, enqueue_many=True)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    print("x1 batch:" + "-" * 10)
    print(sess.run(x1))

    print("x2 batch:" + "-" * 10)
    print(sess.run(x2))
    print("x2 batch:" + "-" * 10)
    print(sess.run(x2))

    print("x3 batch:" + "-" * 10)
    print(sess.run(x3))

    print("x4 batch:" + "-" * 10)
    print(sess.run(x4))

    coord.request_stop()
    coord.join(threads)

輸出如下:

由以上輸出可以看出捂蕴,當(dāng)enqueue_many=False(默認(rèn)值)時(shí)譬涡,輸出為batch_size*tensor.shape,把輸入tensors看作一個(gè)樣本啥辨,Batch就是對(duì)第一個(gè)維度的數(shù)據(jù)進(jìn)行重復(fù)采樣涡匀,將tensor擴(kuò)展一個(gè)維度。
當(dāng)enqueue_many=True時(shí)溉知,tensor是一個(gè)樣本陨瘩,batch_size只是調(diào)整樣本中的維度腕够。這里tensor的維度保持不變,只是在最后一個(gè)維度上根據(jù)batch_size調(diào)整了大小舌劳。而最后一個(gè)維度內(nèi)的順序是亂序的帚湘。
對(duì)于shuffle_batch,注意到甚淡,第1維(矩陣每一行)上的數(shù)據(jù)是打亂的大诸,所以從[1, 2, 3, 4]中取到了[2, 4, 4]。
如果輸入的樣本是一個(gè)3x6的矩陣贯卦。設(shè)置batch_size=5资柔,enqueue_many = False時(shí),tensor會(huì)被擴(kuò)展為3x6x5的張量, 并且撵割。當(dāng)enqueue_many = True時(shí)贿堰,tensor是3x5,第二個(gè)維度上截取size啡彬。
這里比較疑惑的是shuffle在這里感覺(jué)沒(méi)有任何作用羹与??外遇?

(2). tf.parse_example()讀取數(shù)據(jù)

'''
用tf.parse_example()批量讀取數(shù)據(jù)注簿,據(jù)說(shuō)比tf.parse_single_exaple()讀取數(shù)據(jù)的速度快(沒(méi)有驗(yàn)證)
args:
      filename_queue: 文件名隊(duì)列
      shuffle_batch: 是否批量讀取數(shù)據(jù)
      if_enq_many: batch時(shí)enqueue_many參數(shù)的設(shè)定,這里主要用于評(píng)估該參數(shù)的作用
'''
# 第一步: 建立文件名隊(duì)列
filename_queue = tf.train.string_input_producer([tfrecord_path])
def read_parse(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立閱讀器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步: 設(shè)置shuffle_batch
    if shuffle_batch:
        batch = tf.train.shuffle_batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               min_after_dequeue=1000,
                               num_threads=1,
                               enqueue_many=if_enq_many)# 主要是為了評(píng)估enqueue_many的作用

    else:
        batch = tf.train.batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               num_threads=1,
                               enqueue_many=if_enq_many)
        # 第四步:根據(jù)寫(xiě)入時(shí)的格式建立相對(duì)應(yīng)的讀取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量跳仿,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第五步: 用tf.parse_example()解析多個(gè)EXAMPLE PROTO
    Features = tf.parse_example(batch, features)

    # 第六步:對(duì)數(shù)據(jù)進(jìn)行后處理
    samples_parse= tf.cast(Features['sample'], tf.float32)
    labels_parse = tf.cast(Features['label'], tf.float32)
    return samples_parse, labels_parse

x2_samples, y2_labels = read_parse(filename_queue=filename_queue, shuffle_batch=True, if_enq_many=False)
print(x2_samples, y2_labels)
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開(kāi)始進(jìn)入文件名隊(duì)列和內(nèi)存
    for i in range(1):
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2)

    coord.request_stop()
    coord.join(threads)

調(diào)試的時(shí)候這里碰到一個(gè)bug诡渴,提示:return處local variable 'samples_parse' referenced before assignment。網(wǎng)上給的解決辦法基本是python在自上而下執(zhí)行的時(shí)候無(wú)法區(qū)分變量是全局變量還是局部變量菲语。實(shí)際上是我在寫(xiě)第四步/第五步的時(shí)候多了縮進(jìn)妄辩,導(dǎo)致沒(méi)有定義features。(??:python對(duì)縮進(jìn)敏感)

?? 閱讀器 + 樣本

根據(jù)以上例子山上,假設(shè)txt中的數(shù)據(jù)只有2個(gè)樣本眼耀,如下圖所示:

在建立文件名隊(duì)列時(shí),加入這兩個(gè)txt文檔的文件名

# 第一步: 建立文件名隊(duì)列
filename_queue = tf.train.string_input_producer([tfrecord_path, tfrecord_path1])

(1). 單個(gè)閱讀器 + 單個(gè)樣本

batch_size=1 (注意:這里先將num_threads設(shè)置為1)

sample_single, label_single = tf.train.batch([sample, label],
                                                 batch_size=1,
                                                 capacity=10000,     
                                                 num_threads=1,
                                                 enqueue_many=if_enq_many)
    for i in range(5):
        X14, Y14 = sess.run([x14_samples, y14_labels])
        print('X14: ', X14, 'Y14: ', Y14)

打印輸出結(jié)果為:

('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0.], dtype=float32))

(2). 單個(gè)閱讀器 + 多個(gè)樣本

batch_size = 3
輸出結(jié)果為:

('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))
('X14: ', array([[6., 4., 8., 1., 8.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([1., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))

(3). 多個(gè)閱讀器 + 多個(gè)樣本

多閱讀器需要用tf.train.batch_join()或者tf.train.shuffle_batch_join()佩憾,對(duì)程序作稍微的修改

example_list = [[sample, label] for _ in range(2)]  # Reader設(shè)置為2
sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)

輸出結(jié)果為:

('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.]], dtype=float32), 'Y14: ', array([0., 0., 0.], dtype=float32))
('X14: ', array([[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.],[8., 2., 6., 8., 1.]], dtype=float32), 'Y14: ', array([1., 1., 0.], dtype=float32))
('X14: ', array([[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.],[6., 4., 8., 1., 8.]], dtype=float32), 'Y14: ', array([0., 1., 1.], dtype=float32))
('X14: ', array([[8., 2., 6., 8., 1.],[8., 3., 5., 3., 6.],[5., 6., 8., 6., 1.]], dtype=float32), 'Y14: ', array([0., 0., 1.], dtype=float32))

從輸出結(jié)果來(lái)看哮伟,單個(gè)閱讀器+多個(gè)樣本多個(gè)閱讀器+多個(gè)樣本在結(jié)果呈現(xiàn)時(shí)并沒(méi)有什么區(qū)別,至于對(duì)運(yùn)行速度的影響還有待驗(yàn)證妄帘。

附上對(duì)閱讀器進(jìn)行測(cè)試的完整代碼:

# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os

data_filename1 = 'data/data_train1.txt'  # 生成txt數(shù)據(jù)保存路徑
data_filename2 = 'data/data_train2.txt'  # 生成txt數(shù)據(jù)保存路徑
tfrecord_path1 = 'data/test_data1.tfrecord'  # tfrecord1文件保存路徑
tfrecord_path2 = 'data/test_data2.tfrecord'  # tfrecord2文件保存路徑

##############################  讀取txt文件楞黄,并轉(zhuǎn)為tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename, tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:讀取TXT數(shù)據(jù),并分割出樣本數(shù)據(jù)和標(biāo)簽
    file = open(txt_filename)
    for data_line in file.readlines():  # 每一行
        data_line = data_line.strip('\n')  # 去掉換行符
        sample = []
        spls = data_line.split('/', 1)[0]  # 樣本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]  # 標(biāo)簽
        label = int(label)

        # 第三步: 建立feature字典抡驼,tf.train.Feature()對(duì)單一數(shù)據(jù)編碼成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解為將內(nèi)層多個(gè)feature的字典數(shù)據(jù)再編碼鬼廓,集成為features
        features = tf.train.Features(feature=feature)
        # 第五步:將features數(shù)據(jù)封裝成特定的協(xié)議格式
        example = tf.train.Example(features=features)
        # 第六步:將example數(shù)據(jù)序列化為字符串
        Serialized = example.SerializeToString()
        # 第七步:將序列化的字符串?dāng)?shù)據(jù)寫(xiě)入?yún)f(xié)議緩沖區(qū)
        writer.write(Serialized)
    # 記得關(guān)閉writer和open file的操作
    writer.close()
    file.close()
    return
txt_to_tfrecord(txt_filename=data_filename1, tfrecord_path=tfrecord_path1)
txt_to_tfrecord(txt_filename=data_filename2, tfrecord_path=tfrecord_path2)


# 第一步: 建立文件名隊(duì)列
filename_queue = tf.train.string_input_producer([tfrecord_path1, tfrecord_path2])
def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立閱讀器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根據(jù)寫(xiě)入時(shí)的格式建立相對(duì)應(yīng)的讀取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析單個(gè)EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:對(duì)數(shù)據(jù)進(jìn)行后處理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)

    # 第六步:生成Batch數(shù)據(jù) generate batch
    if shuffle_batch:  # 打亂數(shù)據(jù)順序致盟,隨機(jī)取樣
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                             batch_size=1,
                                                             capacity=10000,
                                                             min_after_dequeue=1000,
                                                             num_threads=1,
                                                             enqueue_many=if_enq_many)  # 主要是為了評(píng)估enqueue_many的作用
    else:  # # 如果不打亂順序則用tf.train.batch(), 輸出隊(duì)列按順序組成Batch輸出

        ###################### multi reader, multi samples, please code as below     ###############################
        '''
        example_list = [[sample,label] for _ in range(2)]  # Reader設(shè)置為2

        sample_single, label_single = tf.train.batch_join(example_list, batch_size=3)
        '''
        #######################  single reader, single sample,  please set batch_size = 1   #########################
        #######################  single reader, multi samples,  please set batch_size = batch_size    ###############
        sample_single, label_single = tf.train.batch([sample, label],
                                                     batch_size=1,
                                                     capacity=10000,
                                                     num_threads=1,
                                                     enqueue_many=if_enq_many)

    return sample_single, label_single

x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=False, if_enq_many=False)

# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不為空的化碎税,必須要初始化local變量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開(kāi)始進(jìn)入文件名隊(duì)列和內(nèi)存
    for i in range(5):
        # Queue + tf.parse_single_example()讀取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1)
        # Queue + tf.parse_example()讀取tfrecord文件

    coord.request_stop()
    coord.join(threads)

2. Dataset + TFrecrods讀取數(shù)據(jù)

這是目前官網(wǎng)上比較推薦的一種方式尤慰,相對(duì)于隊(duì)列讀取文件的方法,更為簡(jiǎn)單雷蹂。
Dataset API:將數(shù)據(jù)直接放在graph中進(jìn)行處理伟端,整體對(duì)數(shù)據(jù)集進(jìn)行上述數(shù)據(jù)操作,使代碼更加簡(jiǎn)潔

Dataset直接導(dǎo)入比較簡(jiǎn)單萎河,這里只是簡(jiǎn)單介紹:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3]) # 輸入必須是list

我們重點(diǎn)看dataset讀取tfrecord文件的過(guò)程 (關(guān)于pipeline的相關(guān)信息可以參見(jiàn)博客)

def _parse_function(example_proto): # 解析函數(shù)
    # 創(chuàng)建解析字典
    dics = {  
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量荔泳,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)}
    # 把序列化樣本和解析字典送入函數(shù)里得到解析的樣本
    parsed_example = tf.parse_single_example(example_proto, dics)
    # 對(duì)樣本數(shù)據(jù)類(lèi)型的變換
    # 這里得到的樣本數(shù)據(jù)都是向量,如果寫(xiě)數(shù)據(jù)的時(shí)候?qū)?shù)據(jù)進(jìn)行過(guò)reshape操作虐杯,可以在這里根據(jù)保存的reshape信息玛歌,對(duì)數(shù)據(jù)進(jìn)行還原。
    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)

    # 返回所有feature
    return parsed_example
'''
read_dataset:
arg: tfrecord_path是需要讀取的tfrecord文件路徑擎椰,如tfrecord_path = ['test.tfrecord', 'test2.tfrecord']支子,同上面Queue方式相同,可以同時(shí)讀取多個(gè)文件
'''
def read_dataset(tfrecord_path = tfrecord_path):
    # 第一步:聲明 tf.data.TFRecordDataset
    # The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 第二步:解析樣本數(shù)據(jù)达舒。 tfrecord文件記錄的是序列化的樣本值朋,因此需要對(duì)樣本進(jìn)行解析。
    # 個(gè)人理解:這個(gè)解析的過(guò)程巩搏,是通過(guò)上面_parse_function函數(shù)建立feature的字典昨登。
    # 而dataset.map()是對(duì)dataset的統(tǒng)一操作,map操作可以理解為在每一個(gè)元素上應(yīng)用一個(gè)函數(shù)贯底,所以其輸入是一個(gè)函數(shù)丰辣。
    new_dataset = dataset.map(_parse_function)
    # 創(chuàng)建獲取數(shù)據(jù)集中樣本的迭代器
    iterator = new_dataset.make_one_shot_iterator()
    # 獲得下一個(gè)樣本
    next_element = iterator.get_next()
    return next_element

next_element = read_dataset()
# 建立session,打印輸出禽捆,查看數(shù)據(jù)是否正確
# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init) # 初始化
    coord = tf.train.Coordinator() # 管理線程
    threads = tf.train.start_queue_runners(coord=coord) # 文件名開(kāi)始進(jìn)入文件名隊(duì)列和內(nèi)存
    for i in range(5):
        print('dataset:', sess.run([next_element['sample'],
                                    next_element['label']]))

    coord.request_stop()
    coord.join(threads)

輸出結(jié)果如下:

('dataset:', [array([5., 6., 8., 6., 1.], dtype=float32), 1.0])
('dataset:', [array([6., 4., 8., 1., 8.], dtype=float32), 1.0])
('dataset:', [array([5., 1., 0., 8., 8.], dtype=float32), 0.0])
('dataset:', [array([8., 2., 6., 8., 1.], dtype=float32), 0.0])
('dataset:', [array([8., 3., 5., 3., 6.], dtype=float32), 0.0])

PS: 這里需要特別特別注意的是當(dāng)sample 或者 label不是標(biāo)量笙什,而且長(zhǎng)度事先無(wú)法獲得的時(shí)候怎么創(chuàng)建解析函數(shù)。
此時(shí) tf.FixedLenFeature(shape=(), dtype=tf.float32)的 shape 無(wú)法指定胚想。

舉例來(lái)說(shuō): sample.shape=[2,3], 在寫(xiě)入tfrecord的時(shí)候要對(duì)矩陣reshape琐凭,同時(shí)保存值和shape. 如果已經(jīng)知道sample的長(zhǎng)度,在解析函數(shù)中可以用上面的tf.FixedLenFeature([6,1], dtype=tf.float32)來(lái)解析浊服。一定一定不能用tf.FixedLenFeature([6], dtype=tf.float32)统屈。這樣無(wú)法還原sample的值,而且會(huì)報(bào)出各種奇葩錯(cuò)誤牙躺。如果不知道sample的shape鸿吆,可以用tf.VarLenFeature(dtype=tf.float32)。由于變長(zhǎng)得到的是稀疏矩陣述呐,解析后需要進(jìn)行轉(zhuǎn)為密集矩陣的處理。

parsed_example['sample'] = tf.sparse_tensor_to_dense(parsed_example['sample'])

上面的代碼輸出是每次取一個(gè)樣本蕉毯,按順序一個(gè)樣本一個(gè)樣本出列乓搬。如果需要打亂順序思犁,用.shuffle(buffer_size= ) 來(lái)打亂順序。其中buffer_size設(shè)置成大于數(shù)據(jù)集匯總樣本數(shù)量的值进肯,以保證樣本順序充分打亂激蹲。

打亂樣本出列順序

def read_dataset(tfrecord_path = tfrecord_path):
    # 聲明讀tfrecord文件
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函數(shù)
    new_dataset = dataset.map(_parse_function)
    # 打亂樣本順序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # 數(shù)據(jù)提前進(jìn)入隊(duì)列
    prefetch_dataset = batch_dataset.prefetch(2000) # 會(huì)快很多
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 獲得下一個(gè)樣本
    next_element = iterator.get_next()
    return next_element

輸出的結(jié)果是:

('dataset:', [array([5., 1., 1., 7., 5.], dtype=float32), 0.0])
('dataset:', [array([8., 0., 8., 2., 7.], dtype=float32), 1.0])
('dataset:', [array([6., 5., 9., 1., 2.], dtype=float32), 1.0])
('dataset:', [array([9., 9., 4., 0., 5.], dtype=float32), 0.0])
('dataset:', [array([1., 9., 9., 2., 9.], dtype=float32), 0.0])

再運(yùn)行一次,取到的數(shù)據(jù)也完全不一樣江掩。已打亂順序学辱,單樣本輸出。

批量輸出樣本:.batch( batch_size )

def read_dataset(tfrecord_path = tfrecord_path):
    # 聲明閱讀器
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函數(shù)
    new_dataset = dataset.map(_parse_function)
    # 打亂樣本順序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # batch輸出
    batch_dataset = shuffle_dataset.batch(2)
    # 數(shù)據(jù)提前進(jìn)入隊(duì)列
    prefetch_dataset = batch_dataset.prefetch(2000)
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 獲得下一個(gè)樣本
    next_element = iterator.get_next()
    return next_element

輸出結(jié)果如下:

('dataset:', [array([[1., 4., 6., 2., 5.], [3., 7., 6., 6., 9.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[8., 2., 2., 6., 3.], [7., 5., 3., 0., 3.]], dtype=float32), array([0., 1.], dtype=float32)])
('dataset:', [array([[2., 8., 9., 5., 7.], [0., 5., 1., 5., 5.]], dtype=float32), array([1., 0.], dtype=float32)])
('dataset:', [array([[0., 8., 1., 6., 0.], [7., 3., 8., 8., 1.]], dtype=float32), array([0., 0.], dtype=float32)])
('dataset:', [array([[2., 4., 9., 8., 9.], [3., 5., 9., 6., 0.]], dtype=float32), array([1., 0.], dtype=float32)])

Epoch: 使用.repeat(num_epochs) 來(lái)指定遍歷幾遍數(shù)據(jù)集
關(guān)于Epoch次數(shù)环形,在Queue讀取文件的方式中策泣,是在創(chuàng)建文件名隊(duì)列時(shí)設(shè)定的

filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=3)

根據(jù)博客中的實(shí)驗(yàn)可知,先取出(樣本總數(shù)??num_Epoch)的數(shù)據(jù)抬吟,打亂順序萨咕,按照batch_size,無(wú)放回的取樣火本,保證每個(gè)樣本都被訪問(wèn)num_Epoch次危队。

三種讀取方式的完整代碼

# -*- coding: UTF-8 -*-
# !/usr/bin/python3
# Env: python3.6
import tensorflow as tf
import numpy as np
import os

# path
data_filename = 'data/data_train.txt'  # 生成txt數(shù)據(jù)保存路徑
size = (10000, 5)
tfrecord_path = 'data/test_data.tfrecord'  # tfrecord文件保存路徑

#################### 生成txt數(shù)據(jù) 10000個(gè)樣本。########################
def generate_data(data_filename=data_filename, size=size):
    if not os.path.exists(data_filename):
        np.random.seed(9)
        x_data = np.random.randint(0, 10, size=size)
        y1_data = np.ones((size[0] // 2, 1), int)  # 一半標(biāo)簽是0钙畔,一半是1
        y2_data = np.zeros((size[0] // 2, 1), int)
        y_data = np.append(y1_data, y2_data)
        np.random.shuffle(y_data)

        xy_data = str('')
        for xy_row in range(len(x_data)):
            x_str = str('')
            for xy_col in range(len(x_data[0])):
                if not xy_col == (len(x_data[0]) - 1):
                    x_str = x_str + str(x_data[xy_row, xy_col]) + ' '
                else:
                    x_str = x_str + str(x_data[xy_row, xy_col])
            y_str = str(y_data[xy_row])
            xy_data = xy_data + (x_str + '/' + y_str + '\n')

        # write to txt
        write_txt = open(data_filename, 'w')
        write_txt.write(xy_data)
        write_txt.close()
    return

################  讀取txt文件茫陆,并轉(zhuǎn)為tfrecord文件 ###########################
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename=data_filename, tfrecord_path=tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:讀取TXT數(shù)據(jù),并分割出樣本數(shù)據(jù)和標(biāo)簽
    file = open(txt_filename)
    for data_line in file.readlines():  # 每一行
        data_line = data_line.strip('\n')  # 去掉換行符
        sample = []
        spls = data_line.split('/', 1)[0]  # 樣本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]  # 標(biāo)簽
        label = int(label)
        print('sample:', sample, 'labels:', label)

        # 第三步: 建立feature字典擎析,tf.train.Feature()對(duì)單一數(shù)據(jù)編碼成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解為將內(nèi)層多個(gè)feature的字典數(shù)據(jù)再編碼簿盅,集成為features
        features = tf.train.Features(feature=feature)
        # 第五步:將features數(shù)據(jù)封裝成特定的協(xié)議格式
        example = tf.train.Example(features=features)
        # 第六步:將example數(shù)據(jù)序列化為字符串
        Serialized = example.SerializeToString()
        # 第七步:將序列化的字符串?dāng)?shù)據(jù)寫(xiě)入?yún)f(xié)議緩沖區(qū)
        writer.write(Serialized)
    # 記得關(guān)閉writer和open file的操作
    writer.close()
    file.close()
    return


###############   用Queue方式中的tf.parse_single_example解析tfrecord  #########################

# 第一步: 建立文件名隊(duì)列
filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=30)


def read_single(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立閱讀器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步:根據(jù)寫(xiě)入時(shí)的格式建立相對(duì)應(yīng)的讀取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第四步: 用tf.parse_single_example()解析單個(gè)EXAMPLE PROTO
    Features = tf.parse_single_example(serialized_example, features)

    # 第五步:對(duì)數(shù)據(jù)進(jìn)行后處理
    sample = tf.cast(Features['sample'], tf.float32)
    label = tf.cast(Features['label'], tf.float32)

    # 第六步:生成Batch數(shù)據(jù) generate batch
    if shuffle_batch:  # 打亂數(shù)據(jù)順序叔锐,隨機(jī)取樣
        sample_single, label_single = tf.train.shuffle_batch([sample, label],
                                                             batch_size=2,
                                                             capacity=10000,
                                                             min_after_dequeue=1000,
                                                             num_threads=1,
                                                             enqueue_many=if_enq_many)  # 主要是為了評(píng)估enqueue_many的作用
    else:  # # 如果不打亂順序則用tf.train.batch(), 輸出隊(duì)列按順序組成Batch輸出
        '''
        example_list = [[sample,label] for _ in range(2)]  # Reader設(shè)置為2

        sample_single, label_single = tf.train.batch_join(example_list, batch_size=1)
        '''

        sample_single, label_single = tf.train.batch([sample, label],
                                                     batch_size=1,
                                                     capacity=10000,
                                                     num_threads=1,
                                                     enqueue_many=if_enq_many)

    return sample_single, label_single


#############   用Queue方式中的tf.parse_example解析tfrecord  ##################################

def read_parse(filename_queue, shuffle_batch, if_enq_many):
    # 第二步: 建立閱讀器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # 第三步: 設(shè)置shuffle_batch
    if shuffle_batch:
        batch = tf.train.shuffle_batch([serialized_example],
                                       batch_size=3,
                                       capacity=10000,
                                       min_after_dequeue=1000,
                                       num_threads=1,
                                       enqueue_many=if_enq_many)  # 主要是為了評(píng)估enqueue_many的作用

    else:
        batch = tf.train.batch([serialized_example],
                               batch_size=3,
                               capacity=10000,
                               num_threads=1,
                               enqueue_many=if_enq_many)
        # 第四步:根據(jù)寫(xiě)入時(shí)的格式建立相對(duì)應(yīng)的讀取features
    features = {
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量挪鹏,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)
    }
    # 第五步: 用tf.parse_example()解析多個(gè)EXAMPLE PROTO
    Features = tf.parse_example(batch, features)

    # 第六步:對(duì)數(shù)據(jù)進(jìn)行后處理
    samples_parse = tf.cast(Features['sample'], tf.float32)
    labels_parse = tf.cast(Features['label'], tf.float32)
    return samples_parse, labels_parse


############### 用Dataset讀取tfrecord文件  ###############################################

# 定義解析函數(shù)
def _parse_function(example_proto):
    dics = {  # 這里沒(méi)用default_value,隨后的都是None
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量愉烙,一定要在這里說(shuō)明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)}
    # 把序列化樣本和解析字典送入函數(shù)里得到解析的樣本
    parsed_example = tf.parse_single_example(example_proto, dics)

    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)
    # 返回所有feature
    return parsed_example


def read_dataset(tfrecord_path=tfrecord_path):
    # 聲明閱讀器
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函數(shù)讨盒,其中num_parallel_calls指定并行線程數(shù)
    new_dataset = dataset.map(_parse_function, num_parallel_calls=4)
    # 打亂樣本順序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # 設(shè)置epoch次數(shù)為10,這里需要注意的是目前看來(lái)只支持先shuffle再repeat的方式
    repeat_dataset = shuffle_dataset.repeat(10) 
    # batch輸出
    batch_dataset = repeat_dataset.batch(2)
    # 數(shù)據(jù)提前進(jìn)入隊(duì)列
    prefetch_dataset = batch_dataset.prefetch(2000)
    # 建立迭代器
    iterator = prefetch_dataset.make_one_shot_iterator()
    # 獲得下一個(gè)樣本
    next_element = iterator.get_next()
    return next_element


##################   建立graph ####################################

# 生成數(shù)據(jù)
# generate_data()
# 讀取數(shù)據(jù)轉(zhuǎn)為tfrecord文件
# txt_to_tfrecord()
# Queue + tf.parse_single_example()讀取tfrecord文件
x1_samples, y1_labels = read_single(filename_queue, shuffle_batch=True, if_enq_many=False)
# Queue + tf.parse_example()讀取tfrecord文件
x2_samples, y2_labels = read_parse(filename_queue, shuffle_batch=True, if_enq_many=False)
# Dataset讀取數(shù)據(jù)
next_element = read_dataset()

# 定義初始化變量范圍
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)  # 初始化
    # 如果tf.train.string_input_producer([tfrecord_path], num_epochs=30)中num_epochs不為空的化步责,必須要初始化local變量
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()  # 管理線程
    threads = tf.train.start_queue_runners(coord=coord)  # 文件名開(kāi)始進(jìn)入文件名隊(duì)列和內(nèi)存
    for i in range(1):
        # Queue + tf.parse_single_example()讀取tfrecord文件
        X1, Y1 = sess.run([x1_samples, y1_labels])
        print('X1: ', X1, 'Y1: ', Y1)
        # Queue + tf.parse_example()讀取tfrecord文件
        X2, Y2 = sess.run([x2_samples, y2_labels])
        print('X2: ', X2, 'Y2: ', Y2)
        # Dataset讀取數(shù)據(jù)
        print('dataset:', sess.run([next_element['sample'],
                                    next_element['label']]))
        #這里需要注意返顺,每run一次,迭代器會(huì)取下一個(gè)樣本蔓肯。
        # 如果是 a= sess.run(next_element['sample'])
        #             b = sess.run(next_element['label'])遂鹊,
        # 則a樣本對(duì)應(yīng)的標(biāo)簽值不是b,b是下一個(gè)樣本對(duì)應(yīng)的標(biāo)簽值蔗包。

    coord.request_stop()
    coord.join(threads)

另外秉扑,關(guān)于dataset加速的用法,可以參見(jiàn)官網(wǎng)說(shuō)明

Dataset+TFRecord讀取變長(zhǎng)數(shù)據(jù)

使用dataset中的padded_batch方法來(lái)進(jìn)行

padded_batch(
    batch_size,
    padded_shapes,
    padding_values=None    #默認(rèn)使用各類(lèi)型數(shù)據(jù)的默認(rèn)值,一般使用時(shí)可忽略該項(xiàng)
)

參數(shù)padded_shapes 指明每條記錄中各成員要pad成的形狀舟陆,成員若是scalar误澳,則用[ ],若是list秦躯,則用[mx_length]忆谓,若是array,則用[d1,...,dn]踱承,假如各成員的順序是scalar數(shù)據(jù)倡缠、list數(shù)據(jù)、array數(shù)據(jù)茎活,則padded_shapes=([], [mx_length], [d1,...,dn])昙沦;
例如tfrecord文件中的key是fea, e.g.fea.shape=[568, 366], 二維,長(zhǎng)度變化妙色。fea_shape=[568,366]桅滋,一維, label=[1, 0, 2,0,3,0]一維,長(zhǎng)度變化身辨。
再讀取變長(zhǎng)數(shù)據(jù)的時(shí)候映射函數(shù)應(yīng)為:

def _parse_function(example_proto):
    dics = {
        'fea': tf.VarLenFeature(dtype=tf.float32),
        'fea_shape': tf.FixedLenFeature(shape=(2,), dtype=tf.int64),
        'label': tf.VarLenFeature(dtype=tf.float32)}

    parsed_example = tf.parse_single_example(example_proto, dics)
    parsed_example['fea'] = tf.sparse_tensor_to_dense(parsed_example['fea'])
    parsed_example['label'] = tf.sparse_tensor_to_dense(parsed_example['label'])
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.int32)
    parsed_example['fea'] = tf.reshape(parsed_example['fea'], parsed_example['fea_shape'])
    return parsed_example

利用tf.VarLenFeature()代替tf.FixedLenFeature()丐谋,在后處理中要注意用tf.sparse_tensor_to_dense()將讀取的變長(zhǎng)數(shù)據(jù)轉(zhuǎn)為稠密矩陣。

def dataset():
    tf_lst = get_tf_list(tf_file_lst)
    dataset = tf.data.TFRecordDataset(tf_lst)
    new_dataset = dataset.map(_parse_function)
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    repeat_dataset = shuffle_dataset.repeat(10)
    prefetch_dataset = repeat_dataset.prefetch(2000)
    batch_dataset = prefetch_dataset.padded_batch(2, padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]})
    iterator = batch_dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    return next_element

這里padded_shapes={'fea': [None, None], 'fea_shape': [None], 'label': [None]}
如果報(bào)錯(cuò) All elements in a batch must have the same rank as the padded shape for component1: expected rank 2 but got element with rank 1請(qǐng)仔細(xì)查看padded_shapes中設(shè)置的維度是否正確煌珊。如果padded_shapes={'fea': [None, None], 'fea_shape': [None, None], 'label': [None]}即fea_shape本來(lái)的rank應(yīng)該是1号俐,但是在pad的時(shí)候設(shè)置了2,所以報(bào)錯(cuò)定庵。

如果報(bào)錯(cuò)The two structures don't have the same sequence type. Input structure has type <class 'tuple'>, while shallow structure has type <class 'dict'>.吏饿,則可能是padded_shapes定義的格式不對(duì),如定義成了padded_shapes=([None, None],[None],[None])蔬浙,請(qǐng)按照字典格式定義pad的方式猪落。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市畴博,隨后出現(xiàn)的幾起案子笨忌,更是在濱河造成了極大的恐慌,老刑警劉巖俱病,帶你破解...
    沈念sama閱讀 212,383評(píng)論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件官疲,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡亮隙,警方通過(guò)查閱死者的電腦和手機(jī)途凫,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,522評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門(mén),熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)溢吻,“玉大人维费,你說(shuō)我怎么就攤上這事。” “怎么了犀盟?”我有些...
    開(kāi)封第一講書(shū)人閱讀 157,852評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵噪漾,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我且蓬,道長(zhǎng),這世上最難降的妖魔是什么题翰? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,621評(píng)論 1 284
  • 正文 為了忘掉前任恶阴,我火速辦了婚禮,結(jié)果婚禮上豹障,老公的妹妹穿的比我還像新娘冯事。我一直安慰自己,他們只是感情好血公,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,741評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布昵仅。 她就那樣靜靜地躺著,像睡著了一般累魔。 火紅的嫁衣襯著肌膚如雪摔笤。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 49,929評(píng)論 1 290
  • 那天垦写,我揣著相機(jī)與錄音吕世,去河邊找鬼傅瞻。 笑死斋荞,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的恤批。 我是一名探鬼主播分蓖,決...
    沈念sama閱讀 39,076評(píng)論 3 410
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼尔艇,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了么鹤?” 一聲冷哼從身側(cè)響起终娃,我...
    開(kāi)封第一講書(shū)人閱讀 37,803評(píng)論 0 268
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎午磁,沒(méi)想到半個(gè)月后尝抖,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,265評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡迅皇,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,582評(píng)論 2 327
  • 正文 我和宋清朗相戀三年昧辽,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片登颓。...
    茶點(diǎn)故事閱讀 38,716評(píng)論 1 341
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡搅荞,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情咕痛,我是刑警寧澤痢甘,帶...
    沈念sama閱讀 34,395評(píng)論 4 333
  • 正文 年R本政府宣布,位于F島的核電站茉贡,受9級(jí)特大地震影響塞栅,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜腔丧,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,039評(píng)論 3 316
  • 文/蒙蒙 一放椰、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧愉粤,春花似錦砾医、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,798評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至影暴,卻和暖如春错邦,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背坤检。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,027評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工兴猩, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人早歇。 一個(gè)月前我還...
    沈念sama閱讀 46,488評(píng)論 2 361
  • 正文 我出身青樓倾芝,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親箭跳。 傳聞我的和親對(duì)象是個(gè)殘疾皇子晨另,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,612評(píng)論 2 350

推薦閱讀更多精彩內(nèi)容

  • 讀取機(jī)制 Tensorflow中數(shù)據(jù)讀取機(jī)制可見(jiàn)下圖 關(guān)于這張圖,這篇文章已經(jīng)介紹的非常詳細(xì)谱姓,簡(jiǎn)而言之借尿,Tenso...
    cheerss閱讀 5,379評(píng)論 6 4
  • 文件隊(duì)列 參考了這篇博客的內(nèi)容為了實(shí)現(xiàn)數(shù)據(jù)讀入和數(shù)據(jù)處理的管線化,tensorflow使用文件隊(duì)列來(lái)獨(dú)立處理數(shù)據(jù)讀...
  • 在學(xué)習(xí)tensorflow的過(guò)程中屉来,有很多小伙伴反映讀取數(shù)據(jù)這一塊很難理解路翻。確實(shí)這一塊官方的教程比較簡(jiǎn)略,網(wǎng)上也找...
    yalesaleng閱讀 281評(píng)論 0 0
  • 引子:我們?cè)谌粘I钪薪?jīng)常會(huì)有許多用手機(jī)拍照的機(jī)會(huì)茄靠,但由于設(shè)備等各種原因茂契,這樣拍攝出的照片通常都不會(huì)太好,當(dāng)做生活...
    極浦閱讀 1,788評(píng)論 0 20
  • 顏淵第十二(主要講孔子教育弟子如何實(shí)行仁德慨绳,如何為政和處世) 每日《論語(yǔ)》編輯:曹友寶 【原文】 12.17季康子...
    曹友寶閱讀 302評(píng)論 0 0