TensorFlow從0到1 - 16 - L2正則化對抗“過擬合”

TensorFlow從0到1系列回顧

前面的14 交叉熵?fù)p失函數(shù)——防止學(xué)習(xí)緩慢15 重新思考神經(jīng)網(wǎng)絡(luò)初始化從學(xué)習(xí)緩慢問題入手,嘗試改進(jìn)神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)腻贰。本篇討論過擬合問題吁恍,并引入與之相對的L2正則化(Regularization)方法。

overfitting播演,來源:http://blog.algotrading101.com

無處不在的過擬合

模型對于已知數(shù)據(jù)的描述適應(yīng)性過高冀瓦,導(dǎo)致對新數(shù)據(jù)的泛化能力不佳,我們稱模型對于數(shù)據(jù)過擬合(overfitting)写烤。

過擬合無處不在翼闽。

羅素的火雞對自己的末日始料未及,曾真理般存在的牛頓力學(xué)淪為狹義相對論在低速情況下的近似洲炊,次貸危機(jī)破滅了美國買房只漲不跌的神話肄程,血戰(zhàn)鋼鋸嶺的醫(yī)療兵Desmond也并不是懦夫。

凡是基于經(jīng)驗(yàn)的學(xué)習(xí)选浑,都存在過擬合的風(fēng)險蓝厌。動物、人古徒、機(jī)器都不能幸免拓提。

Russell's Turkey,來源:http://chaospet.com/115-russells-turkey/g

誰存在過擬合隧膘?

對于一些離散的二維空間中的樣本點(diǎn)代态,下面兩條曲線誰存在過擬合?

誰存在過擬合疹吃?

遵循奧卡姆剃刀的一派蹦疑,主張“如無必要,勿增實(shí)體”。他們相信相對簡單的模型泛化能力更好:上圖中的藍(lán)色直線萨驶,雖然只有很少的樣本點(diǎn)直接落在它上面歉摧,但是不妨認(rèn)為這些樣本點(diǎn)或多或少包含一些噪聲。基于這種認(rèn)知叁温,可以預(yù)測新樣本也會在這條直線附近出現(xiàn)再悼。

或許很多時候,傾向簡單會占上風(fēng)膝但,但是真實(shí)世界的復(fù)雜性深不可測冲九。雖然在自然科學(xué)中,奧卡姆剃刀被作為啟發(fā)性技巧來使用跟束,幫助科學(xué)家發(fā)展理論模型工具莺奸,但是它并沒有被當(dāng)做邏輯上不可辯駁的定理或者科學(xué)結(jié)論〖窖纾總有簡單模型表達(dá)不了灭贷,只能通過復(fù)雜模型來描述的事物存在。很有可能紅色的曲線才是對客觀世界的真實(shí)反映花鹅。

康德為了對抗奧卡姆剃刀產(chǎn)生的影響,創(chuàng)建了他自己的反剃刀:“存在的多樣性不應(yīng)被粗暴地忽視”枫浙。

阿爾伯特·愛因斯坦告誡:“科學(xué)理論應(yīng)該盡可能簡單刨肃,但不能過于簡單÷嶂悖”

所以僅從上圖來判斷真友,一個理性的回答是:不知道。即使是如此簡單的二維空間情況下紧帕,在沒有更多的新樣本數(shù)據(jù)做出驗(yàn)證之前盔然,不能僅通過模型形式的簡單或復(fù)雜來判定誰存在過擬合。

過擬合的判斷

二維是嗜、三維的模型愈案,本身可以很容易的繪制出來,當(dāng)新的樣本出現(xiàn)后鹅搪,通過觀察即可大致判斷模型是否存在過擬合站绪。

然而現(xiàn)實(shí)情況要復(fù)雜的多。對MNIST數(shù)字識別所采用的3層感知器——輸入層784個神經(jīng)元丽柿,隱藏層30個神經(jīng)元恢准,輸出層10個神經(jīng)元,包含23860個參數(shù)(23860 = 784 x 30 + 30 x 10 + 30 + 10)甫题,靠繪制模型來觀察是不現(xiàn)實(shí)的馁筐。

最有效的方式是通過識別精度判斷模型是否存在過擬合:比較模型對驗(yàn)證集和訓(xùn)練集的識別精度,如果驗(yàn)證集識別精度大幅低于訓(xùn)練集坠非,則可以判斷模型存在過擬合敏沉。

至于為什么是驗(yàn)證集而不是測試集,請復(fù)習(xí)11 74行Python實(shí)現(xiàn)手寫體數(shù)字識別中“驗(yàn)證集與超參數(shù)”一節(jié)。

然而靜態(tài)的比較已訓(xùn)練模型對兩個集合的識別精度無法回答一個問題:過擬合是什么時候發(fā)生的赦抖?

要獲得這個信息舱卡,就需要在模型訓(xùn)練過程中動態(tài)的監(jiān)測每次迭代(Epoch)后訓(xùn)練集和驗(yàn)證集的識別精度,一旦出現(xiàn)訓(xùn)練集識別率繼續(xù)上升而驗(yàn)證集識別率不再提高队萤,就說明過擬合發(fā)生了轮锥。

這種方法還會帶來一個額外的收獲:確定作為超參數(shù)之一的迭代數(shù)(Epoch Number)的量級。更進(jìn)一步要尔,甚至可以不設(shè)置固定的迭代次數(shù)舍杜,以過擬合為信號,一旦發(fā)生就提前停止(early stopping)訓(xùn)練赵辕,避免后續(xù)無效的迭代既绩。

過擬合監(jiān)測

了解了過擬合的概念以及監(jiān)測方法,就可以開始分析我們訓(xùn)練MNIST數(shù)字識別模型是否存在過擬合了还惠。

所用代碼:tf_16_mnist_loss_weight.py饲握。它在12 TensorFlow構(gòu)建3層NN玩轉(zhuǎn)MNIST代碼的基礎(chǔ)上,使用了交叉熵?fù)p失蚕键,以及1/sqrt(nin)權(quán)重初始化:

  • 1個隱藏層救欧,包含30個神經(jīng)元;
  • 學(xué)習(xí)率:3.0锣光;
  • 迭代數(shù):30次笆怠;
  • mini batch:10;

訓(xùn)練過程中誊爹,分別對訓(xùn)練集和驗(yàn)證集的識別精度進(jìn)行了跟蹤蹬刷,如下圖所示,其中紅線代表訓(xùn)練集識別率频丘,藍(lán)線代表測試集識別率办成。圖中顯示,大約在第15次迭代前后搂漠,測試集的識別精度穩(wěn)定在95.5%不再提高诈火,而訓(xùn)練集的識別精度仍然繼續(xù)上升,直到30次迭代全部結(jié)束后達(dá)到了98.5%状答,兩者相差3%冷守。

由此可見,模型存在明顯的過擬合的特征惊科。

訓(xùn)練集和驗(yàn)證集識別精度(基于TensorBoard繪制)

過擬合的對策:L2正則化

對抗過擬合最有效的方法就是增加訓(xùn)練數(shù)據(jù)的完備性拍摇,但它昂貴且有限。另一種思路是減小網(wǎng)絡(luò)的規(guī)模馆截,但它可能會因?yàn)橄拗屏四P偷谋磉_(dá)潛力而導(dǎo)致識別精度整體下降充活。

本篇引入L2正則化(Regularization)蜂莉,可以在原有的訓(xùn)練數(shù)據(jù),以及網(wǎng)絡(luò)架構(gòu)不縮減的情況下混卵,有效避免過擬合映穗。L2正則化即在損失函數(shù)C的表達(dá)式上追加L2正則化項(xiàng)

L2正則化

上式中的C0代表原損失函數(shù),可以替換成均方誤差幕随、交叉熵等任何一種損失函數(shù)表達(dá)式蚁滋。

關(guān)于L2正則化項(xiàng)的幾點(diǎn)說明:

  • 求和∑是對網(wǎng)絡(luò)中的所有權(quán)重進(jìn)行的;
  • λ(lambda)為自定義參數(shù)(超參數(shù))赘淮;
  • n是訓(xùn)練樣本的數(shù)量(注意不是所有權(quán)重的數(shù)量辕录!);
  • L2正則化并沒有偏置參與梢卸;

該如何理解正則化呢走诞?

對于使網(wǎng)絡(luò)達(dá)到最小損失的權(quán)重w,很可能有非常多不同分布的解:有的均值偏大蛤高、有的偏小蚣旱,有的分布均勻,有的稀疏戴陡。那么在這個w的解空間里塞绿,該如何挑選相對更好的呢?正則化通過添加約束的方式猜欺,幫我們找到一個方向位隶。

L2正則化表達(dá)式暗示著一種傾向:訓(xùn)練盡可能的小的權(quán)重拷窜,較大的權(quán)重需要保證能顯著降低原有損失C0才能保留开皿。

至于正則化為何能有效的緩解過擬合,這方面數(shù)學(xué)解釋其實(shí)并不充分篮昧,更多是基于經(jīng)驗(yàn)的認(rèn)知赋荆。

L2正則化的實(shí)現(xiàn)

因?yàn)樵谠袚p失函數(shù)中追加了L2正則化項(xiàng),那么是不是得修改現(xiàn)有反向傳播算法(BP1中有用到C的表達(dá)式)懊昨?答案是不需要窄潭。

C對w求偏導(dǎo)數(shù),可以拆分成原有C0對w求偏導(dǎo)酵颁,以及L2正則項(xiàng)對w求偏導(dǎo)嫉你。前者繼續(xù)利用原有的反向傳播計算方法,而后者可以直接計算得到:

C對于偏置b求偏導(dǎo)保持不變:

基于上述躏惋,就可以得到權(quán)重w和偏置b的更新方法:

TensorFlow實(shí)現(xiàn)L2正則化

TensorFlow的最優(yōu)化方法tf.train.GradientDescentOptimizer包辦了梯度下降幽污、反向傳播,所以基于TensorFlow實(shí)現(xiàn)L2正則化簿姨,并不能按照上節(jié)的算法直接干預(yù)權(quán)重的更新距误,而要使用TensorFlow方式:

tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)
tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)
regularizer = tf.contrib.layers.l2_regularizer(scale=5.0/50000)
reg_term = tf.contrib.layers.apply_regularization(regularizer)

loss = (tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +
    reg_term)

對上述代碼的一些說明:

  • 將網(wǎng)絡(luò)中所有層中的權(quán)重簸搞,依次通過tf.add_to_collectio加入到tf.GraphKeys.WEIGHTS中;
  • 調(diào)用tf.contrib.layers.l2_regularizer生成L2正則化方法准潭,注意所傳參數(shù)scale=λ/n(n為訓(xùn)練樣本的數(shù)量);
  • 調(diào)用tf.contrib.layers.apply_regularization來生成損失函數(shù)的L2正則化項(xiàng)reg_term趁俊,所傳第一個參數(shù)為上面生成的正則化方法,第二個參數(shù)為none時默認(rèn)值為tf.GraphKeys.WEIGHTS刑然;
  • 最后將L2正則化reg_term項(xiàng)追加到損失函數(shù)表達(dá)式寺擂;

向原有損失函數(shù)追加L2正則化項(xiàng),模型和訓(xùn)練設(shè)置略作調(diào)整:

  • 1個隱藏層闰集,包含100個神經(jīng)元沽讹;
  • 學(xué)習(xí)率:0.5;
  • 迭代數(shù):30次武鲁;
  • mini batch:10爽雄;

重新運(yùn)行訓(xùn)練,跟蹤訓(xùn)練集和驗(yàn)證集的識別精度沐鼠,如下圖所示挚瘟。圖中顯示,在整個30次迭代中饲梭,訓(xùn)練集和驗(yàn)證集的識別率均持續(xù)上升(都超過95%)乘盖,最終兩者的差距控制在0.5%,過擬合程度顯著的減輕了憔涉。

需要注意的是订框,盡管正則化有效降低了驗(yàn)證集上過擬合程度,但是也降低了訓(xùn)練集的識別精度兜叨。所以在實(shí)現(xiàn)L2正則化時增加了隱藏層的神經(jīng)元數(shù)量(從30到100)來抵消識別精度的下降穿扳。

L2正則化(基于TensorBoard繪制)

附完整代碼

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
                                      validation_size=10000)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 100]) / tf.sqrt(784.0))
    '''W_2 = tf.get_variable(
        name="W_2",
        regularizer=regularizer,
        initializer=tf.random_normal([784, 30], stddev=1 / tf.sqrt(784.0)))'''
    b_2 = tf.Variable(tf.random_normal([100]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([100, 10]) / tf.sqrt(100.0))
    '''W_3 = tf.get_variable(
        name="W_3",
        regularizer=regularizer,
        initializer=tf.random_normal([30, 10], stddev=1 / tf.sqrt(30.0)))'''
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_2)
    tf.add_to_collection(tf.GraphKeys.WEIGHTS, W_3)
    regularizer = tf.contrib.layers.l2_regularizer(scale=5.0 / 50000)
    reg_term = tf.contrib.layers.apply_regularization(regularizer)

    loss = (tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3)) +
        reg_term)

    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    scalar_accuracy = tf.summary.scalar('accuracy', accuracy)
    train_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/train', sess.graph)
    validation_writer = tf.summary.FileWriter(
        'MNIST/logs/tf16_reg/validation')

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        accuracy_currut_train = sess.run(accuracy,
                                         feed_dict={x: mnist.train.images,
                                                    y_: mnist.train.labels})

        accuracy_currut_validation = sess.run(
            accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        sum_accuracy_train = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.train.images,
                       y_: mnist.train.labels})

        sum_accuracy_validation = sess.run(
            scalar_accuracy,
            feed_dict={x: mnist.validation.images,
                       y_: mnist.validation.labels})

        train_writer.add_summary(sum_accuracy_train, epoch)
        validation_writer.add_summary(sum_accuracy_validation, epoch)

        print("Epoch %s: train: %s validation: %s"
              % (epoch, accuracy_currut_train, accuracy_currut_validation))
        best = (best, accuracy_currut_validation)[
            best <= accuracy_currut_validation]

    # Test trained model
    print("best: %s" % best)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='../MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下載 tf_16_mnist_loss_weight_reg.py

上一篇 15 重新思考神經(jīng)網(wǎng)絡(luò)初始化
下一篇 17 Step By Step上手TensorBoard


共享協(xié)議:署名-非商業(yè)性使用-禁止演繹(CC BY-NC-ND 3.0 CN)
轉(zhuǎn)載請注明:作者黑猿大叔(簡書)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市国旷,隨后出現(xiàn)的幾起案子矛物,更是在濱河造成了極大的恐慌,老刑警劉巖跪但,帶你破解...
    沈念sama閱讀 218,755評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件履羞,死亡現(xiàn)場離奇詭異,居然都是意外死亡屡久,警方通過查閱死者的電腦和手機(jī)忆首,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,305評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來被环,“玉大人,你說我怎么就攤上這事蛤售《侍叮” “怎么了揣钦?”我有些...
    開封第一講書人閱讀 165,138評論 0 355
  • 文/不壞的土叔 我叫張陵雳灾,是天一觀的道長。 經(jīng)常有香客問我冯凹,道長,這世上最難降的妖魔是什么匈庭? 我笑而不...
    開封第一講書人閱讀 58,791評論 1 295
  • 正文 為了忘掉前任,我火速辦了婚禮浑劳,結(jié)果婚禮上阱持,老公的妹妹穿的比我還像新娘。我一直安慰自己衷咽,他們只是感情好蒜绽,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,794評論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著鼎姊,像睡著了一般相赁。 火紅的嫁衣襯著肌膚如雪相寇。 梳的紋絲不亂的頭發(fā)上噪生,一...
    開封第一講書人閱讀 51,631評論 1 305
  • 那天跺嗽,我揣著相機(jī)與錄音页藻,去河邊找鬼。 笑死份帐,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的畜挨。 我是一名探鬼主播,決...
    沈念sama閱讀 40,362評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼毡咏,長吁一口氣:“原來是場噩夢啊……” “哼逮刨!你這毒婦竟也來了呕缭?” 一聲冷哼從身側(cè)響起修己,我...
    開封第一講書人閱讀 39,264評論 0 276
  • 序言:老撾萬榮一對情侶失蹤睬愤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后尤辱,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,724評論 1 315
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡奸鸯,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年娄涩,在試婚紗的時候發(fā)現(xiàn)自己被綠了映跟。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,040評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡球恤,死狀恐怖荸镊,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情躬存,我是刑警寧澤,帶...
    沈念sama閱讀 35,742評論 5 346
  • 正文 年R本政府宣布宛逗,位于F島的核電站盾剩,受9級特大地震影響替蔬,放射性物質(zhì)發(fā)生泄漏屎暇。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,364評論 3 330
  • 文/蒙蒙 一快毛、第九天 我趴在偏房一處隱蔽的房頂上張望番挺。 院中可真熱鬧,春花似錦襟衰、人聲如沸粪摘。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,944評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽椎咧。三九已至,卻和暖如春勤讽,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背向臀。 一陣腳步聲響...
    開封第一講書人閱讀 33,060評論 1 270
  • 我被黑心中介騙來泰國打工诸狭, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人三娩。 一個月前我還...
    沈念sama閱讀 48,247評論 3 371
  • 正文 我出身青樓妹懒,卻偏偏與公主長得像双吆,于是被迫代替她去往敵國和親会前。 傳聞我的和親對象是個殘疾皇子匾竿,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,979評論 2 355

推薦閱讀更多精彩內(nèi)容