今天寫測試程序的時(shí)候發(fā)現(xiàn)預(yù)測結(jié)果錯(cuò)到離譜系奉,眼看又要哭暈在廁所的我,又仔細(xì)檢查了一遍訓(xùn)練程序惭嚣,發(fā)現(xiàn)是模型保存錯(cuò)了 -_-||| 遵湖,把saver放在了循環(huán)的外面,這就很尷尬了晚吞。延旧。。改完又可以給自己放個(gè)小長假槽地,讓程序自己慢慢重跑一次吧啦啦啦小魔仙全身變迁沫。。捌蚊。
一集畅、模型的保存
分兩步。
1.在計(jì)算圖之后(所有變量節(jié)點(diǎn)都創(chuàng)建好之后)缅糟,定義一個(gè) saver 對象挺智。
2.開啟 Session ,利用 saver 保存模型溺拱。
首先逃贝,在定義計(jì)算圖之后,開啟會話之前迫摔,定義一個(gè) saver 對象沐扳。
saver = tf.train.Saver()
Saver 類在初始化時(shí),有一些常用的參數(shù):
- var_list 默認(rèn)為 None句占,即保存所有可保存的對象沪摄。
- reshape為 True 時(shí),表示從一個(gè) checkpoint 中恢復(fù)參數(shù)時(shí)允許參數(shù)shape發(fā)生變化。(當(dāng)我們r(jià)eshape了一個(gè)變量又希望加載舊模型時(shí)杨拐,該操作就很有用祈餐。)
- max_to_keep 自動(dòng)保存 max_to_keep 個(gè)模型,默認(rèn)值為 5哄陶。(也就是說帆阳,盡管程序每個(gè) step 保存一次模型,但實(shí)際上只會保存最近的5次屋吨。)
- keep_checkpoint_every_n_hours 用于指定保留 Checkpoints 文件的時(shí)間蜒谤,默認(rèn)為 10000 小時(shí)。
然后至扰,在開啟 Session 會話后鳍徽,利用 saver 保存模型:
# 開啟會話
with tf.Session() as sess:
sess.run(init)
***省略代碼***
#保存模型
# 注意:路徑最后一項(xiàng)是模型名字,加載時(shí)模型路徑應(yīng)該為‘save/model/’
saver.save(sess,'save/model/model',global_step=step)
- 第一個(gè)參數(shù) sess 是定義的會話敢课,記錄了這次訓(xùn)練中所有變量的值阶祭。
- 第二個(gè)參數(shù)是模型保存的路徑和名字。
- 第三個(gè)參數(shù)用于把訓(xùn)練時(shí)的迭代次數(shù)加入文件名直秆。
例如:
# 模型的文件名:my_model-1
saver.save(sess,'save/model',global_step=1)
# 模型的文件名:my_model-1000
saver.save(sess,'save/model',global_step=1000)
保存之前要記住濒募,saver自動(dòng)保存max_to_keep個(gè)模型(默認(rèn)為5個(gè)),多了也沒用圾结,會自動(dòng)忽略噠~
下面是幾種常用的使用情況:
使用1 每次迭代保存一個(gè)模型
for i in range(2000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
saver.save(sess, './model/model', global_step=i+1)
使用2 每100次迭代保存一個(gè)模型
# 一共迭代num_step次
for i in range(num_step):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if i%100 == 0:
saver.save(sess, './model/model', global_step=i+1)
使用3 保存結(jié)果最好的模型
# 一共迭代num_step次
max_acc = 0
for i in range(num_step):
batch_xs, batch_ys = mnist.train.next_batch(100)
val_loss,val_acc=sess.run([loss,acc], feed_dict={x: batch_xs, y_: batch_ys})
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if val_acc>max_acc:
max_acc = val_acc
saver.save(sess, './model/model', global_step=i+1)
使用4 保存結(jié)果最好的3個(gè)模型
saver = tf.train.Saver(max_to_keep=3)
***省略代碼***
# 一共迭代num_step次
max_acc = 0
for i in range(num_step):
batch_xs, batch_ys = mnist.train.next_batch(100)
val_loss,val_acc=sess.run([loss,acc], feed_dict={x: batch_xs, y_: batch_ys})
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
if val_acc>max_acc:
max_acc = val_acc
saver.save(sess, './model/model', global_step=i+1)
模型路徑下會出現(xiàn)4個(gè)文件:
checkpoint 保存目錄下所有模型的文件列表
.index / .data 保存模型所有參數(shù)
.meta 保存計(jì)算圖

二萨咳、模型的加載
模型恢復(fù)用的是restore(sess, save_path)
函數(shù),它需要兩個(gè)參數(shù)疫稿,sess表示當(dāng)前會話培他,之前保存的結(jié)果將被加載入這個(gè)會話,save_path指的是保存的模型路徑遗座。如:
# 加載模型參數(shù)
saver.restore(sess, "model/model-xxxx") # xxxx是指定的加載模型舀凛,注意這里不用加模型的后綴名
注意:這里只加載了模型的所有參數(shù),需要重新定義計(jì)算圖途蒋。如果不想重新定義計(jì)算圖猛遍,也可以直接加載持久化的計(jì)算圖:
# 加載計(jì)算圖
saver =tf.train.import_meta_graph("Model/model.ckpt.meta")
若不指定加載模型,可以直接獲得訓(xùn)練過程中最后保存的模型号坡,以下兩種方法可以實(shí)現(xiàn)獲得最近一次保存的模型:
獲得最近一次保存的模型 方法一
我們可以使用tf.train.latest_checkpoint()
函數(shù)來自動(dòng)獲取最后一次保存的模型懊烤。如:
model = tf.train.latest_checkpoint('model/') # 保存模型所在的路徑
print(model)
# ./model\model.ckpt-47557
saver.restore(sess,model)
獲得最近一次保存的模型 方法二
我們可以使用tf.train.get_checkpoint_state()
函數(shù)來自動(dòng)獲取最后一次保存的模型。如:
ckpt = tf.train.get_checkpoint_state('./model')
print(ckpt)
# model_checkpoint_path: "./model\\model.ckpt-47557"
# all_model_checkpoint_paths: "./model\\model.ckpt-40992"
# all_model_checkpoint_paths: "./model\\model.ckpt-45218"
# all_model_checkpoint_paths: "./model\\model.ckpt-47557"
print(ckpt.model_checkpoint_path)
# './model\\model.ckpt-47557'
saver.restore(sess, ckpt.model_checkpoint_path)
Reference
Tensorflow模型的保存與恢復(fù)
tensorflow模型保存與加載
TensorFlow模型保存和提取方法