背景
反向傳播訓練(Backpropagation)一個神經(jīng)網(wǎng)絡是一種常見的方法萄金。網(wǎng)上并不缺少介紹反向傳播是如何工作的論文。但很少包括一個用實際數(shù)字的例子。這篇文章是我試圖解釋它是如何工作的和一個具體的例子, 大家可以對比自己的計算,以確保他們正確理解反向傳播噪窘。
Python 實現(xiàn)反向傳播算法
您可以到 Github 嘗試我寫的一個反向傳播算法Python腳本。
反向傳播算法可視化
一個交互式可視化顯示神經(jīng)網(wǎng)絡學習過程, 可以看看我的神經(jīng)網(wǎng)絡可視化網(wǎng)站效扫。
額外的資源
果你發(fā)現(xiàn)本教程有用,想繼續(xù)學習神經(jīng)網(wǎng)絡及其應用,我強烈推薦看看Adrian Rosebrock的優(yōu)秀教程Getting Started with Deep Learning and Python
概述
對于本教程,我們將使用一個有 2 個輸入神經(jīng)元倔监、2 個隱藏的神經(jīng)元和 2 個輸出神經(jīng)元的神經(jīng)網(wǎng)絡直砂。此外,隱藏層和輸出層將包括一個 偏差神經(jīng)元(Bias)。
這里的基本結(jié)構(gòu):
為了一些數(shù)字,這是初始權(quán)重,偏差,和訓練輸入/輸出:
反向傳播的目標是優(yōu)化神經(jīng)網(wǎng)絡的權(quán)重,這樣神經(jīng)網(wǎng)絡可以學習如何正確將任意輸入映射到輸出浩习。
本教程的剩余部分我們要處理一個訓練集:給定輸入0.05和0.10,我們希望神經(jīng)網(wǎng)絡輸出0.01和0.99静暂。
前向傳播
讓我們看看目前神經(jīng)網(wǎng)絡給定的偏差、權(quán)重和輸入的0.05和0.10瘦锹。為此我們要養(yǎng)活這些輸入提前雖然網(wǎng)絡籍嘹。
我們算出每個隱藏神經(jīng)元的總輸入,再利用總輸入作為激活函數(shù)(這里我們使用 Sigmoid 函數(shù))的變量弯院,然后在輸出層神經(jīng)元重復這一步驟。
這是我們?nèi)绾斡嬎?code>h1總輸入:
然后使用 Sigmoid 函數(shù)計算h1
輸出:
同理得h2
輸出:
我們對輸出層神經(jīng)元重復這個過程泪掀,使用隱層神經(jīng)元的輸出作為輸入听绳。
這是o1
的輸出:
同理得o2
輸出:
計算總誤差
我們現(xiàn)在可以計算每個輸出神經(jīng)元平方誤差和:
例如,o1
預期輸出為 0.01,但實際輸出為0.75136507异赫,因此他的誤差是:
重復這個過程得到o2
(預期輸出是0.99)的誤差是
因此椅挣,神經(jīng)網(wǎng)絡的總誤差為
反向傳播過程
反向傳播的目標是更新連接的權(quán)重以使每個神經(jīng)元的實際輸出更加接近預期輸出,從而減少每個神經(jīng)元以及整個網(wǎng)絡的誤差塔拳。
輸出層
考慮一下ω5
鼠证,我們希望知道ω5
的改變對誤差的影響有大多,稱為
(誤差對
ω5
求偏導數(shù))根據(jù)我們所知道的鏈式法則得出:
可視化我們所做的事情
我們需要弄清楚這個等式的每一部分靠抑。
首先量九,
o1
的輸出變化對總誤差的影響有多大?
我們用總誤差對
求偏導數(shù)時颂碧,的值變?yōu)?0 荠列,因為不會影響o2
的誤差。
下一步载城,o1
總輸入的變化對于o1
的輸出的影響有多大肌似?
最后,計算 ω5
的變化對o1
總輸入的影響有多大诉瓦?
將這三者放在一起:
Delta規(guī)則——權(quán)值的修正量等于誤差乘以輸入
我們也可以將這個計算過程組合成 δ規(guī)則 的形式:
(1)
令
(2)
因為
所以(3)
聯(lián)立(1)(2)(3)得
為了減少誤差川队,我們從當前權(quán)重減去這個值(學習率可自定義,這里我們設置為0.5):
重復這個過程睬澡,我們可以得到權(quán)重 ω6
, ω7
, 和 ω8
:
我們在得到新的隱藏層神經(jīng)元的輸入權(quán)重之后再更新 ω6
, ω7
, 和 ω8
(也就是說固额,在進行反向傳播的時候我們使用舊的權(quán)重值)
隱藏層
接下來,我們將繼續(xù)向后傳播,計算新值ω1
, ω2
, ω3
, 和 ω4
猴贰。
全局來說对雪,我們需要計算
可視化:
我們要用類似計算輸出層那樣的過程,但略有不同的是:每個隱層神經(jīng)元的輸出會對多個輸出神經(jīng)元的輸出和誤差產(chǎn)生印象。我們知道
out_h1
將同時影響out_o1
和out_o2
(為方便表示米绕,這里用下劃線表示下標瑟捣,下同)馋艺。因此out_h1
對每個輸出神經(jīng)元的影響:
開始:
我們之前計算過
然后
ω5
,因為:得:
同理得:
因此迈套,
捐祠。
然后我們計算:
接下來我們計算h1
的總輸入對ω1
求偏導數(shù):
綜上所述,
你也可以這么寫
現(xiàn)在我們可以更新ω1
了:
重復該過程計算 ω1
, ω2
, 和 ω3
:
最后,我們已經(jīng)更新所有的權(quán)重! 我們最初提出 0.05 和 0.1 的輸入,網(wǎng)絡上的誤差為 0.298371109 桑李。第一輪反向傳播之后,現(xiàn)在總誤差降至 0.291027924 踱蛀。它可能看起來沒有調(diào)整太多。但是在這個過程重復 10000 次之后贵白,比如說率拒,誤差降到0.000035085。在這一時刻禁荒,當我們輸入0.05和0.1時猬膨,兩個輸出神經(jīng)元分別輸出0.015912196 ( vs 預期 0.01) and 0.984065734 (vs 預期 0.99) 。
如果你做到這一步呛伴,發(fā)現(xiàn)任何錯誤或者能想到更通俗易懂的說明方法勃痴,請加我公眾號 jinkey-love 交流。