TensorFlow 保存和加載模型

可以在訓(xùn)練期間和訓(xùn)練后保存模型進(jìn)度刻获。 這意味著模型可以從中斷的地方恢復(fù)蜀涨,并避免長時(shí)間的訓(xùn)練。 保存也意味著您可以共享您的模型将鸵,而其他人可以重新創(chuàng)建您的工作勉盅。 在發(fā)布研究模型和技術(shù)時(shí),大多數(shù)機(jī)器學(xué)習(xí)從業(yè)者分享:

  1. 用于創(chuàng)建模型的代碼
  2. 模型的訓(xùn)練權(quán)重或參數(shù)

共享此數(shù)據(jù)有助于其他人了解模型的工作原理顶掉,并使用新數(shù)據(jù)自行嘗試草娜。

注意:小心不受信任的代碼 - TensorFlow模型是代碼。 有關(guān)詳細(xì)信息痒筒,請(qǐng)參閱安全使用TensorFlow宰闰。

選項(xiàng)

保存TensorFlow模型有多種方法 - 取決于您使用的API茬贵。 本指南使用tf.keras,一個(gè)高級(jí)API移袍,用于在TensorFlow中構(gòu)建和訓(xùn)練模型解藻。 有關(guān)其他方法,請(qǐng)參閱TensorFlow保存和還原指南或保存在急切中葡盗。

安裝

安裝和引用

安裝和導(dǎo)入TensorFlow和依賴項(xiàng)螟左,有下面兩種方式:

  1. 命令行:pip install -q h5py pyyaml
  2. 在Anaconda Navigator中安裝;

下載樣本數(shù)據(jù)集

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__

'1.11.0'

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定義模型

讓我們構(gòu)建一個(gè)簡單的模型觅够,我們將用它來演示保存和加載權(quán)重胶背。

# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])
  
  model.compile(optimizer=tf.keras.optimizers.Adam(), 
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['accuracy'])
  
  return model


# Create a basic model instance
model = create_model()
model.summary()

在訓(xùn)練期間保存檢查點(diǎn)

主要用例是在訓(xùn)練期間和訓(xùn)練結(jié)束時(shí)自動(dòng)保存檢查點(diǎn)。 通過這種方式喘先,您可以使用訓(xùn)練有素的模型钳吟,而無需重新訓(xùn)練,或者在您離開的地方接受訓(xùn)練 - 以防止訓(xùn)練過程中斷窘拯。

tf.keras.callbacks.ModelCheckpoint是執(zhí)行此任務(wù)的回調(diào)红且。 回調(diào)需要幾個(gè)參數(shù)來配置檢查點(diǎn)。

檢查點(diǎn)回調(diào)使用情況

訓(xùn)練模型并將模型傳遞給ModelCheckpoint:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
 save_weights_only=True,
 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10, 
  validation_data = (test_images,test_labels),
  callbacks = [cp_callback])  # pass callback to training

這將創(chuàng)建一個(gè)TensorFlow檢查點(diǎn)文件集合涤姊,這些文件在每個(gè)時(shí)期結(jié)束時(shí)更新:

!ls {checkpoint_dir}

checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index

創(chuàng)建一個(gè)新的未經(jīng)訓(xùn)練的模型暇番。 僅從權(quán)重還原模型時(shí),必須具有與原始模型具有相同體系結(jié)構(gòu)的模型砂轻。 由于它是相同的模型架構(gòu)奔誓,我們可以共享權(quán)重,盡管它是模型的不同實(shí)例搔涝。

現(xiàn)在重建一個(gè)新的未經(jīng)訓(xùn)練的模型厨喂,并在測(cè)試集上進(jìn)行評(píng)估。 未經(jīng)訓(xùn)練的模型將在偶然水平上執(zhí)行(準(zhǔn)確度約為10%):

model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

然后從檢查點(diǎn)加載權(quán)重庄呈,并重新評(píng)估:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 40us/step
Restored model, accuracy: 87.60%

檢查點(diǎn)回調(diào)選項(xiàng)

回調(diào)提供了幾個(gè)選項(xiàng)蜕煌,可以為生成的檢查點(diǎn)提供唯一的名稱,并調(diào)整檢查點(diǎn)頻率诬留。

訓(xùn)練一個(gè)新模型斜纪,每5個(gè)時(shí)期保存一次唯一命名的檢查點(diǎn):

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
  epochs = 50, callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)

現(xiàn)在,查看生成的檢查點(diǎn)并選擇最新的檢查點(diǎn):

! ls {checkpoint_dir}

checkpoint cp-0030.ckpt.data-00000-of-00001
cp-0005.ckpt.data-00000-of-00001 cp-0030.ckpt.index
cp-0005.ckpt.index cp-0035.ckpt.data-00000-of-00001
cp-0010.ckpt.data-00000-of-00001 cp-0035.ckpt.index
cp-0010.ckpt.index cp-0040.ckpt.data-00000-of-00001
cp-0015.ckpt.data-00000-of-00001 cp-0040.ckpt.index
cp-0015.ckpt.index cp-0045.ckpt.data-00000-of-00001
cp-0020.ckpt.data-00000-of-00001 cp-0045.ckpt.index
cp-0020.ckpt.index cp-0050.ckpt.data-00000-of-00001
cp-0025.ckpt.data-00000-of-00001 cp-0050.ckpt.index
cp-0025.ckpt.index

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest

'training_2/cp-0050.ckpt'

注意:默認(rèn)的tensorflow格式僅保存最近的5個(gè)檢查點(diǎn)文兑。

要測(cè)試盒刚,請(qǐng)重置模型并加載最新的檢查點(diǎn):

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 96us/step
Restored model, accuracy: 86.80%

這些文件是什么?

上述代碼將權(quán)重存儲(chǔ)到檢查點(diǎn)格式的文件集合中绿贞,這些文件僅包含二進(jìn)制格式的訓(xùn)練權(quán)重因块。 檢查點(diǎn)包含:*一個(gè)或多個(gè)包含模型權(quán)重的分片。 *索引文件籍铁,指示哪些權(quán)重存儲(chǔ)在哪個(gè)分片中涡上。

如果您只在一臺(tái)機(jī)器上訓(xùn)練模型趾断,那么您將有一個(gè)帶有后綴的分片:.data-00000-of-00001

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

保存整個(gè)模型

整個(gè)模型可以保存到包含權(quán)重值,模型配置甚至優(yōu)化器配置的文件中吩愧。 這允許您檢查模型并稍后從完全相同的狀態(tài)恢復(fù)培訓(xùn) - 無需訪問原始代碼芋酌。

在Keras中保存功能齊全的模型非常有用 - 您可以在TensorFlow.js中加載它們,然后在Web瀏覽器中訓(xùn)練和運(yùn)行它們雁佳。

Keras使用HDF5標(biāo)準(zhǔn)提供基本保存格式脐帝。 出于我們的目的,可以將保存的模型視為單個(gè)二進(jìn)制blob甘穿。

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')

Epoch 1/5
1000/1000 [==============================] - 0s 395us/step - loss: 1.1260 - acc: 0.6870
Epoch 2/5
1000/1000 [==============================] - 0s 135us/step - loss: 0.4136 - acc: 0.8760
Epoch 3/5
1000/1000 [==============================] - 0s 138us/step - loss: 0.2811 - acc: 0.9280
Epoch 4/5
1000/1000 [==============================] - 0s 153us/step - loss: 0.2078 - acc: 0.9480
Epoch 5/5
1000/1000 [==============================] - 0s 154us/step - loss: 0.1452 - acc: 0.9750

現(xiàn)在從該文件重新創(chuàng)建模型:

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

檢查其準(zhǔn)確性:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

這項(xiàng)技術(shù)可以保存以下:

  1. 權(quán)重值
  2. 模型的配置(架構(gòu))
  3. 優(yōu)化器配置

Keras通過檢查架構(gòu)來保存模型腮恩。 目前,它無法保存TensorFlow優(yōu)化器(來自tf.train)温兼。 使用這些時(shí),您需要在加載后重新編譯模型武契,并且您將失去優(yōu)化器的狀態(tài)募判。

下一步是什么

這是使用tf.keras保存和加載的快速指南。

tf.keras指南顯示了有關(guān)使用tf.keras保存和加載模型的更多信息咒唆。

請(qǐng)參閱在急切執(zhí)行期間保存以備保存届垫。

“保存和還原”指南包含有關(guān)TensorFlow保存的低級(jí)詳細(xì)信息。

完整代碼:

from __future__ import absolute_import,division,print_function
import os
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)


# Download dataset
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1,28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1,28 * 28) / 255.0

# Define a model
# Returns a short sequential model
def create_model():
    model = tf.keras.models.Sequential([
    keras.layers.Dense(512,activation=tf.nn.relu,input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10,activation=tf.nn.softmax)
])

model.compile(optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.sparse_categorical_crossentropy,
  metrics=['accuracy'])
return model

# Create a basic model instance
model = create_model()
model.summary()

# Checkpoint callback usage
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
 save_weights_only=True,
 verbose=1)
model = create_model()
model.fit(train_images,train_labels,epochs=10,
  validation_data=(test_images,test_labels),
  callbacks=[cp_callback]) # pass callback to training

# Create a new, untrained model. 
model = create_model()
loss,acc = model.evaluate(test_images,test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

# Load the weights from chekpoint, and re-evaluate.
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Train a new model, and save uniquely named checkpoints once every 5epochs
# include the epoch in the file name. (uses 'str.format')
checkpoint_path = 'training_2/cp-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1,save_weights_only=True,
    # Save weights, every 5-epochs
    period=5)

model = create_model()
model.fit(train_images,train_labels,
  epochs=50,callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)


latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)


# To test, reset the model and load the latest checkpoint
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Manually save weights
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))


# Save the entire model
model = create_model()
model.fit(train_images,train_labels,
  epochs=5)
# Save entire model to a HDF5 file
model.save('my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

loss, acc = new_model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末全释,一起剝皮案震驚了整個(gè)濱河市装处,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌浸船,老刑警劉巖妄迁,帶你破解...
    沈念sama閱讀 217,509評(píng)論 6 504
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異李命,居然都是意外死亡登淘,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,806評(píng)論 3 394
  • 文/潘曉璐 我一進(jìn)店門封字,熙熙樓的掌柜王于貴愁眉苦臉地迎上來黔州,“玉大人,你說我怎么就攤上這事阔籽×髌蓿” “怎么了?”我有些...
    開封第一講書人閱讀 163,875評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵笆制,是天一觀的道長绅这。 經(jīng)常有香客問我,道長项贺,這世上最難降的妖魔是什么君躺? 我笑而不...
    開封第一講書人閱讀 58,441評(píng)論 1 293
  • 正文 為了忘掉前任峭判,我火速辦了婚禮,結(jié)果婚禮上棕叫,老公的妹妹穿的比我還像新娘林螃。我一直安慰自己,他們只是感情好俺泣,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,488評(píng)論 6 392
  • 文/花漫 我一把揭開白布疗认。 她就那樣靜靜地躺著,像睡著了一般伏钠。 火紅的嫁衣襯著肌膚如雪横漏。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,365評(píng)論 1 302
  • 那天熟掂,我揣著相機(jī)與錄音缎浇,去河邊找鬼。 笑死赴肚,一個(gè)胖子當(dāng)著我的面吹牛素跺,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播誉券,決...
    沈念sama閱讀 40,190評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼指厌,長吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了踊跟?” 一聲冷哼從身側(cè)響起踩验,我...
    開封第一講書人閱讀 39,062評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎商玫,沒想到半個(gè)月后箕憾,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,500評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡决帖,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,706評(píng)論 3 335
  • 正文 我和宋清朗相戀三年厕九,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片地回。...
    茶點(diǎn)故事閱讀 39,834評(píng)論 1 347
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡扁远,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出刻像,到底是詐尸還是另有隱情畅买,我是刑警寧澤,帶...
    沈念sama閱讀 35,559評(píng)論 5 345
  • 正文 年R本政府宣布细睡,位于F島的核電站谷羞,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜湃缎,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,167評(píng)論 3 328
  • 文/蒙蒙 一犀填、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧嗓违,春花似錦九巡、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,779評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至偿洁,卻和暖如春撒汉,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背涕滋。 一陣腳步聲響...
    開封第一講書人閱讀 32,912評(píng)論 1 269
  • 我被黑心中介騙來泰國打工睬辐, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人宾肺。 一個(gè)月前我還...
    沈念sama閱讀 47,958評(píng)論 2 370
  • 正文 我出身青樓溉委,卻偏偏與公主長得像,于是被迫代替她去往敵國和親爱榕。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,779評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容

  • 近期做了一些反垃圾的工作坡慌,除了使用常用的規(guī)則匹配過濾等手段黔酥,也采用了一些機(jī)器學(xué)習(xí)方法進(jìn)行分類預(yù)測(cè)。我們使用Tens...
    liuyan731閱讀 12,754評(píng)論 0 19
  • 在這篇tensorflow教程中洪橘,我會(huì)解釋: 1) Tensorflow的模型(model)長什么樣子跪者? 2) 如...
    JunsorPeng閱讀 3,419評(píng)論 1 6
  • 世界這么大,你應(yīng)該去看看 今天剛剛高考完的表妹問起大學(xué)報(bào)考志愿應(yīng)該怎么填熄求,因?yàn)榘l(fā)揮的不太好渣玲,家里人都建議她學(xué)護(hù)理,...
    Miss凌妹妹閱讀 455評(píng)論 6 4
  • 近海風(fēng)云烈弟晚, 征帆拓遠(yuǎn)洲忘衍。 迷霧隱奇?zhèn)ィ?但去必賢優(yōu)。
    村客閱讀 153評(píng)論 0 6
  • 你說你不吃香菜 我說我吃餃子不帶湯 我倆的碗里卻是漂著香菜的餃子湯 在南方第一次吃熱干面卿城,我嫌它太噎人 在北方枚钓,后...
    Crazy麻麻閱讀 346評(píng)論 7 9