#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet5_forward
import os
import numpy as np
BATCH_SIZE = 100#一次喂入神經(jīng)網(wǎng)絡(luò)圖片數(shù)量
LEARNING_RATE_BASE =? 0.005 #學(xué)習(xí)率0.005
LEARNING_RATE_DECAY = 0.99 #學(xué)習(xí)衰減率
REGULARIZER = 0.0001
STEPS = 50000 #迭代次數(shù)
MOVING_AVERAGE_DECAY = 0.99 #滑動(dòng)平均衰減率
MODEL_SAVE_PATH="./model/" #模塊路徑
MODEL_NAME="mnist_model" #模塊名稱(chēng)
def backward(mnist):
? ? x = tf.placeholder(tf.float32,[#浮點(diǎn)型
BATCH_SIZE,#喂入圖片數(shù)量
mnist_lenet5_forward.IMAGE_SIZE,#行分辨率
mnist_lenet5_forward.IMAGE_SIZE,#列分辨率
mnist_lenet5_forward.NUM_CHANNELS]) #通道數(shù)
? ? y_ = tf.placeholder(tf.float32, [None, mnist_lenet5_forward.OUTPUT_NODE])
? ? y = mnist_lenet5_forward.forward(x,True, REGULARIZER) #調(diào)用向前傳播過(guò)程
? ? global_step = tf.Variable(0, trainable=False) #全局計(jì)數(shù)器初始化為零
? ? #交叉熵
? ? ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
? ? cem = tf.reduce_mean(ce)
? ? loss = cem + tf.add_n(tf.get_collection('losses'))#對(duì)得到的向量求均值
? ? #指數(shù)衰減學(xué)習(xí)率
? ? learning_rate = tf.train.exponential_decay(
? ? ? ? LEARNING_RATE_BASE,#學(xué)習(xí)率0.005
? ? ? ? global_step,
? ? ? ? mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,#0.99 #學(xué)習(xí)衰減率
? ? ? ? staircase=True) #階梯衰減
? ? #梯度下降算法
? ? train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
? ? #滑動(dòng)平均模型
? ? ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
? ? ema_op = ema.apply(tf.trainable_variables())
? ? with tf.control_dependencies([train_step, ema_op]):
? ? ? ? train_op = tf.no_op(name='train')#將train_step和ema_op綁定到train_op
? ? saver = tf.train.Saver() #實(shí)例化一個(gè)保存和恢復(fù)變量saver
? ? with tf.Session() as sess:
? ? ? ? init_op = tf.global_variables_initializer()
? ? ? ? sess.run(init_op)
? ? ? ? ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
? ? ? ? if ckpt and ckpt.model_checkpoint_path:
? ? ? ? saver.restore(sess, ckpt.model_checkpoint_path)
? ? ? ? for i in range(STEPS):
? ? ? ? ? ? xs, ys = mnist.train.next_batch(BATCH_SIZE) #讀取100數(shù)據(jù)
? ? ? ? ? ? reshaped_xs = np.reshape(xs,(? #轉(zhuǎn)換成相同矩陣
? ? BATCH_SIZE,#100
? ? ? ? mnist_lenet5_forward.IMAGE_SIZE,#行分辨率
? ? ? ? mnist_lenet5_forward.IMAGE_SIZE,#列分辨率
? ? ? ? mnist_lenet5_forward.NUM_CHANNELS))#通道
? ? ? ? ? ? _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})
? ? ? ? ? ? if i % 100 == 0:
? ? ? ? ? ? ? ? print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
? ? ? ? ? ? ? ? saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main():
? ? mnist = input_data.read_data_sets("./data/", one_hot=True)
? ? backward(mnist)
if __name__ == '__main__':
? ? main()