最近在用 java 改寫一個用 python 編寫的 model,遇到了有關(guān)模型保存與恢復(fù)的問題个榕,發(fā)現(xiàn)網(wǎng)上的資料有些混亂赖捌,在這里做一些記錄。
.ckpt
1. .ckpt 全稱為 checkpoint猫缭,代表著一個檢查點,即為 model 訓(xùn)練過程中的一個快照壹店,可能是在訓(xùn)練開始猜丹,也可能是在訓(xùn)練完成。
2. .ckpt 是由 Saver 調(diào)用 save 產(chǎn)生的:
saver.save(sess,"/tmp/model.ckpt")
3. 由 Saver 調(diào)用 restore 來復(fù)原 model 的數(shù)據(jù):
saver.restore(sess,path)
注意這里硅卢,復(fù)原的只有數(shù)據(jù)射窒,不含 graph 信息。
4. .ckpt 不是單獨的一個文件将塑,而是一系列文件脉顿。
其內(nèi)部包含了:
①checkpoint: .ckpt 的標(biāo)記信息。
②.data: model 中 graph 的數(shù)據(jù)点寥,包括各種變量艾疟,不含常量。
③.index: 索引信息开财。
④.meta: graph 信息汉柒。
在這里要搞明白一點,一個 model 是由 graph(④) + 數(shù)據(jù)(②) 組成的责鳍。
graph 代表著執(zhí)行邏輯碾褂,在 tensorflow 中,每個算子用一個 node 來表示历葛,眾多 node 組合起來便是一張圖(graph)正塌,也就是我們的執(zhí)行邏輯嘀略,而這些執(zhí)行邏輯在 Saver 調(diào)用 save 時,會被存到 .meta 中(不含數(shù)據(jù))乓诽。各個 node 中含有各種參數(shù)(變量帜羊,比如訓(xùn)練的權(quán)重),這些參數(shù)則被存儲到 .data 中鸠天。graph 與數(shù)據(jù)是分別存儲的讼育。
tf.train.import_meta_graph
該方法只能恢復(fù) graph,不恢復(fù)數(shù)據(jù)稠集。
注意與上面提及的 saver.restore 區(qū)分奶段,saver.restore 只恢復(fù)數(shù)據(jù),不恢復(fù) graph剥纷。
recover model
現(xiàn)在我們來討論下痹籍,如何能恢復(fù)一個model。前面已經(jīng)提過了晦鞋,一個 model 由 graph 和 數(shù)據(jù)組成蹲缠,所以只要能恢復(fù)這兩部分就可以了,依據(jù)恢復(fù)的方法不同悠垛,可以分為兩類线定。
①分別恢復(fù) graph 和數(shù)據(jù):
對于數(shù)據(jù)來說,可以用 saver.restore 來恢復(fù)确买。
對于graph來說渔肩,依據(jù)恢復(fù)方法不同可以分為兩種:
A.硬編碼恢復(fù):在調(diào)用方法中,重新書寫 graph 信息拇惋。
B. .meta 恢復(fù):通過調(diào)用 tf.train.import_meta_graph 方法獲得 graph,并配合 get_tensor_by_name 的方法來調(diào)用 model 中特定的算子(node)抹剩。
saver = tf.train.import_meta_graph('~/tmp/model.ckpt-1000.meta')
graph = tf.get_default_graph()
input = graph.get_tensor_by_name('input:0')
② freezing(固化):
該方法將變量(訓(xùn)練的權(quán)重)固化在 graph 中撑帖,即用常量來替換 graph 中的變量,從而達到無需恢復(fù)數(shù)據(jù)澳眷,直接調(diào)用 graph 即可胡嘿。權(quán)重一旦被固化就不能再修改,該方法一般用于生產(chǎn)環(huán)境钳踊。
注:筆者在測試 Java API 時衷敌,其只支持調(diào)用 freezing 后的圖。
References: