TensorFlow 低階API派哲,手動訓練一個小型回歸模型侥猬。
定義數(shù)據(jù)
我們首先來定義一些輸入值 x笛园,以及每個輸入值的預期輸出值 y_true:
x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)
y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32)
定義模型
接下來懦胞,建立一個簡單的線性模型,其輸出值只有 1 個:
linear_model = tf.layers.Dense(units=1)
y_pred = linear_model(x)
您可以如下評估預測值:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(y_pred))
該模型尚未接受訓練脊串,因此這里的‘預測’值并不理想辫呻。
損失
要優(yōu)化模型,您首先需要定義損失琼锋,我們將使用均方誤差印屁,這是回歸問題的標準損失。
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
print(sess.run(loss))
這里會生成一個損失值:
2.23962
訓練
TensorFlow 提供了執(zhí)行標準優(yōu)化算法的優(yōu)化器斩例。這些優(yōu)化器被實現(xiàn)為 tf.train.Optimizer 的子類雄人。它們會逐漸改變每個變量,以便將損失最小化。最簡單的優(yōu)化算法是梯度下降法础钠,由 tf.train.GradientDescentOptimizer 實現(xiàn)恰力。它會根據(jù)損失相對于變量的導數(shù)大小來修改各個變量。例如:
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
該代碼構(gòu)建了優(yōu)化所需的所有圖組件旗吁,并返回一個訓練指令踩萎。該訓練指令在運行時會更新圖中的變量。您可以按以下方式運行該指令:
for i in range(100):
_, loss_value = sess.run((train, loss))
print(loss_value)
由于 train 是一個指令而不是張量很钓,因此它在運行時不會返回一個值香府。為了查看訓練期間損失的進展,我們會同時運行損失張量码倦,生成如下所示的輸出值:
1.35659
1.00412
0.759167
0.588829
0.470264
0.387626
0.329918
0.289511
0.261112
0.241046
...
完整程序如下:
from __future__ import absolute_import, division, print_function
import tensorflow as tf
# TF 手動訓練一個小型回歸模型
# 1. 定義數(shù)據(jù)
x = tf.constant([[1], [2], [3], [4]],dtype = tf.float32)
y_true = tf.constant([[0], [-1], [-2], [-3]],dtype = tf.float32)
# 2. 定義模型
linear_model = tf.layers.Dense(units = 1) # 定義一個簡單的線性模型企孩,只有1個輸出值
y_pred = linear_model(x)
# 3. 損失
loss = tf.losses.mean_squared_error(labels = y_true, predictions = y_pred)
# 4. 訓練
optimizer = tf.train.GradientDescentOptimizer(0.01) # 學習率為0.01的梯度下降優(yōu)化器
train = optimizer.minimize(loss)
init = tf.global_variables_initializer() # 層包含的變量必須先初始化,然后才能使用
sess = tf.Session() # 創(chuàng)建會話:要評估張量袁稽,需要實例化一個 tf.Session 對象
sess.run(init)
for i in range(100):
_, loss_value = sess.run((train, loss)) # 執(zhí)行層
print('Loss at step {} : {:.3f}'.format(i,loss_value))
print(sess.run(y_pred))