Tensorflow 模型持久化

當(dāng)我們使用 tensorflow 訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時候糊治,模型持久化對于我們的訓(xùn)練有很重要的作用折汞。

  • 如果我們的神經(jīng)網(wǎng)絡(luò)比較復(fù)雜咒劲,訓(xùn)練數(shù)據(jù)比較多会放,那么我們的模型訓(xùn)練就會耗時很長饲齐,如果在訓(xùn)練過程中出現(xiàn)某些不可預(yù)計的錯誤,導(dǎo)致我們的訓(xùn)練意外終止咧最,那么我們將會前功盡棄捂人。為了避免這個問題,我們就可以通過模型持久化(保存為CKPT格式)來暫存我們訓(xùn)練過程中的臨時數(shù)據(jù)矢沿。
  • 如果我們訓(xùn)練的模型需要提供給用戶做離線的預(yù)測滥搭,那么我們只需要前向傳播的過程,只需得到預(yù)測值就可以了捣鲸,這個時候我們就可以通過模型持久化(保存為PB格式)只保存前向傳播中需要的變量并將變量的值固定下來瑟匆,這個時候只需用戶提供一個輸入,我們就可以通過模型得到一個輸出給用戶栽惶。

保存為 CKPT 格式的模型

  1. 定義運(yùn)算過程
  2. 聲明并得到一個 Saver
  3. 通過 Saver.save 保存模型
# coding=UTF-8 支持中文編碼格式
import tensorflow as tf
import shutil
import os.path

MODEL_DIR = "model/ckpt"
MODEL_NAME = "model.ckpt"

# if os.path.exists(MODEL_DIR): 刪除目錄
#     shutil.rmtree(MODEL_DIR)
if not tf.gfile.Exists(MODEL_DIR): #創(chuàng)建目錄
    tf.gfile.MakeDirs(MODEL_DIR)

#下面的過程你可以替換成CNN愁溜、RNN等你想做的訓(xùn)練過程,這里只是簡單的一個計算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #輸入占位符外厂,并指定名字冕象,后續(xù)模型讀取可能會用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.greater(_y, 50, name="predictions") #輸出節(jié)點(diǎn)名字,后續(xù)模型讀取會用到酣衷,比50大返回true交惯,否則返回false

init = tf.global_variables_initializer()
saver = tf.train.Saver() #聲明saver用于保存模型

with tf.Session() as sess:
    sess.run(init)
    print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #輸入一個數(shù)據(jù)測試一下
    saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存
    print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn)

for op in tf.get_default_graph().get_operations(): #打印模型節(jié)點(diǎn)信息
    print (op.name, op.values())

運(yùn)行后生成的文件如下:

model_ckpt
  • checkpoint : 記錄目錄下所有模型文件列表
  • ckpt.data : 保存模型中每個變量的取值
  • ckpt.meta : 保存整個計算圖的結(jié)構(gòu)

保存為 PB 格式模型

  1. 定義運(yùn)算過程
  2. 通過 get_default_graph().as_graph_def() 得到當(dāng)前圖的計算節(jié)點(diǎn)信息
  3. 通過 graph_util.convert_variables_to_constants 將相關(guān)節(jié)點(diǎn)的values固定
  4. 通過 tf.gfile.GFile 進(jìn)行模型持久化
# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util


# MODEL_DIR = "model/pb"
# MODEL_NAME = "addmodel.pb"

# if os.path.exists(MODEL_DIR): 刪除目錄
#     shutil.rmtree(MODEL_DIR)
#
# if not tf.gfile.Exists(MODEL_DIR): #創(chuàng)建目錄
#     tf.gfile.MakeDirs(MODEL_DIR)

output_graph = "model/pb/add_model.pb"

#下面的過程你可以替換成CNN、RNN等你想做的訓(xùn)練過程穿仪,這里只是簡單的一個計算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true席爽,否則返回false
predictions = tf.add(_y, 10,name="predictions") #做一個加法運(yùn)算

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
    graph_def = tf.get_default_graph().as_graph_def() #得到當(dāng)前的圖的 GraphDef 部分,通過這個部分就可以完成重輸入層到輸出層的計算過程

    output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化啊片,將變量值固定
        sess,
        graph_def,
        ["predictions"] #需要保存節(jié)點(diǎn)的名字
    )
    with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
        f.write(output_graph_def.SerializeToString())  # 序列化輸出
    print("%d ops in the final graph." % len(output_graph_def.node))
    print (predictions)

# for op in tf.get_default_graph().get_operations(): 打印模型節(jié)點(diǎn)信息
#     print (op.name)

*GraphDef:這個屬性記錄了tensorflow計算圖上節(jié)點(diǎn)的信息只锻。

model_pb
  • add_model.pb : 里面保存了重輸入層到輸出層這個計算過程的計算圖和相關(guān)變量的值,我們得到這個模型后傳入一個輸入紫谷,既可以得到一個預(yù)估的輸出值

CKPT 轉(zhuǎn)換成 PB格式

  1. 通過傳入 CKPT 模型的路徑得到模型的圖和變量數(shù)據(jù)
  2. 通過 import_meta_graph 導(dǎo)入模型中的圖
  3. 通過 saver.restore 從模型中恢復(fù)圖中各個變量的數(shù)據(jù)
  4. 通過 graph_util.convert_variables_to_constants 將模型持久化
# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util

MODEL_DIR = "model/pb"
MODEL_NAME = "frozen_model.pb"

if not tf.gfile.Exists(MODEL_DIR): #創(chuàng)建目錄
    tf.gfile.MakeDirs(MODEL_DIR)

def freeze_graph(model_folder):
    checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt文件狀態(tài)是否可用
    input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路徑
    output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路徑

    output_node_names = "predictions" #原模型輸出操作節(jié)點(diǎn)的名字
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到圖齐饮、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.

    graph = tf.get_default_graph() #獲得默認(rèn)的圖
    input_graph_def = graph.as_graph_def()  #返回一個序列化的圖代表當(dāng)前的圖

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint) #恢復(fù)圖并得到數(shù)據(jù)

        print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 測試讀出來的模型是否正確捐寥,注意這里傳入的是輸出 和輸入 節(jié)點(diǎn)的 tensor的名字,不是操作節(jié)點(diǎn)的名字

        output_graph_def = graph_util.convert_variables_to_constants(  #模型持久化祖驱,將變量值固定
            sess,
            input_graph_def,
            output_node_names.split(",") #如果有多個輸出節(jié)點(diǎn)握恳,以逗號隔開
        )
        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
            f.write(output_graph_def.SerializeToString()) #序列化輸出
        print("%d ops in the final graph." % len(output_graph_def.node)) #得到當(dāng)前圖有幾個操作節(jié)點(diǎn)

        for op in graph.get_operations():
            print(op.name, op.values())

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析捺僻,help是提示符匕坯,type是輸入的類型葛峻,
    # 這里運(yùn)行程序時需要帶上模型ckpt的路徑术奖,不然會報 error: too few arguments
    aggs = parser.parse_args()
    freeze_graph(aggs.model_folder)
    # freeze_graph("model/ckpt") #模型目錄

加載pb模型

1.通過 tf.gfile.GFile 打開模型
2.通過 tf.GraphDef().ParseFromString 得到模型中的圖和變量數(shù)據(jù)
3.通過 tf.import_graph_def 加載目前的圖
4.拿到輸入節(jié)點(diǎn)和輸出節(jié)點(diǎn)tensor并進(jìn)行預(yù)測

# coding=UTF-8

import tensorflow as tf

def load_graph(model_dir):
    with tf.gfile.GFile(model_dir, "rb") as f: #讀取模型數(shù)據(jù)
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) #得到模型中的計算圖和數(shù)據(jù)

        with tf.Graph().as_default() as graph: #這里的Graph()要有括號腰耙,不然會報TypeError
            tf.import_graph_def(graph_def, name="michael") #導(dǎo)入模型中的圖到現(xiàn)在這個新的計算圖中,不指定名字的話默認(rèn)是 import
            return graph


if __name__  == "__main__":
    graph = load_graph("model/pb/frozen_model.pb") #這里傳入的是完整的路徑包括pb的名字挺庞,不然會報FailedPreconditionError

    for op in graph.get_operations(): #打印出圖中的節(jié)點(diǎn)信息
        print (op.name, op.values())

    x = graph.get_tensor_by_name('michael/input_holder:0') #得到輸入節(jié)點(diǎn)tensor的名字选侨,記得跟上導(dǎo)入圖時指定的name
    y = graph.get_tensor_by_name('michael/predictions:0') #得到輸出節(jié)點(diǎn)tensor的名字

    with tf.Session(graph=graph) as sess: #創(chuàng)建會話運(yùn)行計算
        y_out = sess.run(y, feed_dict={x: [10.0]})
        print(y_out)
    print ("finish")

部分參考: TensorFlow實(shí)戰(zhàn)Google深度學(xué)習(xí)框架、http://blog.csdn.net/lujiandong1/article/details/53385092

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子答捕,更是在濱河造成了極大的恐慌逝钥,老刑警劉巖艘款,帶你破解...
    沈念sama閱讀 206,602評論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件岳枷,死亡現(xiàn)場離奇詭異朱庆,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)剧罩,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,442評論 2 382
  • 文/潘曉璐 我一進(jìn)店門来氧,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人僚楞,你說我怎么就攤上這事揉忘∧啵” “怎么了?”我有些...
    開封第一講書人閱讀 152,878評論 0 344
  • 文/不壞的土叔 我叫張陵,是天一觀的道長涮母。 經(jīng)常有香客問我哈蝇,道長,這世上最難降的妖魔是什么吠勘? 我笑而不...
    開封第一講書人閱讀 55,306評論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮峭拘,結(jié)果婚禮上鸡挠,老公的妹妹穿的比我還像新娘彭沼。我一直安慰自己姓惑,他們只是感情好于毙,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,330評論 5 373
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般蚤氏。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上于游,一...
    開封第一講書人閱讀 49,071評論 1 285
  • 那天,我揣著相機(jī)與錄音蚌成,去河邊找鬼担忧。 笑死,一個胖子當(dāng)著我的面吹牛惩猫,可吹牛的內(nèi)容都是我干的帆锋。 我是一名探鬼主播皮官,決...
    沈念sama閱讀 38,382評論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了馍佑?” 一聲冷哼從身側(cè)響起拭荤,我...
    開封第一講書人閱讀 37,006評論 0 259
  • 序言:老撾萬榮一對情侶失蹤奇徒,失蹤者是張志新(化名)和其女友劉穎罢低,沒想到半個月后奕短,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,512評論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡日杈,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 35,965評論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,094評論 1 333
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出杂抽,到底是詐尸還是另有隱情,我是刑警寧澤愚屁,帶...
    沈念sama閱讀 33,732評論 4 323
  • 正文 年R本政府宣布送浊,位于F島的核電站,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜幅疼,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,283評論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望爽篷。 院中可真熱鬧悴晰,春花似錦、人聲如沸逐工。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,286評論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽钻弄。三九已至佃却,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間窘俺,已是汗流浹背饲帅。 一陣腳步聲響...
    開封第一講書人閱讀 31,512評論 1 262
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留瘤泪,地道東北人灶泵。 一個月前我還...
    沈念sama閱讀 45,536評論 2 354
  • 正文 我出身青樓,卻偏偏與公主長得像对途,于是被迫代替她去往敵國和親赦邻。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,828評論 2 345

推薦閱讀更多精彩內(nèi)容