學(xué)習(xí)資料:
https://www.tensorflow.org/get_started/tflearn
今天學(xué)習(xí)用 tf.contrib.learn 來(lái)建立 DNN 對(duì) Iris 數(shù)據(jù)集進(jìn)行分類(lèi).
問(wèn)題:
我們有 Iris 數(shù)據(jù)集,它包含150個(gè)樣本數(shù)據(jù),分別來(lái)自三個(gè)品種腊状,每個(gè)品種有50個(gè)樣本屡穗,每個(gè)樣本具有四個(gè)特征伍派,以及它屬于哪一類(lèi)谣拣,分別由 0鲫竞,1,2 代表三個(gè)品種智绸。
我們將這150個(gè)樣本分為兩份野揪,一份是訓(xùn)練集具有120個(gè)樣本,另一份是測(cè)試集具有30個(gè)樣本瞧栗。
我們要做的就是建立一個(gè)神經(jīng)網(wǎng)絡(luò)分類(lèi)模型對(duì)每個(gè)樣本進(jìn)行分類(lèi)斯稳,識(shí)別它是哪個(gè)品種。
一共有 5 步:
- 導(dǎo)入 CSV 格式的數(shù)據(jù)集
- 建立神經(jīng)網(wǎng)絡(luò)分類(lèi)模型
- 用訓(xùn)練數(shù)據(jù)集訓(xùn)練模型
- 評(píng)價(jià)模型的準(zhǔn)確率
- 對(duì)新樣本數(shù)據(jù)進(jìn)行分類(lèi)
代碼:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import urllib
import numpy as np
import tensorflow as tf
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"
IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
def main():
# If the training and test sets aren't stored locally, download them.
if not os.path.exists(IRIS_TRAINING):
raw = urllib.urlopen(IRIS_TRAINING_URL).read()
with open(IRIS_TRAINING, "w") as f:
f.write(raw)
if not os.path.exists(IRIS_TEST):
raw = urllib.urlopen(IRIS_TEST_URL).read()
with open(IRIS_TEST, "w") as f:
f.write(raw)
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
# Define the training inputs
def get_train_inputs():
x = tf.constant(training_set.data)
y = tf.constant(training_set.target)
return x, y
# Fit model.
classifier.fit(input_fn=get_train_inputs, steps=2000)
# Define the test inputs
def get_test_inputs():
x = tf.constant(test_set.data)
y = tf.constant(test_set.target)
return x, y
# Evaluate accuracy.
accuracy_score = classifier.evaluate(input_fn=get_test_inputs,
steps=1)["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
# Classify two new flower samples.
def new_samples():
return np.array(
[[6.4, 3.2, 4.5, 1.5],
[5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
predictions = list(classifier.predict(input_fn=new_samples))
print(
"New Samples, Class Predictions: {}\n"
.format(predictions))
if __name__ == "__main__":
main()
從代碼可以看出很簡(jiǎn)短的幾行就可以完成之前學(xué)過(guò)的很長(zhǎng)的代碼所做的事情迹恐,用起來(lái)和用 sklearn 相似挣惰。
關(guān)于 tf.contrib.learn
可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn
可以看到里面也有 kmeans,logistic殴边,linear
等模型:
在上面的代碼中:
- 用
tf.contrib.learn.datasets.base.load_csv_with_header
可以導(dǎo)入 CSV 數(shù)據(jù)集憎茂。 - 分類(lèi)器模型只需要一行代碼,就可以設(shè)置這個(gè)模型具有多少隱藏層锤岸,每個(gè)隱藏層有多少神經(jīng)元竖幔,以及最后分為幾類(lèi)。
- 模型的訓(xùn)練也是只需要一行代碼,輸入指定的數(shù)據(jù)是偷,包括特征和標(biāo)簽拳氢,再指定迭代的次數(shù),就可以進(jìn)行訓(xùn)練蛋铆。
- 獲得準(zhǔn)確率也同樣很簡(jiǎn)單,只需要輸入測(cè)試集,調(diào)用 evaluate馋评。
- 預(yù)測(cè)新的數(shù)據(jù)集,只需要把新的樣本數(shù)據(jù)傳遞給 predict。
關(guān)于代碼里幾個(gè)新的方法:
1. load_csv_with_header()
:
用于導(dǎo)入 CSV刺啦,需要三個(gè)必需的參數(shù):
- filename栗恩,CSV文件的路徑
- target_dtype,數(shù)據(jù)集的目標(biāo)值的numpy數(shù)據(jù)類(lèi)型洪燥。
- features_dtype磕秤,數(shù)據(jù)集的特征值的numpy數(shù)據(jù)類(lèi)型。
在這里捧韵,target 是花的品種市咆,它是一個(gè)從 0-2 的整數(shù),所以對(duì)應(yīng)的numpy數(shù)據(jù)類(lèi)型是np.int
2. tf.contrib.layers.real_valued_column
:
所有的特征數(shù)據(jù)都是連續(xù)的再来,因此用 tf.contrib.layers.real_valued_column,數(shù)據(jù)集中有四個(gè)特征(萼片寬度蒙兰,萼片高度,花瓣寬度和花瓣高度)芒篷,因此 dimension=4 搜变。
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
3. DNNClassifier:
-
feature_columns=feature_columns
, 上面定義的一組特征 -
hidden_units=[10, 20, 10]
,三個(gè)隱藏層分別包含10,20针炉,10個(gè)神經(jīng)元挠他。 -
n_classes=3
,三個(gè)目標(biāo)類(lèi),代表三個(gè) Iris 品種篡帕。 -
model_dir=/tmp/iris_model
,TensorFlow在模型訓(xùn)練期間將保存 checkpoint data殖侵。
在后面會(huì)學(xué)到關(guān)于 TensorFlow 的 logging and monitoring 的章節(jié)贸呢,可以 track 一下訓(xùn)練中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。
推薦閱讀 歷史技術(shù)博文鏈接匯總
http://www.reibang.com/p/28f02bb59fe5
也許可以找到你想要的