如何使用Tensorflow保存或者加載模型(一)

1.背景

在深度學習的開源框架中碳抄,Tensorflow是最熱門的框架之一草添,相信很多同學已經(jīng)有了不同程度的學習和了解酪劫。但站長在平時的溝通發(fā)現(xiàn)豆瘫,很多同學反應(yīng)不知道怎么使用自己訓練好的模型進行預(yù)測珊蟀,不知道怎么繼續(xù)接著之前訓練了多個輪次的模型進行訓練,不知道怎么生成工業(yè)化場景里可上線的模型文件等等。 因此育灸,站長會寫一個針對Tensorflow的模型保存和加載的系列文章腻窒,為大家解決相關(guān)問題。

1.1 模型文件介紹

Tensorflow保存模型的時候會生成三個文件磅崭,分別是meta file儿子,index filedata file砸喻。

meta file 這個文件是描述圖結(jié)構(gòu)柔逼,包括GraphDef, SaverDef等。值得注意的是割岛,在Tensorfow中圖和變量是分開的愉适,關(guān)于圖結(jié)構(gòu)的信息主要保存在meta file中。
index file 這個文件是關(guān)于tensor的索引文件癣漆,key就是tensor的名字维咸,value就是序列化后的BundleEntryProto。
data file 這個文件保存了所有變量的值惠爽。

1.2 模型的保存示例代碼一

這里先簡單地實現(xiàn)了一個例子癌蓖,實現(xiàn)的是初始化了隨機變量v1和v2,并將變量v1和v2相加獲得變量v3婚肆。注意:這里保存模型的目錄是result而不是model.ckpt费坊。而這里的model.ckpt是模型文件的前綴,如果想要在同一個目錄下保存多個模型的話旬痹,可以通過修改這個前綴達成目的附井。

# -*- coding: utf-8 -*-
import tensorflow as tf

v1 = tf.get_variable("v1", shape=[10], initializer=tf.random_normal_initializer)
v2 = tf.get_variable("v2", shape=[10], initializer=tf.random_normal_initializer)
v3 = tf.add(v1,v2, name="v3")

init_op = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
  print("Start initialing model parameters")
  sess.run(init_op)
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())
  print("v3 : %s" % v3.eval())
  # Save the variables to disk.
  save_path = saver.save(sess, "./result/model.ckpt")
  print("Model saved in path: %s" % save_path)

1.3 模型的加載示例代碼一

當模型已經(jīng)輸出后,我們就可以通過saver去加載所有的變量两残,并執(zhí)行相關(guān)運算操作永毅。注意,模型加載的路徑是模型文件前綴是model.ckpt而不是其中一個文件名或者整個目錄名

# -*- coding: utf-8 -*-
import tensorflow as tf

###2.Restore Variables###
tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[10])
v2 = tf.get_variable("v2", shape=[10])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "./result/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

當然人弓,如果你只希望加載部分的變量的時候沼死,可以在創(chuàng)建Saver的時候,只傳入部分變量崔赌。例如:

saver = tf.train.Saver([v1,v2])

如果你只想保存最后3個epochs的模型和每兩個小時保存一個模型的話意蛀,可以這么設(shè)置。

saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=2)

1.4 模型的保存示例代碼二

由于上面的示例比較簡單健芭,我們用一個線性回歸的例子作為示例吧县钥。

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np

##1.創(chuàng)建PlaceHolder和初始化參數(shù)##
X = tf.placeholder("float", name="X")
Y = tf.placeholder("float", name="Y")

W = tf.Variable(np.random.randn(), name= "W")
b = tf.Variable(np.random.randn(), name= "b")

learning_rate = 0.02
epochs = 100

data_x = np.linspace(0, 50, 50)
data_y = np.linspace(0, 50, 50)

##2.實現(xiàn)梯度下降##
y_pred = tf.add(tf.multiply(X, W), b, name="y_pred")
loss = tf.reduce_sum(tf.pow(y_pred-Y, 2)) / (2 * len(data_x))
opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
##初始化變量##
init = tf.global_variables_initializer()

##創(chuàng)建Saver##
saver = tf.train.Saver()

##3.構(gòu)建Tensorflow Session##
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(epochs):
        for (batch_x, batch_y) in zip(data_x, data_y):
            sess.run(opt, feed_dict={X: batch_x, Y: batch_y})

        if (epoch + 1) % 10 == 0:
            cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
            print("Epoch", (epoch + 1), ": cost =", cost, "W =", sess.run(W), "b =", sess.run(b))


    # 存儲必須的變量#
    training_cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
    weight = sess.run(W)
    bias = sess.run(b)

    # 用變量進行預(yù)測#
    predictions = weight * X + bias
    print("Training Cost =", training_cost, "Weight =", weight, "bias =", bias, '\n')
    print("預(yù)測結(jié)果:", weight * 0.01 + bias)

    # 保存模型#
    save_path = saver.save(sess, "./result/model.ckpt", global_step=epochs)
    print("Model saved in path: %s" % save_path)

訓練線性模型后的結(jié)果如下:

Epoch 10 : cost = 8.91454e-07 W = 0.99995387 b = 0.0023029544
Epoch 20 : cost = 8.285824e-07 W = 0.99995565 b = 0.0022180977
Epoch 30 : cost = 7.692257e-07 W = 0.9999573 b = 0.0021363394
Epoch 40 : cost = 7.133759e-07 W = 0.9999589 b = 0.0020576443
Epoch 50 : cost = 6.604228e-07 W = 0.9999603 b = 0.0019818414
Epoch 60 : cost = 6.1390205e-07 W = 0.99996185 b = 0.0019088342
Epoch 70 : cost = 5.691646e-07 W = 0.9999632 b = 0.0018385095
Epoch 80 : cost = 5.2784395e-07 W = 0.9999646 b = 0.0017707513
Epoch 90 : cost = 4.897609e-07 W = 0.9999659 b = 0.0017055188
Epoch 100 : cost = 4.5440865e-07 W = 0.99996716 b = 0.0016426622
Training Cost = 4.5440865e-07 Weight = 0.99996716 bias = 0.0016426622 

預(yù)測結(jié)果: 0.011642333795316517
Model saved in path: ./result/model.ckpt-100

Process finished with exit code 0

1.5 模型的加載示例代碼二

讀取本地模型文件,并開始預(yù)測新樣本

# -*- coding: utf-8 -*-
import tensorflow as tf

with tf.Session() as sess:
    ##加載meta的圖結(jié)構(gòu)和權(quán)重
    ##這里要用meta文件的文件名而不是路徑名
    saver = tf.train.import_meta_graph('./result/model.ckpt-100.meta')
    ##這里要用路徑名
    saver.restore(sess, tf.train.latest_checkpoint('./result'))

    ##加載圖
    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name("X:0")
    ##輸入數(shù)據(jù)點
    feed_dict = {X: 0.01}

    ##打印預(yù)測結(jié)果
    y_pred = graph.get_tensor_by_name("y_pred:0")
    print("預(yù)測結(jié)果是:", sess.run(y_pred, feed_dict=feed_dict))

加載模型后慈迈,預(yù)測的結(jié)果如下:

##與訓練結(jié)束時打印的結(jié)果一致
預(yù)測結(jié)果是: 0.011642333

2.總結(jié)

本文介紹了基本的Tensorflow模型保存和加載方式和相應(yīng)的示例若贮,可是,如果公司要求要在Java環(huán)境下使用,要怎么做呢谴麦?下一篇文章會介紹一種新的模型保存和加載方法蠢沿,會更簡單,兼容性更強匾效,支持Java調(diào)用舷蟀。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市面哼,隨后出現(xiàn)的幾起案子野宜,更是在濱河造成了極大的恐慌,老刑警劉巖精绎,帶你破解...
    沈念sama閱讀 221,548評論 6 515
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件速缨,死亡現(xiàn)場離奇詭異锌妻,居然都是意外死亡代乃,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,497評論 3 399
  • 文/潘曉璐 我一進店門仿粹,熙熙樓的掌柜王于貴愁眉苦臉地迎上來搁吓,“玉大人,你說我怎么就攤上這事吭历《樽校” “怎么了?”我有些...
    開封第一講書人閱讀 167,990評論 0 360
  • 文/不壞的土叔 我叫張陵晌区,是天一觀的道長摩骨。 經(jīng)常有香客問我,道長朗若,這世上最難降的妖魔是什么恼五? 我笑而不...
    開封第一講書人閱讀 59,618評論 1 296
  • 正文 為了忘掉前任,我火速辦了婚禮哭懈,結(jié)果婚禮上灾馒,老公的妹妹穿的比我還像新娘。我一直安慰自己遣总,他們只是感情好睬罗,可當我...
    茶點故事閱讀 68,618評論 6 397
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著旭斥,像睡著了一般容达。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上垂券,一...
    開封第一講書人閱讀 52,246評論 1 308
  • 那天董饰,我揣著相機與錄音,去河邊找鬼。 笑死卒暂,一個胖子當著我的面吹牛啄栓,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播也祠,決...
    沈念sama閱讀 40,819評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼昙楚,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了诈嘿?” 一聲冷哼從身側(cè)響起堪旧,我...
    開封第一講書人閱讀 39,725評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎奖亚,沒想到半個月后淳梦,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,268評論 1 320
  • 正文 獨居荒郊野嶺守林人離奇死亡昔字,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,356評論 3 340
  • 正文 我和宋清朗相戀三年爆袍,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片作郭。...
    茶點故事閱讀 40,488評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡陨囊,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出夹攒,到底是詐尸還是另有隱情蜘醋,我是刑警寧澤,帶...
    沈念sama閱讀 36,181評論 5 350
  • 正文 年R本政府宣布咏尝,位于F島的核電站压语,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏编检。R本人自食惡果不足惜胎食,卻給世界環(huán)境...
    茶點故事閱讀 41,862評論 3 333
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望蒙谓。 院中可真熱鬧斥季,春花似錦、人聲如沸累驮。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,331評論 0 24
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽谤专。三九已至躁锡,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間置侍,已是汗流浹背映之。 一陣腳步聲響...
    開封第一講書人閱讀 33,445評論 1 272
  • 我被黑心中介騙來泰國打工拦焚, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人杠输。 一個月前我還...
    沈念sama閱讀 48,897評論 3 376
  • 正文 我出身青樓赎败,卻偏偏與公主長得像,于是被迫代替她去往敵國和親蠢甲。 傳聞我的和親對象是個殘疾皇子僵刮,可洞房花燭夜當晚...
    茶點故事閱讀 45,500評論 2 359

推薦閱讀更多精彩內(nèi)容