Tensorflow 訓(xùn)練好的模型保存和載入

方法一 這種存儲(chǔ)方式在加載模型時(shí)需要再次定義網(wǎng)絡(luò)結(jié)構(gòu)

模型訓(xùn)練和存儲(chǔ)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)

learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1

X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])

W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")

pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()

##初始化存儲(chǔ)器和存儲(chǔ)路徑
saver = tf.train.Saver(max_to_keep=4)
model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
    os.makedirs(path)

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            _,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
            avg_cost += c / total_batch
        if (epoch + 1) % display_step == 0:
            print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
        saver.save(sess,model_path,write_meta_graph=True)
    print ("Optimization Finished")

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

加載模型

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)

X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
    saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))
    graph = tf.get_default_graph()
    W = graph.get_tensor_by_name("W:0")
    b = graph.get_tensor_by_name("b:0")
    pred = tf.nn.softmax(tf.matmul(X,W) + b)

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))

    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

方法二 這種存儲(chǔ)方式在加載模型時(shí)不用定義網(wǎng)絡(luò)結(jié)構(gòu)

模型訓(xùn)練和存儲(chǔ)

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)
print (mnist)

learning_rate = 0.01
training_epochs = 5
batch_size = 100
display_step = 1

X = tf.placeholder(tf.float32,[None,784],name="X")
Y = tf.placeholder(tf.float32,[None,10],name="Y")

W = tf.Variable(tf.zeros([784,10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b")

pred = tf.nn.softmax(tf.matmul(X,W) + b)
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(pred), reduction_indices =1))

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()

saver = tf.train.Saver(max_to_keep=4)

##把要加載的對(duì)象提前加入集合
tf.add_to_collection("pred",pred)

model_path = "./model/lr"
path = os.path.dirname(os.path.abspath(model_path))
if os.path.isdir(path) is False:
    os.makedirs(path)

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            _,c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
            avg_cost += c / total_batch
        if (epoch + 1) % display_step == 0:
            print ("Epoch:","%04d" % (epoch + 1),"cost=","{}".format(avg_cost))
        saver.save(sess,model_path,write_meta_graph=True)
    print ("Optimization Finished")

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:3000],Y:mnist.test.labels[:3000]}))

模型加載

import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/home/devops/test/TensorFlowOnSpark/mnist/",one_hot=True)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model/lr.meta")
    saver.restore(sess,tf.train.latest_checkpoint("/home/devops/test/TensorFlowOnSpark/examples/mnist/my/curve/model"))

    pred = tf.get_collection("pred")[0]
    graph = tf.get_default_graph()

    X = graph.get_operation_by_name("X").outputs[0]
    Y = graph.get_operation_by_name("Y").outputs[0]

    correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(Y,1))
    accuracy  = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    print ("Accuracy:",accuracy.eval({X:mnist.test.images[:300],Y:mnist.test.labels[:300]}))

1.Tensorflow模型文件的組成

    主要包含兩個(gè)文件
  1. 元圖 meta graph

保存完整的圖結(jié)構(gòu) 包含所有的變量 操作等 擴(kuò)展名為meta

2.檢查點(diǎn)文件 checkpoint

二進(jìn)制文件 包含所有的權(quán)重 偏差 梯度和其他所有保存的值 擴(kuò)展名是.ckpt , 0.11版本之后不再僅使用一個(gè).ckpt文件來表示了 而是兩個(gè)文件 .data-00000-of-00001 和.index

其中.data 是包含訓(xùn)練變量的文件

此外還有一個(gè)名為checkpoint的文件 用于保存最新檢查點(diǎn)的記錄

2.如何保存Tensorflow模型

在模型訓(xùn)練完成后 可調(diào)用tf.train.Saver()實(shí)例來保存所有的參數(shù)和計(jì)算圖

由于tensorflow中的變量只能存在于session中,因此需要在session中調(diào)用save 將模型存儲(chǔ)


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))

w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)

saver = tf.train.Saver()

with tf.Session() as sess:

    sess.run(tf.global_variable_initialize())

saver.save(sess,’/path/to/save') 

運(yùn)行后可得以下文件:

model/

├── checkpoint

├── my_test_model.data-00000-of-00001

├── my_test_model.index

└── my_test_model.meta

如果想在1000次迭代之后再保存模型押赊,可通過傳遞步數(shù)來調(diào)用save

saver.save(sess,’model_path’,global_step=1000)

image.png

如果想每1000次保存一下模型,由于.meta文件會(huì)在第一次保存時(shí)創(chuàng)建 而且圖結(jié)構(gòu)不會(huì)再變化,因此只需要保存模型進(jìn)一步迭代的數(shù)據(jù) 而不用存儲(chǔ)網(wǎng)絡(luò)結(jié)構(gòu) 可調(diào)用

saver.save(sess,’model_path’,global_step=step,write_meta_graph=False)

如果只想保存最新的4個(gè)模型參數(shù)梭姓,并且希望在訓(xùn)練階段每?jī)尚r(shí)保存一個(gè)模型误辑,可調(diào)用

saver = tf.train.Saver(max_to_keep=4,keep_checkpoint_every_n_hours=2)

如果在tf.train.Saver() 中沒有指定任何東西崭孤,那么他會(huì)保存模型的所有變量胀溺,如果只想保存部分變量則需要通過列表或字典的形式將變量傳遞進(jìn)去

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2],name='w1’))

w2 = tf.Variable(tf.random_normal(shape=[5]),name=‘w2’)

saver = tf.train.Saver([w1,w2])

with tf.Session() as sess:

    sess.run(tf.global_variable_initializer())

    sess.run(tf.global_variable_initialize())

    saver.save(sess,’/path/to/save')

3.如何導(dǎo)入一個(gè)訓(xùn)練好的模型并進(jìn)行修改和微調(diào)

需要完成兩件事情

1.構(gòu)建網(wǎng)絡(luò)結(jié)構(gòu)

可通過手動(dòng)編寫代碼創(chuàng)建每一層網(wǎng)絡(luò)結(jié)構(gòu)來重構(gòu)整個(gè)網(wǎng)絡(luò)

保存模型時(shí)也會(huì)將網(wǎng)絡(luò)結(jié)構(gòu)存儲(chǔ)到.meta文件中厘唾,可直接調(diào)用tf.train.import()函數(shù)來導(dǎo)入這個(gè)模型

saver = tf.train.import_meta_graph(‘model-1000.meta’) 這個(gè)操作是將.meta文件中的計(jì)算圖數(shù)據(jù)直接附加到當(dāng)前定義的圖中介杆,但是我們?nèi)匀恍枰ゼ虞d計(jì)算圖上所有已經(jīng)訓(xùn)練好的權(quán)重參數(shù)

2.加載參數(shù)

new_saver.restore(sess,tf.train.latest_checkpoint('./‘))   checkpoint文件所在路徑

with tf.Session() as sess:

new_saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’) 

new_saver.restore(sess,tf.train.latest_checkpoint('./')) 
#讀取參數(shù) 
print(sess.run(‘w1:0')) 

4.恢復(fù)任何預(yù)先訓(xùn)練好的模型用于預(yù)測(cè) (工作中的開發(fā)方式)

import tensorflow as tf

w1 = tf.placeholder(‘float’,name='w1’)

w2 = tf.placeholder(‘float’,name=‘w2ww’)

b1 = tf.Variable(2.0,name=‘bias’)

w3 = tf.add(w1,w2)

w4 = tf.multiply(w3,b1,name=‘op_to_restore’)

saver = tf.train.Saver()

with tf.Session() as sess:

  sess.run(tf.global_variables_initializer()) 

  print (sess.run(24,feed_dict={w1:4,w2:8})) 

  saver.save(sess,’test_model’,global_step=1000) 

當(dāng)需要載入這個(gè)模型時(shí)鹃操,不僅需要恢復(fù)所有的計(jì)算圖和權(quán)重參數(shù) 還需要準(zhǔn)備一個(gè)新的feed_dict

用于將新的訓(xùn)練數(shù)據(jù)傳送到網(wǎng)絡(luò)中進(jìn)行訓(xùn)練,可通過graph.get_tensor_by_name() 來獲得對(duì)這些保存的操作和占位符變量的引用

w1 = graph.get_tensor_by_name(‘w1:0’)

op_to_restore = graph.get_tensor_by_name(“op_to_restore:0”)

使用不同的數(shù)據(jù)來運(yùn)行相同的網(wǎng)絡(luò) 則需要通過feed_dict來傳遞數(shù)據(jù)

with tf.Session() as sess:

  saver = tf.train.import_meta_graph(’test_model-1000.meta’)

  saver.restore(sess,tf.train.latest_checkpoint(‘./‘)) 

  graph = tf.get_default_graph() 

  w1 = graph.get_tensor_by_name(“w1:0”) 

  w2 = graph.get_tensor_by_name(“w2:0”) 

  feed_dict= {w1:13.0,w2:17.0} 

  op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’) 

  print (sess.run(op_to_restore,feed_dict)) 

如果想在原來的計(jì)算圖基礎(chǔ)上添加更多的操作和圖層春哨,并進(jìn)行訓(xùn)練

import tensorflow as tf

with tf.Session() as sess:

    saver = tf.train.import_meta_graph(‘my_test_model-1000.meta’)

    saver.restore(sess,tf.train.latest_checkpoint(‘./‘)) 

    graph = tf.get_default_graph() 

    w1 = graph.get_tensor_by_name(‘w1:0’) 

    w2 = graph.get_tensor_by_name(‘w2:0’) 

    feed_dict = {w1:13,w2:17} 

    op_to_restore = graph.get_tensor_by_name(‘op_to_restore:0’) 

    #新添加操作 

    add_on_op = tf.multiply(op_to_restore,2) 

    print (sess.run(add_on_op,feed_dict))
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末荆隘,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子悲靴,更是在濱河造成了極大的恐慌臭胜,老刑警劉巖莫其,帶你破解...
    沈念sama閱讀 216,544評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異耸三,居然都是意外死亡乱陡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,430評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門仪壮,熙熙樓的掌柜王于貴愁眉苦臉地迎上來憨颠,“玉大人,你說我怎么就攤上這事积锅∷” “怎么了?”我有些...
    開封第一講書人閱讀 162,764評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵缚陷,是天一觀的道長(zhǎng)适篙。 經(jīng)常有香客問我,道長(zhǎng)箫爷,這世上最難降的妖魔是什么嚷节? 我笑而不...
    開封第一講書人閱讀 58,193評(píng)論 1 292
  • 正文 為了忘掉前任,我火速辦了婚禮虎锚,結(jié)果婚禮上硫痰,老公的妹妹穿的比我還像新娘。我一直安慰自己窜护,他們只是感情好效斑,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,216評(píng)論 6 388
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著柱徙,像睡著了一般缓屠。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上护侮,一...
    開封第一講書人閱讀 51,182評(píng)論 1 299
  • 那天藏研,我揣著相機(jī)與錄音,去河邊找鬼概行。 笑死蠢挡,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的凳忙。 我是一名探鬼主播业踏,決...
    沈念sama閱讀 40,063評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼涧卵!你這毒婦竟也來了勤家?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 38,917評(píng)論 0 274
  • 序言:老撾萬榮一對(duì)情侶失蹤柳恐,失蹤者是張志新(化名)和其女友劉穎伐脖,沒想到半個(gè)月后热幔,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,329評(píng)論 1 310
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡讼庇,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,543評(píng)論 2 332
  • 正文 我和宋清朗相戀三年绎巨,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片蠕啄。...
    茶點(diǎn)故事閱讀 39,722評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡场勤,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出歼跟,到底是詐尸還是另有隱情和媳,我是刑警寧澤,帶...
    沈念sama閱讀 35,425評(píng)論 5 343
  • 正文 年R本政府宣布哈街,位于F島的核電站留瞳,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏骚秦。R本人自食惡果不足惜撼港,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,019評(píng)論 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望骤竹。 院中可真熱鬧,春花似錦往毡、人聲如沸蒙揣。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,671評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)懒震。三九已至,卻和暖如春嗤详,著一層夾襖步出監(jiān)牢的瞬間个扰,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,825評(píng)論 1 269
  • 我被黑心中介騙來泰國(guó)打工葱色, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留递宅,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,729評(píng)論 2 368
  • 正文 我出身青樓苍狰,卻偏偏與公主長(zhǎng)得像办龄,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子淋昭,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,614評(píng)論 2 353

推薦閱讀更多精彩內(nèi)容

  • Swift1> Swift和OC的區(qū)別1.1> Swift沒有地址/指針的概念1.2> 泛型1.3> 類型嚴(yán)謹(jǐn) 對(duì)...
    cosWriter閱讀 11,097評(píng)論 1 32
  • 書名貪婪的大腦:為何人類會(huì)無止境地尋求意義作者(英)丹尼爾·博爾(Daniel Bor)譯者林旭文豆瓣http:/...
    xuwensheng閱讀 15,226評(píng)論 8 54
  • 窗外俐填,下著雨。 今天找了一個(gè)很好的理由溜出去聽講座了翔忽,沒有學(xué)到什么東西英融,反而是看透了某些東西盏檐。我又...
    書言菡語(yǔ)閱讀 150評(píng)論 0 0
  • 被羅胖安利了好幾本書胡野,好些沒看完從去年拖到了今年,最近在看《聯(lián)盟》撩银,對(duì)書中提到的公司不是家的觀念深以為然给涕。 “我們...
    超級(jí)瑪小麗閱讀 9,520評(píng)論 56 171