通過上一篇 13 馴獸師:神經(jīng)網(wǎng)絡調(diào)教綜述垂睬,對神經(jīng)網(wǎng)絡的調(diào)教有了一個整體印象缠借,本篇從學習緩慢這一常見問題入手歧譬,引入交叉熵損失函數(shù),并分析它是如何克服學習緩慢問題陷猫。
“嚴重錯誤”導致學習緩慢
回顧識別MNIST的網(wǎng)絡架構(gòu)秫舌,我們采用了經(jīng)典的S型神經(jīng)元的妖,以及常見的基于均方誤差(MSE)的二次函數(shù)作為損失函數(shù)。殊不知這種組合足陨,在實際輸出與預期偏離較大時嫂粟,會造成學習緩慢。
簡單的說墨缘,如果在初始化權(quán)重和偏置時星虹,故意產(chǎn)生一個背離預期較大的輸出,那么訓練網(wǎng)絡的過程中需要用很多次迭代镊讼,才能抵消掉這種背離宽涌,恢復正常的學習。這種現(xiàn)象與人類學習的經(jīng)驗相悖:對于明顯的錯誤狠毯,人類能進行快速的修正护糖。
為了看清楚這個現(xiàn)象,可以用一個S型神經(jīng)元嚼松,從微觀角度進行重現(xiàn)嫡良。這個神經(jīng)元接受1個固定的輸入“1”,期望經(jīng)過訓練后能輸出“0”献酗,因此待訓練參數(shù)為1個權(quán)重w和1個偏置b寝受,如下圖:
先觀察一個“正常”初始化的情況罕偎。
令w=0.6很澄,b=0.9,可認為其符合均值為0颜及,標準差為1的正態(tài)分布甩苛。此時,輸入1俏站,輸出0.82讯蒲。接下來開始使用梯度下降法進行迭代訓練,從Epoch-Cost曲線可以看到“損失”快速降低肄扎,到第100次時就很低了墨林,到第300次迭代時已經(jīng)幾乎為0,符合預期犯祠,如下圖:
接下來換一種初始化策略旭等。
將w和b都賦值為“2.0”。此時衡载,輸入1搔耕,輸出為0.98——比之前的0.82偏離預期值0更遠了。接下來的訓練Epoch-Cost曲線顯示200次迭代后“損失”依然很高月劈,減少緩慢度迂,而最后100次迭代才開始恢復正常的學習藤乙,如下圖:
學習緩慢原因分析
單個樣本情況下,基于均方誤差的二次損失函數(shù)為:
一個神經(jīng)元的情況下就不用反向傳播求導了惭墓,已知a = σ(z)坛梁,z = wx + b,直接使用鏈式求導即可:
將唯一的一個訓練樣本(x=1腊凶,y=0)代入划咐,得到:
觀察σ(z)函數(shù)曲線會發(fā)現(xiàn),當σ接近于1時钧萍,σ曲線特別的平坦褐缠,所以此處σ'(z)是一個非常小的值,由上式可推斷C的梯度也會非常小风瘦,“下降”自然也就會變得緩慢队魏。這種情況也成為神經(jīng)元飽和。這就解釋了前面初始的神經(jīng)元輸出a=0.98万搔,為什么會比a=0.82學習緩慢那么多胡桨。
交叉熵損失函數(shù)
S型神經(jīng)元,與二次均方誤差損失函數(shù)的組合瞬雹,一旦神經(jīng)元輸出發(fā)生“嚴重錯誤”昧谊,網(wǎng)絡將陷入一種艱難而緩慢的學習“沼澤”中。
對此一個簡單的策略就是更換損失函數(shù)酗捌,使用交叉熵損失函數(shù)可以明顯的改善當發(fā)生“嚴重錯誤”時導致的學習緩慢呢诬,使神經(jīng)網(wǎng)絡的學習更符合人類經(jīng)驗——快速從錯誤中修正。
交叉熵損失函數(shù)定義如下:
在證明它真的能避免學習緩慢之前胖缤,有必要先確認它是否至少可以衡量“損失”尚镰,后者并不顯而易見。
一個函數(shù)能夠作為損失函數(shù)哪廓,要符合以下兩個特性:
- 非負钓猬;
- 當實際輸出接近預期,那么損失函數(shù)應該接近0撩独。
交叉熵全部符合。首先账月,實際輸出a的取值范圍為(0, 1)综膀,所以無論是lna還是ln(1-a)都是負數(shù),期望值y的取值非0即1局齿,因此中括號里面每項都是負數(shù)剧劝,再加上表達式最前面的一個負號,所以整體為非負抓歼。再者讥此,當預期y為0時拢锹,如果實際輸出a接近0時,C也接近0萄喳;當預期y為1時卒稳,如果實際輸出a接近1,那么C也接近0他巨。
接下來分析為什么交叉熵可以避免學習緩慢充坑,仍然從求C的偏導開始。
單樣本情況下染突,交叉熵損失函數(shù)可以記為:
對C求w的偏導數(shù):
a = σ(z)捻爷,將其代入:
對于Sigmoid函數(shù),有σ'(z) = σ(z)(1-σ(z)),所以上式中的σ'(z)被抵消了,得到:
由此可見红碑,C的梯度不再與σ'(z)有關(guān)寓娩,而與a-y相關(guān),其結(jié)果就是:實際輸出與預期偏離越大局蚀,梯度越大,學習越快。
對于偏置棵介,同理有:
更換損失函數(shù)為交叉熵后,回到之前學習緩慢的例子吧史,重新訓練邮辽,Epoch-Cost曲線顯示學習緩慢的情況消失了。
推廣到多神經(jīng)元網(wǎng)絡
前面的有效性證明是基于一個神經(jīng)元所做的微觀分析贸营,將其推廣到多層神經(jīng)元網(wǎng)絡也是很容易的吨述。從分量的角度來看,假設輸出神經(jīng)元的預期值是y = y1钞脂,y2揣云,...,實際輸出aL = aL1冰啃,aL2邓夕,...,那么交叉熵損失函數(shù)計算公式如下:
評價交叉熵損失阎毅,注意以下3點:
交叉熵無法改善隱藏層中神經(jīng)元發(fā)生的學習緩慢焚刚。損失函數(shù)定義中的aL是最后一層神經(jīng)元的實際輸出,所以“損失”C針對輸出層神經(jīng)元的權(quán)重wLj求偏導數(shù)扇调,可以產(chǎn)生抵消σ'(zLj)的效果矿咕,從而避免輸出層神經(jīng)元的學習緩慢問題。但是“損失”C對于隱藏層神經(jīng)元的權(quán)重wL-1j求偏導,就無法產(chǎn)生抵消σ'(zL-1j)的效果碳柱。
交叉熵損失函數(shù)只對網(wǎng)絡輸出“明顯背離預期”時發(fā)生的學習緩慢有改善效果捡絮,如果初始輸出背離預期并不明顯,那么應用交叉熵損失函數(shù)也無法觀察到明顯的改善莲镣。從另一個角度看福稳,應用交叉熵損失是一種防御性策略,增加訓練的穩(wěn)定性剥悟。
應用交叉熵損失并不能改善或避免神經(jīng)元飽和灵寺,而是當輸出層神經(jīng)元發(fā)生飽和時,能夠避免其學習緩慢的問題区岗。
小結(jié)
現(xiàn)有神經(jīng)網(wǎng)絡中存在一種風險:由于初始化或其他巧合因素略板,一旦出現(xiàn)輸出與預期偏離過大,就會導致網(wǎng)絡學習緩慢慈缔。本篇分析了該現(xiàn)象出現(xiàn)的原因叮称,引入交叉熵損失函數(shù),并推理證明了其有效性藐鹤。
附完整代碼
代碼基于12 TF構(gòu)建3層NN玩轉(zhuǎn)MNIST中的tf_12_mnist_nn.py瓤檐,修改了損失函數(shù),TensorFlow提供了交叉熵的封裝:
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
FLAGS = None
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W_2 = tf.Variable(tf.random_normal([784, 30]))
b_2 = tf.Variable(tf.random_normal([30]))
z_2 = tf.matmul(x, W_2) + b_2
a_2 = tf.sigmoid(z_2)
W_3 = tf.Variable(tf.random_normal([30, 10]))
b_3 = tf.Variable(tf.random_normal([10]))
z_3 = tf.matmul(a_2, W_3) + b_3
a_3 = tf.sigmoid(z_3)
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
# loss = tf.reduce_mean(tf.norm(y_ - a_3, axis=1)**2) / 2
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# Train
best = 0
for epoch in range(30):
for _ in range(5000):
batch_xs, batch_ys = mnist.train.next_batch(10)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
accuracy_currut = sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print("Epoch %s: %s / 10000" % (epoch, accuracy_currut))
best = (best, accuracy_currut)[best <= accuracy_currut]
# Test trained model
print("best: %s / 10000" % best)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/MNIST/',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
下載 tf_14_mnist_nn_cross_entropy.py娱节。
上一篇 13 AI馴獸師:神經(jīng)網(wǎng)絡調(diào)教綜述
下一篇 15 1/sqrt(n)權(quán)重初始化
共享協(xié)議:署名-非商業(yè)性使用-禁止演繹(CC BY-NC-ND 3.0 CN)
轉(zhuǎn)載請注明:作者黑猿大叔(簡書)