tf.data
模塊包含一系列類提陶,用于加載數(shù)據(jù)儡蔓、操作數(shù)據(jù)并通過管道將數(shù)據(jù)傳送給模型。本文主要介紹之前提到的iris_data.py
中的train_input_fn
函數(shù)。
0 train_input_fn定義
def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
# Build the Iterator, and return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
接下來對這個函數(shù)進(jìn)行簡單介紹料扰。
1. Arguments
該函數(shù)需要如下三個參數(shù):
-
features
:包含有原始輸入特征的{"feature_name": array}字典或者DataFrame
-
labels
:包含每個樣本標(biāo)簽的數(shù)組 -
batch_size
:表示所需批次大小的整數(shù)
2. Slices
最簡單的情況,可以使用tf.data.Dataset.from_tensor_slices
接收一個數(shù)組焙蹭,并創(chuàng)建該數(shù)組的slices表示的tf.data.Dataset
晒杈,這個方法根據(jù)數(shù)組的第一維創(chuàng)建對應(yīng)的slices。比如mnist訓(xùn)練數(shù)據(jù)集的形狀是(60000, 28, 28)
孔厉,通過from_tensor_slices
返回的Dataset
對象包含有60000個slices拯钻,其中每一個都是28*28的圖像,具體代碼如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train
mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print mnist_ds
上述代碼打印出如下內(nèi)容撰豺,展示了數(shù)據(jù)集中slices的shapes以及types粪般。需要注意的是,我們并不知道Dataset中的包含有多少個slices污桦。
<TensorSliceDataset shapes: (28, 28), types: tf.uint8>
上述數(shù)據(jù)集表示了一個簡單的數(shù)組亩歹,但是實(shí)際上Dataset可以表示更復(fù)雜的情況。如下所示凡橱,如果feature
是一個標(biāo)準(zhǔn)的python字典捆憎,那么創(chuàng)建的Dataset
的shapes
和types
也將會被保留:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print dataset
<TensorSliceDataset
shapes: {
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
types: {
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64}
>
同樣的在之前提到的train_input_fn
中,我們傳遞的是一個(dict(features), labels)
這樣的數(shù)據(jù)機(jī)梭纹,那么創(chuàng)建的Dataset
同樣會保留其結(jié)構(gòu)信息躲惰,如下所示:
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDataset
shapes: (
{
SepalLength: (), PetalWidth: (),
PetalLength: (), SepalWidth: ()},
()),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
3 manipulation
當(dāng)前創(chuàng)建的Dataset
會按固定順序迭代,并且一次僅生成一個元素变抽。在它被用于訓(xùn)練之前础拨,還需要其他的操作。tf.data.Dataset
類提供了一系列方法來處理數(shù)據(jù)并生成后續(xù)訓(xùn)練可用的數(shù)據(jù)绍载。如下所示:
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
shuffle
方法使用一個固定的緩沖區(qū)诡宗,將Dataset
中的slices進(jìn)行隨即化處理。這里將buffer_size
設(shè)置的比Dataset
中的slices數(shù)要大一些击儡,可以保證數(shù)據(jù)可以完全被隨機(jī)化處理(iris數(shù)據(jù)一共有150條樣本)
repeat
方法會在調(diào)用結(jié)束后重啟Dataset
塔沃,保證后續(xù)訓(xùn)練時這個數(shù)據(jù)集可以使用。
batch
方法會收集樣本阳谍,并將它們放在一起以創(chuàng)建批次(有時候使用樣本進(jìn)行訓(xùn)練是按照batch進(jìn)行訓(xùn)練的蛀柴,例如mini batch mini batch gradient descent優(yōu)化算法),這為Dataset
的shapes增加了一個維度矫夯。如下代碼對之前的mnist Dataset
使用batch
方法鸽疾,生成100個批次的數(shù)據(jù),每一個批次都是包含有多個slices训貌,其中每個slices都是一個28*28的圖像數(shù)據(jù)制肮。
print mnist_ds.batch(100)
<BatchDataset
shapes: (?, 28, 28),
types: tf.uint8>
需要注意的是冒窍,Dataset
中第一維shapes是不確定的,因?yàn)樽詈笠粋€批次所具有的slices數(shù)量是不確定的豺鼻。
在train_input_fn
中综液,經(jīng)過批處理之后,Dataset
的結(jié)構(gòu)如下所示:
print dataset
<TensorSliceDataset
shapes: (
{
SepalLength: (?,), PetalWidth: (?,),
PetalLength: (?,), SepalWidth: (?,)},
(?,)),
types: (
{
SepalLength: tf.float64, PetalWidth: tf.float64,
PetalLength: tf.float64, SepalWidth: tf.float64},
tf.int64)>
4 return
在train_input_fn
中返回的Dataset
包含的是(feature_dict, labels)
對儒飒。在后續(xù)train
意乓、evaluate
使用的都是這種結(jié)構(gòu),但是在predict
中labels
被省略了约素。