[0.2] Tensorflow踩坑記之頭疼的tf.data

今天嘗試總結(jié)一下 tf.data 這個API的一些用法吧。之所以會用到這個API遏暴,是因為需要處理的數(shù)據(jù)量很大见咒,而且數(shù)據(jù)均是分布式的存儲在多臺服務(wù)器上,所以沒有辦法采用傳統(tǒng)的喂數(shù)據(jù)方式低匙,而是運用了 tf.data 對數(shù)據(jù)進(jìn)行了相應(yīng)的預(yù)處理,并且最近正趕上總結(jié)需要碳锈,嘗試寫一下關(guān)于 tf.data 的一些用法顽冶,有錯誤的地方一定告訴我哈。

Tensorflow的數(shù)據(jù)讀取

先來看一下Tensorflow的數(shù)據(jù)讀取機制吧

這一篇文章對于 tensorflow的數(shù)據(jù)讀取機制 講解得很不錯售碳,大噶可以先看一下强重,有一個了解。

Dataset API是怎么用的呢

雖然上面的資料關(guān)于 tf.data 講解得都很好贸人,但是我沒有找到一個很完整滴運用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 的例子间景,所以才想嘗試寫一寫這篇總結(jié)。

MNIST的經(jīng)典例子

本篇博客結(jié)合 mnist 的經(jīng)典例子艺智,針對不同的源數(shù)據(jù):csv數(shù)據(jù)和tfrecord數(shù)據(jù)倘要,分別運用 tf.data.TextLineDataset()tf.data.TFRecordDataset() 創(chuàng)建不同的 Dataset 并運用四種不同的 Iterator ,分別是 單次力惯,可初始化碗誉,可重新初始化召嘶,以及可饋送迭代器 的方式實現(xiàn)對源數(shù)據(jù)的預(yù)處理工作父晶。

我將相關(guān)的資料放在了瀾子的Github 上,歡迎互粉哇(星星眼)弄跌。其中包括了所需的 后綴名為csv和tfrecords的源數(shù)據(jù) (data的文件夾)甲喝,以及在 jupyter notebook實現(xiàn)的具體代碼 (tf_dataset_learn.ipynb)。

如果有需要的同學(xué)可以直接
git clone https://github.com/lanhongvp/tensorflow_dataset_learn.git
然后用 jupyter 跑一跑看看輸出铛只,這樣可以有一個比較直觀的認(rèn)識埠胖。關(guān)于 Git和Github 的使用糠溜,大噶可以看我VSCODE_GIT這一篇博客啦。接下來直撤,針對MNIST例子做一個簡單的說明吧非竿。

tf.data.TFRecordDataset() & make_one_shot_iterator()

tf.data.TFRecordDataset() 輸入?yún)?shù)直接是后綴名為tfrecords的文件路徑,正因如此谋竖,即可解決數(shù)據(jù)量過大红柱,導(dǎo)致無法單機訓(xùn)練的問題。本篇博客中蓖乘,文件路徑即為/Users/honglan/Desktop/train_output.tfrecords锤悄,此處是我自己電腦上的路徑,大家可以 根據(jù)自己的需要修改為對應(yīng)的文件路徑嘉抒。
make_one_shot_iterator() 即為單次迭代器零聚,是最簡單的迭代器形式,僅支持對數(shù)據(jù)集進(jìn)行一次迭代些侍,不需要顯式初始化隶症。
配合 MNIST數(shù)據(jù)集以及tf.data.TFRecordDataset(),實現(xiàn)代碼如下岗宣。

# Validate tf.data.TFRecordDataset() using make_one_shot_iterator()
import tensorflow as tf
import numpy as np

num_epochs = 2
num_class = 10
sess = tf.Session()

# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)

    # Parse the string into an array of pixels corresponding to the image
    images = tf.decode_raw(parsed["image_raw"],tf.uint8)
    images = tf.reshape(images,[28,28,1])
    labels = tf.cast(parsed['label'], tf.int32)
    labels = tf.one_hot(labels,num_class)
    pixels = tf.cast(parsed['pixels'], tf.int32)
    print("IMAGES",images)
    print("LABELS",labels)
    
    return {"image_raw": images}, labels


filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
# replace the filenames with your own path
dataset = tf.data.TFRecordDataset(filenames)
print("DATASET",dataset)

# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
print("DATASET_1",dataset)
dataset = dataset.shuffle(buffer_size=10000)
print("DATASET_2",dataset)
dataset = dataset.batch(32)
print("DATASET_3",dataset)
dataset = dataset.repeat(num_epochs)
print("DATASET_4",dataset)
iterator = dataset.make_one_shot_iterator()

# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
features, labels = iterator.get_next()

print("FEATURES",features)
print("LABELS",labels)
print("SESS_RUN_LABELS \n",sess.run(labels))

tf.data.TFRecordDataset() & Initializable iterator

make_initializable_iterator() 為可初始化迭代器沿腰,運用此迭代器首先需要先運行顯式 iterator.initializer 操作,然后才能使用狈定。并且颂龙,可運用 可初始化迭代器實現(xiàn)訓(xùn)練集和驗證集的切換
配合 MNIST數(shù)據(jù)集 實現(xiàn)代碼如下纽什。

# Validate tf.data.TFRecordDataset() using make_initializable_iterator()
# In order to switch between train and validation data
num_epochs = 2
num_class = 10

def parser(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
        "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
        "label": tf.FixedLenFeature((), tf.int64,
                                    default_value=tf.zeros([], dtype=tf.int64)),
    }
    parsed = tf.parse_single_example(record, keys_to_features)
    
    # Parse the string into an array of pixels corresponding to the image
    images = tf.decode_raw(parsed["image_raw"],tf.uint8)
    images = tf.reshape(images,[28,28,1])
    labels = tf.cast(parsed['label'], tf.int32)
    labels = tf.one_hot(labels,10)
    pixels = tf.cast(parsed['pixels'], tf.int32)
    print("IMAGES",images)
    print("LABELS",labels)
    
    return {"image_raw": images}, labels


filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser) # Parse the record into tensors
# print("DATASET",dataset)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
print("DATASET",dataset)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
print("ITERATOR",iterator)
print("FEATURES",features)
print("LABELS",labels)


# Initialize `iterator` with training data.
training_filenames = ["/Users/honglan/Desktop/train_output.tfrecords"] 
# replace the filenames with your own path
sess.run(iterator.initializer,feed_dict={filenames: training_filenames})
print("TRAIN\n",sess.run(labels))
# print(sess.run(features))

# Initialize `iterator` with validation data.
validation_filenames = ["/Users/honglan/Desktop/val_output.tfrecords"] 
# replace the filenames with your own path
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
print("VAL\n",sess.run(labels))

tf.data.TextLineDataset() & Reinitializable iterator

tf.data.TextLineDataset()措嵌,輸入?yún)?shù)可以是后綴名為csv或者是txt的源數(shù)據(jù)的文件路徑。
此處用的迭代器是 Reinitializable iterator芦缰,即為可重新初始化迭代器企巢。官方定義如下。配合 MNIST數(shù)據(jù)集 實現(xiàn)代碼見第二部分让蕾。

可重新初始化迭代器可以通過多個不同的 Dataset 對象進(jìn)行初始化浪规。例如,您可能有一個訓(xùn)練輸入管道探孝,它會對輸入圖片進(jìn)行隨機擾動來改善泛化笋婿;還有一個驗證輸入管道,它會評估對未修改數(shù)據(jù)的預(yù)測顿颅。這些管道通常會使用不同的 Dataset 對象缸濒,這些對象具有相同的結(jié)構(gòu)(即每個組件具有相同類型和兼容形狀)。

# validate tf.data.TextLineDataset() using Reinitializable iterator
# In order to switch between train and validation data

def decode_line(line):
    # Decode the line to tensor
    record_defaults = [[1.0] for col in range(785)]
    items = tf.decode_csv(line, record_defaults)
    features = items[1:785]
    label = items[0]

    features = tf.cast(features, tf.float32)
    features = tf.reshape(features,[28,28,1])
    label = tf.cast(label, tf.int64)
    label = tf.one_hot(label,num_class)
    return features,label


def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
    """create dataset for train and validation dataset"""
    dataset = tf.data.TextLineDataset(filename).skip(1)
    if n_repeats > 0:
        dataset = dataset.repeat(n_repeats)         # for train
    # dataset = dataset.map(decode_line).map(normalize)  
    dataset = dataset.map(decode_line)    
    # decode and normalize
    if is_shuffle:
        dataset = dataset.shuffle(10000)            # shuffle
    dataset = dataset.batch(batch_size)
    return dataset


training_filenames = ["/Users/honglan/Desktop/train.csv"] 
# replace the filenames with your own path
validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
# replace the filenames with your own path

# Create different datasets
training_dataset = create_dataset(training_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # train_filename
validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # val_filename

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
features, labels = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Using reinitializable iterator to alternate between training and validation.
sess.run(training_init_op)
print("TRAIN\n",sess.run(labels))
# print(sess.run(features))

# Reinitialize `iterator` with validation data.
sess.run(validation_init_op)
print("VAL\n",sess.run(labels))

tf.data.TextLineDataset() & Feedable iterator.

數(shù)據(jù)集讀取方式同上一部分一樣,運用tf.data.TextLineDataset()此處運用的迭代器是 可饋送迭代器庇配,其可以與 tf.placeholder 一起使用斩跌,通過熟悉的 feed_dict 機制選擇每次調(diào)用 tf.Session.run 時所使用的 Iterator。并使用 tf.data.Iterator.from_string_handle定義一個可讓在兩個數(shù)據(jù)集之間切換的可饋送迭代器捞慌,結(jié)合 MNIST數(shù)據(jù)集 的代碼如下

# validate tf.data.TextLineDataset() using two different iterator
# In order to switch between train and validation data

def decode_line(line):
    # Decode the line to tensor
    record_defaults = [[1.0] for col in range(785)]
    items = tf.decode_csv(line, record_defaults)
    features = items[1:785]
    label = items[0]

    features = tf.cast(features, tf.float32)
    features = tf.reshape(features,[28,28])
    label = tf.cast(label, tf.int64)
    label = tf.one_hot(label,num_class)
    return features,label


def create_dataset(filename, batch_size=32, is_shuffle=False, n_repeats=0):
    """create dataset for train and validation dataset"""
    dataset = tf.data.TextLineDataset(filename).skip(1)
    if n_repeats > 0:
        dataset = dataset.repeat(n_repeats)         # for train
    # dataset = dataset.map(decode_line).map(normalize)  
    dataset = dataset.map(decode_line)    
    # decode and normalize
    if is_shuffle:
        dataset = dataset.shuffle(10000)            # shuffle
    dataset = dataset.batch(batch_size)
    return dataset


training_filenames = ["/Users/honglan/Desktop/train.csv"] 
# replace the filenames with your own path
validation_filenames = ["/Users/honglan/Desktop/val.csv"] 
# replace the filenames with your own path

# Create different datasets
training_dataset = create_dataset(training_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # train_filename
validation_dataset = create_dataset(validation_filenames, batch_size=32, \
                                  is_shuffle=True, n_repeats=num_epochs) # val_filename

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
features, labels = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Using different handle to alternate between training and validation.
print("TRAIN\n",sess.run(labels, feed_dict={handle: training_handle}))
# print(sess.run(features))

# Initialize `iterator` with validation data.
sess.run(validation_iterator.initializer)
print("VAL\n",sess.run(labels, feed_dict={handle: validation_handle}))

小結(jié)

  • 運用tfrecords處理數(shù)據(jù)的速度明顯加快
  • 可以根據(jù)自身需要選擇不同的iterator方式對源數(shù)據(jù)進(jìn)行預(yù)處理
  • 單機訓(xùn)練時也可以采用 tf.data中API的相應(yīng)處理方式
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末耀鸦,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子啸澡,更是在濱河造成了極大的恐慌揭糕,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,372評論 6 498
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件锻霎,死亡現(xiàn)場離奇詭異著角,居然都是意外死亡,警方通過查閱死者的電腦和手機旋恼,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,368評論 3 392
  • 文/潘曉璐 我一進(jìn)店門吏口,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人冰更,你說我怎么就攤上這事产徊。” “怎么了蜀细?”我有些...
    開封第一講書人閱讀 162,415評論 0 353
  • 文/不壞的土叔 我叫張陵舟铜,是天一觀的道長。 經(jīng)常有香客問我奠衔,道長谆刨,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,157評論 1 292
  • 正文 為了忘掉前任归斤,我火速辦了婚禮痊夭,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘脏里。我一直安慰自己她我,他們只是感情好,可當(dāng)我...
    茶點故事閱讀 67,171評論 6 388
  • 文/花漫 我一把揭開白布迫横。 她就那樣靜靜地躺著番舆,像睡著了一般。 火紅的嫁衣襯著肌膚如雪矾踱。 梳的紋絲不亂的頭發(fā)上恨狈,一...
    開封第一講書人閱讀 51,125評論 1 297
  • 那天,我揣著相機與錄音介返,去河邊找鬼拴事。 笑死,一個胖子當(dāng)著我的面吹牛圣蝎,可吹牛的內(nèi)容都是我干的刃宵。 我是一名探鬼主播,決...
    沈念sama閱讀 40,028評論 3 417
  • 文/蒼蘭香墨 我猛地睜開眼徘公,長吁一口氣:“原來是場噩夢啊……” “哼牲证!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起关面,我...
    開封第一講書人閱讀 38,887評論 0 274
  • 序言:老撾萬榮一對情侶失蹤坦袍,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后等太,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體捂齐,經(jīng)...
    沈念sama閱讀 45,310評論 1 310
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,533評論 2 332
  • 正文 我和宋清朗相戀三年缩抡,在試婚紗的時候發(fā)現(xiàn)自己被綠了奠宜。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 39,690評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡瞻想,死狀恐怖压真,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情蘑险,我是刑警寧澤滴肿,帶...
    沈念sama閱讀 35,411評論 5 343
  • 正文 年R本政府宣布,位于F島的核電站佃迄,受9級特大地震影響泼差,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜呵俏,卻給世界環(huán)境...
    茶點故事閱讀 41,004評論 3 325
  • 文/蒙蒙 一拴驮、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧柴信,春花似錦套啤、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,659評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至绪氛,卻和暖如春唆鸡,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背枣察。 一陣腳步聲響...
    開封第一講書人閱讀 32,812評論 1 268
  • 我被黑心中介騙來泰國打工争占, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留燃逻,地道東北人。 一個月前我還...
    沈念sama閱讀 47,693評論 2 368
  • 正文 我出身青樓臂痕,卻偏偏與公主長得像伯襟,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子握童,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,577評論 2 353

推薦閱讀更多精彩內(nèi)容