方法一 這種存儲(chǔ)方式在加載模型時(shí)需要再次定義網(wǎng)絡(luò)結(jié)構(gòu)
模型訓(xùn)練和存儲(chǔ)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)
learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
##初始化存儲(chǔ)器和存儲(chǔ)路徑
saver = tf.train.Saver(max_to_keep=4)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
os.makedirs(path)
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % display_step == 0:
print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
saver.save(sess,model_path,write_meta_graph=True)
print ("Optimization Finished")
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
加載模型
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
with tf.Session() as sess:
saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("W:0")
b = graph.get_tensor_by_name("b:0")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
方法二 這種存儲(chǔ)方式在加載模型時(shí)不用定義網(wǎng)絡(luò)結(jié)構(gòu)
模型訓(xùn)練和存儲(chǔ)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)
learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1
X = tf.placeholder(tf.float32,[None,784],name="X")
Y = tf.placeholder(tf.float32,[None,10],name="Y")
W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=4)
##把要加載的對(duì)象提前加入集合
tf.add_to_collection("pred",pred)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
os.makedirs(path)
with tf.Session() as sess:
sess.run(init)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples/batch_size)
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
avg_cost += c / total_batch
if (epoch + 1) % display_step == 0:
print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
saver.save(sess,model_path,write_meta_graph=True)
print ("Optimization Finished")
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))
模型加載
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
with tf.Session() as sess:
saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
pred = tf.get_collection("pred")[0]
graph = tf.get_default_graph()
X = graph.get_operation_by_name("X").outputs[0]
Y = graph.get_operation_by_name("Y").outputs[0]
correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
print ("Accuracy:",accuracy.eval({X:mnist.test.images[:300],Y:mnist.test.labels[:300]}))
1.Tensorflow模型文件的組成
主要包含兩個(gè)文件
- 元圖 meta graph
保存完整的圖結(jié)構(gòu) 包含所有的變量 操作等 擴(kuò)展名為meta
2.檢查點(diǎn)文件 checkpoint
二進(jìn)制文件 包含所有的權(quán)重 偏差 梯度和其他所有保存的值 擴(kuò)展名是.ckpt , 0.11版本之后不再僅使用一個(gè).ckpt文件來表示了 而是兩個(gè)文件 .data-00000-of-00001 和.index
其中.data 是包含訓(xùn)練變量的文件
此外還有一個(gè)名為checkpoint的文件 用于保存最新檢查點(diǎn)的記錄
2.如何保存Tensorflow模型
在模型訓(xùn)練完成后 可調(diào)用tf.train.Saver()實(shí)例來保存所有的參數(shù)和計(jì)算圖
由于tensorflow中的變量只能存在于session中,因此需要在session中調(diào)用save 將模型存儲(chǔ)
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))
w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variable_initialize())
saver.save(sess,’/path/to/save')
運(yùn)行后可得以下文件:
model/
├── checkpoint
├── my_test_model.data-00000-of-00001
├── my_test_model.index
└── my_test_model.meta
如果想在1000次迭代之后再保存模型押赊,可通過傳遞步數(shù)來調(diào)用save
saver.save(sess,’model_path’,global_step=1000)
如果想每1000次保存一下模型,由于.meta文件會(huì)在第一次保存時(shí)創(chuàng)建 而且圖結(jié)構(gòu)不會(huì)再變化,因此只需要保存模型進(jìn)一步迭代的數(shù)據(jù) 而不用存儲(chǔ)網(wǎng)絡(luò)結(jié)構(gòu) 可調(diào)用
saver.save(sess,’model_path’,global_step=step,write_meta_graph=False)
如果只想保存最新的4個(gè)模型參數(shù)梭姓,并且希望在訓(xùn)練階段每?jī)尚r(shí)保存一個(gè)模型误辑,可調(diào)用
saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)
如果在tf.train.Saver() 中沒有指定任何東西崭孤,那么他會(huì)保存模型的所有變量胀溺,如果只想保存部分變量則需要通過列表或字典的形式將變量傳遞進(jìn)去
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))
w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)
saver = tf.train.Saver([w1,w2])
with tf.Session() as sess:
sess.run(tf.global_variable_initializer())
sess.run(tf.global_variable_initialize())
saver.save(sess,’/path/to/save')
3.如何導(dǎo)入一個(gè)訓(xùn)練好的模型并進(jìn)行修改和微調(diào)
需要完成兩件事情
1.構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)
可通過手動(dòng)編寫代碼創(chuàng)建每一層網(wǎng)絡(luò)結(jié)構(gòu)來重構(gòu)整個(gè)網(wǎng)絡(luò)
保存模型時(shí)也會(huì)將網(wǎng)絡(luò)結(jié)構(gòu)存儲(chǔ)到.meta文件中厘唾,可直接調(diào)用tf.train.import()函數(shù)來導(dǎo)入這個(gè)模型
saver = tf.train.import_meta_graph(‘model-1000.meta’) 這個(gè)操作是將.meta文件中的計(jì)算圖數(shù)據(jù)直接附加到當(dāng)前定義的圖中介杆,但是我們?nèi)匀恍枰ゼ虞d計(jì)算圖上所有已經(jīng)訓(xùn)練好的權(quán)重參數(shù)
2.加載參數(shù)
new_saver.restore(sess,tf.train.latest_checkpoint('./‘)) checkpoint文件所在路徑
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)
new_saver.restore(sess,tf.train.latest_checkpoint('./'))
#讀取參數(shù)
print(sess.run(‘w1:0'))
4.恢復(fù)任何預(yù)先訓(xùn)練好的模型用于預(yù)測(cè) (工作中的開發(fā)方式)
import tensorflow as tf
w1 = tf.placeholder(‘float’,name='w1’)
w2 = tf.placeholder(‘float’,name=‘w2ww’)
b1 = tf.Variable(2.0,name=‘bias’)
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name=‘op_to_restore’)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(24,feed_dict={w1:4,w2:8}))
saver.save(sess,’test_model’,global_step=1000)
當(dāng)需要載入這個(gè)模型時(shí)鹃操,不僅需要恢復(fù)所有的計(jì)算圖和權(quán)重參數(shù) 還需要準(zhǔn)備一個(gè)新的feed_dict
用于將新的訓(xùn)練數(shù)據(jù)傳送到網(wǎng)絡(luò)中進(jìn)行訓(xùn)練,可通過graph.get_tensor_by_name() 來獲得對(duì)這些保存的操作和占位符變量的引用
w1 = graph.get_tensor_by_name(‘w1:0’)
op_to_restore = graph.get_tensor_by_name(“op_to_restore:0”)
使用不同的數(shù)據(jù)來運(yùn)行相同的網(wǎng)絡(luò) 則需要通過feed_dict來傳遞數(shù)據(jù)
with tf.Session() as sess:
saver = tf.train.import_meta_graph(’test_model-1000.meta’)
saver.restore(sess,tf.train.latest_checkpoint(‘./‘))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name(“w1:0”)
w2 = graph.get_tensor_by_name(“w2:0”)
feed_dict= {w1:13.0,w2:17.0}
op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’)
print (sess.run(op_to_restore,feed_dict))
如果想在原來的計(jì)算圖基礎(chǔ)上添加更多的操作和圖層春哨,并進(jìn)行訓(xùn)練
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)
saver.restore(sess,tf.train.latest_checkpoint(‘./‘))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name(‘w1:0’)
w2 = graph.get_tensor_by_name(‘w2:0’)
feed_dict = {w1:13,w2:17}
op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’)
#新添加操作
add_on_op = tf.multiply(op_to_restore,2)
print (sess.run(add_on_op,feed_dict))