一文初探Tensorflow高級API使用(初學(xué)者篇)

筆記整理者:王小草
筆記整理時(shí)間:2017年2月26日
對應(yīng)的官方文檔地址:https://www.tensorflow.org/get_started/tflearn
官方文檔上次更新時(shí)間:2017年2月15日


今天我們要向Tensorflow高級API的學(xué)習(xí)門檻邁進(jìn)一步啤挎。別聽到高級API就覺得是難度高的意思蔗草,其實(shí)高級API恰恰是為了降低大家的編碼難度而設(shè)置的洒扎。Tensorflow更高層的API使得配置,訓(xùn)練,評估多種多樣的機(jī)器學(xué)習(xí)模型更簡單方便了。

本文將使用高層API:tf.contrib.learn 來構(gòu)建一個(gè)分類神經(jīng)網(wǎng)絡(luò),將它放在“鳶尾花數(shù)據(jù)集”上進(jìn)行訓(xùn)練,并且估計(jì)模型神僵,使得模型能根據(jù)特征(萼片和花瓣幾何形狀)預(yù)測出花的種類。

1. 加載鳶尾花數(shù)據(jù)集到Tensorflow上

首先介紹一下我們今天要使用的數(shù)據(jù)集:

鳶尾花數(shù)據(jù)集:Iris data set 由150個(gè)樣本組成覆劈。其中保礼,總共有3個(gè)類別:山鳶尾(Iris setosa)沛励,虹膜錦葵(Iris virginica),變色鳶尾 (Iris versicolor) 炮障,每個(gè)類別50個(gè)樣本目派。

下圖,從左到右分別是 Iris setosa , Iris versicolor, and Iris virginica三類花的圖片:


image_1b9suplqg7v91rva6qj1cda87u13.png-573.3kB

數(shù)據(jù)的每一行(也就是每個(gè)樣本)包含了樣本的特征與類別標(biāo)簽胁赢。
特征有:萼片的長度企蹭,萼片的寬度,花瓣的長度智末,花瓣的寬度谅摄。
類別標(biāo)簽用整型數(shù)字表示:0表示萼片,1表示Iris versicolor系馆,2表示Iris virginica
數(shù)據(jù)格式如下:


image_1b9sruij71dl3eco10cf1laa1bgm.png-34.3kB

在機(jī)器學(xué)習(xí)的建模中送漠,我們一般將數(shù)據(jù)集拆分成訓(xùn)練集與測試集,訓(xùn)練集用來訓(xùn)練模型由蘑,測試集用來測試模型的泛化能力闽寡。所以此處,也將150個(gè)樣本的數(shù)據(jù)集隨機(jī)地拆分成兩個(gè)部分:
(1)訓(xùn)練集包含120個(gè)樣本(放在iris_training.csv文件中)
(2)測試集包含30個(gè)樣本(放在iris_test.csv文件中)
在開始寫程序之前尼酿,要先下載好這兩個(gè)數(shù)據(jù)集哦~

現(xiàn)在我們已經(jīng)了解了數(shù)據(jù)集大概的樣子了爷狈,于是開始上代碼嘍~

首先,還是先導(dǎo)入要用的庫

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

接著裳擎,把下載好的訓(xùn)練集與測試集根據(jù)它們的路徑加載的dataset中淆院,使用的是learn.datasets.base中的load_csv_with_header()這個(gè)方法。這個(gè)方法需要傳入3個(gè)參數(shù):
(1)filename:文件路徑/文件名
(2)target_dtype:標(biāo)簽類別的數(shù)據(jù)類型
(3)features_dtype:特征的數(shù)據(jù)類型

# 定義數(shù)據(jù)集的路徑
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# 加載數(shù)據(jù)集
# # 加載訓(xùn)練集
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)

# # 加載測試集  
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

注意句惯,加載建立后的Dataset是命名元組,可以使用training_set.data調(diào)用訓(xùn)練數(shù)據(jù)集的特征數(shù)據(jù)支救,使用training_set.target調(diào)用訓(xùn)練數(shù)據(jù)集的類別標(biāo)簽數(shù)據(jù)抢野。對test_set的測試數(shù)據(jù)集也是同理。

2. 構(gòu)建深度神經(jīng)網(wǎng)絡(luò)分類模型

tf.contrib.learn提供了多種多樣的預(yù)定義模型各墨,叫做Estimators(估計(jì)器)指孤,這些Estimator在你擬運(yùn)行訓(xùn)練與評估模型的操作的時(shí)候可以實(shí)現(xiàn)開箱即用,也就是說贬堵,當(dāng)你要使用某個(gè)模型的時(shí)候恃轩,不再需要去寫他的內(nèi)部邏輯,直接調(diào)用這個(gè)模型的接口黎做,用一句代碼搞定即可叉跛。

于是,這里我們就來使用tf.contrib.learn配置一個(gè)深層神經(jīng)網(wǎng)絡(luò)的分類模型蒸殿,只需要了了幾行代碼~

# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

以上代碼首先定義了模型的特征列筷厘,并且指定了特征數(shù)據(jù)的數(shù)據(jù)類型鸣峭。在上一節(jié)中我們看到所有的特征都是連續(xù)型變量,所以tf.contrib.layers.real_valued_column這個(gè)函數(shù)被用來構(gòu)建特征列酥艳。另外摊溶,我們的數(shù)據(jù)集中有4個(gè)特征,故傳入?yún)?shù)dimension=4.

接著充石,以上代碼使用了tf.contrib.learn.DNNClassifier這個(gè)函數(shù)來直接構(gòu)建DNN模型莫换。(記得前面兩個(gè)筆記,無論是講簡單的分類模型softmax regression還是稍微復(fù)雜的卷積神經(jīng)網(wǎng)絡(luò)骤铃,都是自己一層一層地去寫模型的邏輯結(jié)構(gòu)拉岁,相當(dāng)繁瑣,看劲厌!高級的API已經(jīng)為我們封裝好了這些模型膛薛,我們只需要直接調(diào)用方法就行)
DNNClassifier這個(gè)方法需要傳入4個(gè)參數(shù):
(1)feature_columns=feature_columns,將剛剛預(yù)先定義好的特征列傳給參數(shù)feature_columns补鼻。
(2)hidden_units=[10, 20, 10]哄啄,設(shè)置隱藏層中的神經(jīng)元個(gè)數(shù),這里表示共有3個(gè)隱藏層风范,依次的神經(jīng)元個(gè)數(shù)為10,20,10咨跌。
(3)n_classes=3,設(shè)置目標(biāo)分類的個(gè)數(shù)硼婿,這個(gè)是3類锌半,分成3種鳶尾花。
(4)model_dir=/tmp/iris_model寇漫,這是保存模型訓(xùn)練過程中的checkpoint檢查點(diǎn)的數(shù)據(jù)的路徑

3. 模型擬合真實(shí)數(shù)據(jù)進(jìn)行訓(xùn)練

上面一步建立了一個(gè)模型刊殉,現(xiàn)在你可以將鳶尾花的訓(xùn)練數(shù)據(jù)集利用fit()這個(gè)方法來擬合進(jìn)模型。主要是通過傳入?yún)?shù)的方式州胳,將訓(xùn)練集中的特征傳給x,將訓(xùn)練集中的標(biāo)簽傳給y记焊,并且定義了訓(xùn)練的次數(shù)(比如這里是2000次):

# Fit model
classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

注意的是,模型的狀態(tài)會在訓(xùn)練中被緩存在分類器中classifier栓撞,所以你可以按照自己的喜好來分開迭代遍膜,例如,上面代碼等同于下面兩句代碼:

classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
classifier.fit(x=training_set.data, y=training_set.target, steps=1000)

4.評估模型的精度

第1步導(dǎo)入了數(shù)據(jù)瓤湘,第2步構(gòu)建了模型瓢颅,第3步在訓(xùn)練集上進(jìn)行了訓(xùn)練學(xué),現(xiàn)在第4步弛说,我們要去評估訓(xùn)練好的模型了挽懦。

評估模型的時(shí)候使用的是測試集,與.fit()方法相似木人,評估模型調(diào)用.evaluate()方法巾兆,并且將測試集的特征傳入給x猎物,測試集的標(biāo)簽傳入給y,并且指定計(jì)算的是accuracy角塑。

accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

運(yùn)行以上的所有代碼蔫磨,會打印出最后的精度:

Accuracy: 0.966667

每次訓(xùn)練的accuracy可能會有點(diǎn)不相同,但都應(yīng)該是在90%之上的哈~

5.預(yù)測新的數(shù)據(jù)

模型建好了圃伶,也通過了評估堤如,現(xiàn)在終于到了用武之時(shí)呢~我們要用模型與預(yù)測新的數(shù)據(jù)。


image_1b9tb7ha6oti156pplb23d128k1g.png-11.7kB

比如窒朋,現(xiàn)在新來了兩條未知的數(shù)據(jù)搀罢,至知道這兩朵花的4個(gè)特征,卻不知道它們的種類侥猩,于是調(diào)用.predict()方法進(jìn)行預(yù)測:

# 新的兩個(gè)樣本
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)

# 預(yù)測  
y = list(classifier.predict(new_samples, as_iterable=True))

# 打印
print('Predictions: {}'.format(str(y)))

.predict()返回的是一個(gè)數(shù)組榔至,預(yù)測的結(jié)果打印出來應(yīng)是如下,第一個(gè)樣本為1類欺劳,第二哥贗本為二類唧取。

Prediction: [1 2]


將以上代碼所有整合在一起如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

# Fit model.
classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)

# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
                                     y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

# Classify two new flower samples.
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))


tf.contrib.learn包括了各種類型的深度學(xué)習(xí)和機(jī)器學(xué)習(xí)的算法。它是從Tensorflow官方Scikit Flow直接遷移過來的划提,其使用的風(fēng)格與Scikit-learn相似(用python寫機(jī)器學(xué)習(xí)的小伙伴應(yīng)該很熟悉)枫弟。
從Tensorflowv0.9版本時(shí)候,tf.learn已經(jīng)能夠無縫與其他contrib模型結(jié)合起來使用啦~

原文: 一文初探Tensorflow高級API使用(初學(xué)者篇)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末鹏往,一起剝皮案震驚了整個(gè)濱河市淡诗,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌伊履,老刑警劉巖韩容,帶你破解...
    沈念sama閱讀 217,277評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異唐瀑,居然都是意外死亡宙攻,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,689評論 3 393
  • 文/潘曉璐 我一進(jìn)店門介褥,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人递惋,你說我怎么就攤上這事柔滔。” “怎么了萍虽?”我有些...
    開封第一講書人閱讀 163,624評論 0 353
  • 文/不壞的土叔 我叫張陵睛廊,是天一觀的道長。 經(jīng)常有香客問我杉编,道長超全,這世上最難降的妖魔是什么咆霜? 我笑而不...
    開封第一講書人閱讀 58,356評論 1 293
  • 正文 為了忘掉前任,我火速辦了婚禮嘶朱,結(jié)果婚禮上蛾坯,老公的妹妹穿的比我還像新娘。我一直安慰自己疏遏,他們只是感情好脉课,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,402評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著财异,像睡著了一般倘零。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上戳寸,一...
    開封第一講書人閱讀 51,292評論 1 301
  • 那天呈驶,我揣著相機(jī)與錄音,去河邊找鬼疫鹊。 笑死袖瞻,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的订晌。 我是一名探鬼主播虏辫,決...
    沈念sama閱讀 40,135評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼锈拨!你這毒婦竟也來了砌庄?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,992評論 0 275
  • 序言:老撾萬榮一對情侶失蹤奕枢,失蹤者是張志新(化名)和其女友劉穎娄昆,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體缝彬,經(jīng)...
    沈念sama閱讀 45,429評論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡萌焰,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,636評論 3 334
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了谷浅。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片扒俯。...
    茶點(diǎn)故事閱讀 39,785評論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖一疯,靈堂內(nèi)的尸體忽然破棺而出撼玄,到底是詐尸還是另有隱情,我是刑警寧澤墩邀,帶...
    沈念sama閱讀 35,492評論 5 345
  • 正文 年R本政府宣布掌猛,位于F島的核電站,受9級特大地震影響眉睹,放射性物質(zhì)發(fā)生泄漏荔茬。R本人自食惡果不足惜废膘,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,092評論 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望慕蔚。 院中可真熱鬧丐黄,春花似錦、人聲如沸坊萝。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,723評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽十偶。三九已至菩鲜,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間惦积,已是汗流浹背接校。 一陣腳步聲響...
    開封第一講書人閱讀 32,858評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留狮崩,地道東北人蛛勉。 一個(gè)月前我還...
    沈念sama閱讀 47,891評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像睦柴,于是被迫代替她去往敵國和親诽凌。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,713評論 2 354

推薦閱讀更多精彩內(nèi)容