卷積神經(jīng)網(wǎng)絡(luò)之AlexNet(附完整代碼)

由于受到計算機性能的影響,雖然LeNet在圖像分類中取得了較好的成績潜慎,但是并沒有引起很多的關(guān)注抹剩。 知道2012年,Alex等人提出的AlexNet網(wǎng)絡(luò)在ImageNet大賽上以遠超第二名的成績奪冠绪励,卷積神經(jīng)網(wǎng)絡(luò)乃至深度學(xué)習(xí)重新引起了廣泛的關(guān)注肿孵。

AlexNet特點

AlexNet是在LeNet的基礎(chǔ)上加深了網(wǎng)絡(luò)的結(jié)構(gòu),學(xué)習(xí)更豐富更高維的圖像特征疏魏。AlexNet的特點:

1.更深的網(wǎng)絡(luò)結(jié)構(gòu)
2.使用層疊的卷積層停做,即卷積層+卷積層+池化層來提取圖像的特征
3.使用Dropout抑制過擬合
4.使用數(shù)據(jù)增強Data Augmentation抑制過擬合
5.使用Relu替換之前的sigmoid的作為激活函數(shù)
6.多GPU訓(xùn)練
7.ReLu作為激活函數(shù)

在最初的感知機模型中,輸入和輸出的關(guān)系如下:
??=∑??????????+??

只是單純的線性關(guān)系大莫,這樣的網(wǎng)絡(luò)結(jié)構(gòu)有很大的局限性:即使用很多這樣結(jié)構(gòu)的網(wǎng)絡(luò)層疊加蛉腌,其輸出和輸入仍然是線性關(guān)系,無法處理有非線性關(guān)系的輸入輸出只厘。因此烙丛,對每個神經(jīng)元的輸出做個非線性的轉(zhuǎn)換也就是,將上面就加權(quán)求和
∑??????????+?? 的結(jié)果輸入到一個非線性函數(shù)羔味,也就是激活函數(shù)中河咽。 這樣,由于激活函數(shù)的引入介评,多個網(wǎng)絡(luò)層的疊加就不再是單純的線性變換库北,而是具有更強的表現(xiàn)能力爬舰。


image

在最初, ?????????????? 和 ???????函數(shù)最常用的激活函數(shù)寒瓦。


image

在網(wǎng)絡(luò)層數(shù)較少時情屹,?????????????? 函數(shù)的特性能夠很好的滿足激活函數(shù)的作用:它把一個實數(shù)壓縮至0到1之間,當(dāng)輸入的數(shù)字非常大的時候杂腰,結(jié)果會接近1垃你;當(dāng)輸入非常大的負(fù)數(shù)時,則會得到接近0的結(jié)果喂很。這種特性惜颇,能夠很好的模擬神經(jīng)元在受刺激后,是否被激活向后傳遞信息(輸出為0少辣,幾乎不被激活凌摄;輸出為1,完全被激活)漓帅。??????????????一個很大的問題就是梯度飽和锨亏。 觀察??????????????函數(shù)的曲線,當(dāng)輸入的數(shù)字較大(或較忻Ω伞)時器予,其函數(shù)值趨于不變,其導(dǎo)數(shù)變的非常的小捐迫。這樣乾翔,在層數(shù)很多的的網(wǎng)絡(luò)結(jié)構(gòu)中,進行反向傳播時施戴,由于很多個很小的??????????????導(dǎo)數(shù)累成反浓,導(dǎo)致其結(jié)果趨于0,權(quán)值更新較慢暇韧。

  1. ReLu


    image

    針對 ?????????????? 梯度飽和導(dǎo)致訓(xùn)練收斂慢的問題勾习,在AlexNet中引入了ReLU。ReLU是一個分段線性函數(shù)懈玻,小于等于0則輸出為0巧婶;大于0的則恒等輸出。相比于 ?????????????? 涂乌,ReLU有以下有點:

    • 計算開銷下艺栈。 ?????????????? 的正向傳播有指數(shù)運算,倒數(shù)運算湾盒,而ReLu是線性輸出湿右;反向傳播中, ?????????????? 有指數(shù)運算罚勾,而ReLU有輸出的部分毅人,導(dǎo)數(shù)始終為1.
    • 梯度飽和問題
    • 稀疏性吭狡。Relu會使一部分神經(jīng)元的輸出為0,這樣就造成了網(wǎng)絡(luò)的稀疏性丈莺,并且減少了參數(shù)的相互依存關(guān)系划煮,緩解了過擬合問題的發(fā)生。

這里有個問題缔俄,前面提到弛秋,激活函數(shù)要用非線性的,是為了使網(wǎng)絡(luò)結(jié)構(gòu)有更強的表達的能力俐载。那這里使用ReLU本質(zhì)上卻是個線性的分段函數(shù)蟹略,是怎么進行非線性變換的。

這里把神經(jīng)網(wǎng)絡(luò)看著一個巨大的變換矩陣 ?? 遏佣,其輸入為所有訓(xùn)練樣本組成的矩陣 ?? 挖炬,輸出為矩陣 ?? 。
??=?????

這里的 ?? 是一個線性變換的話状婶,則所有的訓(xùn)練樣本 ?? 進行了線性變換輸出為 ?? 茅茂。

那么對于ReLU來說,由于其是分段的太抓,0的部分可以看著神經(jīng)元沒有激活,不同的神經(jīng)元激活或者不激活令杈,其神經(jīng)玩過組成的變換矩陣是不一樣的走敌。
設(shè)有兩個訓(xùn)練樣本??1,??2 其訓(xùn)練時神經(jīng)網(wǎng)絡(luò)組成的變換矩陣為??1,??2 由于??1變換對應(yīng)的神經(jīng)網(wǎng)絡(luò)中激活神經(jīng)元和??2是不一樣的,這樣??1,??2實際上是兩個不同的線性變換逗噩。也就是說掉丽,每個訓(xùn)練樣本使用的線性變換矩陣???? 是不一樣的,在整個訓(xùn)練樣本空間來說异雁,其經(jīng)歷的是非線性變換捶障。**

簡單來說,不同訓(xùn)練樣本中的同樣的特征纲刀,在經(jīng)過神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)時项炼,流經(jīng)的神經(jīng)元是不一樣的(激活函數(shù)值為0的神經(jīng)元不會被激活)。這樣示绊,最終的輸出實際上是輸入樣本的非線性變換锭部。

單個訓(xùn)練樣本是線性變換,但是每個訓(xùn)練樣本的線性變換是不一樣的面褐,這樣整個訓(xùn)練樣本集來說拌禾,就是非線性的變換。

數(shù)據(jù)增強

神經(jīng)網(wǎng)絡(luò)由于訓(xùn)練的參數(shù)多展哭,表能能力強湃窍,所以需要比較多的數(shù)據(jù)量闻蛀,不然很容易過擬合。當(dāng)訓(xùn)練數(shù)據(jù)有限時您市,可以通過一些變換從已有的訓(xùn)練數(shù)據(jù)集中生成一些新的數(shù)據(jù)觉痛,以快速地擴充訓(xùn)練數(shù)據(jù)。對于圖像數(shù)據(jù)集來說墨坚,可以對圖像進行一些形變操作:

  • 翻轉(zhuǎn)
  • 隨機裁剪
  • 平移秧饮,顏色光照的變換
  • ...

AlexNet中對數(shù)據(jù)做了以下操作:

  1. 隨機裁剪,對256×256的圖片進行隨機裁剪到227×227泽篮,然后進行水平翻轉(zhuǎn)盗尸。
  2. 測試的時候,對左上帽撑、右上泼各、左下、右下亏拉、中間分別做了5次裁剪扣蜻,然后翻轉(zhuǎn),共10個裁剪及塘,之后對結(jié)果求平均莽使。
  3. 對RGB空間做PCA(主成分分析)提佣,然后對主成分做一個(0, 0.1)的高斯擾動勾笆,也就是對顏色、光照作變換荆残,結(jié)果使錯誤率又下降了1%肋层。

層疊池化

在LeNet中池化是不重疊的亿笤,即池化的窗口的大小和步長是相等的,如下


image

在AlexNet中使用的池化(Pooling)卻是可重疊的栋猖,也就是說净薛,在池化的時候,每次移動的步長小于池化的窗口長度蒲拉。AlexNet池化的大小為3×3的正方形肃拜,每次池化移動步長為2,這樣就會出現(xiàn)重疊全陨。重疊池化可以避免過擬合爆班,這個策略貢獻了0.3%的Top-5錯誤率。與非重疊方案??=2辱姨,??=2相比柿菩,輸出的維度是相等的,并且能在一定程度上抑制過擬合雨涛。

局部相應(yīng)歸一化

ReLU具有讓人滿意的特性枢舶,它不需要通過輸入歸一化來防止飽和懦胞。如果至少一些訓(xùn)練樣本對ReLU產(chǎn)生了正輸入,那么那個神經(jīng)元上將發(fā)生學(xué)習(xí)凉泄。然而躏尉,我們?nèi)匀话l(fā)現(xiàn)接下來的局部響應(yīng)歸一化有助于泛化。??????,??表示神經(jīng)元激活后众,通過在(??,??)位置應(yīng)用核??胀糜,然后應(yīng)用ReLU非線性來計算,響應(yīng)歸一化激活??????,??通過下式給定:
??????,??=??????,??/(??+??∑??=??????(0,?????/2)??????(???1,??+??/2)(??????,??)2)??

其中蒂誉,??是卷積核的個數(shù)教藻,也就是生成的FeatureMap的個數(shù);??,??,??,??是超參數(shù)右锨,論文中使用的值是??=2,??=5,??=10?4,??=0.75

輸出??????,??和輸入??????,??的上標(biāo)表示的是當(dāng)前值所在的通道括堤,也即是疊加的方向是沿著通道進行。將要歸一化的值??????,??所在附近通道相同位置的值的平方累加起來∑??????(???1,??+??/2)??=??????(0,?????/2)(??????,??)2

Dropout
這個是比較常用的抑制過擬合的方法了绍移。
引入Dropout主要是為了防止過擬合悄窃。在神經(jīng)網(wǎng)絡(luò)中Dropout通過修改神經(jīng)網(wǎng)絡(luò)本身結(jié)構(gòu)來實現(xiàn),對于某一層的神經(jīng)元蹂窖,通過定義的概率將神經(jīng)元置為0轧抗,這個神經(jīng)元就不參與前向和后向傳播,就如同在網(wǎng)絡(luò)中被刪除了一樣瞬测,同時保持輸入層與輸出層神經(jīng)元的個數(shù)不變鸦致,然后按照神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)方法進行參數(shù)更新。在下一次迭代中涣楷,又重新隨機刪除一些神經(jīng)元(置為0),直至訓(xùn)練結(jié)束抗碰。
Dropout應(yīng)該算是AlexNet中一個很大的創(chuàng)新狮斗,現(xiàn)在神經(jīng)網(wǎng)絡(luò)中的必備結(jié)構(gòu)之一。Dropout也可以看成是一種模型組合弧蝇,每次生成的網(wǎng)絡(luò)結(jié)構(gòu)都不一樣碳褒,通過組合多個模型的方式能夠有效地減少過擬合,Dropout只需要兩倍的訓(xùn)練時間即可實現(xiàn)模型組合(類似取平均)的效果看疗,非常高效沙峻。
如下圖:

image

Alex網(wǎng)絡(luò)結(jié)構(gòu)

image

上圖中的輸入是224×224,不過經(jīng)過計算(224?11)/4=54.75并不是論文中的55×55两芳,而使用227×227作為輸入摔寨,則(227?11)/4=55

網(wǎng)絡(luò)包含8個帶權(quán)重的層;前5層是卷積層怖辆,剩下的3層是全連接層是复。最后一層全連接層的輸出是1000維softmax的輸入删顶,softmax會產(chǎn)生1000類標(biāo)簽的分布網(wǎng)絡(luò)包含8個帶權(quán)重的層;前5層是卷積層淑廊,剩下的3層是全連接層逗余。最后一層全連接層的輸出是1000維softmax的輸入,softmax會產(chǎn)生1000類標(biāo)簽的分布季惩。

卷積層C1
該層的處理流程是: 卷積-->ReLU-->池化-->歸一化录粱。
卷積,輸入是227×227画拾,使用96個11×11×3的卷積核啥繁,得到的FeatureMap為55×55×96。
ReLU碾阁,將卷積層輸出的FeatureMap輸入到ReLU函數(shù)中输虱。
池化,使用3×3步長為2的池化單元(重疊池化脂凶,步長小于池化單元的寬度)宪睹,輸出為27×27×96((55?3)/2+1=27)
局部響應(yīng)歸一化,使用??=2,??=5,??=10?4,??=0.75進行局部歸一化蚕钦,輸出的仍然為27×27×96亭病,輸出分為兩組,每組的大小為27×27×48

卷積層C2
該層的處理流程是:卷積-->ReLU-->池化-->歸一化
卷積嘶居,輸入是2組27×27×48罪帖。使用2組,每組128個尺寸為5×5×48的卷積核邮屁,并作了邊緣填充padding=2整袁,卷積的步長為1. 則輸出的FeatureMap為2組,每組的大小為27×27 ??????????128. ((27+2?2?5)/1+1=27)
ReLU佑吝,將卷積層輸出的FeatureMap輸入到ReLU函數(shù)中
池化運算的尺寸為3×3坐昙,步長為2,池化后圖像的尺寸為(27?3)/2+1=13芋忿,輸出為13×13×256
局部響應(yīng)歸一化炸客,使用??=2,??=5,??=10?4,??=0.75進行局部歸一化,輸出的仍然為13×13×256戈钢,輸出分為2組痹仙,每組的大小為13×13×128

卷積層C3
該層的處理流程是: 卷積-->ReLU
卷積,輸入是13×13×256殉了,使用2組共384尺寸為3×3×256的卷積核开仰,做了邊緣填充padding=1,卷積的步長為1.則輸出的FeatureMap為13×13 ??????????384
ReLU,將卷積層輸出的FeatureMap輸入到ReLU函數(shù)中
卷積層C4
該層的處理流程是: 卷積-->ReLU
該層和C3類似抖所。
卷積梨州,輸入是13×13×384,分為兩組田轧,每組為13×13×192.使用2組暴匠,每組192個尺寸為3×3×192的卷積核,做了邊緣填充padding=1傻粘,卷積的步長為1.則輸出的FeatureMap為13×13 ??????????384每窖,分為兩組,每組為13×13×192
ReLU弦悉,將卷積層輸出的FeatureMap輸入到ReLU函數(shù)中
卷積層C5
該層處理流程為:卷積-->ReLU-->池化
卷積窒典,輸入為13×13×384,分為兩組稽莉,每組為13×13×192瀑志。使用2組,每組為128尺寸為3×3×192的卷積核污秆,做了邊緣填充padding=1劈猪,卷積的步長為1.則輸出的FeatureMap為13×13×256
ReLU,將卷積層輸出的FeatureMap輸入到ReLU函數(shù)中
池化良拼,池化運算的尺寸為3×3战得,步長為2,池化后圖像的尺寸為 (13?3)/2+1=6,即池化后的輸出為6×6×256
全連接層FC6
該層的流程為:(卷積)全連接 -->ReLU -->Dropout
卷積->全連接: 輸入為6×6×256,該層有4096個卷積核庸推,每個卷積核的大小為6×6×256常侦。由于卷積核的尺寸剛好與待處理特征圖(輸入)的尺寸相同,即卷積核中的每個系數(shù)只與特征圖(輸入)尺寸的一個像素值相乘贬媒,一一對應(yīng)聋亡,因此,該層被稱為全連接層际乘。由于卷積核與特征圖的尺寸相同杀捻,卷積運算后只有一個值,因此蚓庭,卷積后的像素層尺寸為4096×1×1,即有4096個神經(jīng)元仅仆。
ReLU,這4096個運算結(jié)果通過ReLU激活函數(shù)生成4096個值
Dropout,抑制過擬合器赞,隨機的斷開某些神經(jīng)元的連接或者是不激活某些神經(jīng)元
全連接層FC7
流程為:全連接-->ReLU-->Dropout
全連接,輸入為4096的向量
ReLU,這4096個運算結(jié)果通過ReLU激活函數(shù)生成4096個值
Dropout,抑制過擬合墓拜,隨機的斷開某些神經(jīng)元的連接或者是不激活某些神經(jīng)元
輸出層
第七層輸出的4096個數(shù)據(jù)與第八層的1000個神經(jīng)元進行全連接港柜,經(jīng)過訓(xùn)練后輸出1000個float型的值,這就是預(yù)測結(jié)果。
AlexNet參數(shù)數(shù)量
卷積層的參數(shù) = 卷積核的數(shù)量 * 卷積核 + 偏置

C1: 96個11×11×3的卷積核夏醉,96×11×11×3+96=34848
C2: 2組爽锥,每組128個5×5×48的卷積核,(128×5×5×48+128)×2=307456
C3: 384個3×3×256的卷積核畔柔,3×3×256×384+384=885120
C4: 2組氯夷,每組192個3×3×192的卷積核,(3×3×192×192+192)×2=663936
C5: 2組靶擦,每組128個3×3×192的卷積核腮考,(3×3×192×128+128)×2=442624
FC6: 4096個6×6×256的卷積核,6×6×256×4096+4096=37752832
FC7: 4096?4096+4096=16781312
output: 4096?1000=4096000
卷積層 C2,C4,C5中的卷積核只和位于同一GPU的上一層的FeatureMap相連玄捕。從上面可以看出踩蔚,參數(shù)大多數(shù)集中在全連接層,在卷積層由于權(quán)值共享枚粘,權(quán)值參數(shù)較少馅闽。
完整代碼

# -- encoding:utf-8 --
 
Create on 19/5/25 10:06
 

import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 定義外部傳入的參數(shù)
tf.app.flags.DEFINE_bool(flag_name="is_train",
                         default_value=True,
                         docstring="給定是否是訓(xùn)練操作,True表示訓(xùn)練馍迄,F(xiàn)alse表示預(yù)測8R病!")
tf.app.flags.DEFINE_string(flag_name="checkpoint_dir",
                           default_value="./mnist/models/models_alext",
                           docstring="給定模型存儲的文件夾柬姚,默認(rèn)為./mnist/models/models_alext")
tf.app.flags.DEFINE_string(flag_name="logdir",
                           default_value="./mnist/graph/graph_alext",
                           docstring="給定模型日志存儲的路徑拟杉,默認(rèn)為./mnist/graph/graph_alext")
tf.app.flags.DEFINE_integer(flag_name="batch_size",
                            default_value=16,
                            docstring="給定訓(xùn)練的時候每個批次的樣本數(shù)目,默認(rèn)為16.")
tf.app.flags.DEFINE_integer(flag_name="store_per_batch",
                            default_value=100,
                            docstring="給定每隔多少個批次進行一次模型持久化的操作量承,默認(rèn)為100")
tf.app.flags.DEFINE_integer(flag_name="validation_per_batch",
                            default_value=100,
                            docstring="給定每隔多少個批次進行一次模型的驗證操作搬设,默認(rèn)為100")
tf.app.flags.DEFINE_float(flag_name="learning_rate",
                          default_value=0.01,
                          docstring="給定模型的學(xué)習(xí)率,默認(rèn)0.01")
FLAGS = tf.app.flags.FLAGS


def create_dir_with_not_exits(dir_path):
    """
    如果文件的文件夾路徑不存在撕捍,直接創(chuàng)建
    :param dir_path:
    :return:
    """
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def conv2d(name, net, filter_height, filter_width, output_channels, stride=1, pandding='SAME'):
    with tf.variable_scope(name):
        input_channels = net.get_shape()[-1]
        filter = tf.get_variable(name='w', shape=[filter_height, filter_width, input_channels, output_channels])
        bias = tf.get_variable(name='b', shape=[output_channels])
        net = tf.nn.conv2d(input=net, filter=filter, strides=[1, stride, stride, 1], padding=pandding)
        net = tf.nn.bias_add(net, bias)
    return net


def relu(name, net):
    with tf.variable_scope(name):
        return tf.nn.relu(net)


def max_pool(name, net, height, width, stride, padding='SAME'):
    with tf.variable_scope(name):
        return tf.nn.max_pool(value=net, ksize=[1, height, width, 1],
                              strides=[1, stride, stride, 1], padding=padding)


def fc(name, net, units):
    with tf.variable_scope(name):
        input_units = net.get_shape()[-1]
        w = tf.get_variable(name='w', shape=[input_units, units])
        b = tf.get_variable(name='b', shape=[units])
        net = tf.matmul(net, w) + b
    return net


def create_model(input_x, show_image=False):
    """
    構(gòu)建模型
    :param input_x: 占位符拿穴,格式為[None, 784]
    :return:
    """
    # 定義一個網(wǎng)絡(luò)結(jié)構(gòu): input -> conv -> relu -> pooling -> conv -> relu -> pooling -> FC -> relu -> FC
    with tf.variable_scope("net"):
        with tf.variable_scope("Input"):
            # 這里定義一些圖像的處理方式,包括:格式轉(zhuǎn)換忧风、基礎(chǔ)處理(大小默色、剪切...)
            x = tf.reshape(input_x, shape=[-1, 28, 28, 1])
            # 這里做resize只是為了符合論文中的輸入大小
            x = tf.image.resize_nearest_neighbor(x, size=(224, 224))

        with tf.device('/GPU:0'), tf.variable_scope("net11"):
            net1 = conv2d('conv1', x, 11, 11, 48, 4)
            net1 = relu('relu1', net1)
            net1 = max_pool('pool1', net1, 2, 2, 2)
            net1 = conv2d('conv2', net1, 5, 5, 128, 1)
            net1 = relu('relu2', net1)
            net1 = max_pool('pool2', net1, 2, 2, 2)
        with tf.device('/GPU:1'), tf.variable_scope("net21"):
            net2 = conv2d('conv1', x, 11, 11, 48, 4)
            net2 = relu('relu1', net2)
            net2 = max_pool('pool1', net2, 2, 2, 2)
            net2 = conv2d('conv2', net2, 5, 5, 128, 1)
            net2 = relu('relu2', net2)
            net2 = max_pool('pool2', net2, 2, 2, 2)

        # 合并兩個網(wǎng)絡(luò)的輸出
        net = tf.concat([net1, net2], axis=-1)

        with tf.device('/GPU:0'), tf.variable_scope("net12"):
            net1 = conv2d('conv3', net, 3, 3, 192, 1)
            net1 = relu('relu3', net1)
            net1 = conv2d('conv4', net1, 3, 3, 192, 1)
            net1 = relu('relu4', net1)
            net1 = conv2d('conv5', net1, 3, 3, 128, 1)
            net1 = relu('relu5', net1)
            net1 = max_pool('pool3', net1, 2, 2, 2)
        with tf.device('/GPU:1'), tf.variable_scope("net22"):
            net2 = conv2d('conv3', net, 3, 3, 192, 1)
            net2 = relu('relu3', net2)
            net2 = conv2d('conv4', net2, 3, 3, 192, 1)
            net2 = relu('relu4', net2)
            net2 = conv2d('conv5', net2, 3, 3, 128, 1)
            net2 = relu('relu5', net2)
            net2 = max_pool('pool3', net2, 2, 2, 2)

        # 合并兩個網(wǎng)絡(luò)的輸出
        net = tf.concat([net1, net2], axis=-1)
        shape = net.get_shape()
        net = tf.reshape(net, shape=[-1, shape[1] * shape[2] * shape[3]])

        # 做全連接操作
        with tf.device('/GPU:0'), tf.variable_scope("net13"):
            net1 = fc('fc1', net, 2048)
            net1 = relu('relu6', net1)
        with tf.device('/GPU:1'), tf.variable_scope("net23"):
            net2 = fc('fc1', net, 2048)
            net2 = relu('relu6', net2)

        # 合并兩個網(wǎng)絡(luò)的輸出
        net = tf.concat([net1, net2], axis=-1)

        # 做全連接操作
        with tf.device('/GPU:0'), tf.variable_scope("net14"):
            net1 = fc('fc2', net, 2048)
            net1 = relu('relu7', net1)
        with tf.device('/GPU:1'), tf.variable_scope("net24"):
            net2 = fc('fc2', net, 2048)
            net2 = relu('relu7', net2)

        # 合并兩個網(wǎng)絡(luò)的輸出
        net = tf.concat([net1, net2], axis=-1)

        # 全連接
        logits = fc('fc3', net, 10)

        with tf.variable_scope("Prediction"):
            # 每行的最大值對應(yīng)的下標(biāo)就是當(dāng)前樣本的預(yù)測值
            predictions = tf.argmax(logits, axis=1)

    return logits, predictions


def create_loss(labels, logits):
    """
    基于給定的實際值labels和預(yù)測值logits進行一個交叉熵?fù)p失函數(shù)的構(gòu)建
    :param labels:  是經(jīng)過啞編碼之后的Tensor對象,形狀為[n_samples, n_class]
    :param logits:  是神經(jīng)網(wǎng)絡(luò)的最原始的輸出狮腿,形狀為[n_samples, n_class], 每一行最大值那個位置對應(yīng)的就是預(yù)測類別腿宰,沒有經(jīng)過softmax函數(shù)轉(zhuǎn)換。
    :return:
    """
    with tf.name_scope("loss"):
        # loss = tf.reduce_mean(-tf.log(tf.reduce_sum(labels * tf.nn.softmax(logits))))
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits))
        tf.summary.scalar('loss', loss)
    return loss


def create_train_op(loss, learning_rate=0.01, global_step=None):
    """
    基于給定的損失函數(shù)構(gòu)建一個優(yōu)化器缘厢,優(yōu)化器的目的就是讓這個損失函數(shù)最小化
    :param loss:
    :param learning_rate:
    :param global_step:
    :return:
    """
    with tf.name_scope("train"):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)
    return train_op


def create_accuracy(labels, predictions):
    """
    基于給定的實際值和預(yù)測值吃度,計算準(zhǔn)確率
    :param labels:  是經(jīng)過啞編碼之后的Tensor對象,形狀為[n_samples, n_class]
    :param predictions: 實際的預(yù)測類別下標(biāo)贴硫,形狀為[n_samples,]
    :return:
    """
    with tf.name_scope("accuracy"):
        # 獲取實際的類別下標(biāo)椿每,形狀為[n_samples,]
        y_labels = tf.argmax(labels, 1)
        # 計算準(zhǔn)確率
        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_labels, predictions), tf.float32))
        tf.summary.scalar('accuracy', accuracy)
    return accuracy


def train():
    # 對于文件是否存在做一個檢測
    create_dir_with_not_exits(FLAGS.checkpoint_dir)
    create_dir_with_not_exits(FLAGS.logdir)

    with tf.Graph().as_default():
        # 一伊者、執(zhí)行圖的構(gòu)建
        # 0. 相關(guān)輸入Tensor對象的構(gòu)建
        input_x = tf.placeholder(dtype=tf.float32, shape=[None, 784], name='input_x')
        input_y = tf.placeholder(dtype=tf.float32, shape=[None, 10], name='input_y')
        global_step = tf.train.get_or_create_global_step()

        # 1. 網(wǎng)絡(luò)結(jié)構(gòu)的構(gòu)建
        logits, predictions = create_model(input_x)
        # 2. 構(gòu)建損失函數(shù)
        loss = create_loss(input_y, logits)
        # 3. 構(gòu)建優(yōu)化器
        train_op = create_train_op(loss,
                                   learning_rate=FLAGS.learning_rate,
                                   global_step=global_step)
        # 4. 構(gòu)建評估指標(biāo)
        accuracy = create_accuracy(input_y, predictions)

        # 二、執(zhí)行圖的運行/訓(xùn)練(數(shù)據(jù)加載间护、訓(xùn)練亦渗、持久化、可視化汁尺、模型的恢復(fù)....)
        with tf.Session() as sess:
            # 獲取一個日志輸出對象
            train_logdir = os.path.join(FLAGS.logdir, 'train')
            validation_logdir = os.path.join(FLAGS.logdir, 'validation')
            train_writer = tf.summary.FileWriter(logdir=train_logdir, graph=sess.graph)
            validation_writer = tf.summary.FileWriter(logdir=validation_logdir, graph=sess.graph)

            # a. 創(chuàng)建一個持久化對象(默認(rèn)會將所有的模型參數(shù)全部持久化法精,因為不是所有的都需要的,最好僅僅持久化的訓(xùn)練的模型參數(shù))
            var_list = tf.trainable_variables()
            # 是因為global_step這個變量是不參與模型訓(xùn)練的均函,所以模型不會持久化亿虽,這里加入之后,可以明確也持久化這個變量苞也。
            var_list.append(global_step)
            saver = tf.train.Saver(var_list=var_list)

            # a. 變量的初始化操作(所有的非訓(xùn)練變量的初始化 + 持久化的變量恢復(fù))
            # 所有變量初始化(如果有持久化的洛勉,后面做了持久化后,會覆蓋的)
            sess.run(tf.global_variables_initializer())
            # 做模型的恢復(fù)操作
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                print("進行模型恢復(fù)操作...")
                # 恢復(fù)模型
                saver.restore(sess, ckpt.model_checkpoint_path)
                # 恢復(fù)checkpoint的管理信息
                saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)


            # 獲取所有的summary輸出操作
            summary = tf.summary.merge_all()

            # b. 訓(xùn)練數(shù)據(jù)的產(chǎn)生/獲热绯佟(基于numpy隨機產(chǎn)生<可以先考慮一個固定的數(shù)據(jù)集>)
            mnist = input_data.read_data_sets(
                train_dir='../datas/mnist',  # 給定本地磁盤的數(shù)據(jù)存儲路徑
                one_hot=True,  # 給定返回的數(shù)據(jù)中是否對Y做啞編碼
                validation_size=5000  # 給定驗證數(shù)據(jù)集的大小
            )

            # c. 模型訓(xùn)練
            batch_size = FLAGS.batch_size
            step = sess.run(global_step)
            vn_accuracy_ = 0
            while True:
                # 開始模型訓(xùn)練
                x_train, y_train = mnist.train.next_batch(batch_size=batch_size)
                _, loss_, accuracy_, summary_ = sess.run([train_op, loss, accuracy, summary], feed_dict={
                    input_x: x_train,
                    input_y: y_train
                })
                print("第{}次訓(xùn)練后模型的損失函數(shù)為:{}, 準(zhǔn)確率:{}".format(step, loss_, accuracy_))
                train_writer.add_summary(summary_, global_step=step)

                # 持久化
                if step % FLAGS.store_per_batch == 0:
                    file_name = 'model_%.3f_%.3f_.ckpt' % (loss_, accuracy_)
                    save_path = os.path.join(FLAGS.checkpoint_dir, file_name)
                    saver.save(sess, save_path=save_path, global_step=step)

                if step % FLAGS.validation_per_batch == 0:
                    vn_loss_, vn_accuracy_, vn_summary_ = sess.run([loss, accuracy, summary],
                                                                   feed_dict={
                                                                       input_x: mnist.validation.images,
                                                                       input_y: mnist.validation.labels
                                                                   })
                    print("第{}次訓(xùn)練后模型在驗證數(shù)據(jù)上的損失函數(shù)為:{}, 準(zhǔn)確率:{}".format(step,
                                                                    vn_loss_,
                                                                    vn_accuracy_))
                    validation_writer.add_summary(vn_summary_, global_step=step)

                # 退出訓(xùn)練(要求當(dāng)前的訓(xùn)練數(shù)據(jù)集上的準(zhǔn)確率至少為0.8收毫,然后最近一次驗證數(shù)據(jù)上的準(zhǔn)確率為0.8)
                if accuracy_ > 0.99 and vn_accuracy_ > 0.99:
                    # 退出之前再做一次持久化操作
                    file_name = 'model_%.3f_%.3f_.ckpt' % (loss_, accuracy_)
                    save_path = os.path.join(FLAGS.checkpoint_dir, file_name)
                    saver.save(sess, save_path=save_path, global_step=step)
                    break
                step += 1
            # 關(guān)閉輸出流
            train_writer.close()
            validation_writer.close()


def prediction():
    # TODO: 參考以前的代碼自己把這個區(qū)域的內(nèi)容填充一下。我下周晚上講殷勘。
    # 做一個預(yù)測(預(yù)測的評估此再,對mnist.test這個里面的數(shù)據(jù)進行評估效果的查看)
    with tf.Graph().as_default():
        pass


def main(_):
    if FLAGS.is_train:
        # 進入訓(xùn)練的代碼執(zhí)行中
        print("開始進行模型訓(xùn)練運行.....")
        train()
    else:
        # 進入測試、預(yù)測的代碼執(zhí)行中
        print("開始進行模型驗證玲销、測試代碼運行.....")
        prediction()
    print("Done!!!!")


if __name__ == '__main__':
    # 默認(rèn)情況下输拇,直接調(diào)用當(dāng)前py文件中的main函數(shù)
    tf.app.run()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市贤斜,隨后出現(xiàn)的幾起案子策吠,更是在濱河造成了極大的恐慌,老刑警劉巖瘩绒,帶你破解...
    沈念sama閱讀 218,682評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件猴抹,死亡現(xiàn)場離奇詭異,居然都是意外死亡锁荔,警方通過查閱死者的電腦和手機蟀给,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,277評論 3 395
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來阳堕,“玉大人跋理,你說我怎么就攤上這事√褡埽” “怎么了前普?”我有些...
    開封第一講書人閱讀 165,083評論 0 355
  • 文/不壞的土叔 我叫張陵,是天一觀的道長越驻。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么缀旁? 我笑而不...
    開封第一講書人閱讀 58,763評論 1 295
  • 正文 為了忘掉前任记劈,我火速辦了婚禮,結(jié)果婚禮上并巍,老公的妹妹穿的比我還像新娘目木。我一直安慰自己,他們只是感情好懊渡,可當(dāng)我...
    茶點故事閱讀 67,785評論 6 392
  • 文/花漫 我一把揭開白布刽射。 她就那樣靜靜地躺著,像睡著了一般剃执。 火紅的嫁衣襯著肌膚如雪誓禁。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,624評論 1 305
  • 那天肾档,我揣著相機與錄音摹恰,去河邊找鬼。 笑死怒见,一個胖子當(dāng)著我的面吹牛俗慈,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播遣耍,決...
    沈念sama閱讀 40,358評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼闺阱,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了舵变?” 一聲冷哼從身側(cè)響起酣溃,我...
    開封第一講書人閱讀 39,261評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎棋傍,沒想到半個月后救拉,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,722評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡瘫拣,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年亿絮,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片麸拄。...
    茶點故事閱讀 40,030評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡派昧,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出拢切,到底是詐尸還是另有隱情蒂萎,我是刑警寧澤,帶...
    沈念sama閱讀 35,737評論 5 346
  • 正文 年R本政府宣布淮椰,位于F島的核電站五慈,受9級特大地震影響纳寂,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜泻拦,卻給世界環(huán)境...
    茶點故事閱讀 41,360評論 3 330
  • 文/蒙蒙 一毙芜、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧争拐,春花似錦腋粥、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,941評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至绑雄,卻和暖如春展辞,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背绳慎。 一陣腳步聲響...
    開封第一講書人閱讀 33,057評論 1 270
  • 我被黑心中介騙來泰國打工纵竖, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人杏愤。 一個月前我還...
    沈念sama閱讀 48,237評論 3 371
  • 正文 我出身青樓靡砌,卻偏偏與公主長得像,于是被迫代替她去往敵國和親珊楼。 傳聞我的和親對象是個殘疾皇子通殃,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,976評論 2 355