Tensorflow上使用簡單神經(jīng)網(wǎng)絡
當我們回頭去看Tensorflow線性模型的簡單應用時候,你會發(fā)現(xiàn),他的模型如下:
這長得多像神經(jīng)網(wǎng)絡啊....沒錯,你可以理解這就是一個簡化版本的升級網(wǎng)絡.
看了一下之前的思路,我們只需要在模型選擇上動刀子就好.
1.引用庫
2.加載一系列的數(shù)字圖片
3.Tensorflow圖構(gòu)造
3.1 模型選擇
3.1.1 喂入數(shù)據(jù)準備
3.1.2 等待優(yōu)化的參數(shù)
3.1.3 構(gòu)造初步的模型
3.2 等待優(yōu)化的損失函數(shù)
3.3 創(chuàng)建優(yōu)化器
3.4 評價性能
4.Run
4.1 初始化變量
4.2 裝載數(shù)據(jù)源
4.3 開始run訓練模型
4.4 訓練之后,對模型進行評價
3.1 模型選擇
我們打算使用一個三層的神經(jīng)網(wǎng)絡,其中包含一個輸入層(具有784個節(jié)點,這是由于圖片數(shù)據(jù)是28*28=784),一個隱藏層(500個節(jié)點),輸出層(10個節(jié)點)
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
3.1.1 喂入數(shù)據(jù)準備
這個步驟目前來看不需要變更
3.1.2 等待優(yōu)化的參數(shù)
我們需要優(yōu)化的參數(shù)是層和層之間的權重以及bias,
所以在 輸入層和hidden layer之間會有weight_1, bias1需要關注;hidden layer和輸出層之間是weight_2, bias2
但是升級網(wǎng)絡的擬合能力實在太強大了,為了防止參數(shù)過多,學習過了,出現(xiàn)過擬合的情況,我們對參數(shù)進行了正則化.所以我們按照下面的方式來創(chuàng)建權重
REGULARAZTION_RATE = 0.0001
def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape,initializer=tf.truncated_normal_initializer(stddev=0.1)) #生成截斷正態(tài)分布的隨機數(shù),標準差為0.1
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights
3.1.3 構(gòu)造初步的模型
#define the forward network
def inference(input_tensor, regularizer):
with tf.variable_scope('layer1'):#聲明第一層神經(jīng)網(wǎng)絡的變量并完成前向傳播過程
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) #tf.nn.relu是作為激活函數(shù)
with tf.variable_scope('layer2'):#聲明第二層神經(jīng)網(wǎng)絡的變量并完成前向傳播過程
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
logits = inference(x, None) #這個地方暫時傳入None,不考慮正則化
Refer
tensorflow中的關鍵字global_step使用
什么是 L1 L2 正規(guī)化 正則化 Regularization (深度學習 deep learning)
使用TensorFlow實現(xiàn)的神經(jīng)網(wǎng)絡進行MNIST手寫體數(shù)字識別