tf.estimator Quickstart

setosa--versicolor--virginica

文章類型:翻譯
文章內(nèi)容:

  • 1斤彼、完整的神經(jīng)網(wǎng)絡(luò)源代碼
  • 2、加載 Iris CSV數(shù)據(jù)到Tensorflow
  • 3骄恶、構(gòu)建深度神經(jīng)網(wǎng)絡(luò)分類器
  • 4、數(shù)據(jù)輸入管道
  • 5匕垫、利用 Iris data擬合神經(jīng)網(wǎng)絡(luò)分類器
  • 6僧鲁、評(píng)估神經(jīng)網(wǎng)絡(luò)分類器的準(zhǔn)確性
  • 7、對(duì)新樣本進(jìn)行分類
  • 8年缎、其他資源

前沿


Tensorflow的高級(jí)機(jī)器學(xué)習(xí)API(tf.estimator)使得配置悔捶、訓(xùn)練和評(píng)估各種機(jī)器學(xué)習(xí)模型更加的簡(jiǎn)單。在本教程中单芜,你將使用tf.estimator構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)分類器,用于訓(xùn)練 Iris data犁柜,構(gòu)建一個(gè)預(yù)測(cè) Iris flower的模型洲鸠,并預(yù)測(cè)新的Iris flower。你將編寫代碼完成以下五個(gè)步驟:

  • 1馋缅、將包含訓(xùn)練集和測(cè)試集的CSV數(shù)據(jù)加載到Tensorflow Dataset
    2扒腕、構(gòu)建神經(jīng)網(wǎng)絡(luò)模型分類器
    3、利用訓(xùn)練數(shù)據(jù)訓(xùn)練模型
    4萤悴、評(píng)估模型的精度
    5瘾腰、分類新的樣本
    

注意:在學(xué)習(xí)本教程之前,請(qǐng)?jiān)谀臋C(jī)器上安裝Tensorflow覆履。

一蹋盆、完整的神經(jīng)網(wǎng)絡(luò)源代碼


以下是神經(jīng)網(wǎng)絡(luò)分類器的完整代碼:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import urllib

import numpy as np
import tensorflow as tf

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

def main():
  # If the training and test sets aren't stored locally, download them.
  if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING, "w") as f:
      f.write(raw)

  if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST, "w") as f:
      f.write(raw)

  # Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[10, 20, 10],
                                          n_classes=3,
                                          model_dir="/tmp/iris_model")
  # Define the training inputs
  train_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": np.array(training_set.data)},
      y=np.array(training_set.target),
      num_epochs=None,
      shuffle=True)

  # Train model.
  classifier.train(input_fn=train_input_fn, steps=2000)

  # Define the test inputs
  test_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": np.array(test_set.data)},
      y=np.array(test_set.target),
      num_epochs=1,
      shuffle=False)

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]

  print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

  # Classify two new flower samples.
  new_samples = np.array(
      [[6.4, 3.2, 4.5, 1.5],
       [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
  predict_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={"x": new_samples},
      num_epochs=1,
      shuffle=False)

  predictions = list(classifier.predict(input_fn=predict_input_fn))
  predicted_classes = [p["classes"] for p in predictions]

  print("New Samples, Class Predictions:    {}\n".format(predicted_classes))

if __name__ == "__main__":
    main()

二、加載Iris CSV數(shù)據(jù)到Tensorflow


在這個(gè)教程中硝全,Iris 數(shù)據(jù)被隨機(jī)分成兩部分:

部分?jǐn)?shù)據(jù)集

開(kāi)始伟众,要加載必要的模塊析藕,然后定義在哪里去下載并保存數(shù)據(jù)集

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import urllib

import tensorflow as tf
import numpy as np

IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

然后,如果訓(xùn)練集和測(cè)試集并不存在本地凳厢,那么就下載它們账胧。

if not os.path.exists(IRIS_TRAINING):
  raw = urllib.urlopen(IRIS_TRAINING_URL).read()
  with open(IRIS_TRAINING,'w') as f:
    f.write(raw)

if not os.path.exists(IRIS_TEST):
  raw = urllib.urlopen(IRIS_TEST_URL).read()
  with open(IRIS_TEST,'w') as f:
    f.write(raw)

接下來(lái)竞慢,使用模塊learn.datasets.base中的load_csv_with_header()加載訓(xùn)練集和測(cè)試集到 Datasets,該模塊包含三個(gè)必須的參數(shù):

  •  1治泥、filename筹煮,它將文件路徑轉(zhuǎn)化為CSV文件。
     2车摄、target_dtype寺谤,獲取數(shù)據(jù)集目標(biāo)值的numpy datatype。
     3吮播、feature_dtype变屁,獲取訓(xùn)練集特征值的numpy datatype。
    

此處意狠,target表示花的種類粟关,它是[0, 2]之間的整數(shù),所以 target_dtypenp.int环戈。

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

Datasets in tf.contrib.learn are named tuples闷板;你可以通過(guò)data and target field 訪問(wèn)特征數(shù)據(jù)和目標(biāo)數(shù)據(jù)。此處院塞,training_set.datatraining_set.target包含了訓(xùn)練集的特征數(shù)據(jù)和目標(biāo)數(shù)據(jù)遮晚;test_set.datatest_set.target包含測(cè)試集的特征數(shù)據(jù)和目標(biāo)數(shù)據(jù)。
稍后拦止,在利用 Iris data擬合神經(jīng)網(wǎng)絡(luò)分類器中看到training_set.datatraining_set.target訓(xùn)練你的模型县遣。在評(píng)估神經(jīng)網(wǎng)絡(luò)分類器的準(zhǔn)確性中,你將使用test_set.datatest_set.target汹族。但是首先萧求,在下一節(jié)中你要構(gòu)建你的模型。

三顶瞒、構(gòu)建深度神經(jīng)網(wǎng)絡(luò)分類器


tf.estimator提供了很多各種預(yù)定義的模型夸政,稱之為“Estimator”,利用它你可以在數(shù)據(jù)集之上進(jìn)行訓(xùn)練和評(píng)估榴徐。在此部分守问,你將配置深度神經(jīng)網(wǎng)絡(luò)分類器模型去擬合Iris數(shù)據(jù)。使用tf.estimator箕速,你可以實(shí)例化tf.estimator.DNNClassifier酪碘,只需要幾行代碼就可以搞定。

# Specify that all features have real-value data
feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]

# Build 3 layer DNN with 10, 20, 10 units respectively
classifier = tf.estimator.DNNClassifier(feature_columns = feature_columns,
                                       hidden_units = [10, 20, 10],
                                       n_classes = 3,
                                       model_dir = "/tmp/iris_model")

上面的代碼首先定義了模型的特征列盐茎,為數(shù)據(jù)集中特征指定了數(shù)據(jù)類型兴垦。所有的特征數(shù)據(jù)都是連續(xù)的,所以 tf.feature_column.numeric_column(該函數(shù)返回一個(gè)實(shí)數(shù)列)是構(gòu)建特征數(shù)據(jù)非常合適的函數(shù)。在數(shù)據(jù)集中總共有四個(gè)特征(sepal width探越,sepal height狡赐,petal width,petal height)钦幔,所以形狀必須被設(shè)置為 [4] 來(lái)保存所有的數(shù)據(jù)枕屉。
然后,創(chuàng)建DNNClassifier模型需要用到以下參數(shù):

  •   1鲤氢、feature_columns = feature_columns搀擂。設(shè)置特征列。
      2卷玉、hidden_units = [10, 20, 10]哨颂。設(shè)置隱含層。   
      3相种、n_classes = 3威恼。表示三個(gè)目標(biāo)分類
      4、model_dir = "/tmp/iris_model" 寝并,存儲(chǔ)checkpoint數(shù)據(jù)和TensorBoard summaries數(shù)據(jù)的目錄   
    

四箫措、數(shù)據(jù)輸入管道


tf.estimator API使用input 函數(shù),這個(gè)將創(chuàng)建一個(gè)Tensorflow操作用于為模型產(chǎn)生數(shù)據(jù)衬潦。我們使用 tf.estimator.inputs.numpy_input_fn 產(chǎn)生 input 的管道斤蔓。

# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
                    x = {"x": np.array(training_set.data)},
                    y = np.array(training_set.target),
                    num_epochs = None, 
                    shuffle = True)

五、利用 Iris data擬合神經(jīng)網(wǎng)絡(luò)分類器


現(xiàn)在你可以配置DNN 分類器的模型了镀岛,你可以把模型放在訓(xùn)練集上進(jìn)行訓(xùn)練附迷。訓(xùn)練的步數(shù)為2000次

# Train model
classifier.train(input_fn = train_input_fn, steps = 2000)

模型的狀態(tài)保存在分類器中,也就是說(shuō)哎媚,如果你喜歡的話,你可以反復(fù)訓(xùn)練喊儡。例如拨与,以下的訓(xùn)練方式是等效的。

classifier.train(input_fn = train_input_fn, steps = 1000)
classifier.tarin(input_fn = train_input_fn, steps = 1000)

不管怎么樣艾猜,如果你想在訓(xùn)練過(guò)程中追蹤模型买喧,你可以需要使用TensorFlow SessionRunHook來(lái)執(zhí)行日志操作。

六匆赃、評(píng)估神經(jīng)網(wǎng)絡(luò)分類器的準(zhǔn)確性


以下代碼表示在測(cè)試集上評(píng)估模型的精度淤毛。

# Define the test inputs
test_input_fn = tf.estimator.inputs.numpy_input_fn(
                x = {"x": np.array(test_set.data)},
                y = np.array(test_set.target),
                num_epochs = 1,
                shuffle = False)
# Evaluate accuracy
accuracy_score = classifier.evaluate(input_fn = test_input_fn)["accuracy"]
print "\nTest Accuracy: {0:f}\n".format(accuracy_score)

注意:這里的num_epochs = 1的參數(shù)非常重要。test_input_fn將迭代數(shù)據(jù)集一次算柳,然后發(fā)送 OutOfRangeError低淡。這個(gè)錯(cuò)誤信號(hào)標(biāo)志著分類器停止評(píng)估,所以它會(huì)對(duì)輸入進(jìn)行一次評(píng)估。
當(dāng)你運(yùn)行整個(gè)腳本的時(shí)候蔗蹋,會(huì)打印如下數(shù)據(jù):

Test Accuracy: 0.966667

你的準(zhǔn)確性可能會(huì)有所不同何荚,但應(yīng)該會(huì)高于90%。這對(duì)于一個(gè)較小的數(shù)據(jù)集而言并不壞猪杭。

七餐塘、對(duì)新樣本進(jìn)行分類


使用estimatorpredict() 方法可以分類新的樣本。例如皂吮,以下有兩個(gè)新樣本戒傻。

帶分類樣本

使用predict()函數(shù)會(huì)返回一個(gè)dicts,它可以很簡(jiǎn)單的轉(zhuǎn)化為list蜂筹,下面的代碼檢索并打印出結(jié)果需纳。

# Classify two new flower samples
new_samples = np.array([[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype = np.float32)
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
                    x = {"x": new_samples},
                    num_epochs = 1,
                    shuffle = False)
predictions = list(classifier.predict(input_fn = predict_input_fn))
predicted_classes = [p["classes"] for p in predictions]

print "New Samples, Class Preditions: {}\n".format(predicted_classes)

所得的結(jié)果如下:

New Samples, Class Predictions: [1 2]

因此該模型預(yù)測(cè)的第一個(gè)樣本為Iris versicolor,第二個(gè)樣本是Iris virginica狂票。

八候齿、其他資源


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末灭必,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子乃摹,更是在濱河造成了極大的恐慌禁漓,老刑警劉巖,帶你破解...
    沈念sama閱讀 211,817評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件孵睬,死亡現(xiàn)場(chǎng)離奇詭異播歼,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)掰读,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,329評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門秘狞,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人蹈集,你說(shuō)我怎么就攤上這事烁试。” “怎么了拢肆?”我有些...
    開(kāi)封第一講書人閱讀 157,354評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵减响,是天一觀的道長(zhǎng)靖诗。 經(jīng)常有香客問(wèn)我,道長(zhǎng)辩蛋,這世上最難降的妖魔是什么呻畸? 我笑而不...
    開(kāi)封第一講書人閱讀 56,498評(píng)論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮悼院,結(jié)果婚禮上伤为,老公的妹妹穿的比我還像新娘。我一直安慰自己据途,他們只是感情好绞愚,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,600評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著颖医,像睡著了一般位衩。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上熔萧,一...
    開(kāi)封第一講書人閱讀 49,829評(píng)論 1 290
  • 那天糖驴,我揣著相機(jī)與錄音,去河邊找鬼佛致。 笑死贮缕,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的俺榆。 我是一名探鬼主播感昼,決...
    沈念sama閱讀 38,979評(píng)論 3 408
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼罐脊!你這毒婦竟也來(lái)了定嗓?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書人閱讀 37,722評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤萍桌,失蹤者是張志新(化名)和其女友劉穎宵溅,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體上炎,經(jīng)...
    沈念sama閱讀 44,189評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡层玲,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,519評(píng)論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了反症。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,654評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡畔派,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情烦绳,我是刑警寧澤,帶...
    沈念sama閱讀 34,329評(píng)論 4 330
  • 正文 年R本政府宣布惧眠,位于F島的核電站暮顺,受9級(jí)特大地震影響捶码,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,940評(píng)論 3 313
  • 文/蒙蒙 一旬蟋、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦言秸、人聲如沸凳枝。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 30,762評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)合是。三九已至,卻和暖如春锭环,著一層夾襖步出監(jiān)牢的瞬間聪全,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 31,993評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工辅辩, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留难礼,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,382評(píng)論 2 360
  • 正文 我出身青樓玫锋,卻偏偏與公主長(zhǎng)得像蛾茉,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子撩鹿,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,543評(píng)論 2 349

推薦閱讀更多精彩內(nèi)容