去年寫的文章,從notion的博客搬到這邊來(lái)發(fā)一下(本來(lái)想搬到微信公眾號(hào)的耸彪,但是那個(gè)格式真的反人類就作罷了)鳍侣,原文請(qǐng)到這里看mewimpetus.以后文章都會(huì)再這邊先發(fā)抡医。
引言
擴(kuò)散模型是今年AI領(lǐng)域最熱門的研究方向英融。由其引發(fā)的AI繪畫的產(chǎn)業(yè)變革正在如火如荼的進(jìn)行癣防,大有淘汰一大票初中級(jí)畫師的勢(shì)頭趴俘,目前主流的(諸如OpenAI的DALL-E 2;Google的ImageGen;以及已經(jīng)商業(yè)化的MidJourney;注重二次元的NovelAI寥闪;開源引爆這波熱潮的stable-diffusion)圖像生成模型效果已經(jīng)讓人驚艷太惠,若是再發(fā)展幾年,它帶來(lái)的影響將不可估量疲憋,可以說(shuō)整個(gè)繪畫產(chǎn)業(yè)正在經(jīng)歷著一場(chǎng)百年未有之大變局凿渊。而這些功能強(qiáng)大的繪畫模型,無(wú)疑都與Denoising Diffusion Probabilistic Models 擺脫不了關(guān)系缚柳,它的原始論文由Google Brain在2020年發(fā)表埃脏。 這篇博文主要帶大家一起來(lái)探究一下DDPM的工作原理和實(shí)現(xiàn)細(xì)節(jié)。
擴(kuò)散模型的基本流程
其實(shí)擴(kuò)散模型的基本思路同GAN以及VAE并無(wú)二致秋忙,都是試圖從一個(gè)簡(jiǎn)單分布的隨機(jī)噪聲出發(fā)彩掐,經(jīng)過(guò)一系列的轉(zhuǎn)換,轉(zhuǎn)變成類似于真實(shí)數(shù)據(jù)的數(shù)據(jù)樣本翰绊。
它主要包含前向加噪聲和反向去噪聲兩個(gè)過(guò)程:
- 從真實(shí)的數(shù)據(jù)分布中隨機(jī)采樣一個(gè)圖片佩谷,然后通過(guò)一個(gè)固定的過(guò)程逐步往上面添加高斯隨機(jī)噪聲,直到圖片變成一個(gè)純粹的噪聲
- 構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)监嗜,去學(xué)習(xí)一個(gè)去噪的過(guò)程谐檀,從一個(gè)純粹的噪聲出發(fā),逐步還原回一個(gè)真實(shí)的圖像裁奇。
接下來(lái)我們用數(shù)學(xué)形式來(lái)表達(dá)上面的兩個(gè)過(guò)程桐猬。
前向擴(kuò)散
我們將真實(shí)數(shù)據(jù)的分布定義為,然后可以從這個(gè)分布中隨機(jī)采樣一個(gè)”真圖“
刽肠,于是我們就可以定義一個(gè)前向擴(kuò)散的遞推過(guò)程
為每個(gè)時(shí)間步
添加少量高斯噪聲并執(zhí)行
步溃肪。DDPM作者將
定義為這樣一個(gè)條件高斯分布(其中的
是一個(gè)既定的遞增表):
顯然,當(dāng)時(shí)刻的圖像為
的條件下音五,
時(shí)刻的圖像
服從一個(gè)均值
惫撰,方差
的各項(xiàng)同性高斯分布。我們?cè)儆^察一下這個(gè)遞推式躺涝,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=1-%5Cbeta" alt="1-\beta" mathimg="1">和
都小于1厨钻,顯然
的均值會(huì)比
更加趨向于
,方差也更趨向于
,因此如果設(shè)計(jì)合適的
序列,最終的
將趨近于標(biāo)準(zhǔn)的高斯分布
夯膀。根據(jù)高斯分布的性質(zhì)1:
如果
且
與
都是實(shí)數(shù)诗充,那么
。
上述的條件高斯分布顯然可以通過(guò)從標(biāo)準(zhǔn)高斯分布的線性變換得到诱建,我們定義,那么只要讓
,那么第
個(gè)時(shí)間步的圖像
蝴蜓。
為了更好的計(jì)算任意時(shí)刻的條件分布,我們根據(jù)上面的遞推式逐步推導(dǎo)到
,為了方便推導(dǎo)俺猿,我們令
茎匠,
則有了推導(dǎo)1:
上式中第3行到第4行的推導(dǎo)用到了上述的性質(zhì)1,以及高斯分布的另一個(gè)性質(zhì)2:
如果
與
是獨(dú)立統(tǒng)計(jì)的高斯隨機(jī)變量辜荠,那么汽抚,它們的和也滿足高斯分布
。
由性質(zhì)1可知伯病,,而
,再根據(jù)性質(zhì)2,就可得
, 再根據(jù)性質(zhì)1寫回到多項(xiàng)式的形式即得到推導(dǎo)的結(jié)果否过。
基于這個(gè)最終的推導(dǎo)結(jié)果午笛,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Calpha_t" alt="\alpha_t" mathimg="1">是事先已經(jīng)定義好的,我們只需要給出初始真實(shí)分布采樣苗桂,即可以計(jì)算出任何第
步的樣本
药磺,而不需要每次都從
開始一步步計(jì)算。
反向去噪
有了前向的過(guò)程煤伟,我們反過(guò)來(lái)想癌佩,既然前向擴(kuò)散是一個(gè)馬爾可夫過(guò)程,那么它的逆過(guò)程顯然也是馬爾可夫過(guò)程便锨,如果我們可以構(gòu)造一個(gè)相反的條件分布,那不就可以從最終的
開始一步步地去噪围辙,從而反推回初始的
了嗎? 但是我們并不知道反向條件高斯分布的均值和方差放案。不過(guò)姚建,在這個(gè)深度學(xué)習(xí)的時(shí)代,我們可以從真實(shí)數(shù)據(jù)集
出發(fā)吱殉,通過(guò)前向過(guò)程生成一系列的
的真實(shí)擴(kuò)散序列掸冤,然后設(shè)計(jì)一個(gè)神經(jīng)網(wǎng)絡(luò)從這些序列中來(lái)近似學(xué)習(xí)一個(gè)分布
使其接近真實(shí)的
,其中的
是這個(gè)神經(jīng)網(wǎng)絡(luò)需要學(xué)習(xí)的參數(shù)友雳,于是從
變換到
的概率可以表示成:
當(dāng)我們前向過(guò)程所定義的足夠小時(shí)稿湿,反向過(guò)程也滿足高斯分布,因此我們可以假設(shè)神經(jīng)網(wǎng)絡(luò)要學(xué)習(xí)的這個(gè)分布是高斯分布押赊,這意味著它需要去學(xué)習(xí)其均值
和方差
,換成與上述前向過(guò)程相同的表示則有遞推公式:
借助這個(gè)公式饺藤,我們就可以完成去噪過(guò)程了,接下來(lái)的任務(wù)變成了如何訓(xùn)練這個(gè)神經(jīng)網(wǎng)絡(luò)。
如何訓(xùn)練
基本思路
不知大家又沒(méi)有覺(jué)得這個(gè)加噪聲和去噪聲的過(guò)程和VAE的編碼和解碼的過(guò)程十分類似策精,那么是否可以從VAE的訓(xùn)練方式中得到一些啟發(fā)呢舰始?實(shí)際上作者就是這么想的。
顯然询刹,如果直接使用與
的對(duì)比誤差會(huì)導(dǎo)致模型過(guò)擬合成AE一樣的無(wú)生成能力的模型谜嫉。因此,我們使用與VAE類似的變分推斷的方法凹联,希望網(wǎng)絡(luò)輸出的
盡量接近由真實(shí)
變化而來(lái)的
的分布沐兰,即最小化似然
與真實(shí)的
的
。于是每一個(gè)時(shí)間步驟
的誤差可以定義為:
而當(dāng)時(shí)蔽挠,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=q(%5Cmathbf%7Bx%7D*0)" alt="q(\mathbf{x}*0)" mathimg="1">是確定的住闯,因此可以忽略這部分,故而
,因此
于是整個(gè)去噪過(guò)程的誤差就是: 澳淑。實(shí)際訓(xùn)練時(shí)比原,我們并沒(méi)有使用整體的誤差
,而是通過(guò)均勻隨機(jī)選擇
杠巡,來(lái)最小化
量窘。
目標(biāo)函數(shù)
要直接計(jì)算上面的KL散度是困難的,但是正如前面所說(shuō)的氢拥, 是一個(gè)高斯分布蚌铜,于是根據(jù)貝葉斯公式有:
其中 代表所有剩余與
無(wú)關(guān)的項(xiàng)。
根據(jù)高斯分布的基本方程:
與上述的推導(dǎo)結(jié)果位置依次對(duì)應(yīng)可得其方差和均值為:
根據(jù)上面的推導(dǎo)1可得 ,帶入上式可得:
最小化上述的KL散度嫩海,可以轉(zhuǎn)化為計(jì)算神經(jīng)網(wǎng)絡(luò)的預(yù)測(cè)的均值方差與上述均值方差的L2損失:
DDPM的論文作者在論文中說(shuō)他使用一個(gè)固定的方差取得了差不多的效果冬殃,因此他的神經(jīng)網(wǎng)絡(luò)只去學(xué)習(xí)了均值,而把方差設(shè)置成了
或者是
出革,因此我們接下來(lái)的推導(dǎo)也只考慮均值造壮。后來(lái)Improved diffusion models 這篇論文將其改進(jìn)后就讓神經(jīng)網(wǎng)絡(luò)同時(shí)去學(xué)習(xí)均值和方差了,有興趣的同學(xué)可以自行去了解骂束。
觀察上面的耳璧,除了
,其余項(xiàng)均為固定值與
無(wú)關(guān),于是我們不妨將神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)目標(biāo)從高斯分布的均值轉(zhuǎn)變?yōu)?img class="math-inline" src="https://math.jianshu.com/math?formula=%5Cepsilon_%5Ctheta" alt="\epsilon_\theta" mathimg="1"> 展箱,即去預(yù)測(cè)每個(gè)事件步的噪聲量而非高斯分布的均值旨枯,因此我們最終的目標(biāo)函數(shù)就變成了:
然后,整個(gè)訓(xùn)練算法便是這樣一個(gè)過(guò)程:
- 從真實(shí)的復(fù)雜未知分布
隨機(jī)抽取一個(gè)樣本
- 從
到
均勻采樣一個(gè)時(shí)間步
- 從均值為
方差為
的標(biāo)準(zhǔn)高斯分布中隨機(jī)采樣一個(gè)
- 計(jì)算隨機(jī)梯度
混驰,并通過(guò)隨機(jī)梯度下降優(yōu)化
- 重復(fù)上述過(guò)程直到收斂
采樣生成
當(dāng)上述的神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)好 , 就可以計(jì)算出均值
,于是我們就可以從一個(gè)隨機(jī)高斯噪聲
攀隔,通過(guò)條件去噪概率
進(jìn)行采樣生成皂贩,逐步從
到
。
具體來(lái)說(shuō)Sampling是這樣一個(gè)過(guò)程:
隨機(jī)采樣一個(gè)
-
令
昆汹,依次執(zhí)行:
返回最終的
網(wǎng)絡(luò)結(jié)構(gòu)
雖然有了訓(xùn)練的方案明刷,但是如何來(lái)設(shè)計(jì)這個(gè)神經(jīng)網(wǎng)絡(luò)才能讓我們這個(gè)擴(kuò)散和反擴(kuò)散的過(guò)程取得較好的效果呢?DDPM的作者選擇了U-Net 满粗,并且在實(shí)驗(yàn)中取得了很好的效果辈末。
這個(gè)用于學(xué)習(xí)的U-Net網(wǎng)絡(luò)十分復(fù)雜,由一系列的諸如下采樣映皆、上采樣挤聘、殘差、位置Embedding捅彻、ResNet/ConvNeXT block组去、注意力模塊、Group Normalization等組件組合而成步淹,為了讓大家了解整個(gè)網(wǎng)絡(luò)各個(gè)組件的具體結(jié)構(gòu)和連接方式从隆,我繪制了一個(gè)詳細(xì)的網(wǎng)絡(luò)圖:
根據(jù)這個(gè)圖,我們可以用tensorflow或者pytorch非常輕松的實(shí)現(xiàn)這個(gè)網(wǎng)絡(luò)贤旷。不過(guò)顯然這個(gè)網(wǎng)絡(luò)很大广料,特別是圖片很大時(shí)占用的顯存會(huì)很高,而且采樣步驟多推理也很慢幼驶,因此后面有很多對(duì)于DDPM的改進(jìn),篇幅關(guān)系韧衣,關(guān)于對(duì)DDPM的改進(jìn)我們下篇文章再講盅藻。