tensorflow的圖基本操作和怎么通過restore操作tensorflow的圖

tenosflow的圖操作比較重要

通過圖形操作可以讓對圖有跟進一步了解
上一個簡單的訓(xùn)練的代碼

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#載入數(shù)據(jù)集
mnist = input_data.read_data_sets(r'E:\python\mnist_data', one_hot=True)
#每個批次100張照片
batch_size = 100
#計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#創(chuàng)建一個簡單的神經(jīng)網(wǎng)絡(luò)人断,輸入層784個神經(jīng)元,輸出層10個神經(jīng)元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#二次代價函數(shù)
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化變量
init = tf.global_variables_initializer()

#結(jié)果存放在一個布爾型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
#求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(11):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
    #保存模型
    saver.save(sess,'net/my_net.ckpt')

結(jié)果是:

0.8241

再把圖restore

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file('net/my_net.ckpt', None, True,True)
#載入數(shù)據(jù)集
mnist = input_data.read_data_sets(r'E:\python\mnist_data', one_hot=True)

#每個批次100張照片
batch_size = 100
#計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#創(chuàng)建一個簡單的神經(jīng)網(wǎng)絡(luò)垒探,輸入層784個神經(jīng)元虫溜,輸出層10個神經(jīng)元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#二次代價函數(shù)
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化變量
init = tf.global_variables_initializer()

#結(jié)果存放在一個布爾型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
#求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
    saver.restore(sess,'net/my_net.ckpt')
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

結(jié)果是:

0.8241

高潮來了

上面的restore需要再寫一遍圖就是
with tf.Session() as sess:前面所有的代碼:下面這的不需要

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
reader = tf.train.NewCheckpointReader('net/my_net.ckpt').get_variable_to_shape_map()
for variable  in reader:#遍歷變量的名稱和維度
    print(variable  列肢,reader[variable])
print_tensors_in_checkpoint_file('net/my_net.ckpt', None, True,True)
#載入數(shù)據(jù)集
mnist = input_data.read_data_sets(r'E:\python\mnist_data', one_hot=True)

saver = tf.train.import_meta_graph('net/my_net.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'net/my_net.ckpt')
    g =  tf.get_default_graph()#獲取圖
    op = g.get_operations()#獲取圖中的操作,主要是為了查看參數(shù)名字(為查看未命名參數(shù))
#    print(g,op)
    sofmax = tf.get_default_graph().get_tensor_by_name('Softmax_5:0')
    x = tf.get_default_graph().get_tensor_by_name('Placeholder:0')
    accuracy = tf.get_default_graph().get_tensor_by_name('Mean_1:0')
    y = tf.get_default_graph().get_tensor_by_name('Placeholder_1:0')
    print(sofmax,x)
    print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))

結(jié)果是:

0.8241

你需要通過op找到sofmax 员萍,x靡羡,accuracy瞒大,y的tensor的名稱螃征,再 用tf.get_default_graph().get_tensor_by_name('Mean_1:0')得到tensor實體,
這句saver = tf.train.import_meta_graph('net/my_net.ckpt.meta')直接加載圖透敌,這樣極大簡化了restore編寫盯滚。
當(dāng)然你在編寫代碼時給sofmax ,x酗电,accuracy魄藕,y取一個方便的名字會更加方便。
這句print_tensors_in_checkpoint_file('net/my_net.ckpt', None, True,True)是查看checkpoint的參數(shù)撵术,也就是weight背率。

還有一個save成一個文件的方法,保存為pb文件:

看了(公輸睚信)的博客改成了下面的格式

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util

#載入數(shù)據(jù)集
mnist = input_data.read_data_sets(r'E:\python\mnist_data', one_hot=True)
#每個批次100張照片
batch_size = 100
#計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#創(chuàng)建一個簡單的神經(jīng)網(wǎng)絡(luò),輸入層784個神經(jīng)元寝姿,輸出層10個神經(jīng)元
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b) 

#二次代價函數(shù)
# loss = tf.reduce_mean(tf.square(y-prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化變量
init = tf.global_variables_initializer()

#結(jié)果存放在一個布爾型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
#求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(1):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
    #保存模型
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['Softmax','Placeholder','Mean'] )
    with tf.gfile.GFile('net/my_net.pb', 'wb') as fid:
        serialized_graph = output_graph_def.SerializeToString()
        fid.write(serialized_graph)

restore已經(jīng)保存的pb文件:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
#載入數(shù)據(jù)集
mnist = input_data.read_data_sets(r'E:\python\mnist_data', one_hot=True)
import os

def load_model(path_to_model):
    if not os.path.exists(path_to_model):
        raise ValueError("'path_to_model.pb' is not exist.")

    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_to_model, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return model_graph
model_graph = load_model('net/my_net.pb')
with model_graph.as_default():
    with tf.Session(graph=model_graph) as sess:
        sofmax = model_graph.get_tensor_by_name('Softmax:0')
        x= model_graph.get_tensor_by_name('Placeholder:0')
        accuracy= model_graph.get_tensor_by_name('Mean:0')
        y = model_graph.get_tensor_by_name('Placeholder_1:0')
        print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末交排,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子饵筑,更是在濱河造成了極大的恐慌个粱,老刑警劉巖,帶你破解...
    沈念sama閱讀 221,198評論 6 514
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件翻翩,死亡現(xiàn)場離奇詭異,居然都是意外死亡稻薇,警方通過查閱死者的電腦和手機嫂冻,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 94,334評論 3 398
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來塞椎,“玉大人桨仿,你說我怎么就攤上這事“负荩” “怎么了服傍?”我有些...
    開封第一講書人閱讀 167,643評論 0 360
  • 文/不壞的土叔 我叫張陵,是天一觀的道長骂铁。 經(jīng)常有香客問我吹零,道長,這世上最難降的妖魔是什么拉庵? 我笑而不...
    開封第一講書人閱讀 59,495評論 1 296
  • 正文 為了忘掉前任灿椅,我火速辦了婚禮,結(jié)果婚禮上钞支,老公的妹妹穿的比我還像新娘茫蛹。我一直安慰自己,他們只是感情好烁挟,可當(dāng)我...
    茶點故事閱讀 68,502評論 6 397
  • 文/花漫 我一把揭開白布婴洼。 她就那樣靜靜地躺著,像睡著了一般撼嗓。 火紅的嫁衣襯著肌膚如雪柬采。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 52,156評論 1 308
  • 那天且警,我揣著相機與錄音警没,去河邊找鬼。 笑死振湾,一個胖子當(dāng)著我的面吹牛杀迹,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 40,743評論 3 421
  • 文/蒼蘭香墨 我猛地睜開眼树酪,長吁一口氣:“原來是場噩夢啊……” “哼浅碾!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起续语,我...
    開封第一講書人閱讀 39,659評論 0 276
  • 序言:老撾萬榮一對情侶失蹤垂谢,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后疮茄,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體滥朱,經(jīng)...
    沈念sama閱讀 46,200評論 1 319
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 38,282評論 3 340
  • 正文 我和宋清朗相戀三年力试,在試婚紗的時候發(fā)現(xiàn)自己被綠了徙邻。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,424評論 1 352
  • 序言:一個原本活蹦亂跳的男人離奇死亡畸裳,死狀恐怖缰犁,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情怖糊,我是刑警寧澤帅容,帶...
    沈念sama閱讀 36,107評論 5 349
  • 正文 年R本政府宣布,位于F島的核電站伍伤,受9級特大地震影響并徘,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜扰魂,卻給世界環(huán)境...
    茶點故事閱讀 41,789評論 3 333
  • 文/蒙蒙 一饮亏、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧阅爽,春花似錦路幸、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,264評論 0 23
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至百侧,卻和暖如春砰识,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背佣渴。 一陣腳步聲響...
    開封第一講書人閱讀 33,390評論 1 271
  • 我被黑心中介騙來泰國打工辫狼, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人辛润。 一個月前我還...
    沈念sama閱讀 48,798評論 3 376
  • 正文 我出身青樓膨处,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子真椿,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 45,435評論 2 359

推薦閱讀更多精彩內(nèi)容

  • 這篇文章是針對有tensorflow基礎(chǔ)但是記不住復(fù)雜變量函數(shù)的讀者鹃答,文章列舉了從輸入變量到前向傳播,反向優(yōu)化突硝,數(shù)...
    horsetif閱讀 1,176評論 0 1
  • 在這篇tensorflow教程中测摔,我會解釋: 1) Tensorflow的模型(model)長什么樣子? 2) 如...
    JunsorPeng閱讀 3,429評論 1 6
  • 該文章為轉(zhuǎn)載文章解恰,作者簡介:汪劍锋八,現(xiàn)在在出門問問負責(zé)推薦與個性化。曾在微軟雅虎工作护盈,從事過搜索和推薦相關(guān)工作挟纱。 T...
    名字真的不重要閱讀 5,287評論 0 3
  • 卜算子*小暑逢七七感懷 文/無痕 小暑雨無常, 伏日云多疊黄琼。 又遇盧溝七七恥, 怒火燒心熱整慎。 石獅點點痕脏款, 橋路蕭...
    閑扯者閱讀 597評論 0 1
  • 送友人 土黃的葉子 輕舞飛下 是流年的回首 傳奇了季節(jié)的神話 刻畫你的額頭 梳染你的白發(fā) 多想多想 把殘陽摘下 放...
    酸黃連閱讀 463評論 0 0