tensorflow提供兩種模型格式
- checkpoint:依賴于創(chuàng)建模型的代碼
- SavedModel:與模型代碼無關(guān)
這里盡介紹checkpoint
1. 保存經(jīng)過部分訓練的模型
Estimator自動將如下內(nèi)容寫入磁盤
- checkpoints: 訓練期間所創(chuàng)建的模型版本
- event files: 包含有TensorBoard用于創(chuàng)建可視化圖標的全部信息
要指定模型的頂級存儲目錄飒货,可以使用Estimator構(gòu)造函數(shù)的可選參數(shù)model_dir
丹允,設(shè)置代碼如下所示
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir="./models_dir")
當調(diào)用Estimator的train
方法時藕畔,Estimator會將checkpoint和其他文件保存到model_dir
目錄中逐哈,保存之后荆责,這個目錄中的文件如下所示:
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
這個目錄存儲的是Estimator在第一步訓練開始和第200不訓練結(jié)束時創(chuàng)建的checkpoints
2. checkpoint頻率
默認情況下于样,Estimator按照如下時間將checkpoint保存到model_dir
中
- 每600秒保存一次
- 在
train
方法開始以及完成時都要保存checkpoint - 在目錄中最多保留5個最近的checkpoints
可以通過如下步驟來更改默認設(shè)置:
- 創(chuàng)建
RunConfig
對象來自定義設(shè)置 - 在實例化Estimator時媳友,將該
RunConfig
對象傳遞個Estimatro的config
參數(shù)
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60,
keep_checkpoint_max = 10,
)
3. 從checkpoint中恢復模型
在第一次調(diào)用Estimator的train
方法時肥缔,Tensorflow會將checkpoint保存到model_dir
中丛版,隨后每次調(diào)用Estimator的train
巩掺、eval
或者predict
方法時,都會發(fā)生下列情況:
- Esitmator運行
model_fun()
構(gòu)建模型圖 - Estimator根據(jù)最近寫入的checkpoint中存儲的數(shù)據(jù)來初始化新模型的權(quán)重
4. 避免不當恢復
通過checkpoint恢復模型的狀態(tài)必須保證模型和checkpoint保存的兼容才可以硼婿。例如我們訓練了一個DNNClassifier
Estimator锌半,它包含有2個隱藏層且每層都有10個節(jié)點,經(jīng)過訓練兵保存了checkpoint到model_dir
中寇漫。后續(xù)在訓練的時候刊殉,假如將代碼中的隱藏層修改為了每層20個節(jié)點,這樣用這樣的Estimator調(diào)用train
時就回報錯州胳,因為checkpoint保存的模型結(jié)構(gòu)與代碼中的模型是不兼容的记焊。這一點切記。