訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)的目的是啥?不就是有朝一日讓它有用武之地嗎鸟辅?可是,在別處使用訓(xùn)練好的網(wǎng)絡(luò)莺葫,得先把網(wǎng)絡(luò)的參數(shù)(就是那些variables)保存下來匪凉,怎么保存呢?其實(shí)捺檬,tensorflow已經(jīng)給我們提供了很方便的API再层,來幫助我們實(shí)現(xiàn)訓(xùn)練參數(shù)的存儲與讀取,如果想了解詳情,請看晦澀難懂的官方API聂受,接下來我簡單介紹一下我的理解蒿秦。
保存與讀取數(shù)據(jù)全靠下面這個類實(shí)現(xiàn):
class tf.train.Saver
當(dāng)我們需要存儲數(shù)據(jù)時,下面2條指令就夠了
saver = tf.train.Saver()
save_path = saver.save(sess, model_path)
解釋一下蛋济,首先創(chuàng)建一個saver類渤早,然后調(diào)用saver的save方法(函數(shù)),save需要傳遞兩個參數(shù)瘫俊,一個是你的訓(xùn)練session鹊杖,另一個是文件存儲路徑,例如“/tmp/superNet.ckpt”扛芽,這個存儲路徑是可以包含文件名的骂蓖。save方法會返回一個存儲路徑。當(dāng)然川尖,save方法還有別的參數(shù)可以傳遞登下,這里不再介紹。
然后怎么讀取數(shù)據(jù)呢叮喳?看下面
saver = tf.train.Saver()
load_path = saver.restore(sess, model_path)
和存儲數(shù)據(jù)神似氨环肌!不再贅述馍悟。
下面是重點(diǎn)畔濒!關(guān)于tf.train.Saver()使用的幾點(diǎn)小心得!
- 1锣咒、save方法在實(shí)現(xiàn)數(shù)據(jù)讀取時侵状,它僅僅讀數(shù)據(jù),關(guān)鍵是得有一些提前聲明好的variables來接受這些數(shù)據(jù)毅整,因此趣兄,當(dāng)save讀取數(shù)據(jù)到sess時,需要提前聲明與數(shù)據(jù)匹配的variables悼嫉,否則程序就報錯了艇潭。
- 2、save讀取的數(shù)據(jù)不需要initialize戏蔑。
- 3蹋凝、目前想到的就這么多,隨時補(bǔ)充辛臊。
為了對數(shù)據(jù)存儲和讀取有更直觀的認(rèn)識仙粱,我自己寫了兩個實(shí)驗(yàn)小程序,下面是第一個彻舰,訓(xùn)練網(wǎng)絡(luò)并存儲數(shù)據(jù)伐割,用的MNIST數(shù)據(jù)集
import tensorflow as tf
import sys
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
# 一些 hyper parameters
activation = tf.nn.relu
batch_size = 100
iteration = 20000
hidden1_units = 30
# 注意候味!這里是存儲路徑!
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def loss(logits, labels):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
loss = tf.reduce_mean(cross_entropy)
return loss
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
loss = loss(logits, y_)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
accuracy = evaluation(logits, y_)
# 先實(shí)例化一個Saver()類
saver = tf.train.Saver()
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
for i in xrange(iteration):
batch = mnist.train.next_batch(batch_size)
if i%1000 == 0 and i:
train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})
print "step %d, train accuracy %g" %(i, train_accuracy)
sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})
# 存儲訓(xùn)練好的variables
save_path = saver.save(sess, model_path)
print "[+] Model saved in file: %s" % save_path
接下來是讀取數(shù)據(jù)并做測試隔心!
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
activation = tf.nn.relu
hidden1_units = 30
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
accuracy = evaluation(logits, y_)
saver = tf.train.Saver()
with tf.Session() as sess:
# 讀取之前訓(xùn)練好的數(shù)據(jù)
load_path = saver.restore(sess, model_path)
print "[+] Model restored from %s" % load_path
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})