TensorFlow從0到1 - 14 - 交叉熵損失函數(shù)——防止學習緩慢

TensorFlow從0到1系列回顧

通過上一篇 13 馴獸師:神經(jīng)網(wǎng)絡調(diào)教綜述垂睬,對神經(jīng)網(wǎng)絡的調(diào)教有了一個整體印象缠借,本篇從學習緩慢這一常見問題入手歧譬,引入交叉熵損失函數(shù),并分析它是如何克服學習緩慢問題陷猫。

學習緩慢

“嚴重錯誤”導致學習緩慢

回顧識別MNIST的網(wǎng)絡架構(gòu)秫舌,我們采用了經(jīng)典的S型神經(jīng)元的妖,以及常見的基于均方誤差(MSE)的二次函數(shù)作為損失函數(shù)。殊不知這種組合足陨,在實際輸出與預期偏離較大時嫂粟,會造成學習緩慢。

簡單的說墨缘,如果在初始化權(quán)重和偏置時星虹,故意產(chǎn)生一個背離預期較大的輸出,那么訓練網(wǎng)絡的過程中需要用很多次迭代镊讼,才能抵消掉這種背離宽涌,恢復正常的學習。這種現(xiàn)象與人類學習的經(jīng)驗相悖:對于明顯的錯誤狠毯,人類能進行快速的修正护糖。

為了看清楚這個現(xiàn)象,可以用一個S型神經(jīng)元嚼松,從微觀角度進行重現(xiàn)嫡良。這個神經(jīng)元接受1個固定的輸入“1”,期望經(jīng)過訓練后能輸出“0”献酗,因此待訓練參數(shù)為1個權(quán)重w和1個偏置b寝受,如下圖:

單一神經(jīng)元

先觀察一個“正常”初始化的情況罕偎。

令w=0.6很澄,b=0.9,可認為其符合均值為0颜及,標準差為1的正態(tài)分布甩苛。此時,輸入1俏站,輸出0.82讯蒲。接下來開始使用梯度下降法進行迭代訓練,從Epoch-Cost曲線可以看到“損失”快速降低肄扎,到第100次時就很低了墨林,到第300次迭代時已經(jīng)幾乎為0,符合預期犯祠,如下圖:

正常的學習

接下來換一種初始化策略旭等。

將w和b都賦值為“2.0”。此時衡载,輸入1搔耕,輸出為0.98——比之前的0.82偏離預期值0更遠了。接下來的訓練Epoch-Cost曲線顯示200次迭代后“損失”依然很高月劈,減少緩慢度迂,而最后100次迭代才開始恢復正常的學習藤乙,如下圖:

學習緩慢

學習緩慢原因分析

單個樣本情況下,基于均方誤差的二次損失函數(shù)為:

B-N-F-8

一個神經(jīng)元的情況下就不用反向傳播求導了惭墓,已知a = σ(z)坛梁,z = wx + b,直接使用鏈式求導即可:

B-N-F-11

將唯一的一個訓練樣本(x=1腊凶,y=0)代入划咐,得到:

B-N-F-11-2

觀察σ(z)函數(shù)曲線會發(fā)現(xiàn),當σ接近于1時钧萍,σ曲線特別的平坦褐缠,所以此處σ'(z)是一個非常小的值,由上式可推斷C的梯度也會非常小风瘦,“下降”自然也就會變得緩慢队魏。這種情況也成為神經(jīng)元飽和。這就解釋了前面初始的神經(jīng)元輸出a=0.98万搔,為什么會比a=0.82學習緩慢那么多胡桨。

Sigmoid

交叉熵損失函數(shù)

S型神經(jīng)元,與二次均方誤差損失函數(shù)的組合瞬雹,一旦神經(jīng)元輸出發(fā)生“嚴重錯誤”昧谊,網(wǎng)絡將陷入一種艱難而緩慢的學習“沼澤”中。

對此一個簡單的策略就是更換損失函數(shù)酗捌,使用交叉熵損失函數(shù)可以明顯的改善當發(fā)生“嚴重錯誤”時導致的學習緩慢呢诬,使神經(jīng)網(wǎng)絡的學習更符合人類經(jīng)驗——快速從錯誤中修正。

交叉熵損失函數(shù)定義如下:

交叉熵損失函數(shù)

在證明它真的能避免學習緩慢之前胖缤,有必要先確認它是否至少可以衡量“損失”尚镰,后者并不顯而易見。

一個函數(shù)能夠作為損失函數(shù)哪廓,要符合以下兩個特性:

  • 非負钓猬;
  • 當實際輸出接近預期,那么損失函數(shù)應該接近0撩独。

交叉熵全部符合。首先账月,實際輸出a的取值范圍為(0, 1)综膀,所以無論是lna還是ln(1-a)都是負數(shù),期望值y的取值非0即1局齿,因此中括號里面每項都是負數(shù)剧劝,再加上表達式最前面的一個負號,所以整體為非負抓歼。再者讥此,當預期y為0時拢锹,如果實際輸出a接近0時,C也接近0萄喳;當預期y為1時卒稳,如果實際輸出a接近1,那么C也接近0他巨。

接下來分析為什么交叉熵可以避免學習緩慢充坑,仍然從求C的偏導開始。

單樣本情況下染突,交叉熵損失函數(shù)可以記為:

交叉熵損失函數(shù)

對C求w的偏導數(shù):

B-N-F-12-2

a = σ(z)捻爷,將其代入:

B-N-F-12-3

對于Sigmoid函數(shù),有σ'(z) = σ(z)(1-σ(z)),所以上式中的σ'(z)被抵消了,得到:

B-N-F-12-4

由此可見红碑,C的梯度不再與σ'(z)有關(guān)寓娩,而與a-y相關(guān),其結(jié)果就是:實際輸出與預期偏離越大局蚀,梯度越大,學習越快

對于偏置棵介,同理有:

B-N-F-12-5

更換損失函數(shù)為交叉熵后,回到之前學習緩慢的例子吧史,重新訓練邮辽,Epoch-Cost曲線顯示學習緩慢的情況消失了。

學習緩慢消失

推廣到多神經(jīng)元網(wǎng)絡

前面的有效性證明是基于一個神經(jīng)元所做的微觀分析贸营,將其推廣到多層神經(jīng)元網(wǎng)絡也是很容易的吨述。從分量的角度來看,假設輸出神經(jīng)元的預期值是y = y1钞脂,y2揣云,...,實際輸出aL = aL1冰啃,aL2邓夕,...,那么交叉熵損失函數(shù)計算公式如下:

交叉熵損失函數(shù)

評價交叉熵損失阎毅,注意以下3點:

  • 交叉熵無法改善隱藏層中神經(jīng)元發(fā)生的學習緩慢焚刚。損失函數(shù)定義中的aL是最后一層神經(jīng)元的實際輸出,所以“損失”C針對輸出層神經(jīng)元的權(quán)重wLj求偏導數(shù)扇调,可以產(chǎn)生抵消σ'(zLj)的效果矿咕,從而避免輸出層神經(jīng)元的學習緩慢問題。但是“損失”C對于隱藏層神經(jīng)元的權(quán)重wL-1j求偏導,就無法產(chǎn)生抵消σ'(zL-1j)的效果碳柱。

  • 交叉熵損失函數(shù)只對網(wǎng)絡輸出“明顯背離預期”時發(fā)生的學習緩慢有改善效果捡絮,如果初始輸出背離預期并不明顯,那么應用交叉熵損失函數(shù)也無法觀察到明顯的改善莲镣。從另一個角度看福稳,應用交叉熵損失是一種防御性策略,增加訓練的穩(wěn)定性剥悟。

  • 應用交叉熵損失并不能改善或避免神經(jīng)元飽和灵寺,而是當輸出層神經(jīng)元發(fā)生飽和時,能夠避免其學習緩慢的問題区岗。

小結(jié)

現(xiàn)有神經(jīng)網(wǎng)絡中存在一種風險:由于初始化或其他巧合因素略板,一旦出現(xiàn)輸出與預期偏離過大,就會導致網(wǎng)絡學習緩慢慈缔。本篇分析了該現(xiàn)象出現(xiàn)的原因叮称,引入交叉熵損失函數(shù),并推理證明了其有效性藐鹤。

附完整代碼

代碼基于12 TF構(gòu)建3層NN玩轉(zhuǎn)MNIST中的tf_12_mnist_nn.py瓤檐,修改了損失函數(shù),TensorFlow提供了交叉熵的封裝:

loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 30]))
    b_2 = tf.Variable(tf.random_normal([30]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([30, 10]))
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    # loss = tf.reduce_mean(tf.norm(y_ - a_3, axis=1)**2) / 2
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
    train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
        accuracy_currut = sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                        y_: mnist.test.labels})
        print("Epoch %s: %s / 10000" % (epoch, accuracy_currut))
        best = (best, accuracy_currut)[best <= accuracy_currut]

    # Test trained model
    print("best: %s / 10000" % best)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下載 tf_14_mnist_nn_cross_entropy.py娱节。

上一篇 13 AI馴獸師:神經(jīng)網(wǎng)絡調(diào)教綜述
下一篇 15 1/sqrt(n)權(quán)重初始化


共享協(xié)議:署名-非商業(yè)性使用-禁止演繹(CC BY-NC-ND 3.0 CN)
轉(zhuǎn)載請注明:作者黑猿大叔(簡書)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末挠蛉,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子肄满,更是在濱河造成了極大的恐慌谴古,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,755評論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件稠歉,死亡現(xiàn)場離奇詭異掰担,居然都是意外死亡,警方通過查閱死者的電腦和手機怒炸,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,305評論 3 395
  • 文/潘曉璐 我一進店門带饱,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人阅羹,你說我怎么就攤上這事勺疼。” “怎么了捏鱼?”我有些...
    開封第一講書人閱讀 165,138評論 0 355
  • 文/不壞的土叔 我叫張陵恢口,是天一觀的道長。 經(jīng)常有香客問我穷躁,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,791評論 1 295
  • 正文 為了忘掉前任问潭,我火速辦了婚禮猿诸,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘狡忙。我一直安慰自己梳虽,他們只是感情好,可當我...
    茶點故事閱讀 67,794評論 6 392
  • 文/花漫 我一把揭開白布灾茁。 她就那樣靜靜地躺著窜觉,像睡著了一般。 火紅的嫁衣襯著肌膚如雪北专。 梳的紋絲不亂的頭發(fā)上禀挫,一...
    開封第一講書人閱讀 51,631評論 1 305
  • 那天,我揣著相機與錄音拓颓,去河邊找鬼语婴。 笑死,一個胖子當著我的面吹牛驶睦,可吹牛的內(nèi)容都是我干的砰左。 我是一名探鬼主播,決...
    沈念sama閱讀 40,362評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼场航,長吁一口氣:“原來是場噩夢啊……” “哼缠导!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起溉痢,我...
    開封第一講書人閱讀 39,264評論 0 276
  • 序言:老撾萬榮一對情侶失蹤僻造,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后适室,有當?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體嫡意,經(jīng)...
    沈念sama閱讀 45,724評論 1 315
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,900評論 3 336
  • 正文 我和宋清朗相戀三年捣辆,在試婚紗的時候發(fā)現(xiàn)自己被綠了蔬螟。 大學時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 40,040評論 1 350
  • 序言:一個原本活蹦亂跳的男人離奇死亡汽畴,死狀恐怖旧巾,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情忍些,我是刑警寧澤鲁猩,帶...
    沈念sama閱讀 35,742評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站罢坝,受9級特大地震影響廓握,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點故事閱讀 41,364評論 3 330
  • 文/蒙蒙 一隙券、第九天 我趴在偏房一處隱蔽的房頂上張望男应。 院中可真熱鬧,春花似錦娱仔、人聲如沸沐飘。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,944評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽耐朴。三九已至,卻和暖如春盹憎,著一層夾襖步出監(jiān)牢的瞬間筛峭,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,060評論 1 270
  • 我被黑心中介騙來泰國打工脚乡, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留蜒滩,地道東北人。 一個月前我還...
    沈念sama閱讀 48,247評論 3 371
  • 正文 我出身青樓奶稠,卻偏偏與公主長得像俯艰,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子锌订,可洞房花燭夜當晚...
    茶點故事閱讀 44,979評論 2 355

推薦閱讀更多精彩內(nèi)容