官方guide: https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/saved_model.md
官方API:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/train/Saver
所有pb格式數(shù)據(jù)文件的格式定義文件proto:https://github.com/tensorflow/tensorflow/tree/r1.15/tensorflow/core/protobuf
由于tf2開始冯痢,tf.train.saver被砍了氮昧,于是這里重點介紹tf1.xx版本的模型持久化方式,手動保存與恢復浦楣,tf官方定義為low-level API袖肥。(high-level 使用 estimator)
一、Saver類
tf.reset_default_graph()
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1),name='v1')
v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=2),name='v2')
result = v1 + v2
result2 = v1 * v2
c1 = tf.zeros([2,2], name="c1")
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=2) # Add ops to save and restore all the variables.
with tf.Session() as sess:
sess.run(init_op)
for x in range(10000):
if x % 1000 == 0:
print('saved:',x,'of 1w')
save_path = saver.save(sess, "Saved_model/test.ckpt", global_step=x)
print(save_path)
print(v1.eval(), v2.eval())
對Saver設置max_to_keep參數(shù)振劳,能自動保存下最新的n個模型椎组。得到結果:
Saver中傳入的文件名都只需要前綴即可,即若要恢復只要傳入test.ckpt-9000历恐,其中-后面的數(shù)字代表了step數(shù)寸癌。00000-of-00001代表device信息专筷,有一個GPU,且在第0個上蒸苇。checkpoint文件保存并維護著模型列表磷蛹,以及最新模型的文件名√钋可使用如下函數(shù)獲取保存模型的最新文件位置:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir='Saved_model')
print(ckpt.model_checkpoint_path) # Saved_model\test.ckpt-9000
Saver中可傳入list或dict弦聂,用來指定保存或恢復時候的變量。默認是所有變量氛什。注意莺葫,一旦傳入,則只會保存或恢復list或dict中的變量枪眉,不管其余變量捺檬。
-
傳入list。
若保存時候的Saver([v1,v2])贸铜,則恢復的時候堡纬,也要這么指定(除非當前graph里只有v1、v2)蒿秦,否則會報錯:變量v3未找到烤镐。即恢復的變量應當是保存時候變量的一個子集。要恢復的變量必須存在棍鳖。 -
傳入dict炮叶。
有時候,保存模型的時候v1變量名name='v1'渡处,但是恢復模型的時候镜悉,graph里v1的變量名設定的是name='v11',又因tf是通過對應的變量名去加載的医瘫,因此會發(fā)生沖突侣肄。此時只要在dict中指定{'v1':v1}即可。dict中的key-value對:<String name: Variable 變量的引用>醇份。案例:若我們使用上面的代碼保存了模型稼锅,變量名為v1和v2,可通過方法查看:
在新文件中僚纷,我們重新構建好graph結構矩距,但是我們此時變量v1的name改了,于是需要設定dict映射畔濒,來將ckpt中的v1加載到新模型的變量v1(name='v11')中:
注意此時,我們新建的graph中锣咒,v3并沒有被初始化侵状,也未被Saver指定恢復數(shù)值赞弥。
那么問題來了,若是模型有100個variable趣兄,新構建的graph中绽左,部分variable的name和之前不一樣,難道我們還需要手動寫dict嗎艇潭?解決方法:通過tf的collection機制拼窥,獲取所有variable,組成dict后手動修改部分key值:
官方Notes:
You can create as many Saver objects as you want if you need to save and restore different subsets of the model variables. The same variable can be listed in multiple saver objects; its value is only changed when the Saver.restore() method is run.
If you only restore a subset of the model variables at the start of a session, you have to run an initialize op for the other variables. See tf.variables_initializer for more information.
二蹋凝、存儲機制詳解
可以發(fā)現(xiàn)鲁纠,存儲后的文件有3種后綴,data鳍寂、index與meta改含。同時,tf還提供了大量接口讓人混淆:
- tf.train.Saver()/saver.restore()
- saver.export_meta_graph()/ tf.train.import_meta_graph
- tf.train.write_graph()/tf.Import_graph_def()
https://yq.aliyun.com/articles/620067
http://www.reibang.com/p/ca637520002f
https://zhuanlan.zhihu.com/p/31308381
2.1 GraphDef 之 tf.Import_graph_def()
學術界適合使用上面所闡述的Saver.save()方法持久化模型迄汛,能方便之后繼續(xù)訓練或測試捍壤。但是工業(yè)界需要通用的模型文件,使得Java/C++也能直接部署鞍爱,調用模型獲得輸出鹃觉。所以工業(yè)界部署模型推薦tf.Import_graph_def()方法。
graph序列化的protobuf叫做graphDef睹逃,就是define graph的意思盗扇,一個graph的定義,包含了計算圖上的節(jié)點信息唯卖。這個graphDef可以用tf.train.write_graph()/tf.Import_graph_def()來寫入和導出粱玲。然而graphDef里面其實是沒有存儲變量具體數(shù)值的,因此無法拿來訓練拜轨,但是可以存常量抽减,就是constant。因此也可將所有session中持有的變量轉constant后(graph.util.convert_variables_to_constants
)橄碾,存儲為pb卵沉,拿來部署做inference。這樣graph結構信息與變量權重就能歸并到一個pb文件中,沒了變量初始化听想、模型保存等輔助節(jié)點后讥蟆,模型文件更小更簡潔,是無視語種的數(shù)據(jù)描述文件停撞,適合工業(yè)部署做predict。
2.2 MetaGraph 之 tf.train.import_meta_graph()
tf.train.import_meta_graph() 方法可以直接從.meta
文件中恢復Graph結構,其包含以下幾種主要成分:
MetaGraph
- MetaInfoDef 這個是存metadata的戈毒,像版本信息啊艰猬,用戶信息,運算方法信息(比如定義了加法埋市、乘法等冠桃,供GraphDef使用)
- GraphDef 上面說的就是這個GraphDef,包含了節(jié)點信息道宅。(節(jié)點使用了哪種運算操作食听、輸入輸出都是什么)
- SaverDef 記錄了所有持久化相關的參數(shù),包括存儲與恢復使用的op的名字污茵、保存頻率等
- CollectionDef 集合名稱到集合內容的映射
- signature_def 記號標記用于saved_model保存pb的時候使用樱报,定義統(tǒng)一的輸入輸出名
- AssetFileDef 記錄外置文件位置
restore只是去restore variable省咨,常量是在MetaGraph的GraphDef里的肃弟。故實驗發(fā)現(xiàn)沒有restore,常量依舊已經(jīng)獲取到了零蓉。
總結來看笤受,saver.save()和saver.restore()保存和讀取的東西不一致,save會保存所有一坨信息敌蜂,而restore只是將data里的variable值恢復到當前graph中的對應節(jié)點里箩兽,graph你得自己新建或使用tf.train.import_meta_graph()。
三章喉、 Saver源碼解析
3.1 Saver([var_list]).init()
當傳入var_list初始化Saver的時候汗贫,若未指定saver_def,則會自動使用build() ---> _build() ---> BaseSaverBuilder() 來創(chuàng)建新的saver_def.
Saver._build():
if not self.saver_def or context.in_eager_mode():
if self._builder is None:
self._builder = BaseSaverBuilder(self._write_version) # 創(chuàng)建BaseSaverBuilder
if self._var_list is None:
self._var_list = variables._all_saveable_objects() # 若未傳入var_list則默認設置為所有variable
self.saver_def = self._builder._build_internal( # 使用BaseSaverBuilder來創(chuàng)建saver_def
self._var_list,
reshape=self._reshape,
sharded=self._sharded,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
name=self._name,
restore_sequentially=self._restore_sequentially,
filename=checkpoint_path,
build_save=build_save, build_restore=build_restore)
再來看BaseSaverBuilder中返回saver_def的關鍵函數(shù):
def _build_internal(self,
names_to_saveables, # 就是Saver初始化時候的var_list
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
filename="model",
build_save=True,
build_restore=True):
# 首先將var_list轉換成names_to_saveables秸脱,格式為<k,v>dict鍵值對:<op_name:op>
# 隨后將op一個個取出落包,將variable包裝為VariableSaveable后存入list:saveables并返回
saveables = self._ValidateAndSliceInputs(names_to_saveables)
# 創(chuàng)建op的name前綴:save
with ops.name_scope(name, "save",
[saveable.op for saveable in saveables]) as name:
# Add the Constant string tensor for the filename.
filename_tensor = constant_op.constant(filename or "model")
# Add the save ops. 創(chuàng)建保持和恢復的ops【重要】
if sharded:
... ...
else:
if build_save:
# 為每個saveables中的op添加保存op,并對op_list進行組合并返回組合依賴后的輸出tensor
#(通過control_flow_ops.with_dependencies)代表運行此tensor前必須運行全部的保存op
save_tensor = self._AddSaveOps(filename_tensor, saveables)
if build_restore:
restore_op = self._AddRestoreOps(filename_tensor, saveables,
restore_sequentially, reshape)
if context.in_graph_mode():
# 正式構建并返回saver_def
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
restore_op_name=restore_op.name,
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
總結:Saver初始化的時候摊唇,就已經(jīng)根據(jù)傳入的var咐蝇,對每個var添加了對應的保存和恢復的op操作。同時構建了saverDef巷查,是metaGraph重要的一部分有序。通過該saverDef可以將很多記錄和參數(shù)持久化為pb,比如文件名的constant op的name岛请,保存流程后的輸出tensor的name等等旭寿。通過這些信息,就能從metaGraph中還原出Saver實例崇败。(Saver只要知道節(jié)點name即可盅称,就能通過name從graphDef中獲得對應的op,而這些op就是保存/恢復op_list后輸出的op,運行這個op即可運行之前的(依賴著的)所有的保存恢復ops缩膝。)
通過tf.train.export_meta_graph我們可以獲得序列化后的metaGraph:
以上代碼使用了三次export_meta_graph搭幻,分別不同:
- 第一次未使用Saver,tf.train.export_meta_graph直接輸出逞盆。
- 第二次初始化了Saver,tf.train.export_meta_graph直接輸出松申。
- 第三次初始化Saver之后使用saver.export_meta_graph輸出云芦。(默認帶上了saver_def)
從結果json我們發(fā)現(xiàn)符合我們的代碼分析:
- 第一個json不包含save/xxx節(jié)點。
- 第二個json包含了save/xxx節(jié)點贸桶,證明了Saver在初始化了時候就已經(jīng)給圖中的variable加上了保存和恢復的ops舅逸。但是默認改方法不帶saverDef,所以沒有這個結構皇筛。tf.train.import_meta_graph函數(shù)使用后無法重建Saver琉历,所以返回None。
- 第三個json包含了save/xxx節(jié)點與SaverDef水醋,因為saver的export函數(shù)默認傳入了saver初始化時構建好的saverDef旗笔,這樣才能在tf.train.import_meta_graph函數(shù)使用后返回重建的Saver實例,否則返回None拄踪。
3.2 Saver.save() 函數(shù):
def save(self,
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix="meta", # meta_graph默認后綴名.meta
write_meta_graph=True, # 若改False蝇恶,則不會生成.meta
write_state=True): # 若改False,則默認保存所有模型文件且無checkpoint文件
需要傳入sess惶桐,因為當前的session持有著變量相關信息撮弧,而save一定會運行Saver()類初始化時候定義的ops從而持久化變量數(shù)據(jù)(.data與.index)。
write_meta_graph=True代表保存變量數(shù)值(.data與.index)的同時會保存metaGraph姚糊。而上面已經(jīng)介紹過贿衍,metaGraph中持有重新構建圖的所有信息。write_state=True則會默認生成checkpoint文件自動記錄訓練文件名稱救恨。
關鍵的恢復代碼:
if context.in_graph_mode():
model_checkpoint_path = sess.run(
self.saver_def.save_tensor_name,
{self.saver_def.filename_tensor_name: checkpoint_file})
這一步就是運行之前Saver()初始化之后贸辈,創(chuàng)建的saver_def中的op:save_tensor_name。這一個op實際上是graph_def中定義的一個node:"save/control_dependency:0"忿薇。而這個op本身是無意義的裙椭,其實是為了調用它所依賴的variables身上的store op。同時傳入文件名參數(shù)署浩。這樣就執(zhí)行了保持variable數(shù)值的ops揉燃。
但是我們知道save()方法不光光如此,它還會生成.meta
與checkpoint
:
if write_state:
self._RecordLastCheckpoint(model_checkpoint_path)
_update_checkpoint_state(
save_dir=save_path_parent,
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=self.last_checkpoints,
latest_filename=latest_filename,
save_relative_paths=self._save_relative_paths)
self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
上述兩個操作分別是更新checkpoint文件以及刪除過時的舊模型文件筋栋。其實checkpoint本身也是pb炊汤,只不過它不影響效率,就使用text_format.MessageToString(ckpt)
將pb message轉換為了text,方便直接打開看和修改抢腐。
以下是生成metaGraph(.meta
文件)的關鍵代碼:
if write_meta_graph:
meta_graph_filename = self._MetaGraphFilename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if context.in_graph_mode():
with sess.graph.as_default():
self.export_meta_graph(meta_graph_filename)
... ...
def saver().export_meta_graph:
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
saver_def=self.saver_def,
collection_list=collection_list,
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices,
clear_extraneous_savers=clear_extraneous_savers)
本質調用了Saver().export_meta_graph() ----> tf.train.export_meta_graph() 最關鍵的是加入了該類存儲著的saver_def姑曙,因此輸出的.meta文件里是包含saver_def的,下次可以用來恢復Saver(其實是記錄著restore_all關鍵node的name迈倍,Saver()初始化的時候已經(jīng)向graph_def里添加好了所有的save/restore的node)伤靠。
3.2 Saver.restore() 函數(shù):
這個函數(shù)就非常簡單了,關鍵代碼:
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
就是導入.meta文件后啼染,里面的saver_def二進制流信息重建并返回了Saver實例宴合,然后就能獲取到restore所有variables的那個op的名字,然后去運行迹鹅。該op:restore_op_name: "save/restore_all"同樣也是依賴于所有variable的assign操作卦洽,即變量賦值。