Tensorflow 模型的保存和導(dǎo)入

我們用上一節(jié)讀取tfrecord文件得到的數(shù)據(jù)來建立一個(gè)簡(jiǎn)單的二分類神經(jīng)網(wǎng)絡(luò)染服。進(jìn)而介紹模型的保存和導(dǎo)入方法

# 簡(jiǎn)單的二分類神經(jīng)網(wǎng)絡(luò):?jiǎn)坞[藏層
# 每個(gè)樣本有5個(gè)features
# 隱藏層有10個(gè)神經(jīng)元,一個(gè)輸出




# -*- coding: UTF-8 -*-
#!/usr/bin/python3

# Env: python3.6

import tensorflow as tf
import numpy as np
import os

# path
data_filename = 'data/data_train.txt'
size = (10000, 5)
tfrecord_path = 'data/test_data.tfrecord'
# tfrecord_path2 = 'data/test_data2.tfrecord'
# generate data 10000*5, label: 0 or 1
# generate tfrecord named test_data.tfrecord.
def generate_data(data_filename = data_filename, size=size):
    if not os.path.exists(data_filename):
        np.random.seed(9)
        x_data = np.random.randint(0, 10, size = size)
        y1_data = np.ones((size[0]//2, 1), int)
        y2_data = np.zeros((size[0]//2, 1), int)
        y_data = np.append(y1_data, y2_data)
        np.random.shuffle(y_data)

        # stitching together x and y in one file
        xy_data = str('')
        for xy_row in range(len(x_data)):
            x_str = str('')
            for xy_col in range(len(x_data[0])):
                if not xy_col == (len(x_data[0])-1):
                    x_str =x_str+str(x_data[xy_row, xy_col])+' '
                else:
                    x_str = x_str + str(x_data[xy_row, xy_col])
            y_str = str(y_data[xy_row])
            xy_data = xy_data+(x_str+'/'+y_str + '\n')
        #print(xy_data[1])

        # write to txt
        write_txt = open(data_filename, 'w')
        write_txt.write(xy_data)
        write_txt.close()
    return

# obtain data from txt
# every line of data is just as follow: 1 2 3 4 5/1. train data: 1 2 3 4 5, label: 1
def txt_to_tfrecord(txt_filename = data_filename, tfrecord_path = tfrecord_path):
    # 第一步:生成TFRecord Writer
    writer = tf.python_io.TFRecordWriter(tfrecord_path)

    # 第二步:讀取TXT數(shù)據(jù),并分割出樣本數(shù)據(jù)和標(biāo)簽
    file = open(txt_filename)
    for data_line in file.readlines(): # 每一行
        data_line = data_line.strip('\n') # 去掉換行符
        sample = []
        spls = data_line.split('/', 1)[0]# 樣本
        for m in spls.split(' '):
            sample.append(int(m))
        label = data_line.split('/', 1)[1]# 標(biāo)簽
        label = int(label)
        print('sample:', sample, 'labels:', label)

        # 第三步: 建立feature字典,tf.train.Feature()對(duì)單一數(shù)據(jù)編碼成feature
        feature = {'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=sample)),
                   'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
        # 第四步:可以理解為將內(nèi)層多個(gè)feature的字典數(shù)據(jù)再編碼腻扇,集成為features
        features = tf.train.Features(feature = feature)
        # 第五步:將features數(shù)據(jù)封裝成特定的協(xié)議格式
        example = tf.train.Example(features=features)
        # 第六步:將example數(shù)據(jù)序列化為字符串
        Serialized = example.SerializeToString()
        # 第七步:將序列化的字符串?dāng)?shù)據(jù)寫入?yún)f(xié)議緩沖區(qū)
        writer.write(Serialized)
    # 記得關(guān)閉writer和open file的操作
    writer.close()
    file.close()
    return
#txt_to_tfrecord(data_filename, tfrecord_path2)

#  read tfrecord
def _parse_function(example_proto):
    dics = {  # 這里沒用default_value情竹,隨后的都是None
        'sample': tf.FixedLenFeature([5], tf.int64),  # 如果不是標(biāo)量届榄,一定要在這里說明數(shù)組的長(zhǎng)度
        'label': tf.FixedLenFeature([], tf.int64)}
    # 把序列化樣本和解析字典送入函數(shù)里得到解析的樣本
    parsed_example = tf.parse_single_example(example_proto, dics)

    parsed_example['sample'] = tf.cast(parsed_example['sample'], tf.float32)
    parsed_example['label'] = tf.cast(parsed_example['label'], tf.float32)

    # 返回所有feature
    return parsed_example

def read_dataset(tfrecord_path = tfrecord_path):
    # 聲明閱讀器
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    # 建立解析函數(shù)
    new_dataset = dataset.map(_parse_function)
    # 打亂樣本順序
    shuffle_dataset = new_dataset.shuffle(buffer_size=20000)
    # batch輸出
    batch_dataset = shuffle_dataset.batch(2)
    # 建立迭代器
    iterator = batch_dataset.make_one_shot_iterator()
    # 獲得下一個(gè)樣本
    next_element = iterator.get_next()
    x_samples = next_element['sample']
    y_labels = next_element['label']
    return x_samples, y_labels

def weight_bias_variable(weight_shape, bias_shape):
    weight = tf.get_variable('weight', weight_shape, initializer=tf.random_normal_initializer(mean=0, stddev=1))
    bias = tf.get_variable('bias', bias_shape, initializer=tf.random_normal_initializer(mean=0, stddev=1))
    return weight, bias

# neural network:
# input layer: 5 features with on sample
# one hidden layer: 10 neuron
# output: y_out

################      fetch data    ####################
with tf.variable_scope('input_data'):
    x_samples, y_labels = read_dataset()

with tf.variable_scope('hidden_layer1', reuse=tf.AUTO_REUSE):
    w1, b1 = weight_bias_variable(weight_shape=[5, 10], bias_shape=[10])
    y_hidden = tf.nn.relu(tf.matmul(x_samples, w1) + b1)
    tf.summary.histogram('w1', w1)
    tf.summary.histogram('b1', b1)

with tf.variable_scope('output_layer', reuse=tf.AUTO_REUSE):
    w2, b2 = weight_bias_variable(weight_shape=[10, 1], bias_shape=[1])
    y_out = tf.matmul(y_hidden, w2) + b2
    y_out = tf.reshape(y_out, [-1])
    tf.summary.histogram('w2', w2)
    tf.summary.histogram('b2', b2)
with tf.variable_scope('loss_function'):
    # ################     Loss Function
    #  這里的sigmoid是對(duì)y_out的激活函數(shù)
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_out, labels=y_labels, name=None)
    loss_mean = tf.reduce_mean(loss, 0)
    tf.summary.scalar('loss_mean', loss_mean)

    ################## BackPropagation
    # 創(chuàng)建基于梯度下降算法的Optimizer
    optimizer = tf.train.GradientDescentOptimizer(0.01)
    # 添加操作節(jié)點(diǎn)莉兰,用于最小化loss挑围,并更新var_list
    # 該函數(shù)是簡(jiǎn)單的合并了compute_gradients()與apply_gradients()函數(shù)
    # 返回為一個(gè)優(yōu)化更新后的var_list
    train = optimizer.minimize(loss_mean)
save_path = 'data/save/b2.txt'
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    # 建立tensorbord
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter('data/tfboard', sess.graph)

    saver = tf.train.Saver()
    for i in range(2000):
        sess.run(train)
        summary = sess.run(merged)
        writer.add_summary(summary, i)

        if i % 1000 == 0:
            print(' ############### step = %d   ############     ' %i)
            print('b2: ', sess.run(b2))
            # 用官網(wǎng)介紹的checkpoint方式保存模型
            # 創(chuàng)建saver對(duì)象,默認(rèn)max_to_keep=5糖荒,保存最近5次的模型杉辙。
            saver.save(sess, 'data/tmp/model', global_step=1000) # 保存第1000步的模型

    # 將變量保存到文件(這里也可以創(chuàng)建字典,將所有變量寫成tfrecord文件
    # 用sess.run就是將tensor數(shù)據(jù)轉(zhuǎn)為python數(shù)據(jù)捶朵,然后進(jìn)行保存
    b2_save = sess.run(b2)
    print('TXT b2 save:', b2_save)
    np.savetxt(save_path, b2_save)

    writer.close()

    coord.request_stop()
    coord.join(threads)

##### checkpoint 恢復(fù)模型


# 為了區(qū)分蜘矢,我們?cè)俳⒁粋€(gè)session

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # 恢復(fù)模型,這是一個(gè) protocol buffer保存了完整的Tensorflow圖综看,即所有變量品腹、操作和集合等。擁有一個(gè).meta
    last_ckpt = saver.last_checkpoints  # 得到保存模型的路徑
    saver_restore = tf.train.import_meta_graph(os.path.join(last_ckpt[0] + '.meta'))
    # 用 checkpoint 恢復(fù)模型參數(shù)
    saver_restore.restore(sess, last_ckpt[0]) # method 1
    print('methond1:ckpt: ', sess.run(b2))  # 要知道參數(shù)名
    saver.restore(sess, last_ckpt[0])  # method 2
    print('methond2:ckpt: ', sess.run(b2))  # 要知道參數(shù)名

    # 讀取TXT文檔恢復(fù)參數(shù)  # method 3
    b2_restore = np.loadtxt(save_path)
    b2_restore = tf.cast(b2_restore, tf.float32)  # numpy默認(rèn)float64而不是float32红碑,而TF中默認(rèn)時(shí)float32舞吭,才能用TF.RESHAPE()
    b2_restore = tf.reshape(b2_restore, [-1])  # b2是標(biāo)量,shape為[]析珊。要求tensor時(shí)必須給標(biāo)量擴(kuò)維度
    print('TXT b2_restore:', sess.run(b2.assign(b2_restore)))
    # 或者 寫成:
    # print('TXT b2_restore:', sess.run(tf.assign(b2, b2_restore)))
    print(sess.run(b2))

tf.summary()可以利用Tensorboard將網(wǎng)絡(luò)可視化羡鸥,在終端輸入summary的路徑:

$ tensorboard --logdir=/Users/username/PycharmProjects/firsttensorflow/readtf/data/tfboard

終端輸出為:TensorBoard 1.9.0 at http://pc-171-10-100-190.cm.vtr.net:6006 (Press CTRL+C to quit)

我們需要在網(wǎng)頁(yè)中輸入:http://localhost:6006
tensorboard默認(rèn)是在scalars(標(biāo)量)界面:


這里顯示的是損失函數(shù)在2000次訓(xùn)練過程中的變化趨勢(shì)。(因?yàn)槭请S機(jī)生成的數(shù)據(jù)忠寻,這里網(wǎng)絡(luò)性能并沒有考慮)

切換到GRAPHS界面惧浴,可以看到網(wǎng)絡(luò)結(jié)構(gòu):

可以雙擊每個(gè)scope查看里面的tensor flow和operation。

傳送:

  1. 神經(jīng)網(wǎng)絡(luò)激活函數(shù)
  2. tensorflow中常用的神將網(wǎng)絡(luò)函數(shù)
  3. 梯度下降算法
  4. checkpoint方法保存加載模型的Blog
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末锡溯,一起剝皮案震驚了整個(gè)濱河市赶舆,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌祭饭,老刑警劉巖芜茵,帶你破解...
    沈念sama閱讀 218,755評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異倡蝙,居然都是意外死亡九串,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,305評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門寺鸥,熙熙樓的掌柜王于貴愁眉苦臉地迎上來猪钮,“玉大人,你說我怎么就攤上這事胆建】镜停” “怎么了?”我有些...
    開封第一講書人閱讀 165,138評(píng)論 0 355
  • 文/不壞的土叔 我叫張陵笆载,是天一觀的道長(zhǎng)扑馁。 經(jīng)常有香客問我涯呻,道長(zhǎng),這世上最難降的妖魔是什么腻要? 我笑而不...
    開封第一講書人閱讀 58,791評(píng)論 1 295
  • 正文 為了忘掉前任复罐,我火速辦了婚禮,結(jié)果婚禮上雄家,老公的妹妹穿的比我還像新娘效诅。我一直安慰自己,他們只是感情好趟济,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,794評(píng)論 6 392
  • 文/花漫 我一把揭開白布乱投。 她就那樣靜靜地躺著,像睡著了一般咙好。 火紅的嫁衣襯著肌膚如雪篡腌。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,631評(píng)論 1 305
  • 那天勾效,我揣著相機(jī)與錄音,去河邊找鬼叛甫。 笑死层宫,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的其监。 我是一名探鬼主播萌腿,決...
    沈念sama閱讀 40,362評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼抖苦!你這毒婦竟也來了毁菱?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,264評(píng)論 0 276
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤锌历,失蹤者是張志新(化名)和其女友劉穎贮庞,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體究西,經(jīng)...
    沈念sama閱讀 45,724評(píng)論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡窗慎,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了卤材。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片遮斥。...
    茶點(diǎn)故事閱讀 40,040評(píng)論 1 350
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖扇丛,靈堂內(nèi)的尸體忽然破棺而出术吗,到底是詐尸還是另有隱情,我是刑警寧澤帆精,帶...
    沈念sama閱讀 35,742評(píng)論 5 346
  • 正文 年R本政府宣布较屿,位于F島的核電站材蹬,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏吝镣。R本人自食惡果不足惜堤器,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,364評(píng)論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望末贾。 院中可真熱鬧闸溃,春花似錦、人聲如沸拱撵。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,944評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)拴测。三九已至乓旗,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間集索,已是汗流浹背屿愚。 一陣腳步聲響...
    開封第一講書人閱讀 33,060評(píng)論 1 270
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留务荆,地道東北人妆距。 一個(gè)月前我還...
    沈念sama閱讀 48,247評(píng)論 3 371
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像函匕,于是被迫代替她去往敵國(guó)和親娱据。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,979評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容

  • 文章主要分為:一盅惜、深度學(xué)習(xí)概念中剩;二、國(guó)內(nèi)外研究現(xiàn)狀抒寂;三结啼、深度學(xué)習(xí)模型結(jié)構(gòu);四蓬推、深度學(xué)習(xí)訓(xùn)練算法妆棒;五、深度學(xué)習(xí)的優(yōu)點(diǎn)...
    艾剪疏閱讀 21,835評(píng)論 0 58
  • 與 TensorFlow 的初次相遇 https://jorditorres.org/wp-content/upl...
    布客飛龍閱讀 3,948評(píng)論 2 89
  • 如果叫你說岀一件最需要意志力的事沸伏,你第一個(gè)想到的是什么糕珊?,對(duì)大多數(shù)人來說毅糟,最大的考驗(yàn)?zāi)^于扺制誘惑红选,扺制來自甜甜圈...
    鏟屎官88閱讀 196評(píng)論 0 0
  • 周一食譜 周二食譜 周三食譜 周四食譜 周五食譜
    高冷楠閱讀 100評(píng)論 0 0