TensorFlow編程框架
TensorFlow深度學習框架大致分為4層辩块,結構如下
在使用TensorFlow做訓練模型的時候趴久,官方推薦使用API
Estimators:訓練模型API融蹂,它提供了訓練模型凸舵,評估模型和使用模型進行預測的方法
Datasets:數(shù)據(jù)集API疆前,它提供了獲取數(shù)據(jù)以及對訓練模型進行數(shù)據(jù)輸入的方法,它與Estimators能夠很好的協(xié)調工作
鳶尾花分類:概述
如下圖有三種鳶尾花余耽,分別是清風藤戴而、雜色鳶尾和維爾吉妮卡(這里翻譯不一定準砾肺,但不影響理解),通過萼片和花瓣的長度和寬度我們可以分辨出它們屬于哪個品種
數(shù)據(jù)集
鳶尾花數(shù)據(jù)集包含4個特征集和1個標簽集防嗡,如下:
特征集变汪,與鳶尾花的生物特征相關
⊙ sepal length?萼片長度
⊙ sepal width 萼片寬度
⊙ petal length 花瓣長度
⊙ petal width 花瓣寬度
標簽集,是鳶尾花的分類標識
⊙ Iris setosa (0)?清風藤
⊙ Iris versicolor (1)?雜色鳶尾
⊙ Iris virginica (2)?維爾吉妮卡
算法
深層神經網絡分類模型的算法圖如下:
⊙ 2層隱層
⊙ 每層有10個節(jié)點
推斷
通過訓練好的鳶尾花分類模型蚁趁,我們輸入一個未經過人工分類的鳶尾花特征數(shù)據(jù)裙盾,能得出類似如下的推斷結果:
⊙ 0.03 for Iris Setosa
⊙ 0.95 for Iris Versicolor
⊙ 0.02 for Iris Virginica
這就是通過模型識別為某種鳶尾花的概率,它們的和為1
Estimators的使用
Estimator是TensorFlow的高層訓練模型API他嫡,它屏蔽了數(shù)據(jù)初始化番官、日志、模型保存和恢復等細節(jié)钢属,令你可以專心訓練你的模型徘熔,使用默認的Estimator步驟如下:
⊙ 創(chuàng)建一個或多個輸入函數(shù)(input function)
⊙ 定義模型特征集
⊙ 實例化Estimator,并傳入特征集和超參數(shù)
⊙ 使用特定的輸入函數(shù)作為參數(shù)調用Estimator的方法
創(chuàng)建輸入函數(shù)
輸入函數(shù)為模型訓練淆党、模型評估和數(shù)據(jù)預測等操作提供數(shù)據(jù)輸入酷师,它的返回值一般是個二元組:
features元組:一個map
????⊙ key是特征的名字
? ??⊙ values是包含所有特征值的數(shù)組
labels元組:一個包含所有標簽值的數(shù)組
一個簡單的輸入函數(shù)實現(xiàn)如下:
def input_evaluation_set(): ? ?
????features = {'SepalLength': np.array([6.4, 5.0]), ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ?'SepalWidth': ?np.array([2.8, 2.3]), ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ?'PetalLength': np.array([5.6, 3.3]), ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ?'PetalWidth': ?np.array([2.2, 1.0])}? ?
????labels = np.array([2, 1]) ? ? return features, labels
TensorFlow建議使用Dataset API,它對解析一些數(shù)據(jù)輸入源非常有幫助染乌,其API層次如下:
Dataset:數(shù)據(jù)集API的基類山孔,包含創(chuàng)建和傳輸數(shù)據(jù)集的接口
TextLineDataset:從文本文件讀取數(shù)據(jù)集
TFRecordDataset:從TFRecord文件讀取數(shù)據(jù)集
FixedLengthRecordDataset:從二進制文件讀取數(shù)據(jù)集
Iterator:數(shù)據(jù)集迭代器,通過它可以遍歷整個數(shù)據(jù)集
一個使用Dataset API的輸入函數(shù)例子:
def train_input_fn(features, labels, batch_size):
????"""An input function for training"""
????# Convert the inputs to a Dataset.
????dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
????# Shuffle, repeat, and batch the examples.
????dataset =? dataset.shuffle(1000).repeat().batch(batch_size)
????# Build the Iterator, and return the read end of the pipeline.
????return dataset.make_one_shot_iterator().get_next()
定義特征列
特征列告訴Estimator都輸入哪些特征荷憋。
例如鳶尾花的分類台颠,有4種特征,生成特征列代碼如下:
my_feature_columns = []
for key in train_x.keys():
????my_feature_columns.append(tf.feature_column.numeric_column(key=key))
實例化Estimator
鳶尾花分類是個典型的分類問題勒庄,TensorFlow內置了幾種分類器Estimator模型:
⊙?tf.estimator.DNNClassifier:面向多類分類的深度學習模型
⊙?tf.estimator.DNNLinearCombinedClassifier:面向wide-n-deep模型
⊙?tf.estimator.LinearClassifier:面向線性分類模型
就鳶尾花分類問題蓉媳,最合適的是tf.estimator.DNNClassifier
# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.DNNClassifier(
????feature_columns = my_feature_columns,
????# Two hidden layers of 10 nodes each.
????hidden_units = [10, 10],
????# The model must choose between 3 classes.
????n_classes = 3)
訓練、評估和預測
我們已經擁有了一個Estimator的實例锅铅,于是我們執(zhí)行以下步驟:
⊙ 訓練該模型
⊙ 評估訓練好的模型
⊙ 用訓練好的模型做預測
訓練模型
調用Estimator的train方法開始訓練模型:
# Train the Model.
classifier.train( ? ?
????input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
????steps=args.train_steps)
這里我們用lambda表達式對我們的輸入函數(shù)做一個封裝,以捕獲輸入函數(shù)的參數(shù)减宣,steps參數(shù)告訴該方法訓練多少步后停止盐须。
評估訓練好的模型
模型訓練好后,我們需要評估它的準確性漆腌,以下是評估訓練模型的代碼片段:
# Evaluate the model.
eval_result = classifier.evaluate( ? ?
????input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
運行代碼輸出如下:
Test set accuracy: 0.967
使用訓練好的模型做預測
模型訓練好后贼邓,我們就可以用它來做預測了,我們輸入未經過分類的鳶尾花特征數(shù)據(jù)闷尿,然后調用predict做預測:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = { ? ?
????'SepalLength': [5.1, 5.9, 6.9], ? ?
????'SepalWidth': [3.3, 3.0, 3.1], ? ?
????'PetalLength': [1.7, 4.2, 5.4], ? ?
????'PetalWidth': [0.5, 1.5, 2.1],
}
predictions = classifier.predict(input_fn = lambda:iris_data.eval_input_fn(predict_x, batch_size = args.batch_size))
使用迭代器獲取predictions集合的數(shù)據(jù):
for pred_dict, expec in zip(predictions, expected): ? ?
????template = ('\nPrediction is "{}" ({:.1f}%), expected "{}"') ? ?
????class_id = pred_dict['class_ids'][0] ? ?
????probability = pred_dict['probabilities'][class_id] ? ?
????print(template.format(iris_data.SPECIES[class_id], 100 * probability, expec))
運行代碼輸出如下:
Prediction is "Setosa" (99.6%), expected "Setosa"
Prediction is "Versicolor" (99.8%), expected "Versicolor"
Prediction is "Virginica" (97.9%), expected "Virginica"