tensorflow框架下
使用saver=tf.train.Saver()保存模型會輸出以下四種文件
checkpoint 文本文件,記錄了模型文件的路徑信息列表
.ckpt.meta 保存了模型的計算圖結構信息(模型的網絡結構)
.ckpt.data-00000-of-00001 網絡權重信息
.ckpt.index 保存了模型中的變量參數(shù)(權重)信息
模型加載方式
(1)
def restore_model_ckpt(ckpt_file_path):
??? sess =tf.Session()
??? saver =tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加載模型結構
??? saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 指定目錄就可以恢復所有變量信息
(2)
saver = tf.train.import_meta_graph(path_to_ckpt_meta)
saver.restore(sess, path_to_ckpt_data)
.pb文件是谷歌推薦的保存模型的方式
將模型參數(shù)固化到圖文件中欺缘,里面保存了圖結構+數(shù)據(jù),合并了一些基礎計算和刪除了反向傳播相關計算得到的protobuf協(xié)議文件幅慌,加載模型時只需要這一個文件就好
keras框架下
.h5 保存的模型參數(shù)或者模型
.json .yaml 保存的模型結構
.hdf5 保存的模型參數(shù)
keras(tensorflow backend)中可以通過如下方式加載模型
(1)
loaded_model = model_from_json(open('model_architecture-1.json').read())
loaded_model.load_weights('saved_models/weights-improvement-19-0.98100.hdf5', by_name=True)
#loaded_model.load_weights('my_model_weights.h5', by_name=True)
(2)
model = load_model('my_model.h5')