class tf.train.Saver
saver = tf.train.Saver()
save_path = saver.save(sess, model_path)
saver = tf.train.Saver()
load_path = saver.restore(sess, model_path)
- 1锣咒、save方法在實(shí)現(xiàn)數(shù)據(jù)讀取時侵状,它僅僅讀數(shù)據(jù),關(guān)鍵是得有一些提前聲明好的variables來接受這些數(shù)據(jù)毅整,因此趣兄,當(dāng)save讀取數(shù)據(jù)到sess時,需要提前聲明與數(shù)據(jù)匹配的variables悼嫉,否則程序就報錯了艇潭。
- 2、save讀取的數(shù)據(jù)不需要initialize戏蔑。
- 3蹋凝、目前想到的就這么多,隨時補(bǔ)充辛臊。
import tensorflow as tf
import sys
# load MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
# 一些 hyper parameters
activation = tf.nn.relu
batch_size = 100
iteration = 20000
hidden1_units = 30
# 注意候味!這里是存儲路徑!
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def loss(logits, labels):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, labels)
loss = tf.reduce_mean(cross_entropy)
return loss
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
loss = loss(logits, y_)
train_op = tf.train.AdamOptimizer(1e-4).minimize(loss)
accuracy = evaluation(logits, y_)
# 先實(shí)例化一個Saver()類
saver = tf.train.Saver()
init = tf.initialize_all_variables()
with tf.Session() as sess:
for i in xrange(iteration):
batch = mnist.train.next_batch(batch_size)
if i%1000 == 0 and i:
train_accuracy = sess.run(accuracy, feed_dict={X: batch[0], y_: batch[1]})
print "step %d, train accuracy %g" %(i, train_accuracy)
sess.run(train_op, feed_dict={X: batch[0], y_: batch[1]})
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})
# 存儲訓(xùn)練好的variables
save_path = saver.save(sess, model_path)
print "[+] Model saved in file: %s" % save_path
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data', one_hot=True)
activation = tf.nn.relu
hidden1_units = 30
model_path = sys.path[0] + '/simple_mnist.ckpt'
X = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
W_fc1 = tf.Variable(tf.truncated_normal([784, hidden1_units], stddev=0.2))
b_fc1 = tf.Variable(tf.zeros([hidden1_units]))
W_fc2 = tf.Variable(tf.truncated_normal([hidden1_units, 10], stddev=0.2))
b_fc2 = tf.Variable(tf.zeros([10]))
def inference(img):
fc1 = activation(tf.nn.bias_add(tf.matmul(img, W_fc1), b_fc1))
logits = tf.nn.bias_add(tf.matmul(fc1, W_fc2), b_fc2)
return logits
def evaluation(logits, labels):
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
return accuracy
logits = inference(X)
accuracy = evaluation(logits, y_)
saver = tf.train.Saver()
with tf.Session() as sess:
# 讀取之前訓(xùn)練好的數(shù)據(jù)
load_path = saver.restore(sess, model_path)
print "[+] Model restored from %s" % load_path
print '[+] Test accuracy is %f' % sess.run(accuracy, feed_dict={X: mnist.test.images, y_: mnist.test.labels})