卷積神經(jīng)網(wǎng)絡(luò)之VGG(附完整代碼)

前言

VGG是Oxford的Visual Geometry Group的組提出的(大家應(yīng)該能看出VGG名字的由來(lái)了)。該網(wǎng)絡(luò)是在ILSVRC 2014上的相關(guān)工作,主要工作是證明了增加網(wǎng)絡(luò)的深度能夠在一定程度上影響網(wǎng)絡(luò)最終的性能概疆。VGG有兩種結(jié)構(gòu)竭鞍,分別是VGG16和VGG19人芽,兩者并沒(méi)有本質(zhì)上的區(qū)別,只是網(wǎng)絡(luò)深度不一樣邓夕。

VGG原理

VGG16相比AlexNet的一個(gè)改進(jìn)是采用連續(xù)的幾個(gè)3x3的卷積核代替AlexNet中的較大卷積核(11x11,7x7阎毅,5x5)焚刚。對(duì)于給定的感受野(與輸出有關(guān)的輸入圖片的局部大小)扇调,采用堆積的小卷積核是優(yōu)于采用大的卷積核矿咕,因?yàn)槎鄬臃蔷€性層可以增加網(wǎng)絡(luò)深度來(lái)保證學(xué)習(xí)更復(fù)雜的模式,而且代價(jià)還比較欣桥ァ(參數(shù)更少)碳柱。

簡(jiǎn)單來(lái)說(shuō),在VGG中熬芜,使用了3個(gè)3x3卷積核來(lái)代替7x7卷積核莲镣,使用了2個(gè)3x3卷積核來(lái)代替5*5卷積核,這樣做的主要目的是在保證具有相同感知野的條件下涎拉,提升了網(wǎng)絡(luò)的深度瑞侮,在一定程度上提升了神經(jīng)網(wǎng)絡(luò)的效果的圆。

比如,3個(gè)步長(zhǎng)為1的3x3卷積核的一層層疊加作用可看成一個(gè)大小為7的感受野(其實(shí)就表示3個(gè)3x3連續(xù)卷積相當(dāng)于一個(gè)7x7卷積)半火,其參數(shù)總量為 3x(9xC^2) 越妈,如果直接使用7x7卷積核,其參數(shù)總量為 49xC^2 钮糖,這里 C 指的是輸入和輸出的通道數(shù)梅掠。很明顯,27xC2小于49xC2藐鹤,即減少了參數(shù)瓤檐;而且3x3卷積核有利于更好地保持圖像性質(zhì)。

這里解釋一下為什么使用2個(gè)3x3卷積核可以來(lái)代替5*5卷積核:

5x5卷積看做一個(gè)小的全連接網(wǎng)絡(luò)在5x5區(qū)域滑動(dòng)娱节,我們可以先用一個(gè)3x3的卷積濾波器卷積挠蛉,然后再用一個(gè)全連接層連接這個(gè)3x3卷積輸出,這個(gè)全連接層我們也可以看做一個(gè)3x3卷積層肄满。這樣我們就可以用兩個(gè)3x3卷積級(jí)聯(lián)(疊加)起來(lái)代替一個(gè) 5x5卷積谴古。

具體如下圖所示:

至于為什么使用3個(gè)3x3卷積核可以來(lái)代替7*7卷積核,推導(dǎo)過(guò)程與上述類(lèi)似稠歉,大家可以自行繪圖理解掰担。

VGG網(wǎng)絡(luò)結(jié)構(gòu)
下面是VGG網(wǎng)絡(luò)的結(jié)構(gòu)(VGG16和VGG19都在):

VGG16包含了16個(gè)隱藏層(13個(gè)卷積層和3個(gè)全連接層),如上圖中的D列所示
VGG19包含了19個(gè)隱藏層(16個(gè)卷積層和3個(gè)全連接層)怒炸,如上圖中的E列所示
VGG網(wǎng)絡(luò)的結(jié)構(gòu)非常一致带饱,從頭到尾全部使用的是3x3的卷積和2x2的max pooling。

如果你想看到更加形象化的VGG網(wǎng)絡(luò)阅羹,可以使用經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)(CNN)結(jié)構(gòu)可視化工具來(lái)查看高清無(wú)碼的VGG網(wǎng)絡(luò)勺疼。

VGG優(yōu)缺點(diǎn)
VGG優(yōu)點(diǎn)
VGGNet的結(jié)構(gòu)非常簡(jiǎn)潔,整個(gè)網(wǎng)絡(luò)都使用了同樣大小的卷積核尺寸(3x3)和最大池化尺寸(2x2)捏鱼。

幾個(gè)小濾波器(3x3)卷積層的組合比一個(gè)大濾波器(5x5或7x7)卷積層好:

驗(yàn)證了通過(guò)不斷加深網(wǎng)絡(luò)結(jié)構(gòu)可以提升性能执庐。

VGG缺點(diǎn)
VGG耗費(fèi)更多計(jì)算資源,并且使用了更多的參數(shù)(這里不是3x3卷積的鍋)导梆,導(dǎo)致更多的內(nèi)存占用(140M)轨淌。其中絕大多數(shù)的參數(shù)都是來(lái)自于第一個(gè)全連接層。VGG可是有3個(gè)全連接層翱茨帷递鹉!

PS:有的文章稱:發(fā)現(xiàn)這些全連接層即使被去除,對(duì)于性能也沒(méi)有什么影響藏斩,這樣就顯著降低了參數(shù)數(shù)量梳虽。

注:很多pretrained的方法就是使用VGG的model(主要是16和19),VGG相對(duì)其他的方法灾茁,參數(shù)空間很大窜觉,最終的model有500多m谷炸,AlexNet只有200m,GoogLeNet更少禀挫,所以train一個(gè)vgg模型通常要花費(fèi)更長(zhǎng)的時(shí)間旬陡,所幸有公開(kāi)的pretrained model讓我們很方便的使用。

關(guān)于感受野:

假設(shè)你一層一層地重疊了3個(gè)3x3的卷積層(層與層之間有非線性激活函數(shù))语婴。在這個(gè)排列下描孟,第一個(gè)卷積層中的每個(gè)神經(jīng)元都對(duì)輸入數(shù)據(jù)體有一個(gè)3x3的視野。

代碼篇:VGG訓(xùn)練與測(cè)試
這里推薦兩個(gè)開(kāi)源庫(kù)砰左,訓(xùn)練請(qǐng)參考tensorflow-vgg匿醒,快速測(cè)試請(qǐng)參考VGG-in TensorFlow。

代碼我就不介紹了缠导,其實(shí)跟上述內(nèi)容一致廉羔,跟著原理看code應(yīng)該會(huì)很快。我快速跑了一下VGG-in TensorFlow僻造,代碼親測(cè)可用憋他,效果很nice,就是model下載比較煩髓削。

70108F1D-3E7A-4591-A036-2E4310E790FA.png
# -- encoding:utf-8 --
"""
Create on 19/5/25 10:06
"""

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 定義外部傳入的參數(shù)
tf.app.flags.DEFINE_bool(flag_name="is_train",
                         default_value=True,
                         docstring="給定是否是訓(xùn)練操作竹挡,True表示訓(xùn)練,F(xiàn)alse表示預(yù)測(cè)A⑻拧揪罕!")
tf.app.flags.DEFINE_string(flag_name="checkpoint_dir",
                           default_value="./mnist/models/models_vgg",
                           docstring="給定模型存儲(chǔ)的文件夾,默認(rèn)為./mnist/models/models_vgg")
tf.app.flags.DEFINE_string(flag_name="logdir",
                           default_value="./mnist/graph/graph_vgg",
                           docstring="給定模型日志存儲(chǔ)的路徑宝泵,默認(rèn)為./mnist/graph/graph_vgg")
tf.app.flags.DEFINE_integer(flag_name="batch_size",
                            default_value=8,
                            docstring="給定訓(xùn)練的時(shí)候每個(gè)批次的樣本數(shù)目好啰,默認(rèn)為16.")
tf.app.flags.DEFINE_integer(flag_name="store_per_batch",
                            default_value=100,
                            docstring="給定每隔多少個(gè)批次進(jìn)行一次模型持久化的操作,默認(rèn)為100")
tf.app.flags.DEFINE_integer(flag_name="validation_per_batch",
                            default_value=100,
                            docstring="給定每隔多少個(gè)批次進(jìn)行一次模型的驗(yàn)證操作鲁猩,默認(rèn)為100")
tf.app.flags.DEFINE_float(flag_name="learning_rate",
                          default_value=0.001,
                          docstring="給定模型的學(xué)習(xí)率,默認(rèn)0.01")
FLAGS = tf.app.flags.FLAGS


def create_dir_with_not_exits(dir_path):
    """
    如果文件的文件夾路徑不存在罢坝,直接創(chuàng)建
    :param dir_path:
    :return:
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def layer_normalization(net, eps=1e-8):
    # 縮放參數(shù)廓握、平移參數(shù)y=gamma * x + beta
    gamma = tf.get_variable('gamma', shape=[],
                            initializer=tf.constant_initializer(1))
    beta = tf.get_variable('beta', shape=[],
                           initializer=tf.constant_initializer(0))

    # 計(jì)算當(dāng)前批次的均值和標(biāo)準(zhǔn)差
    mean, variance = tf.nn.moments(net, axes=(1, 2, 3), keep_dims=True)

    # 執(zhí)行批歸一化操作
    return tf.nn.batch_normalization(net, mean, variance,
                                     offset=beta, scale=gamma, variance_epsilon=eps)


def create_model(input_x, show_image=False):
    """
    構(gòu)建模型(VGG 11)
    :param input_x: 占位符,格式為[None, 784]
    :param show_image:是否可視化圖像
    :return:
    """
    # 定義一個(gè)網(wǎng)絡(luò)結(jié)構(gòu):  conv3-64 -> LRN -> MaxPooling -> conv3-128 -> MaxPooling -> conv3-256 -> conv3-256 -> MaxPooling -> FC1024 -> FC10
    with tf.variable_scope("net",
                           initializer=tf.random_normal_initializer(0.0, 0.0001)):
        with tf.variable_scope("Input"):
            # 這里定義一些圖像的處理方式嘁酿,包括:格式轉(zhuǎn)換隙券、基礎(chǔ)處理(大小、剪切...)
            net = tf.reshape(input_x, shape=[-1, 28, 28, 1])
            print(net.get_shape())

            if show_image:
                # 可視化圖像
                tf.summary.image(name='image', tensor=net, max_outputs=5)

        # 定義一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)
        # layers = [
        #     ["conv", 3, 3, 1, 64, 1, "relu"],
        #     ["lrn"],
        #     ["max_pooling", 2, 2, 2],
        #     ["conv", 3, 3, 1, 128, 1, "relu"],
        #     ["max_pooling", 2, 2, 2],
        #     ["conv", 3, 3, 1, 256, 2, "relu"],
        #     ["max_pooling", 2, 2, 2],
        #     ["reshape"],
        #     ["FC", 1024, "relu"],
        #     ["FC", 10]
        # ]
        # layers = [
        #     ["conv", 3, 3, 1, 64, 1, "relu"],
        #     ["lrn"],
        #     ["max_pooling", 2, 2, 2],
        #     ["conv", 3, 3, 1, 128, 2, "relu"],
        #     ["ln"],
        #     ["max_pooling", 2, 2, 2],
        #     ["conv", 3, 3, 1, 256, 2, "relu"],
        #     ["ln"],
        #     ["max_pooling", 2, 2, 2],
        #     ["reshape"],
        #     ["FC", 1024, "relu"],
        #     ["FC", 10]
        # ]
        layers = [
            ["conv", 3, 3, 1, 32, 2, "relu"],
            ["max_pooling", 2, 2, 2],
            ["conv", 3, 3, 1, 64, 2, "relu"],
            # 第一個(gè)是池化闹司,第二個(gè)窗口高度娱仔,第三個(gè)是窗口的寬度,第四個(gè)是步長(zhǎng)
            ["max_pooling", 2, 2, 2],
            ["reshape"],
            ["FC", 1024, "relu"],
            ["FC", 10]
        ]
        for idx, layer in enumerate(layers):
            shape = net.get_shape()
            name = layer[0]
            if "conv" == name:
                # a. 獲取相關(guān)的參數(shù)
                # ["conv", 3, 3, 1, 64, 1, "relu" ] -> 名稱 窗口高度 窗口寬度 步長(zhǎng)(一個(gè)值) 輸出通道數(shù) 重復(fù)幾個(gè)卷積 激活函數(shù)(None表示不激活)
                filter_height, filter_width, stride, out_channels, num_conv = layer[1:6]
                try:
                    ac = layer[6]
                except:
                    ac = None

                # 遍歷進(jìn)行卷積層的構(gòu)建
                for i in range(num_conv):
                    with tf.variable_scope("CONV_{}_{}".format(idx, i)):
                        # 獲取當(dāng)前卷積的輸入的通道數(shù)
                        shape = net.get_shape()
                        in_channels = shape[-1]
                        # 構(gòu)建變量
                        filter = tf.get_variable(name='w', shape=[filter_height, filter_width,
                                                                  in_channels, out_channels])
                        bias = tf.get_variable(name='b', shape=[out_channels])
                        # 卷積操作
                        net = tf.nn.conv2d(input=net, filter=filter,
                                           strides=[1, stride, stride, 1], padding='SAME')
                        net = tf.nn.bias_add(net, bias)
                        # 做一個(gè)激活操作
                        if ac is not None:
                            if "relu" == ac:
                                net = tf.nn.relu(net)
                            elif "relu6" == ac:
                                net = tf.nn.relu6(net)
                            else:
                                net = tf.nn.sigmoid(net)

                if show_image:
                    # 對(duì)于卷積之后的值做一個(gè)可視化操作
                    shape = net.get_shape()
                    for k in range(shape[-1]):
                        image_tensor = tf.reshape(net[:, :, :, k], shape=[-1, shape[1], shape[2], 1])
                        tf.summary.image(name='image', tensor=image_tensor, max_outputs=5)
            elif "lrn" == name:
                with tf.variable_scope("LRN_{}".format(idx)):
                    # lrn(input, depth_radius=5, bias=1, alpha=1, beta=0.5, name=None)
                    # depth_radius就是ppt上的n游桩,bias就是ppt上的k牲迫,beta就是β耐朴,alpha就是α
                    net = tf.nn.local_response_normalization(input=net, depth_radius=5,
                                                             bias=1, alpha=1, beta=0.5)
            elif "max_pooling" == name:
                with tf.variable_scope("Max_Pooling_{}".format(idx)):
                    ksize_height = layer[1]
                    ksize_width = layer[2]
                    stride = layer[3]
                    net = tf.nn.max_pool(value=net,
                                         ksize=[1, ksize_height, ksize_width, 1],
                                         strides=[1, stride, stride, 1], padding='SAME')
            elif "FC" == name:
                with tf.variable_scope("FC_{}".format(idx)):
                    # 獲取相關(guān)變量,輸入的維度盹憎,輸出的維度大小以及激活函數(shù)
                    dim_size = shape[-1]
                    unit_size = layer[1]
                    try:
                        ac = layer[2]
                    except:
                        ac = None
                    w = tf.get_variable(name='w', shape=[dim_size, unit_size])
                    b = tf.get_variable(name='b', shape=[unit_size])
                    net = tf.matmul(net, w) + b
                    # 做一個(gè)激活操作
                    if ac is not None:
                        if "relu" == ac:
                            net = tf.nn.relu(net)
                        elif "relu6" == ac:
                            net = tf.nn.relu6(net)
                        else:
                            net = tf.nn.sigmoid(net)
            elif "reshape" == name:
                with tf.variable_scope('reshape'):
                    dim_size = shape[1] * shape[2] * shape[3]
                    net = tf.reshape(net, shape=[-1, dim_size])
            elif "ln" == name:
                with tf.variable_scope("LN_{}".format(idx)):
                    net = layer_normalization(net)

        with tf.variable_scope("Prediction"):
            # 每行的最大值對(duì)應(yīng)的下標(biāo)就是當(dāng)前樣本的預(yù)測(cè)值
            predictions = tf.argmax(net, axis=1)

    return net, predictions


def create_loss(labels, logits):
    """
    基于給定的實(shí)際值labels和預(yù)測(cè)值logits進(jìn)行一個(gè)交叉熵?fù)p失函數(shù)的構(gòu)建
    :param labels:  是經(jīng)過(guò)啞編碼之后的Tensor對(duì)象筛峭,形狀為[n_samples, n_class]
    :param logits:  是神經(jīng)網(wǎng)絡(luò)的最原始的輸出,形狀為[n_samples, n_class], 每一行最大值那個(gè)位置對(duì)應(yīng)的就是預(yù)測(cè)類(lèi)別陪每,沒(méi)有經(jīng)過(guò)softmax函數(shù)轉(zhuǎn)換影晓。
    :return:
    """
    with tf.name_scope("loss"):
        # loss = tf.reduce_mean(-tf.log(tf.reduce_sum(labels * tf.nn.softmax(logits))))
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits))
        tf.summary.scalar('loss', loss)
    return loss


def create_train_op(loss, learning_rate=0.0001, global_step=None):
    """
    基于給定的損失函數(shù)構(gòu)建一個(gè)優(yōu)化器,優(yōu)化器的目的就是讓這個(gè)損失函數(shù)最小化
    :param loss:
    :param learning_rate:
    :param global_step:
    :return:
    """
    with tf.name_scope("train"):
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op


def create_accuracy(labels, predictions):
    """
    基于給定的實(shí)際值和預(yù)測(cè)值檩禾,計(jì)算準(zhǔn)確率
    :param labels:  是經(jīng)過(guò)啞編碼之后的Tensor對(duì)象挂签,形狀為[n_samples, n_class]
    :param predictions: 實(shí)際的預(yù)測(cè)類(lèi)別下標(biāo),形狀為[n_samples,]
    :return:
    """
    with tf.name_scope("accuracy"):
        # 獲取實(shí)際的類(lèi)別下標(biāo)盼产,形狀為[n_samples,]
        y_labels = tf.argmax(labels, 1)
        # 計(jì)算準(zhǔn)確率
        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_labels, predictions), tf.float32))
        tf.summary.scalar('accuracy', accuracy)
    return accuracy


def train():
    # 對(duì)于文件是否存在做一個(gè)檢測(cè)
    create_dir_with_not_exits(FLAGS.checkpoint_dir)
    create_dir_with_not_exits(FLAGS.logdir)

    with tf.Graph().as_default():
        # 一饵婆、執(zhí)行圖的構(gòu)建
        # 0. 相關(guān)輸入Tensor對(duì)象的構(gòu)建
        input_x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='input_x')
        input_y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_y')
        global_step = tf.train.get_or_create_global_step()

        # 1. 網(wǎng)絡(luò)結(jié)構(gòu)的構(gòu)建
        logits, predictions = create_model(input_x)
        # 2. 構(gòu)建損失函數(shù)
        loss = create_loss(input_y, logits)
        # 3. 構(gòu)建優(yōu)化器
        train_op = create_train_op(loss,
                                   learning_rate=FLAGS.learning_rate,
                                   global_step=global_step)
        # 4. 構(gòu)建評(píng)估指標(biāo)
        accuracy = create_accuracy(input_y, predictions)

        # 二、執(zhí)行圖的運(yùn)行/訓(xùn)練(數(shù)據(jù)加載辆飘、訓(xùn)練啦辐、持久化、可視化蜈项、模型的恢復(fù)....)
        with tf.Session() as sess:
            # a. 創(chuàng)建一個(gè)持久化對(duì)象(默認(rèn)會(huì)將所有的模型參數(shù)全部持久化芹关,因?yàn)椴皇撬械亩夹枰模詈脙H僅持久化的訓(xùn)練的模型參數(shù))
            var_list = tf.trainable_variables()
            # 是因?yàn)間lobal_step這個(gè)變量是不參與模型訓(xùn)練的紧卒,所以模型不會(huì)持久化侥衬,這里加入之后,可以明確也持久化這個(gè)變量跑芳。
            var_list.append(global_step)
            saver = tf.train.Saver(var_list=var_list)

            # a. 變量的初始化操作(所有的非訓(xùn)練變量的初始化 + 持久化的變量恢復(fù))
            # 所有變量初始化(如果有持久化的轴总,后面做了持久化后,會(huì)覆蓋的)
            sess.run(tf.global_variables_initializer())
            # 做模型的恢復(fù)操作
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                print("進(jìn)行模型恢復(fù)操作...")
                # 恢復(fù)模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 恢復(fù)checkpoint的管理信息
                saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)

            # 獲取一個(gè)日志輸出對(duì)象
            train_logdir = os.path.join(FLAGS.logdir, 'train')
            validation_logdir = os.path.join(FLAGS.logdir, 'validation')
            train_writer = tf.summary.FileWriter(logdir=train_logdir, graph=sess.graph)
            validation_writer = tf.summary.FileWriter(logdir=validation_logdir, graph=sess.graph)
            # 獲取所有的summary輸出操作
            summary = tf.summary.merge_all()

            # b. 訓(xùn)練數(shù)據(jù)的產(chǎn)生/獲炔└觥(基于numpy隨機(jī)產(chǎn)生<可以先考慮一個(gè)固定的數(shù)據(jù)集>)
            mnist = input_data.read_data_sets(
                train_dir='../datas/mnist',  # 給定本地磁盤(pán)的數(shù)據(jù)存儲(chǔ)路徑
                one_hot=True,  # 給定返回的數(shù)據(jù)中是否對(duì)Y做啞編碼
                validation_size=5000  # 給定驗(yàn)證數(shù)據(jù)集的大小
            )

            # c. 模型訓(xùn)練
            batch_size = FLAGS.batch_size
            step = sess.run(global_step)
            vn_accuracy_ = 0
            while True:
                # 開(kāi)始模型訓(xùn)練
                x_train, y_train = mnist.train.next_batch(batch_size=batch_size)
                _, loss_, accuracy_, summary_ = sess.run([train_op, loss, accuracy, summary], feed_dict={
                    input_x: x_train,
                    input_y: y_train
                })
                print("第{}次訓(xùn)練后模型的損失函數(shù)為:{}, 準(zhǔn)確率:{}".format(step, loss_, accuracy_))
                train_writer.add_summary(summary_, global_step=step)

                # 持久化
                if step % FLAGS.store_per_batch == 0:
                    file_name = 'model_%.3f_%.3f_.ckpt' % (loss_, accuracy_)
                    save_path = os.path.join(FLAGS.checkpoint_dir, file_name)
                    saver.save(sess, save_path=save_path, global_step=step)

                if step % FLAGS.validation_per_batch == 0:
                    vn_loss_, vn_accuracy_, vn_summary_ = sess.run([loss, accuracy, summary],
                                                                   feed_dict={
                                                                       input_x: mnist.validation.images,
                                                                       input_y: mnist.validation.labels
                                                                   })
                    print("第{}次訓(xùn)練后模型在驗(yàn)證數(shù)據(jù)上的損失函數(shù)為:{}, 準(zhǔn)確率:{}".format(step,
                                                                    vn_loss_,
                                                                    vn_accuracy_))
                    validation_writer.add_summary(vn_summary_, global_step=step)

                # 退出訓(xùn)練(要求當(dāng)前的訓(xùn)練數(shù)據(jù)集上的準(zhǔn)確率至少為0.8怀樟,然后最近一次驗(yàn)證數(shù)據(jù)上的準(zhǔn)確率為0.8)
                if accuracy_ > 0.99 and vn_accuracy_ > 0.99:
                    # 退出之前再做一次持久化操作
                    file_name = 'model_%.3f_%.3f_.ckpt' % (loss_, accuracy_)
                    save_path = os.path.join(FLAGS.checkpoint_dir, file_name)
                    saver.save(sess, save_path=save_path, global_step=step)
                    break
                step += 1
            # 關(guān)閉輸出流
            train_writer.close()
            validation_writer.close()


def prediction():
    # TODO: 參考以前的代碼自己把這個(gè)區(qū)域的內(nèi)容填充一下。我下周晚上講盆佣。
    # 做一個(gè)預(yù)測(cè)(預(yù)測(cè)的評(píng)估往堡,對(duì)mnist.test這個(gè)里面的數(shù)據(jù)進(jìn)行評(píng)估效果的查看)
    with tf.Graph().as_default():
        pass


def main(_):
    if FLAGS.is_train:
        # 進(jìn)入訓(xùn)練的代碼執(zhí)行中
        print("開(kāi)始進(jìn)行模型訓(xùn)練運(yùn)行.....")
        train()
    else:
        # 進(jìn)入測(cè)試、預(yù)測(cè)的代碼執(zhí)行中
        print("開(kāi)始進(jìn)行模型驗(yàn)證共耍、測(cè)試代碼運(yùn)行.....")
        prediction()
    print("Done!!!!")


if __name__ == '__main__':
    # 默認(rèn)情況下虑灰,直接調(diào)用當(dāng)前py文件中的main函數(shù)
    tf.app.run()

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市痹兜,隨后出現(xiàn)的幾起案子穆咐,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 206,968評(píng)論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件对湃,死亡現(xiàn)場(chǎng)離奇詭異崖叫,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)熟尉,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,601評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門(mén)归露,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人斤儿,你說(shuō)我怎么就攤上這事剧包。” “怎么了往果?”我有些...
    開(kāi)封第一講書(shū)人閱讀 153,220評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵疆液,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我陕贮,道長(zhǎng)堕油,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 55,416評(píng)論 1 279
  • 正文 為了忘掉前任肮之,我火速辦了婚禮掉缺,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘戈擒。我一直安慰自己眶明,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,425評(píng)論 5 374
  • 文/花漫 我一把揭開(kāi)白布筐高。 她就那樣靜靜地躺著搜囱,像睡著了一般。 火紅的嫁衣襯著肌膚如雪柑土。 梳的紋絲不亂的頭發(fā)上蜀肘,一...
    開(kāi)封第一講書(shū)人閱讀 49,144評(píng)論 1 285
  • 那天,我揣著相機(jī)與錄音稽屏,去河邊找鬼扮宠。 笑死,一個(gè)胖子當(dāng)著我的面吹牛狐榔,可吹牛的內(nèi)容都是我干的坛增。 我是一名探鬼主播,決...
    沈念sama閱讀 38,432評(píng)論 3 401
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼荒叼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼轿偎!你這毒婦竟也來(lái)了典鸡?” 一聲冷哼從身側(cè)響起被廓,我...
    開(kāi)封第一講書(shū)人閱讀 37,088評(píng)論 0 261
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎萝玷,沒(méi)想到半個(gè)月后嫁乘,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體昆婿,經(jīng)...
    沈念sama閱讀 43,586評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,028評(píng)論 2 325
  • 正文 我和宋清朗相戀三年蜓斧,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了仓蛆。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,137評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡挎春,死狀恐怖看疙,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情直奋,我是刑警寧澤能庆,帶...
    沈念sama閱讀 33,783評(píng)論 4 324
  • 正文 年R本政府宣布,位于F島的核電站脚线,受9級(jí)特大地震影響搁胆,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜邮绿,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,343評(píng)論 3 307
  • 文/蒙蒙 一渠旁、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧船逮,春花似錦顾腊、人聲如沸。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 30,333評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至冠骄,卻和暖如春伪煤,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背凛辣。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 31,559評(píng)論 1 262
  • 我被黑心中介騙來(lái)泰國(guó)打工抱既, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人扁誓。 一個(gè)月前我還...
    沈念sama閱讀 45,595評(píng)論 2 355
  • 正文 我出身青樓防泵,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親蝗敢。 傳聞我的和親對(duì)象是個(gè)殘疾皇子捷泞,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,901評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容