saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
saver = tf.train.Saver()
默認(rèn)是保存默認(rèn)圖上的Variable數(shù)據(jù)。當(dāng)然也可以指定保存那些Variable數(shù)據(jù)也切,tf.train.Saver([var_list])
敢伸。
模型的加載
loader = tf.train.Saver()
loader.restore(sess,model_dir)
-
Saver
的第一個(gè)參數(shù)是var_list
用來指定需要存儲(chǔ)或者保存哪些變量胸私。如果不指定的話那么默認(rèn)保存和加載全部的可保存的對象酿傍。
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
表示的意思是需要加載的變量是embedding
def setup_loader(self):
self.loader = tf.train.Saver(self.var_list)
def load_session(self, itr):
self.loader.restore(self.sess, self.model_name + "_weights/" + self.dataset + "/" + itr + ".ckpt")
-----------------------TransE model中的self.var_list---------------------
self.rel_emb = tf.get_variable(name="rel_emb", initializer=tf.random_uniform(shape=[self.num_rel, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.ent_emb = tf.get_variable(name="ent_emb", initializer=tf.random_uniform(shape=[self.num_ent, self.params.emb_size], minval=-sqrt_size, maxval=sqrt_size))
self.var_list = [self.rel_emb, self.ent_emb]
模型的保存
saver = tf.train.Saver(max_to_keep=0)
saver.save(self.sess, filename)
-
os.mkdir()
只對路徑的最后一級目錄進(jìn)行創(chuàng)建弛秋,如果前幾級目錄不存在线召,會(huì)報(bào)錯(cuò)铺韧!而os.makedirs()
可以創(chuàng)建多級目錄,如果路徑的目錄都不存在缓淹,都可以創(chuàng)建出來哈打。 - 按照模型和數(shù)據(jù)集合進(jìn)行分文件夾的保存。
-
max_to_keep
參數(shù):這個(gè)是用來設(shè)置保存模型的個(gè)數(shù)讯壶,默認(rèn)為5料仗,即 max_to_keep=5,保存最近的5個(gè)模型伏蚊。如果想要保存模型的數(shù)量不受限制立轧,則可以將 max_to_keep設(shè)置為None或者0,如果你只想保存最后一代的模型躏吊,則只需要將max_to_keep設(shè)置為1即可氛改。 -
saver.save(sess,filename,global_step=step)
還有最后一個(gè)參數(shù)global_step
,表示保存模型名字的后綴是step比伏。
def setup_saver(self):
self.saver = tf.train.Saver(max_to_keep=0)
def save_model(self, itr):
filename = self.model_name + "_weights/" + self.dataset + "/" + str(itr) + ".ckpt"
if not os.path.exists(os.path.dirname(filename)):
os.makedirs(os.path.dirname(filename))
self.saver.save(self.sess, filename)
例子:保存模型
# construct graph
v1 = tf.Variable([0], name='v1')
v2 = tf.Variable([0], name='v2')
# run graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, 'ckp')
with tf.Session() as sess:
saver = tf.import_meta_graph('ckp.meta')
saver.restore(sess, 'ckp')
當(dāng)執(zhí)行Saver.saver操作的時(shí)候胜卤,在文件系統(tǒng)中生成如下文件:
- index:文件保存了一個(gè)不可變的表數(shù)據(jù),記錄Tensor元數(shù)據(jù)的信息赁项,比如tensor存儲(chǔ)在那個(gè)數(shù)據(jù)data文件中葛躏,以及在數(shù)據(jù)文件中的偏移,校驗(yàn)和其他信息肤舞。
- 數(shù)據(jù)(data) :文件記錄了所有變量(Variable) 的值紫新,當(dāng)restore 某個(gè)變量時(shí),首先從索引文件中找到相應(yīng)變量在哪個(gè)數(shù)據(jù)文件李剖,然后根據(jù)索引直接獲取變量的值,從而實(shí)現(xiàn)變量數(shù)據(jù)的恢復(fù)囤耳。
- 元文件(meta) :保存了MetaGraphDef 的持久化數(shù)據(jù)篙顺,它包括GraphDef, SaverDef 等元數(shù)據(jù)偶芍。就是描述了圖結(jié)構(gòu)的信息。這也是在上例中德玫,在調(diào)用Saver.restore 之前匪蟀,得先調(diào)用tf.import_meta_graph 的真正原因;否則宰僧,缺失計(jì)算圖的實(shí)例材彪,就無法談及恢復(fù)數(shù)據(jù)到圖實(shí)例中了。
-
狀態(tài)文件checkpoint:文件會(huì)記錄最近一次的斷點(diǎn)文件的前綴琴儿,根據(jù)前綴找到對應(yīng)的索引和數(shù)據(jù)文件段化,當(dāng)調(diào)用
tf.train.latest_checkpoint
,可以快速找到最近一次的斷點(diǎn)文件造成,此外显熏,Checkpoint 文件也記錄了所有的斷點(diǎn)文件列表,并且文件列表按照由舊至新的時(shí)間依次排序晒屎。當(dāng)訓(xùn)練任務(wù)時(shí)間周期非常長喘蟆,斷點(diǎn)檢查將持續(xù)進(jìn)行,必將導(dǎo)致磁盤空間被耗盡鼓鲁。為了避免這個(gè)問題蕴轨,存在兩種基本的方法:設(shè)置max_to_keep
: 配置最近有效文件的最大數(shù)目,當(dāng)新的斷點(diǎn)文件生成時(shí)骇吭,且文件數(shù)目超過max_to_keep
橙弱,則刪除最舊的斷點(diǎn)文件;其中绵跷,max_to_keep
默認(rèn)值為5膘螟,keep_checkpoint_every_n_hours
: 在訓(xùn)練過程中每n 小時(shí)做一次斷點(diǎn)檢查,保證只有一個(gè)斷點(diǎn)文件碾局;其中荆残,該選項(xiàng)默認(rèn)是關(guān)閉的。
├── checkpoint
├── ckp.data-00000-of-00001
├── ckp.index
├── ckp.meta