TensorFlow高階API Estimator自定義模型解決圖像分類問題

在之前的文章中利职,我們利用silm工具和谷歌訓(xùn)練好的inception-v3模型完成了一個(gè)花朵圖像分類問題,但代碼還是比較繁瑣。為了更精簡(jiǎn)的代碼和提高可讀性作箍,這一次我們利用TensorFlow提供的高階API Estimator來解決同樣的問題坑质。同時(shí)合武,在最后临梗,我們會(huì)把訓(xùn)練過程中的參數(shù)變化通過TensorBoard展示出來。

Estimator

Estimator是TensorFlow官方提供的一個(gè)高層API稼跳,它更好的整合了原生態(tài)TensorFlow提供的功能盟庞。它可以極大簡(jiǎn)化機(jī)器學(xué)習(xí)編程。下面來看一下TensorFlow API結(jié)構(gòu):


API Architecture

在官方文檔中汤善,有這么一句話:

We strongly recommend writing TensorFlow programs with the following APIs:

  • Estimators, which represent a complete model. The Estimator API provides methods to train the model, to judge the model's accuracy, and to generate predictions.
  • Datasets for Estimators, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed it into your model. The Dataset API meshes well with the Estimators API.

可以看到Estimator和Dataset這兩個(gè)API是官方強(qiáng)烈推薦的什猖。Estimator提供了預(yù)創(chuàng)建的DNN模型,使用起來非常方便红淡。具體怎么使用Estimator預(yù)創(chuàng)建模型不狮,官方文檔里面也有寫,有興趣的可以去看Estimator官方在旱。
但是預(yù)先定義的Estimator功能有限摇零,比如目前無法很好的實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)和循環(huán)神經(jīng)網(wǎng)絡(luò),也沒有辦法支持自定義的損失函數(shù)桶蝎,所以為了更好的使用Estimator驻仅,這篇文章會(huì)教大家怎么用Estimator自定義CNN模型,以及如何配合Dataset讀取圖片數(shù)據(jù)登渣。

數(shù)據(jù)準(zhǔn)備

在這里我們可以使用之前的谷歌提供的花朵分類數(shù)據(jù)集噪服,也可以使用其它的。為了區(qū)分上次結(jié)果這次我們使用新的數(shù)據(jù)集胜茧。在這里我使用百度挑桃分類數(shù)據(jù)集芯咧。下載解壓后可以看到是這樣的目錄:
數(shù)據(jù)集

數(shù)據(jù)集已經(jīng)幫我們劃分好了是訓(xùn)練還是測(cè)試。每一個(gè)文件夾代表一種桃子竹揍,總共有4種桃子(這個(gè)數(shù)據(jù)集肉眼很難辨別敬飒,可能是因?yàn)槲也粔驅(qū)I(yè)-_-)。

數(shù)據(jù)預(yù)處理

我們還是像之前一樣對(duì)數(shù)據(jù)預(yù)處理芬位。在工程目錄下新建select_peach_data.py文件无拗。跟之前處理花朵分類的時(shí)候一樣所以這里直接粘貼代碼:

import glob
import os.path
import numpy as np
import tensorflow as tf

from tensorflow.python.platform import gfile

#輸入圖片地址
INPUT_ALL_DATA = './select_peach'
INPUT_TRAIN_DATA = './select_peach/train'
INPUT_TEST_DATA = './select_peach/test'
OUTPUT_TRAIN_FILE = './path/to/output_train.tfrecords'
OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

#檢索目錄并提取目錄圖片文件生成TFRecords
def get_img_data(sub_dirs,writer,INPUT_DATA,sess):
    current_label = 0
    is_root_dir = True
    print("文件地址: "+INPUT_DATA)
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir = False
            continue
        file_list = []
        dir_name = os.path.basename(sub_dir)

        file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "png")
        # extend合并兩個(gè)數(shù)組
        # glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
        # 比如:glob.glob(r’c:*.txt’) 這里就是獲得C盤下的所有txt文件
        file_list.extend(glob.glob(file_glob))
        if not file_list: continue
        # print('file_list',current_label)
        # 處理圖片數(shù)據(jù)
        index = 0
        for file_name in file_list:
            # 讀取并解析圖片 講圖片轉(zhuǎn)化成299*299方便模型處理
            image_raw_data = gfile.FastGFile(file_name, 'rb').read()
            image = tf.image.decode_png(image_raw_data)
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            image = tf.image.resize_images(image, [299, 299])
            image_value = sess.run(image)
            pixels = image_value.shape[1]
            image_raw = image_value.tostring()
            # 存到features
            example = tf.train.Example(features=tf.train.Features(feature={
                'pixels': _int64_feature(pixels),
                'label': _int64_feature(current_label),
                'image_raw': _bytes_feature(image_raw)
            }))
            chance = np.random.randint(100)
            # 寫入訓(xùn)練集
            writer.write(example.SerializeToString())
            index = index + 1
            if index == 400:
                break
            print("處理文件索引%d index%d"%(current_label,index))
        current_label += 1
#讀取數(shù)據(jù)并將數(shù)據(jù)分割成訓(xùn)練數(shù)據(jù)、驗(yàn)證數(shù)據(jù)和測(cè)試數(shù)據(jù)
def create_image_lists(sess):

    #首先處理訓(xùn)練數(shù)據(jù)集
    sub_dirs = [x[0] for x in os.walk(INPUT_TRAIN_DATA)]
    writer_train = tf.python_io.TFRecordWriter(OUTPUT_TRAIN_FILE)
    get_img_data(sub_dirs,writer_train,INPUT_TRAIN_DATA,sess)

    sub_test_dirs = [x[0] for x in os.walk(INPUT_TEST_DATA)]
    writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
    get_img_data(sub_test_dirs,writer_test,INPUT_TEST_DATA,sess)

    writer_train.close()
    writer_test.close()

def main():
    with tf.Session() as sess:
        create_image_lists(sess)
        print('success')

if __name__ == '__main__':
    main()

這里因?yàn)閠est和train已經(jīng)在文件夾上作了區(qū)分昧碉,所以這里我利用兩個(gè)TFRecordWriter來把數(shù)據(jù)分別寫入兩個(gè)TFRecord英染。為了節(jié)省時(shí)間在這里我并沒有利用全部的訓(xùn)練數(shù)據(jù),只是加載了其中的400份被饿。當(dāng)然在真實(shí)的訓(xùn)練場(chǎng)景下你是需要加載全部的數(shù)據(jù)的四康。
代碼沒有詳盡的注釋,因?yàn)楹椭暗奶幚泶蟛糠侄际且粯拥南廖眨磺宄目梢匀タ次抑暗奈恼隆?a href="http://www.reibang.com/p/fc77879d3591" target="_blank">inception-v3闪金。

自定義Estimator

下面我們開始步入主題。先看一張Estimator類組成圖。

Estimator

以下源自官方文檔的一段話:
Pre-made Estimators are fully baked. Sometimes though, you need more control over an Estimator's behavior. That's where custom Estimators come in. You can create a custom Estimator to do just about anything. If you want hidden layers connected in some unusual fashion, write a custom Estimator. If you want to calculate a unique metric for your model, write a custom Estimator. Basically, if you want an Estimator optimized for your specific problem, write a custom Estimator.

A model function (or model_fn) implements the ML algorithm. The only difference between working with pre-made Estimators and custom Estimators is:

  • With pre-made Estimators, someone already wrote the model function for you.
  • With custom Estimators, you must write the model function.

Your model function could implement a wide range of algorithms, defining all sorts of hidden layers and metrics. Like input functions, all model functions must accept a standard group of input parameters and return a standard group of output values. Just as input functions can leverage the Dataset API, model functions can leverage the Layers API and the Metrics API.

大概意思是:預(yù)創(chuàng)建的 Estimator 是 tf.estimator.Estimator 基類的子類哎垦,而自定義 Estimator 是 tf.estimator.Estimator 的實(shí)例囱嫩。
Pre-made Estimators和custom Estimators差異主要在于tensorflow中是否有它們可以直接使用的模型函數(shù)(model function or model_fn)的實(shí)現(xiàn)。對(duì)于前者漏设,tensorflow中已經(jīng)有寫好的model function墨闲,因而直接調(diào)用即可;而后者的model function需要自己編寫郑口。因此鸳碧,Pre-made Estimators使用方便,但使用范圍小犬性,靈活性差杆兵;custom Estimators則正好相反。

總體來說仔夺,模型是由三部分構(gòu)成:Input functions琐脏、Model functions 和Estimators(評(píng)估控制器,main function)缸兔。

  • Input functions:主要是由Dataset API組成日裙,可以分為train_input_fn和eval_input_fn。前者的任務(wù)(行為)是接受參數(shù)惰蜜,輸出數(shù)據(jù)訓(xùn)練數(shù)據(jù)昂拂,后者的任務(wù)(行為)是接受參數(shù),并輸出驗(yàn)證數(shù)據(jù)和測(cè)試數(shù)據(jù)抛猖。
  • Model functions:是由模型(the Layers API )和監(jiān)控模塊( the Metrics API)組成格侯,主要是實(shí)現(xiàn)模型的訓(xùn)練、測(cè)試(驗(yàn)證)和監(jiān)控顯示模型參數(shù)狀況的功能财著。
  • Estimators:在模型中的作用類似于計(jì)算機(jī)中的操作系統(tǒng)联四。它將各個(gè)部分“粘合”起來,控制數(shù)據(jù)在模型中的流動(dòng)與變換撑教,同時(shí)控制模型的的各種行為(運(yùn)算)朝墩。

在得知以上知識(shí)以后,我們可以開始動(dòng)手編碼起來伟姐。通過以上內(nèi)容得知收苏,首先我們需要先創(chuàng)建自定義的Model functions。下面新建my_estimator文件愤兵。
由于我們這里是實(shí)現(xiàn)自定義的model_fn函數(shù)鹿霸,而model_fn主要功能是定義模型的結(jié)構(gòu),損失函數(shù)以及優(yōu)化器秆乳。還會(huì)對(duì)預(yù)測(cè)和評(píng)測(cè)進(jìn)行處理懦鼠。綜上我們來完成model_fn的編寫。

自定義model_fn

#導(dǎo)入相關(guān)庫
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加載通過TensorFlow-Silm定義好的 inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

#圖片數(shù)據(jù)地址
TRAIN_DATA = './path/to/output_train.tfrecords'
TEST_DATA = './path/to/output_test.tfrecords'

shuffle_buffer = 10000
BATCH = 64
#打開 estimator 日志
tf.logging.set_verbosity(tf.logging.INFO)

#自定義模型
#這里我們提供了兩種方案。一種是直接通過slim工具定義已有模型
#另一種是通過tf.layer更加靈活地定義神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)
def inception_v3_model(image,is_training):
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        predictions,_ = inception_v3.inception_v3(image,num_classes=5)
        return predictions
#定義lenet5模型
def lenet5(x,is_training):
    net = tf.layers.conv2d(x,32,5,activation=tf.nn.relu)
    net = tf.layers.max_pooling2d(net,2,2)
    net = tf.layers.conv2d(net,64,3,activation=tf.nn.relu)
    net = tf.layers.max_pooling2d(net,2,2)
    net = tf.contrib.layers.flatten(net)
    net = tf.layers.dense(net,1024)
    net = tf.layers.dropout(net,rate=0.4,training=is_training)
    return tf.layers.dense(net,5)
#自定義Estimator中使用的模型葛闷。定義的函數(shù)有4個(gè)收入憋槐,
#features給出在輸入函數(shù)中會(huì)提供的輸入層張量双藕。這是個(gè)字典
#字典通過input_fn提供淑趾。如果是系統(tǒng)的輸入
#系統(tǒng)會(huì)提供tf.estimator.inputs.numpy_input_fn中的x參數(shù)指定內(nèi)容
#labels是正確答案,通過numpy_input_fn的y參數(shù)給出
#在這里我們用dataset來自定義輸入函數(shù)忧陪。
#mode取值有3種可能扣泊,分別對(duì)應(yīng)Estimator的train,evaluate,predict這三個(gè)函數(shù)
#mode參數(shù)可以判斷當(dāng)前是訓(xùn)練,預(yù)測(cè)還是驗(yàn)證模式嘶摊。
#最有一個(gè)參數(shù)param也是字典延蟹,里面是有關(guān)于這個(gè)模型的相關(guān)任何超參數(shù)(學(xué)習(xí)率)
def model_fn(features,labels,mode,params):
    predict = lenet5(features,mode == tf.estimator.ModeKeys.TRAIN)
    #如果是預(yù)測(cè)模式,直接返回結(jié)果
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions={"result":tf.argmax(predict,1)}
        )
  #定義損失函數(shù)叶堆,這里使用tf.losses可以直接從tf.losses.get_total_loss()拿到損失
    tf.losses.softmax_cross_entropy(tf.one_hot(labels, 5), predict, weights=1.0)

    #優(yōu)化器
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"])
    #定義訓(xùn)練過程阱飘。傳入global_step的目的,為了在TensorBoard中顯示圖像的橫坐標(biāo)
    train_op = optimizer.minimize(
        loss=tf.losses.get_total_loss(),
        global_step=tf.train.get_global_step()
    )

    #定義評(píng)測(cè)標(biāo)準(zhǔn)
    #這個(gè)函數(shù)會(huì)在調(diào)用Estimator.evaluate的時(shí)候調(diào)用
    accuracy = tf.metrics.accuracy(
            predictions=tf.argmax(predict,1),
            labels=labels,
            name="acc_op"
    )
    eval_metric_ops = {
        "my_metric":accuracy
    }
    #用于向TensorBoard輸出準(zhǔn)確率圖像
    #如果你不需要使用TensorBoard可以不添加這行代碼
    tf.summary.scalar('accuracy', accuracy[1])
    #model_fn會(huì)返回一個(gè)EstimatorSpec
    #EstimatorSpec必須包含模型損失虱颗,訓(xùn)練函數(shù)沥匈。其它為可選項(xiàng)
    #eval_metric_ops用于定義調(diào)用Estimator.evaluate()時(shí)候所指定的函數(shù)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=tf.losses.get_total_loss(),
        train_op=train_op,
        eval_metric_ops=eval_metric_ops
    )

自定義Input functions

定義完了model functions接下來我們通過Dataset API來定義input functions:

#解析tfrecords
def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
            'pixels': tf.FixedLenFeature([], tf.int64)
        }
    )
    decoded_image = tf.decode_raw(features['image_raw'], tf.float16)
    label = features['label']
    return decoded_image, label
#從dataset中讀取訓(xùn)練數(shù)據(jù),這里和之前處理花朵分類的時(shí)候一樣
def my_input_fn(file):
    dataset = tf.data.TFRecordDataset([file])
    dataset = dataset.map(parse)
    dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)
    dataset = dataset.repeat(10)
    iterator = dataset.make_one_shot_iterator()
    batch_img,batch_labels = iterator.get_next()
    with tf.Session() as sess:
        batch_sess_img,batch_sess_labels = sess.run([batch_img,batch_labels])
        #這里需要特別注意 由于batch_sess_img這里是轉(zhuǎn)成了string后在原有長(zhǎng)度上增加了8倍
        #所以在這里我們要先轉(zhuǎn)成numpy然后再reshape要不然會(huì)報(bào)錯(cuò)
        batch_sess_img = np.fromstring(batch_sess_img, dtype=np.float32)
        #numpy轉(zhuǎn)換成Tensor
        batch_sess_img = tf.reshape(batch_sess_img, [BATCH, 299, 299, 3])
    return batch_sess_img,batch_sess_labels

在這里要注意忘渔,Estimator輸入函數(shù)要求每次被調(diào)用可以得到一個(gè)batch的數(shù)據(jù)高帖,包括所有的輸入層數(shù)據(jù)和正確答案標(biāo)注。而且my_input_fn函數(shù)并不能帶有參數(shù)畦粮。稍后我們會(huì)用lambda表達(dá)式解決這個(gè)問題散址。

最后我們通過main函數(shù)來啟動(dòng)訓(xùn)練過程:

def main():
    #定義超參數(shù)
    model_params = {"learning_rate":0.001}
    #定義訓(xùn)練的相關(guān)配置參數(shù)
    #keep_checkpoint_max=1表示在只在目錄下保存一份模型文件
    #log_step_count_steps=50表示每訓(xùn)練50次輸出一次損失的值
    run_config = tf.estimator.RunConfig(keep_checkpoint_max=1,log_step_count_steps=50)
    #通過tf.estimator.Estimator來生成自定義模型
    #把我們自定義的model_fn和超參數(shù)傳進(jìn)去
    #這里我們還傳入了持久化模型的目錄
    #estimator會(huì)自動(dòng)幫我們把模型持久化到這個(gè)目錄下
    estimator = tf.estimator.Estimator(model_fn=model_fn,params=model_params,model_dir="./path/model",config=run_config)
    #開始訓(xùn)練模型,這里說一下lambda表達(dá)式
    #lambda表達(dá)式會(huì)把函數(shù)原本的輸入?yún)?shù)變成0個(gè)或它指定的參數(shù)宣赔≡铮可以理解為函數(shù)的默認(rèn)值
    #這里傳入自定義輸入函數(shù),和訓(xùn)練的輪數(shù)
    estimator.train(input_fn=lambda :my_input_fn(TRAIN_DATA),steps=300)
    #訓(xùn)練完后進(jìn)行驗(yàn)證儒将,這里傳入我們的測(cè)試數(shù)據(jù)
    test_result = estimator.evaluate(input_fn=lambda :my_input_fn(TEST_DATA))
    #輸出測(cè)試驗(yàn)證結(jié)果
    accuracy_score = test_result["my_metric"]
    print("\nTest accuracy:%g %%"%(accuracy_score*100))

if __name__ == '__main__':
    main()

運(yùn)行程序师崎,可以看到如下輸出。因?yàn)槲疫@里是從367步以后繼續(xù)訓(xùn)練椅棺,所以我們?cè)谌罩局锌吹轿疫@里是直接加載了第367步保存的模型犁罩。
每隔一定時(shí)間,Estimator會(huì)自動(dòng)創(chuàng)建模型文件两疚。另外如果訓(xùn)練中斷床估,下一次再啟動(dòng)訓(xùn)練的話,Estimator會(huì)自動(dòng)從模型目錄下加載最新的模型并且用于訓(xùn)練诱渤,非常方便丐巫。這就是為什么谷歌推薦我們用Estimator來訓(xùn)練模型,因?yàn)樗庋b了很多開發(fā)者并不需要關(guān)心的操作,大大提升了我們的開發(fā)效率递胧。

INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./path/model/model.ckpt-367
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 368 into ./path/model/model.ckpt.
INFO:tensorflow:loss = 0.2994086, step = 368
INFO:tensorflow:global_step/sec: 0.116191
INFO:tensorflow:loss = 0.2086069, step = 418 (430.326 sec)
INFO:tensorflow:Saving checkpoints for 438 into ./path/model/model.ckpt.
INFO:tensorflow:global_step/sec: 0.115405
INFO:tensorflow:loss = 0.17857286, step = 468 (433.259 sec)
INFO:tensorflow:Saving checkpoints for 506 into ./path/model/model.ckpt.
INFO:tensorflow:global_step/sec: 0.111342
INFO:tensorflow:loss = 0.107850984, step = 518 (449.065 sec)
INFO:tensorflow:global_step/sec: 0.115999
INFO:tensorflow:loss = 0.08592671, step = 568 (431.040 sec)
INFO:tensorflow:Saving checkpoints for 575 into ./path/model/model.ckpt.
INFO:tensorflow:global_step/sec: 0.112465
INFO:tensorflow:loss = 0.05861471, step = 618 (444.587 sec)
INFO:tensorflow:Saving checkpoints for 643 into ./path/model/model.ckpt.

TensorBoard

為了更加直觀的看到訓(xùn)練過程碑韵,接下來我們將使用谷歌提供的一個(gè)工具TensorBoard來可視化我們的訓(xùn)練過程。
要啟動(dòng)TensorBoard缎脾,執(zhí)行下面的命令:

#PATH替換為你模型保存的目錄祝闻。要注意在這里用的是絕對(duì)路徑。
tensorboard --logdir=PATH

執(zhí)行命令后可以看到如下信息遗菠,說明TensorBoard已經(jīng)跑起來了联喘。

TensorBoard 1.8.0 at http://bogon:6006 (Press CTRL+C to quit)
W0817 16:14:27.129659 Reloader tf_logging.py:121] Found more than one graph event per run, or there was a metagraph containing a graph_def, as well as one or more graph events.  Overwriting the graph with the newest event.
W0817 16:14:27.650306 Reloader tf_lo

所有預(yù)創(chuàng)建的 Estimator 都會(huì)自動(dòng)將大量信息記錄到 TensorBoard 上。不過辙纬,對(duì)于自定義 Estimator豁遭,TensorBoard 只提供一個(gè)默認(rèn)日志(損失圖)以及您明確告知 TensorBoard 要記錄的信息。對(duì)于我們剛剛創(chuàng)建的自定義 Estimator贺拣,并且明確說明要繪制正確率的圖蓖谢,所以TensorBoard 會(huì)生成以下內(nèi)容:


TensorBoard.png

TensorBoard生成了三個(gè)圖。分別表示正確率譬涡,訓(xùn)練處理的批次闪幽,訓(xùn)練輪數(shù)所對(duì)應(yīng)的損失值

簡(jiǎn)而言之,下面是三張圖顯示的內(nèi)容:

  • global_step/sec:這是一個(gè)性能指標(biāo)昂儒,顯示我們?cè)谶M(jìn)行模型訓(xùn)練時(shí)每秒處理的批次數(shù)(梯度更新)沟使。
  • loss:所報(bào)告的損失。
  • accuracy:準(zhǔn)確率由下列兩行記錄:
    • eval_metric_ops={'my_accuracy': accuracy}(評(píng)估期間)渊跋。
    • tf.summary.scalar('accuracy', accuracy[1])(訓(xùn)練期間)腊嗡。
      這些 Tensorboard 圖是務(wù)必要將 global_step 傳遞給優(yōu)化器的 minimize 方法的主要原因之一。如果沒有它拾酝,模型就無法記錄這些圖的 x 坐標(biāo)燕少。

我們來看下TensorBoard的輸出≥锒冢可以看到隨著訓(xùn)練步驟的增加客们,loss在相應(yīng)的減少,accuracy也在慢慢增加材诽。這是一個(gè)健康的訓(xùn)練過程底挫。可以看到LeNet5在這個(gè)數(shù)據(jù)集上的正確率達(dá)到了95%左右脸侥。

eval

因?yàn)槲易远x的Estimator在訓(xùn)練結(jié)束之后并沒有輸出正確率(暫時(shí)沒找到原因)建邓,所以這里我們另外寫一個(gè)程序來測(cè)試這個(gè)模型的正確率。這里我們命名為eval.py睁枕。

import tensorflow as tf
import Estimator1
import numpy as np

TEST_DATA = './path/to/output_test.tfrecords'
CKPT_PATH = './path/model'
EVAL_BATCH = 20
def getValidationData():
   dataset = tf.data.TFRecordDataset([TEST_DATA])
   dataset = dataset.map(Estimator1.parse)
   dataset = dataset.batch(EVAL_BATCH)
   iterator = dataset.make_one_shot_iterator()
   batch_img, batch_labels = iterator.get_next()

   # batch_img作處理
   return batch_img, batch_labels
def my_eval():
   #estimator的eval方法不好使 用傳統(tǒng)方法試試
   batch_img,batch_labels = getValidationData()

   x = tf.placeholder(tf.float32, [None, 299,299,3], name='x-input')
   y_ = tf.placeholder(tf.int64, [None], name='y-input')
   y = Estimator1.lenet5(x, False)
   correct_prediction = tf.equal(tf.argmax(y, 1), y_)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
   saver = tf.train.Saver()
   with tf.Session() as sess:
       while True:
           try:
               ckpt = tf.train.get_checkpoint_state(CKPT_PATH)
               if ckpt and ckpt.model_checkpoint_path:
                   saver.restore(sess,ckpt.model_checkpoint_path)
                   #通過文件名得到模型保存時(shí)迭代的輪數(shù)
                   global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                   batch_sess_img, batch_sess_labels = sess.run([batch_img, batch_labels])
                   batch_sess_img = np.fromstring(batch_sess_img, dtype=np.float32)
                   batch_sess_img = tf.reshape(batch_sess_img, [EVAL_BATCH, 299, 299, 3])
                   batch_sess_img = sess.run(batch_sess_img)
                   print(sess.run([tf.argmax(y,1),y_],feed_dict={x:batch_sess_img,y_:batch_sess_labels}))
                   accuracy_score = sess.run(accuracy,feed_dict={x:batch_sess_img,y_:batch_sess_labels})
                   print("After %s training step(s),validation accuracy = %g"%(global_step,accuracy_score))
               else:
                   print('No checkpoint file found')
                   return
           except tf.errors.OutOfRangeError:
               break
def main():
   my_eval()

if __name__ == '__main__':
   main()

這個(gè)程序大概的作用是:
1.讀取測(cè)試數(shù)據(jù)官边,把測(cè)試數(shù)據(jù)打包成batch沸手。然后定義神經(jīng)網(wǎng)絡(luò)輸入變量x和正確答案的標(biāo)簽y_。
2.把x通過神經(jīng)網(wǎng)絡(luò)得到的前向傳播結(jié)果y和y_作比較來計(jì)算正確率注簿。
3.讀取之前訓(xùn)練好的模型契吉。
4.用一個(gè)while循環(huán)來輸出在訓(xùn)練好的模型上每一個(gè)batch的正確率,直到數(shù)據(jù)讀取完畢诡渴。
運(yùn)行這個(gè)程序可以得到以下輸出:

INFO:tensorflow:Restoring parameters from ./path/model/model.ckpt-643
[array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])]
After 643 training step(s),validation accuracy = 1
INFO:tensorflow:Restoring parameters from ./path/model/model.ckpt-643
[array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]), array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])]
After 643 training step(s),validation accuracy = 1
INFO:tensorflow:Restoring parameters from ./path/model/model.ckpt-643
[array([2, 2, 2, 2, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]), array([2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])]
After 643 training step(s),validation accuracy = 0.95
INFO:tensorflow:Restoring parameters from ./path/model/model.ckpt-643

嗯捐晶。60個(gè)數(shù)據(jù)中只有1個(gè)判斷錯(cuò)誤,也符合我們之前得到的正確率玩徊。

寫在最后

Estimator是TensorFlow官方強(qiáng)烈推薦的API租悄,通過上述程序大家也能看到相比傳統(tǒng)的TensorFlow API谨究,Estimator封裝了大部分與業(yè)務(wù)邏輯無關(guān)的操作恩袱,然而通過Custom Estimator,Estimator也不失靈活性胶哲。

我們之前還通過slim定義了一個(gè)inception-v3模型畔塔,但是由于inception-v3結(jié)構(gòu)比較復(fù)雜,訓(xùn)練的時(shí)間比較久所以這里我們就以LeNet-5作演示了鸯屿。但是在復(fù)雜的圖像分類問題上澈吨,比如ImageNet數(shù)據(jù)集中,LeNet-5的分類效果就不是很好寄摆。如果是復(fù)雜的圖像分類問題谅辣,就要選擇更加復(fù)雜的神經(jīng)網(wǎng)絡(luò)模型來訓(xùn)練才能達(dá)到較高的準(zhǔn)確率。

另外這篇文章主要是以使用Estimator為主婶恼,對(duì)于其中的一些細(xì)節(jié)沒有很好的闡述桑阶。之后的文章會(huì)對(duì)一些技術(shù)細(xì)節(jié)做探究。

歡迎廣大喜歡AI的開發(fā)者互相交流勾邦,有問題也可以在評(píng)論區(qū)里留言蚣录,大家互相討論,一起進(jìn)步眷篇。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末萎河,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子蕉饼,更是在濱河造成了極大的恐慌虐杯,老刑警劉巖,帶你破解...
    沈念sama閱讀 211,042評(píng)論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件昧港,死亡現(xiàn)場(chǎng)離奇詭異擎椰,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)慨飘,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 89,996評(píng)論 2 384
  • 文/潘曉璐 我一進(jìn)店門确憨,熙熙樓的掌柜王于貴愁眉苦臉地迎上來译荞,“玉大人,你說我怎么就攤上這事休弃⊥碳撸” “怎么了?”我有些...
    開封第一講書人閱讀 156,674評(píng)論 0 345
  • 文/不壞的土叔 我叫張陵塔猾,是天一觀的道長(zhǎng)篙骡。 經(jīng)常有香客問我,道長(zhǎng)丈甸,這世上最難降的妖魔是什么糯俗? 我笑而不...
    開封第一講書人閱讀 56,340評(píng)論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮睦擂,結(jié)果婚禮上得湘,老公的妹妹穿的比我還像新娘。我一直安慰自己顿仇,他們只是感情好淘正,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,404評(píng)論 5 384
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著臼闻,像睡著了一般鸿吆。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上述呐,一...
    開封第一講書人閱讀 49,749評(píng)論 1 289
  • 那天惩淳,我揣著相機(jī)與錄音,去河邊找鬼乓搬。 笑死思犁,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的缤谎。 我是一名探鬼主播抒倚,決...
    沈念sama閱讀 38,902評(píng)論 3 405
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼坷澡!你這毒婦竟也來了托呕?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 37,662評(píng)論 0 266
  • 序言:老撾萬榮一對(duì)情侶失蹤频敛,失蹤者是張志新(化名)和其女友劉穎项郊,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體斟赚,經(jīng)...
    沈念sama閱讀 44,110評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡着降,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,451評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了拗军。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片任洞。...
    茶點(diǎn)故事閱讀 38,577評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡蓄喇,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出交掏,到底是詐尸還是另有隱情妆偏,我是刑警寧澤,帶...
    沈念sama閱讀 34,258評(píng)論 4 328
  • 正文 年R本政府宣布盅弛,位于F島的核電站钱骂,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏挪鹏。R本人自食惡果不足惜见秽,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,848評(píng)論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望讨盒。 院中可真熱鬧解取,春花似錦、人聲如沸催植。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,726評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽创南。三九已至,卻和暖如春省核,著一層夾襖步出監(jiān)牢的瞬間稿辙,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,952評(píng)論 1 264
  • 我被黑心中介騙來泰國(guó)打工气忠, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留邻储,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,271評(píng)論 2 360
  • 正文 我出身青樓旧噪,卻偏偏與公主長(zhǎng)得像吨娜,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子淘钟,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,452評(píng)論 2 348

推薦閱讀更多精彩內(nèi)容