硬核NeruIPS 2018最佳論文帖蔓,一個神經(jīng)了的常微分方程

姓名:謝童? 學(xué)號:16020188008? 轉(zhuǎn)自微信公眾號 機(jī)器之心

姓名:謝童? 學(xué)號:16020188008? 轉(zhuǎn)自微信公眾號 機(jī)器之心這是一篇神奇的論文瞬矩,以前一層一層疊加的神經(jīng)網(wǎng)絡(luò)似乎突然變得連續(xù)了徒欣,反向傳播也似乎不再需要一點(diǎn)一點(diǎn)往前傳、一層一層更新參數(shù)了泳桦。

在最近結(jié)束的 NeruIPS 2018 中汤徽,來自多倫多大學(xué)的陳天琦等研究者成為最佳論文的獲得者。他們提出了一種名為神經(jīng)常微分方程的模型灸撰,這是新一類的深度神經(jīng)網(wǎng)絡(luò)谒府。神經(jīng)常微分方程不拘于對已有架構(gòu)的修修補(bǔ)補(bǔ),它完全從另外一個角度考慮如何以連續(xù)的方式借助神經(jīng)網(wǎng)絡(luò)對數(shù)據(jù)建模浮毯。在陳天琦的講解下完疫,機(jī)器之心將向各位讀者介紹這一令人興奮的神經(jīng)網(wǎng)絡(luò)新家族。

在與機(jī)器之心的訪談中债蓝,陳天琦的導(dǎo)師 David Duvenaud 教授談起這位學(xué)生也是贊不絕口壳鹤。Duvenaud 教授認(rèn)為陳天琦不僅是位理解能力超強(qiáng)的學(xué)生,鉆研起問題來也相當(dāng)認(rèn)真透徹饰迹。他說:「天琦很喜歡提出新想法芳誓,他有時會在我提出建議一周后再反饋:『老師你之前建議的方法不太合理。但是我研究出另外一套合理的方法啊鸭,結(jié)果我也做出來了锹淌。』」Ducenaud 教授評價道赠制,現(xiàn)如今人工智能熱度有增無減赂摆,教授能找到優(yōu)秀博士生基本如同「雞生蛋還是蛋生雞」的問題,頂尖學(xué)校的教授通常能快速地招納到博士生钟些,「我很幸運(yùn)地能在事業(yè)起步階段就遇到陳天琦如此優(yōu)秀的學(xué)生烟号。」

本文主要介紹神經(jīng)常微分方程背后的細(xì)想與直觀理解政恍,很多延伸的概念并沒有詳細(xì)解釋汪拥,例如大大降低計(jì)算復(fù)雜度的連續(xù)型流模型和官方 PyTorch 代碼實(shí)現(xiàn)等。這一篇文章重點(diǎn)對比了神經(jīng)常微分方程(ODEnet)與殘差網(wǎng)絡(luò)抚垃,我們不僅能通過這一部分了解如何從熟悉的 ResNet 演化到 ODEnet,同時還能還有新模型的前向傳播過程和特點(diǎn)。

其次文章比較關(guān)注 ODEnet 的反向傳播過程鹤树,即如何通過解常微分方程直接把梯度求出來铣焊。這一部分與傳統(tǒng)的反向傳播有很多不同,因此先理解反向傳播再看源碼可能是更好的選擇罕伯。值得注意的是曲伊,ODEnet 的反傳只有常數(shù)級的內(nèi)存占用成本。

ODEnet 的 PyTorch 實(shí)現(xiàn)地址:https://github.com/rtqichen/torchdiffeq

ODEnet 論文地址:https://arxiv.org/abs/1806.07366

如下展示了文章的主要結(jié)構(gòu):

常微分方程

從殘差網(wǎng)絡(luò)到微分方程

? ? 從微分方程到殘差網(wǎng)絡(luò)

? ? 網(wǎng)絡(luò)對比

神經(jīng)常微分方程

反向傳播

? ? 反向傳播怎么做

連續(xù)型的歸一化流

? ? 變量代換定理

常微分方程

其實(shí)初讀這篇論文追他,還是有一些疑惑的坟募,因?yàn)楹芏喔拍疃疾皇俏覀兯熘摹R虼巳绻胍私膺@個模型邑狸,那么同學(xué)們懈糯,你們首先需要回憶高數(shù)上的微分方程。有了這樣的概念后单雾,我們就能愉快地連續(xù)化神經(jīng)網(wǎng)絡(luò)層級赚哗,并構(gòu)建完整的神經(jīng)常微分方程。

常微分方程即只包含單個自變量 x硅堆、未知函數(shù) f(x) 和未知函數(shù)的導(dǎo)數(shù) f'(x) 的等式屿储,所以說 f'(x) = 2x 也算一個常微分方程。但更常見的可以表示為 df(x)/dx = g(f(x), x)渐逃,其中 g(f(x), x) 表示由 f(x) 和 x 組成的某個表達(dá)式够掠,這個式子是擴(kuò)展一般神經(jīng)網(wǎng)絡(luò)的關(guān)鍵,我們在后面會討論這個式子怎么就連續(xù)化了神經(jīng)網(wǎng)絡(luò)層級茄菊。

一般對于常微分方程疯潭,我們希望解出未知的 f(x),例如 f'(x) = 2x 的通解為 f(x)=x^2 +C买羞,其中 C 表示任意常數(shù)袁勺。而在工程中更常用數(shù)值解,即給定一個初值 f(x_0)畜普,我們希望解出末值 f(x_1)期丰,這樣并不需要解出完整的 f(x),只需要一步步逼近它就行了吃挑。

現(xiàn)在回過頭來討論我們熟悉的神經(jīng)網(wǎng)絡(luò)钝荡,本質(zhì)上不論是全連接、循環(huán)還是卷積網(wǎng)絡(luò)舶衬,它們都類似于一個非常復(fù)雜的復(fù)合函數(shù)埠通,復(fù)合的次數(shù)就等于層級的深度。例如兩層全連接網(wǎng)絡(luò)可以表示為 Y=f(f(X, θ1), θ2)逛犹,因此每一個神經(jīng)網(wǎng)絡(luò)層級都類似于萬能函數(shù)逼近器端辱。

因?yàn)檎w是復(fù)合函數(shù)梁剔,所以很容易接受復(fù)合函數(shù)的求導(dǎo)方法:鏈?zhǔn)椒▌t,并將梯度從最外一層的函數(shù)一點(diǎn)點(diǎn)先向里面層級的函數(shù)傳遞舞蔽,并且每傳到一層函數(shù)荣病,就可以更新該層的參數(shù) θ。現(xiàn)在問題是渗柿,我們前向傳播過后需要保留所有層的激活值个盆,并在沿計(jì)算路徑反傳梯度時利用這些激活值。這對內(nèi)存的占用非常大朵栖,因此也就限制了深度模型的訓(xùn)練過程颊亮。

神經(jīng)常微分方程走了另一條道路,它使用神經(jīng)網(wǎng)絡(luò)參數(shù)化隱藏狀態(tài)的導(dǎo)數(shù)陨溅,而不是如往常那樣直接參數(shù)化隱藏狀態(tài)终惑。這里參數(shù)化隱藏狀態(tài)的導(dǎo)數(shù)就類似構(gòu)建了連續(xù)性的層級與參數(shù),而不再是離散的層級声登。因此參數(shù)也是一個連續(xù)的空間狠鸳,我們不需要再分層傳播梯度與更新參數(shù)∶跎ぃ總而言之件舵,神經(jīng)微分方程在前向傳播過程中不儲存任何中間結(jié)果,因此它只需要近似常數(shù)級的內(nèi)存成本脯厨。

從殘差網(wǎng)絡(luò)到微分方程

殘差網(wǎng)絡(luò)是一類特殊的卷積網(wǎng)絡(luò)铅祸,它通過殘差連接而解決了梯度反傳問題,即當(dāng)神經(jīng)網(wǎng)絡(luò)層級非常深時合武,梯度仍然能有效傳回輸入端临梗。下圖為原論文中殘差模塊的結(jié)構(gòu),殘差塊的輸出結(jié)合了輸入信息與內(nèi)部卷積運(yùn)算的輸出信息稼跳,這種殘差連接或恒等映射表示深層模型至少不能低于淺層網(wǎng)絡(luò)的準(zhǔn)確度盟庞。這樣的殘差模塊堆疊幾十上百個就是非常深的殘差神經(jīng)網(wǎng)絡(luò)。

如果我們將上面的殘差模塊更加形式化地表示為以下方程:

其中 h_t 是第 t 層隱藏單元的輸出值汤善,f 為通過θ_t 參數(shù)化的神經(jīng)網(wǎng)絡(luò)什猖。該方程式表示上圖的整個殘差模塊,如果我們其改寫為殘差的形式红淡,即 h_t+1 - h_t = f(h_t, θ_t )不狮。那么我們可以看到神經(jīng)網(wǎng)絡(luò) f 參數(shù)化的是隱藏層之間的殘差,f 同樣不是直接參數(shù)化隱藏層在旱。

ResNet 假設(shè)層級的離散的摇零,第 t 層到第 t+1 層之間是無定義的。那么如果這中間是有定義的呢桶蝎?殘差項(xiàng) h_t0 - h_t1 是不是就應(yīng)該非常小驻仅,以至于接近無窮辛鲁?這里我們少考慮了分母噪服,即殘差項(xiàng)應(yīng)該表示為 (h_t+1 - h_t )/1铃彰,分母的 1 表示兩個離散的層級之間相差 1。所以再一次考慮層級間有定義芯咧,我們會發(fā)現(xiàn)殘差項(xiàng)最終會收斂到隱藏層對 t 的導(dǎo)數(shù),而神經(jīng)網(wǎng)絡(luò)實(shí)際上參數(shù)化的就是這個導(dǎo)數(shù)竹揍。

所以若我們在層級間加入更多的層敬飒,且最終趨向于添加了無窮層時,神經(jīng)網(wǎng)絡(luò)就連續(xù)化了芬位∥揶郑可以說殘差網(wǎng)絡(luò)其實(shí)就是連續(xù)變換的歐拉離散化,是一個特例昧碉,我們可以將這種連續(xù)變換形式化地表示為一個常微分方程:

如果從導(dǎo)數(shù)定義的角度來看英染,當(dāng) t 的變化趨向于無窮小時,隱藏狀態(tài)的變化 dh(t) 可以通過神經(jīng)網(wǎng)絡(luò)建模被饿。當(dāng) t 從初始一點(diǎn)點(diǎn)變化到終止四康,那么 h(t) 的改變最終就代表著前向傳播結(jié)果。這樣利用神經(jīng)網(wǎng)絡(luò)參數(shù)化隱藏層的導(dǎo)數(shù)狭握,就確確實(shí)實(shí)連續(xù)化了神經(jīng)網(wǎng)絡(luò)層級闪金。

現(xiàn)在若能得出該常微分方程的數(shù)值解,那么就相當(dāng)于完成了前向傳播论颅。具體而言哎垦,若 h(0)=X 為輸入圖像,那么終止時刻的隱藏層輸出 h(T) 就為推斷結(jié)果恃疯。這是一個常微分方程的初值問題漏设,可以直接通過黑箱的常微分方程求解器(ODE Solver)解出來。而這樣的求解器又能控制數(shù)值誤差今妄,因此我們總能在計(jì)算力和模型準(zhǔn)確度之間做權(quán)衡郑口。

形式上來說,現(xiàn)在就需要變換方程 (2) 以求出數(shù)值解蛙奖,即給定初始狀態(tài) h(t_0) 和神經(jīng)網(wǎng)絡(luò)的情況下求出終止?fàn)顟B(tài) h(t_1):

如上所示潘酗,常微分方程的數(shù)值解 h(t_1) 需要求神經(jīng)網(wǎng)絡(luò) f 從 t_0 到 t_1 的積分。我們完全可以利用 ODE solver 解出這個值雁仲,這在數(shù)學(xué)物理領(lǐng)域已經(jīng)有非常成熟的解法仔夺,我們只需要將其當(dāng)作一個黑盒工具使用就行了。

從微分方程到殘差網(wǎng)絡(luò)

前面提到過殘差網(wǎng)絡(luò)是神經(jīng)常微分方程的特例攒砖,可以說殘差網(wǎng)絡(luò)是歐拉方法的離散化缸兔。兩三百年前解常微分方程的歐拉法非常直觀日裙,即 h(t +Δt) = h(t) + Δt×f(h(t), t)。每當(dāng)隱藏層沿 t 走一小步Δt惰蜜,新的隱藏層狀態(tài) h(t +Δt) 就應(yīng)該近似在已有的方向上邁一小步昂拂。如果這樣一小步一小步從 t_0 走到 t_1,那么就求出了 ODE 的數(shù)值解抛猖。

如果我們令 Δt 每次都等于 1格侯,那么離散化的歐拉方法就等于殘差模塊的表達(dá)式 h(t+1) = h(t) + f(h(t), t)。但是歐拉法只是解常微分方程最基礎(chǔ)的方法财著,它每走一步都會產(chǎn)生一點(diǎn)誤差联四,且誤差會累積起來。近百年來撑教,數(shù)學(xué)家構(gòu)建了很多現(xiàn)代 ODE 求解方法朝墩,它們不僅能保證收斂到真實(shí)解,同時還能控制誤差水平伟姐。

陳天琦等研究者構(gòu)建的 ODE 網(wǎng)絡(luò)就使用了一種適應(yīng)性的 ODE solver收苏,它不像歐拉法移動固定的步長,相反它會根據(jù)給定的誤差容忍度選擇適當(dāng)?shù)牟介L逼近真實(shí)解愤兵。如下圖所示鹿霸,左邊的殘差網(wǎng)絡(luò)定義有限轉(zhuǎn)換的離散序列,它從 0 到 1 再到 5 是離散的層級數(shù)秆乳,且在每一層通過激活函數(shù)做一次非線性轉(zhuǎn)換杜跷。此外,黑色的評估位置可以視為神經(jīng)元矫夷,它會對輸入做一次轉(zhuǎn)換以修正傳遞的值葛闷。而右側(cè)的 ODE 網(wǎng)絡(luò)定義了一個向量場,隱藏狀態(tài)會有一個連續(xù)的轉(zhuǎn)換双藕,黑色的評估點(diǎn)也會根據(jù)誤差容忍度自動調(diào)整淑趾。

網(wǎng)絡(luò)對比

在 David 的 Oral 演講中,他以兩段偽代碼展示了 ResNet 與 ODEnet 之間的差別忧陪。如下展示了 ResNet 的主要過程扣泊,其中 f 可以視為卷積層,ResNet 為整個模型架構(gòu)嘶摊。在卷積層 f 中延蟹,h 為上一層輸出的特征圖,t 確定目前是第幾個卷積層叶堆。ResNet 中的循環(huán)體為殘差連接阱飘,因此該網(wǎng)絡(luò)一共 T 個殘差模塊,且最終返回第 T 層的輸出值。

deff(h,?t,?θ):

returnnnet(h,?θ_t)

defresnet(h):

fortin[1:T]:

h?=?h?+?f(h,?t,?θ)

returnh

相比常見的 ResNet沥匈,下面的偽代碼就比較新奇了蔗喂。首先 f 與前面一樣定義的是神經(jīng)網(wǎng)絡(luò),不過現(xiàn)在它的參數(shù)θ是一個整體高帖,同時 t 作為獨(dú)立參數(shù)也需要饋送到神經(jīng)網(wǎng)絡(luò)中缰儿,這表明層級之間也是有定義的,它是一種連續(xù)的網(wǎng)絡(luò)散址。而整個 ODEnet 不需要通過循環(huán)搭建離散的層級乖阵,它只要通過 ODE solver 求出 t_1 時刻的 h 就行了。

deff(h,?t,?θ):

returnnnet([h,?t],?θ)

defODEnet(h,?θ):

returnODESolver(f,?h,?t_0,?t_1,?θ)

除了計(jì)算過程不一樣预麸,陳天琦等研究者還在 MNSIT 測試了這兩種模型的效果义起。他們使用帶有 6 個殘差模塊的 ResNet,以及使用一個 ODE Solver 代替這些殘差模塊的 ODEnet师崎。以下展示了不同網(wǎng)絡(luò)在 MNSIT 上的效果、參數(shù)量椅棺、內(nèi)存占用量和計(jì)算復(fù)雜度犁罩。

其中單個隱藏層的 MLP 引用自 LeCun 在 1998 年的研究,其隱藏層只有 300 個神經(jīng)元两疚,但是 ODEnet 在有相似參數(shù)量的情況下能獲得顯著更好的結(jié)果床估。上表中 L 表示神經(jīng)網(wǎng)絡(luò)的層級數(shù),L tilde 表示 ODE Solver 中的評估次數(shù)诱渤,它可以近似代表 ODEnet 的「層級深度」丐巫。值得注意的是,ODEnet 只有常數(shù)級的內(nèi)存占用勺美,這表示不論層級的深度如何增加递胧,它的內(nèi)存占用基本不會有太大的變化。

神經(jīng)常微分方程

在與 ResNet 的類比中赡茸,我們基本上已經(jīng)了解了 ODEnet 的前向傳播過程缎脾。首先輸入數(shù)據(jù) Z(t_0),我們可以通過一個連續(xù)的轉(zhuǎn)換函數(shù)(神經(jīng)網(wǎng)絡(luò))對輸入進(jìn)行非線性變換占卧,從而得到 f遗菠。隨后 ODESolver 對 f 進(jìn)行積分,再加上初值就可以得到最后的推斷結(jié)果华蜒。如下所示辙纬,殘差網(wǎng)絡(luò)只不過是用一個離散的殘差連接代替 ODE Solver。

在前向傳播中叭喜,ODEnet 還有幾個非常重要的性質(zhì)贺拣,即模型的層級數(shù)與模型的誤差控制。首先因?yàn)槭沁B續(xù)模型,其并沒有明確的層級數(shù)纵柿,因此我們只能使用相似的度量確定模型的「深度」蜈抓,作者在這篇論文中采用 ODE Solver 評估的次數(shù)作為深度。

其次昂儒,深度與誤差控制有著直接的聯(lián)系沟使,ODEnet 通過控制誤差容忍度能確定模型的深度。因?yàn)?ODE Solver 能確保在誤差容忍度之內(nèi)逼近常微分方程的真實(shí)解渊跋,改變誤差容忍度就能改變神經(jīng)網(wǎng)絡(luò)的行為腊嗡。一般而言,降低 ODE Solver 的誤差容忍度將增加函數(shù)的評估的次數(shù)拾酝,因此類似于增加了模型的「深度」燕少。調(diào)整誤差容忍度能允許我們在準(zhǔn)確度與計(jì)算成本之間做權(quán)衡,因此我們在訓(xùn)練時可以采用高準(zhǔn)確率而學(xué)習(xí)更好的神經(jīng)網(wǎng)絡(luò)蒿囤,在推斷時可以根據(jù)實(shí)際計(jì)算環(huán)境調(diào)整為較低的準(zhǔn)確度客们。

如原論文的上圖所示,a 圖表示模型能保證在誤差范圍為內(nèi)材诽,且隨著誤差降低底挫,前向傳播的函數(shù)評估數(shù)增加。b 圖展示了評估數(shù)與相對計(jì)算時間的關(guān)系脸侥。d 圖展示了函數(shù)評估數(shù)會隨著訓(xùn)練的增加而自適應(yīng)地增加建邓,這表明隨著訓(xùn)練的進(jìn)行,模型的復(fù)雜度會增加睁枕。

c 圖比較有意思官边,它表示前向傳播的函數(shù)評估數(shù)大致是反向傳播評估數(shù)的一倍,這恰好表示反向傳播中的 adjoint sensitivity 方法不僅內(nèi)存效率高外遇,同時計(jì)算效率也比直接通過積分器的反向傳播高注簿。這主要是因?yàn)?adjoint sensitivity 并不需要依次傳遞到前向傳播中的每一個函數(shù)評估,即梯度不通過模型的深度由后向前一層層傳跳仿。

反向傳播

師從同門的 Jesse Bettencourt 向機(jī)器之心介紹道滩援,「天琦最擅長的就是耐心講解∷遥」當(dāng)他遇到任何無論是代碼問題玩徊,理論問題還是數(shù)學(xué)問題,一旦是問了同桌的天琦谨究,對方就一定會慢慢地花時間把問題講清楚恩袱、講透徹。而 ODEnet 的反向傳播胶哲,就是這樣一種需要耐心講解的問題畔塔。

ODEnet 的反向傳播與常見的反向傳播有一些不同,我們可能需要仔細(xì)查閱原論文與對應(yīng)的附錄證明才能有較深的理解。此外澈吨,作者給出了 ODEnet 的 PyTorch 實(shí)現(xiàn)把敢,我們也可以通過它了解實(shí)現(xiàn)細(xì)節(jié)。

正如作者而言谅辣,訓(xùn)練一個連續(xù)層級網(wǎng)絡(luò)的主要技術(shù)難點(diǎn)在于令梯度穿過 ODE Solver 的反向傳播修赞。其實(shí)如果令梯度沿著前向傳播的計(jì)算路徑反傳回去是非常直觀的,但是內(nèi)存占用會比較大而且數(shù)值誤差也不能控制桑阶。作者的解決方案是將前向傳播的 ODE Solver 視為一個黑箱操作柏副,梯度很難或根本不需要傳遞進(jìn)去,只需要「繞過」就行了蚣录。

作者采用了一種名為 adjoint method 的梯度計(jì)算方法來「繞過」前向傳播中的 ODE Solver割择,即模型在反傳中通過第二個增廣 ODE Solver 算出梯度,其可以逼近按計(jì)算路徑從 ODE Solver 傳遞回的梯度萎河,因此可用于進(jìn)一步的參數(shù)更新荔泳。這種方法如上圖 c 所示不僅在計(jì)算和內(nèi)存非常有優(yōu)勢,同時還能精確地控制數(shù)值誤差虐杯。

具體而言玛歌,若我們的損失函數(shù)為 L(),且它的輸入為 ODE Solver 的輸出:

我們第一步需要求 L 對 z(t) 的導(dǎo)數(shù)厦幅,或者說模型損失的變化如何取決于隱藏狀態(tài) z(t) 的變化。其中損失函數(shù) L 對 z(t_1) 的導(dǎo)數(shù)可以為整個模型的梯度計(jì)算提供入口慨飘。作者將這一個導(dǎo)數(shù)稱為 adjoint a(t) = -dL/z(t)确憨,它其實(shí)就相當(dāng)于隱藏層的梯度。

在基于鏈?zhǔn)椒▌t的傳統(tǒng)反向傳播中瓤的,我們需要從后一層對前一層求導(dǎo)以傳遞梯度休弃。而在連續(xù)化的 ODEnet 中,我們需要將前面求出的 a(t) 對連續(xù)的 t 進(jìn)行求導(dǎo)圈膏,由于 a(t) 是損失 L 對隱藏狀態(tài) z(t) 的導(dǎo)數(shù)塔猾,這就和傳統(tǒng)鏈?zhǔn)椒▌t中的傳播概念基本一致。下式展示了 a(t) 的導(dǎo)數(shù)稽坤,它能將梯度沿著連續(xù)的 t 向前傳丈甸,附錄 B.1 介紹了該式具體的推導(dǎo)過程。

在獲取每一個隱藏狀態(tài)的梯度后尿褪,我們可以再求它們對參數(shù)的導(dǎo)數(shù)睦擂,并更新參數(shù)。同樣在 ODEnet 中杖玲,獲取隱藏狀態(tài)的梯度后顿仇,再對參數(shù)求導(dǎo)并積分后就能得到損失對參數(shù)的導(dǎo)數(shù),這里之所以需要求積分是因?yàn)椤笇蛹墶箃 是連續(xù)的。這一個方程式可以表示為:

綜上臼闻,我們對 ODEnet 的反傳過程主要可以直觀理解為三步驟鸿吆,即首先求出梯度入口伴隨 a(t_1),再求 a(t) 的變化率 da(t)/dt述呐,這樣就能求出不同時刻的 a(t)惩淳。最后借助 a(t) 與 z(t),我們可以求出損失對參數(shù)的梯度市埋,并更新參數(shù)黎泣。當(dāng)然這里只是簡要的直觀理解,更完整的反傳過程展示在原論文的算法 1缤谎。

反向傳播怎么做

在算法 1 中抒倚,陳天琦等研究者展示了如何借助另一個 OED Solver 一次性求出反向傳播的各種梯度和更新量。要理解算法 1坷澡,首先我們要熟悉 ODESolver 的表達(dá)方式托呕。例如在 ODEnet 的前向傳播中,求解過程可以表示為 ODEsolver(z(t_0), f, t_0, t_1, θ)频敛,我們可以理解為從 t_0 時刻開始令 z(t_0) 以變化率 f 進(jìn)行演化项郊,這種演化即 f 在 t 上的積分,ODESolver 的目標(biāo)是通過積分求得 z(t_1)斟赚。

同樣我們能以這種方式理解算法 1着降,我們的目的是利用 ODESolver 從 z(t_1) 求出 z(t_0)、從 a(t_1) 按照方程 4 積出 a(t_0)拗军、從 0 按照方程 5 積出 dL/dθ任洞。最后我們只需要使用 dL/dθ 更新神經(jīng)網(wǎng)絡(luò) f(z(t), t, θ) 就完成了整個反向傳播過程。

如上所示发侵,若初始給定參數(shù)θ交掏、前向初始時刻 t_0 和終止時刻 t_1、終止?fàn)顟B(tài) z(t_1) 和梯度入口 ?L/?z(t_1)刃鳄。接下來我們可以將三個積分都并在一起以一次性解出所有量盅弛,因此我們可以定義初始狀態(tài) s_0,它們是解常微分方程的初值叔锐。

注意第一個初值 z(t_1)挪鹏,其實(shí)在前向傳播中,從 z(t_0) 到 z(t_1) 都已經(jīng)算過一遍了愉烙,但是模型并不會保留計(jì)算結(jié)果狰住,因此也就只有常數(shù)級的內(nèi)存成本。此外齿梁,在算 a(t) 時需要知道對應(yīng)的 z(t)催植,例如 ?L/?z(t_0) 就要求知道 z(t_0) 的值肮蛹。如果我們不能保存中間狀態(tài)的話,那么也可以從 z(t_1) 到 z(t_0) 反向再算一遍中間狀態(tài)创南。這個計(jì)算過程和前向過程基本一致伦忠,即從 z(t_1) 開始以變化率 f 進(jìn)行演化而推出 z(t_0)。

定義 s_0 后稿辙,我們需要確定初始狀態(tài)都是怎樣「演化」到終止?fàn)顟B(tài)的昆码,定義這些演化的即前面方程 (3)、(4) 和 (5) 的被積函數(shù)邻储,也就是算法 1 中 aug_dynamics() 函數(shù)所定義的赋咽。

其中 f(z(t), t, θ) 從 t_1 到 t_0 積出來為 z(t_0),這第一個常微分方程是為了給第二個提供條件吨娜。而-a(t)*?L/?z(t) 從 t_1 到 t_0 積出來為 a(t_0)脓匿,它類似于傳統(tǒng)神經(jīng)網(wǎng)絡(luò)中損失函數(shù)對第一個隱藏層的導(dǎo)數(shù),整個 a(t) 就相當(dāng)于隱藏層的梯度宦赠。只有獲取積分路徑中所有隱藏層的梯度陪毡,我們才有可能進(jìn)一步解出損失函數(shù)對參數(shù)的梯度。

因此反向傳播中的第一個和第二個常微分方程 都是為第三個微分方程提供條件勾扭,即 a(t) 和 z(t)毡琉。最后,從 t_1 到 t_0 積分 -a(t)*?f(z(t), t, θ)/?θ 就能求出 dL/dθ妙色。只需要一個積分桅滋,我們不再一層層傳遞梯度并更新該層特定的參數(shù)。

如下偽代碼所示身辨,完成反向傳播的步驟很簡單丐谋。先定義各變量演化的方法,再結(jié)合將其結(jié)合初始化狀態(tài)一同傳入 ODESolver 就行了栅表。

deff_and_a([z,?a],?t):

return[f,?-a*df/da,?-a*df/dθ]

[z0,?dL/dx,?dL/dθ]?=

ODESolver([z(t1),?dL/dz(t),0],?f_and_a,?t1,?t0)

連續(xù)型的歸一化流

這種連續(xù)型轉(zhuǎn)換有一個非常重要的屬性笋鄙,即流模型中最基礎(chǔ)的變量代換定理可以便捷快速地計(jì)算得出师枣。在論文的第四節(jié)中怪瓶,作者根據(jù)這樣的推導(dǎo)結(jié)果構(gòu)建了一個新型可逆密度模型,它能克服 Glow 等歸一化流模型的缺點(diǎn)践美,并直接通過最大似然估計(jì)訓(xùn)練洗贰。

變量代換定理

對于概率密度估計(jì)中的變量代換定理,我們可以從單變量的情況開始陨倡。若給定一個隨機(jī)變量 z 和它的概率密度函數(shù) z~π(z)敛滋,我們希望使用映射函數(shù) x=f(z) 構(gòu)建一個新的隨機(jī)變量。函數(shù) f 是可逆的兴革,即 z=g(x)绎晃,其中 f 和 g 互為逆函數(shù)∶弁伲現(xiàn)在問題是如何推斷新變量的未知概率密度函數(shù) p(x)?

通過定義庶艾,積分項(xiàng) ∫π(z)dz 表示無限個無窮小的矩形面積之和袁余,其中積分元Δz 為積分小矩形的寬,小矩形在位置 z 的高為概率密度函數(shù) π(z) 定義的值咱揍。若使用 f^?1(x) 表示 f(x) 的逆函數(shù)颖榜,當(dāng)我們替換變量的時候,z=f^?1(x) 需要服從 Δz/Δx=(f^?1(x))′煤裙。多變量的變量代換定理可以從單變量推廣而出掩完,其中 det ?f/?z 為函數(shù) f 的雅可比行列式:

一般使用變量代換定理需要計(jì)算雅可比矩陣?f/?z 的行列式,這是主要的限制硼砰,最近的研究工作都在權(quán)衡歸一化流模型隱藏層的表達(dá)能力與計(jì)算成本且蓬。但是研究者發(fā)現(xiàn),將離散的層級替換為連續(xù)的轉(zhuǎn)換夺刑,可以簡化計(jì)算缅疟,我們只需要算雅可比矩陣的跡就行了。核心的定理 1 如下所示:

在普通的變量代換定理中遍愿,分布的變換函數(shù) f(或神經(jīng)網(wǎng)絡(luò))必須是可逆的存淫,而且要制作可逆的神經(jīng)網(wǎng)絡(luò)也很復(fù)雜。在陳天琦等研究者定理里沼填,不論 f 是什么樣的神經(jīng)網(wǎng)絡(luò)都沒問題桅咆,它天然可逆,所以這種連續(xù)化的模型對流模型的應(yīng)用應(yīng)該非常方便坞笙。

如下所示岩饼,隨機(jī)變量 z(t_0) 及其分布可以通過一個連續(xù)的轉(zhuǎn)換演化到 z(t_1) 及其分布:

此外,連續(xù)型流模型還有很多性質(zhì)與優(yōu)勢薛夜,但這里并不展開籍茧。變量代換定理 1 在附錄 A 中有完整的證明,感興趣的讀者可查閱原論文了解細(xì)節(jié)梯澜。

最后寞冯,神經(jīng)常微分方程是一種全新的框架,除了流模型外晚伙,很多方法在連續(xù)變換的改變下都有新屬性吮龄,這些屬性可能在離散激活的情況下很難獲得。也許未來會有很多的研究關(guān)注這一新模型咆疗,連續(xù)化的神經(jīng)網(wǎng)絡(luò)也會變得多種多樣

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末漓帚,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子午磁,更是在濱河造成了極大的恐慌尝抖,老刑警劉巖毡们,帶你破解...
    沈念sama閱讀 212,383評論 6 493
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異昧辽,居然都是意外死亡漏隐,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,522評論 3 385
  • 文/潘曉璐 我一進(jìn)店門奴迅,熙熙樓的掌柜王于貴愁眉苦臉地迎上來青责,“玉大人,你說我怎么就攤上這事取具〔绷ィ” “怎么了?”我有些...
    開封第一講書人閱讀 157,852評論 0 348
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我江解,道長,這世上最難降的妖魔是什么构蹬? 我笑而不...
    開封第一講書人閱讀 56,621評論 1 284
  • 正文 為了忘掉前任,我火速辦了婚禮悔据,結(jié)果婚禮上庄敛,老公的妹妹穿的比我還像新娘。我一直安慰自己科汗,他們只是感情好藻烤,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,741評論 6 386
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著头滔,像睡著了一般怖亭。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上坤检,一...
    開封第一講書人閱讀 49,929評論 1 290
  • 那天兴猩,我揣著相機(jī)與錄音,去河邊找鬼早歇。 笑死倾芝,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的缺前。 我是一名探鬼主播蛀醉,決...
    沈念sama閱讀 39,076評論 3 410
  • 文/蒼蘭香墨 我猛地睜開眼悬襟,長吁一口氣:“原來是場噩夢啊……” “哼衅码!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起脊岳,我...
    開封第一講書人閱讀 37,803評論 0 268
  • 序言:老撾萬榮一對情侶失蹤逝段,失蹤者是張志新(化名)和其女友劉穎垛玻,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體奶躯,經(jīng)...
    沈念sama閱讀 44,265評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡帚桩,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,582評論 2 327
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了嘹黔。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片账嚎。...
    茶點(diǎn)故事閱讀 38,716評論 1 341
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖儡蔓,靈堂內(nèi)的尸體忽然破棺而出郭蕉,到底是詐尸還是另有隱情,我是刑警寧澤喂江,帶...
    沈念sama閱讀 34,395評論 4 333
  • 正文 年R本政府宣布召锈,位于F島的核電站,受9級特大地震影響获询,放射性物質(zhì)發(fā)生泄漏涨岁。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 40,039評論 3 316
  • 文/蒙蒙 一吉嚣、第九天 我趴在偏房一處隱蔽的房頂上張望梢薪。 院中可真熱鬧,春花似錦尝哆、人聲如沸沮尿。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,798評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽畜疾。三九已至,卻和暖如春印衔,著一層夾襖步出監(jiān)牢的瞬間啡捶,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 32,027評論 1 266
  • 我被黑心中介騙來泰國打工奸焙, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留瞎暑,地道東北人。 一個月前我還...
    沈念sama閱讀 46,488評論 2 361
  • 正文 我出身青樓与帆,卻偏偏與公主長得像了赌,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子玄糟,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,612評論 2 350

推薦閱讀更多精彩內(nèi)容