TensorFlow-4: tf.contrib.learn 快速入門(mén)

學(xué)習(xí)資料:
https://www.tensorflow.org/get_started/tflearn

相應(yīng)的中文翻譯:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/


今天學(xué)習(xí)用 tf.contrib.learn 來(lái)建立 DNN 對(duì) Iris 數(shù)據(jù)集進(jìn)行分類(lèi).

問(wèn)題:
我們有 Iris 數(shù)據(jù)集,它包含150個(gè)樣本數(shù)據(jù),分別來(lái)自三個(gè)品種腊状,每個(gè)品種有50個(gè)樣本屡穗,每個(gè)樣本具有四個(gè)特征伍派,以及它屬于哪一類(lèi)谣拣,分別由 0鲫竞,1,2 代表三個(gè)品種智绸。
我們將這150個(gè)樣本分為兩份野揪,一份是訓(xùn)練集具有120個(gè)樣本,另一份是測(cè)試集具有30個(gè)樣本瞧栗。
我們要做的就是建立一個(gè)神經(jīng)網(wǎng)絡(luò)分類(lèi)模型對(duì)每個(gè)樣本進(jìn)行分類(lèi)斯稳,識(shí)別它是哪個(gè)品種。

一共有 5 步:

  • 導(dǎo)入 CSV 格式的數(shù)據(jù)集
  • 建立神經(jīng)網(wǎng)絡(luò)分類(lèi)模型
  • 用訓(xùn)練數(shù)據(jù)集訓(xùn)練模型
  • 評(píng)價(jià)模型的準(zhǔn)確率
  • 對(duì)新樣本數(shù)據(jù)進(jìn)行分類(lèi)

代碼:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import urllib

import numpy as np
import tensorflow as tf

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

def main():
  # If the training and test sets aren't stored locally, download them.
  if not os.path.exists(IRIS_TRAINING):
    raw = urllib.urlopen(IRIS_TRAINING_URL).read()
    with open(IRIS_TRAINING, "w") as f:
      f.write(raw)

  if not os.path.exists(IRIS_TEST):
    raw = urllib.urlopen(IRIS_TEST_URL).read()
    with open(IRIS_TEST, "w") as f:
      f.write(raw)

  # Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)
  test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
      filename=IRIS_TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

  # Build 3 layer DNN with 10, 20, 10 units respectively.
  classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                              hidden_units=[10, 20, 10],
                                              n_classes=3,
                                              model_dir="/tmp/iris_model")
  # Define the training inputs
  def get_train_inputs():
    x = tf.constant(training_set.data)
    y = tf.constant(training_set.target)

    return x, y

  # Fit model.
  classifier.fit(input_fn=get_train_inputs, steps=2000)

  # Define the test inputs
  def get_test_inputs():
    x = tf.constant(test_set.data)
    y = tf.constant(test_set.target)

    return x, y

  # Evaluate accuracy.
  accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
                                       steps=1)["accuracy"]

  print("\nTest Accuracy: {0:f}\n".format(accuracy_score))

  # Classify two new flower samples.
  def new_samples():
    return np.array(
      [[6.4, 3.2, 4.5, 1.5],
       [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)

  predictions = list(classifier.predict(input_fn=new_samples))

  print(
      "New Samples, Class Predictions:    {}\n"
      .format(predictions))

if __name__ == "__main__":
    main()

從代碼可以看出很簡(jiǎn)短的幾行就可以完成之前學(xué)過(guò)的很長(zhǎng)的代碼所做的事情迹恐,用起來(lái)和用 sklearn 相似挣惰。

關(guān)于 tf.contrib.learn 可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn

可以看到里面也有 kmeans,logistic殴边,linear 等模型:


在上面的代碼中:

  • tf.contrib.learn.datasets.base.load_csv_with_header 可以導(dǎo)入 CSV 數(shù)據(jù)集憎茂。
  • 分類(lèi)器模型只需要一行代碼,就可以設(shè)置這個(gè)模型具有多少隱藏層锤岸,每個(gè)隱藏層有多少神經(jīng)元竖幔,以及最后分為幾類(lèi)。
  • 模型的訓(xùn)練也是只需要一行代碼,輸入指定的數(shù)據(jù)是偷,包括特征和標(biāo)簽拳氢,再指定迭代的次數(shù),就可以進(jìn)行訓(xùn)練蛋铆。
  • 獲得準(zhǔn)確率也同樣很簡(jiǎn)單,只需要輸入測(cè)試集,調(diào)用 evaluate馋评。
  • 預(yù)測(cè)新的數(shù)據(jù)集,只需要把新的樣本數(shù)據(jù)傳遞給 predict。

關(guān)于代碼里幾個(gè)新的方法:

1. load_csv_with_header():

用于導(dǎo)入 CSV刺啦,需要三個(gè)必需的參數(shù):

  • filename栗恩,CSV文件的路徑
  • target_dtype,數(shù)據(jù)集的目標(biāo)值的numpy數(shù)據(jù)類(lèi)型洪燥。
  • features_dtype磕秤,數(shù)據(jù)集的特征值的numpy數(shù)據(jù)類(lèi)型。

在這里捧韵,target 是花的品種市咆,它是一個(gè)從 0-2 的整數(shù),所以對(duì)應(yīng)的numpy數(shù)據(jù)類(lèi)型是np.int

2. tf.contrib.layers.real_valued_column:

所有的特征數(shù)據(jù)都是連續(xù)的再来,因此用 tf.contrib.layers.real_valued_column,數(shù)據(jù)集中有四個(gè)特征(萼片寬度蒙兰,萼片高度,花瓣寬度和花瓣高度)芒篷,因此 dimension=4 搜变。

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

3. DNNClassifier:

  • feature_columns=feature_columns, 上面定義的一組特征
  • hidden_units=[10, 20, 10],三個(gè)隱藏層分別包含10,20针炉,10個(gè)神經(jīng)元挠他。
  • n_classes=3,三個(gè)目標(biāo)類(lèi),代表三個(gè) Iris 品種篡帕。
  • model_dir=/tmp/iris_model,TensorFlow在模型訓(xùn)練期間將保存 checkpoint data殖侵。

在后面會(huì)學(xué)到關(guān)于 TensorFlow 的 logging and monitoring 的章節(jié)贸呢,可以 track 一下訓(xùn)練中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。


推薦閱讀 歷史技術(shù)博文鏈接匯總
http://www.reibang.com/p/28f02bb59fe5
也許可以找到你想要的

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末拢军,一起剝皮案震驚了整個(gè)濱河市楞陷,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌茉唉,老刑警劉巖固蛾,帶你破解...
    沈念sama閱讀 211,743評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異度陆,居然都是意外死亡魏铅,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,296評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門(mén)坚芜,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)览芳,“玉大人,你說(shuō)我怎么就攤上這事鸿竖〔拙梗” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 157,285評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵缚忧,是天一觀的道長(zhǎng)悟泵。 經(jīng)常有香客問(wèn)我,道長(zhǎng)闪水,這世上最難降的妖魔是什么糕非? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 56,485評(píng)論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮球榆,結(jié)果婚禮上朽肥,老公的妹妹穿的比我還像新娘。我一直安慰自己持钉,他們只是感情好衡招,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,581評(píng)論 6 386
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著每强,像睡著了一般始腾。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上空执,一...
    開(kāi)封第一講書(shū)人閱讀 49,821評(píng)論 1 290
  • 那天浪箭,我揣著相機(jī)與錄音,去河邊找鬼辨绊。 笑死奶栖,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播驼抹,決...
    沈念sama閱讀 38,960評(píng)論 3 408
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼桑孩,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼拜鹤!你這毒婦竟也來(lái)了框冀?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 37,719評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤敏簿,失蹤者是張志新(化名)和其女友劉穎明也,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體惯裕,經(jīng)...
    沈念sama閱讀 44,186評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡温数,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,516評(píng)論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了蜻势。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片撑刺。...
    茶點(diǎn)故事閱讀 38,650評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖握玛,靈堂內(nèi)的尸體忽然破棺而出够傍,到底是詐尸還是另有隱情,我是刑警寧澤挠铲,帶...
    沈念sama閱讀 34,329評(píng)論 4 330
  • 正文 年R本政府宣布冕屯,位于F島的核電站,受9級(jí)特大地震影響拂苹,放射性物質(zhì)發(fā)生泄漏安聘。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,936評(píng)論 3 313
  • 文/蒙蒙 一瓢棒、第九天 我趴在偏房一處隱蔽的房頂上張望浴韭。 院中可真熱鬧,春花似錦脯宿、人聲如沸囱桨。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,757評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)舍肠。三九已至,卻和暖如春窘面,著一層夾襖步出監(jiān)牢的瞬間翠语,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 31,991評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工财边, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留肌括,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,370評(píng)論 2 360
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像谍夭,于是被迫代替她去往敵國(guó)和親黑滴。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,527評(píng)論 2 349

推薦閱讀更多精彩內(nèi)容