知乎上有個(gè)討論坏挠,說(shuō)學(xué)數(shù)學(xué)的看不起搞深度學(xué)習(xí)的芍躏。曲直對(duì)錯(cuò)不論,他們看不起搞深度學(xué)習(xí)的原因很簡(jiǎn)單降狠,因?yàn)閺臄?shù)學(xué)的角度看对竣,深度學(xué)習(xí)僅僅是一個(gè)最優(yōu)化問(wèn)題而已。比如榜配,被炒的很熱的對(duì)抗式生成網(wǎng)絡(luò)(GAN)否纬,從數(shù)學(xué)看,基本原理很容易就能說(shuō)明白蛋褥,剩下的僅僅是需要計(jì)算資源去優(yōu)化參數(shù)临燃,是個(gè)體力活。
本文的目的就是盡可能簡(jiǎn)單地從數(shù)學(xué)角度解釋清楚GAN的數(shù)學(xué)原理烙心,看清它的廬山真面目膜廊。
1,從生成模型說(shuō)起
機(jī)器學(xué)習(xí)的模型可分為生成模型和判別模型淫茵。
簡(jiǎn)單說(shuō)說(shuō)二者的區(qū)別爪瓜,以二分類問(wèn)題來(lái)講,已知一個(gè)樣本的特征為x匙瘪,我們要去判斷它的類別y(取值為0,1)铆铆。也就是要計(jì)算p(ylx),假設(shè)我們已經(jīng)有了N個(gè)樣本丹喻。
計(jì)算p(ylx)的思路有兩個(gè)薄货,一個(gè)就是直接求,即y關(guān)于x的條件分布驻啤,邏輯回歸就是這樣干的菲驴。凡是按這種思路做的模型統(tǒng)一稱作判別模型
另一個(gè)思路是曲線救國(guó),先求出x,y的聯(lián)合分布p(x,y)骑冗,然后赊瞬,根據(jù)p(ylx)=p(x,y)/p(x)來(lái)計(jì)算p(ylx)先煎,樸素貝葉斯就是這樣干的。凡是按這種思路做的模型統(tǒng)一稱作生成模型巧涧。
那么薯蝎,GAN屬于什么呢?答案是生成模型谤绳,因?yàn)樗彩窃谀M數(shù)據(jù)的分布占锯。
舉個(gè)例子,我們有N張尺寸為50*50的小貓的照片缩筛,先不管這些貓多可愛消略,多漂亮,從數(shù)學(xué)的角度看瞎抛,一張貓的圖片可以理解為2500維空間中的一個(gè)點(diǎn)或者說(shuō)是一個(gè)2500維的向量艺演,N張圖就是N個(gè)點(diǎn)或者說(shuō)向量。如果我們想讓神經(jīng)網(wǎng)絡(luò)自動(dòng)生成一個(gè)小貓的圖片桐臊,所要做的就是假設(shè)N張圖片都是某個(gè)2500維空間中的隨機(jī)分布的樣本胎撤,這個(gè)分布抽樣產(chǎn)生的點(diǎn)就是一張貓的圖片。
這里有一個(gè)問(wèn)題断凶,需要再思考一下伤提,我們假設(shè)的這個(gè)分布存在嗎?答案是存在的认烁,因?yàn)檫@是GAN的理論基礎(chǔ)肿男,沒(méi)有這個(gè)假設(shè),GAN就玩不下去了砚著。具體的證明過(guò)于數(shù)學(xué)了次伶,此刻,我們只要相信它存在就是了稽穆。
假設(shè)圖片背后隱藏的分布是Pd(d就是data)冠王,我們的任務(wù)就是用神經(jīng)網(wǎng)絡(luò)生成一個(gè)分布Pg(g就是generate),只要Pg和Pd很接近舌镶,就可以用Pg生成一張小貓的圖片了柱彻。
2,GAN解決的基本問(wèn)題
根據(jù)上面的討論餐胀,我們要解決兩個(gè)問(wèn)題:
一是如何用神經(jīng)網(wǎng)絡(luò)構(gòu)造一個(gè)模擬分布Pg哟楷,
另一個(gè)是如何衡量Pg和Pd是否相似,并根據(jù)衡量結(jié)果去優(yōu)化Pg
這就是GAN解決的最根本的兩個(gè)問(wèn)題否灾。
3卖擅,GAN是如何解決這兩個(gè)問(wèn)題的?
第一個(gè)問(wèn)題很容易解決,以上面的小貓問(wèn)題為例惩阶,只要做一個(gè)神經(jīng)網(wǎng)絡(luò)挎狸,它的輸入是來(lái)自某個(gè)特定分布的數(shù),為便于說(shuō)明断楷,我們就假設(shè)這個(gè)特定分布是一維的锨匆,也就是它產(chǎn)生的數(shù)就是一個(gè)標(biāo)量,如1冬筒,2.5之類的恐锣,經(jīng)過(guò)多層的映射后,產(chǎn)生一個(gè)2500維的向量舞痰⊥亮瘢可以想象,只要輸入來(lái)自一個(gè)特定的分布响牛,映射產(chǎn)生的2500維向量也會(huì)形成一個(gè)分布鞭衩。這個(gè)分布的概率密度函數(shù)就是輸入分布的概率密度函數(shù)在2500維空間的擴(kuò)展。
我們來(lái)舉個(gè)例子娃善,為了說(shuō)明問(wèn)題,這里假設(shè)神經(jīng)網(wǎng)絡(luò)映射后也是一個(gè)一維的數(shù)瑞佩,不再是上面例子中的2500維向量聚磺,但我們要清楚,其原理是一樣的炬丸。
假設(shè)我們的輸入分布是一個(gè)標(biāo)準(zhǔn)正態(tài)分布瘫寝,其概率密度函數(shù)為
我們的神經(jīng)網(wǎng)絡(luò)就相當(dāng)于一個(gè)函數(shù)y=g(x),現(xiàn)在已知x的概率密度函數(shù)如上所示稠炬,那么y的概率密度函數(shù)會(huì)是什么樣子呢焕阿?
沒(méi)有什么是一個(gè)例子說(shuō)明不了的,如果有首启,那就兩個(gè)暮屡。
假設(shè)g就是一個(gè)等值映射,即g(x)=x毅桃,顯然y與x的概率密度函數(shù)一樣褒纲,也是標(biāo)準(zhǔn)正態(tài)分布。
如果g(x)=x+1钥飞,那么莺掠,y的概率密度函數(shù)就是均值為1,方差為1的正態(tài)分布读宙。
可見彻秆,神經(jīng)網(wǎng)絡(luò)對(duì)某個(gè)特定分布x映射后的y確實(shí)是一個(gè)分布,其概率密度函數(shù)既與x的概率密度函數(shù)相關(guān),又與神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)g函數(shù)相關(guān)唇兑。這里說(shuō)的網(wǎng)絡(luò)結(jié)構(gòu)有兩層含義酒朵,一個(gè)網(wǎng)絡(luò)的內(nèi)部構(gòu)造,另一個(gè)就是網(wǎng)絡(luò)的輸出向量維數(shù)幔亥。我們雖然只舉了輸出為1維的情況耻讽,但在輸出是多維時(shí),y的分布會(huì)有一些新的特點(diǎn)帕棉。
比如针肥,一個(gè)有趣的問(wèn)題是,若原始的輸入是一維的正態(tài)分布香伴,其取值空間是整個(gè)一維實(shí)數(shù)空間慰枕,如果輸出的是多維向量分布,那么即纲,它的取值空間還會(huì)是整個(gè)多維空間嗎具帮?
這個(gè)問(wèn)題讓我聯(lián)想到《三體》,太陽(yáng)系遭受二向箔攻擊后低斋,變成了二維空間蜂厅,逃離的人類,開始尋找更高級(jí)的宇宙文明膊畴,期望有一天掘猿,能夠?qū)⒍S化的地球還原成三維的世界。
第二個(gè)問(wèn)題如何解決呢唇跨?
回到剛才小貓的問(wèn)題稠通,假設(shè)我們有一萬(wàn)個(gè)樣本圖片,也有了一個(gè)生成分布的神經(jīng)網(wǎng)絡(luò)G买猖,它的輸入是一個(gè)正態(tài)分布改橘。輸出是2500維的向量。
從數(shù)學(xué)上講玉控,我們的問(wèn)題是飞主,如何衡量Pd,Pg的相似性奸远,以及如何使Pg接近Pd既棺。
具體的方法不用我們費(fèi)心想了,我們直接看現(xiàn)成的就可以懒叛。
即再定義一個(gè)NN丸冕,叫做D,D(x)產(chǎn)生的是0到1之間的一個(gè)值
定義函數(shù)
V(G, D) = E_(x~Pd)(log(D(x)) + E_(x~Pg)(log(1-D(x))
對(duì)于一個(gè)特定的G薛窥,通過(guò)調(diào)整D胖烛,可以得到maxV(G,D)眼姐,它就能衡量Pd和Pg的difference,這個(gè)值越小佩番, 二者差距越小众旗。
為什么maxV(G,D)可以衡量?jī)蓚€(gè)分布的差異?
我們先直接觀察V趟畏,把D看作一個(gè)判別器贡歧,V的第一部分表示D對(duì)來(lái)自真實(shí)分布的數(shù)據(jù)的評(píng)分的期望,第二部分表示D對(duì)來(lái)自G生成的數(shù)據(jù)的評(píng)分與1的差的期望赋秀。
最大化V利朵,就是要使D對(duì)來(lái)自真實(shí)數(shù)據(jù)的評(píng)分盡可能高,對(duì)來(lái)自G生成的數(shù)據(jù)的評(píng)分盡可能低(即讓1-D(x)盡可能高)猎莲。
至此绍弟,我們看到,V越大著洼,表示D對(duì)來(lái)自真實(shí)分布和來(lái)自G的數(shù)據(jù)評(píng)分差異越大樟遣,真實(shí)的樣本越接近1,G產(chǎn)生的樣本越接近0身笤。
通過(guò)調(diào)整D豹悬,得到的maxV,表示在我們能力范圍內(nèi)液荸,我們找到的最強(qiáng)的D屿衅,它能把Pd和Pg產(chǎn)生的數(shù)據(jù)區(qū)分開的程度。自然莹弊,這個(gè)區(qū)分程度就表示了Pd和Pg的差異程度。因?yàn)榭梢韵胂笪谐荆琍d和Pg越接近忍弛,同樣的D,得到的V肯定越小考抄。因?yàn)槎弋a(chǎn)生的數(shù)據(jù)越難區(qū)分開细疚。
以上,是通過(guò)直觀的分析得出的一些認(rèn)識(shí)川梅,我們還可以從數(shù)學(xué)上進(jìn)行一些分析疯兼。
對(duì)V做一些變形,可以得到如下
V = Σ_x(Pd(x)*log(D(x)) + Pg(x)*log(1-D(x))
maxV時(shí)贫途,我們只要看對(duì)于一個(gè)x吧彪,如何最大化這個(gè)值即可:
Pd(x)*log(D(x)) + Pg(x)*log(1-D(x))
這里,除了D是變化的丢早,其他都是已經(jīng)給定的姨裸,
令a=Pd(x)秧倾,D=D(x), b = Pg(x)
a傀缩,b都是一個(gè)確定的數(shù)值那先,因?yàn)閤此時(shí)是確定的,上式進(jìn)一步變成這樣
alogD + blog(1-D)
對(duì)這個(gè)式子赡艰,求導(dǎo)售淡,很容易得出最佳的D值是:
D^* = a/(a+b) 可見,D的值確實(shí)是0到1之間的值
再將D^*帶入最上面的V的定義中慷垮,經(jīng)過(guò)一系列的計(jì)算推導(dǎo)(此處省略了推導(dǎo)過(guò)程揖闸,感興趣的小伙伴可以自己推導(dǎo)一下),我們可以得出maxV實(shí)際表示的是Pd
和(Pd+Pg)/2的KL散度加上Pg和(Pd+Pg)/2的KL散度
如果令g = (Pd+Pg)/2换帜,則
maxV = -2log2 + KL(Pd||g) + KL(Pg||g)
如果有小伙伴對(duì)KL散度不是很理解楔壤,推薦看看我的另一篇文章《機(jī)器學(xué)習(xí)面試之各種混亂的熵》,里面有一個(gè)很通俗易懂的解釋惯驼。
這里有一個(gè)數(shù)學(xué)定義:
JSD(Pg||Pd)=1/2*(KL(Pd||g) + KL(Pg||g))
根據(jù)定義蹲嚣,maxV就變成這樣了:
maxV = -2log2 + 2JSD(Pg||Pd)
JSD全稱是Jessen-Sannon Divergence,它是一個(gè)對(duì)稱化的KL散度祟牲。
從數(shù)學(xué)上可以證明隙畜,JSD最大值是log2,最小值是0
所以maxV最大值是0说贝,最小值是-2log2议惰。
至此,我們的第二個(gè)問(wèn)題已經(jīng)解決了一半乡恕,將Pd和Pg的差異量化成了一個(gè)-2log2到0之間的一個(gè)數(shù)值言询。
另一半問(wèn)題就是如何針對(duì)求出的最大化V的D,更新調(diào)整G傲宜,使得maxV變小一點(diǎn)运杭。
我們將maxV定義為L(zhǎng)(G),即
L(G)=maxV
可見函卒,maxV實(shí)際上就是G的損失函數(shù)辆憔,可以用梯度下降更新神經(jīng)網(wǎng)絡(luò)G的各個(gè)參數(shù)得到G1,使得maxV下降报嵌。
有了G1后虱咧,我們?cè)俅沃貜?fù)上面的最大化V的過(guò)程,得到D1锚国,然后再次運(yùn)用梯度下降腕巡,得到G2,如此循環(huán)下去血筑,可以期待逸雹,Pg將會(huì)越來(lái)越接近Pd营搅。
聲明:上圖來(lái)自李宏毅深度學(xué)習(xí)課程截圖,如侵刪梆砸。
這是從理論上證明的转质,如果你感覺(jué)看得比較暈乎,沒(méi)關(guān)系帖世,我們下面就來(lái)看看實(shí)踐中的做法休蟹。
4,實(shí)踐中的做法
以上的推導(dǎo)都是從理論上進(jìn)行的分析日矫,在工程實(shí)踐中赂弓,有個(gè)問(wèn)題必須解決:
我們并不能真正知道Pd,我們只是有很多Pd產(chǎn)生的樣本哪轿,自然也求不出logD(x)關(guān)于Pd的期望盈魁,對(duì)于Pg的期望也有類似困難。
那么窃诉,實(shí)踐中是怎么做的呢杨耙?
首先,既然直接計(jì)算期望而不可得飘痛,我們只能退而求其次珊膜,從Pd中sample 出m個(gè)樣本,再?gòu)腜g中sample出m個(gè)樣本宣脉,最大化D(x)在這2m個(gè)樣本的V值车柠,如下圖所示:
仔細(xì)觀察上式,最大化V的過(guò)程塑猖,就是將Pd產(chǎn)生的樣本作為正樣本竹祷,Pg產(chǎn)生的樣本作為負(fù)樣本,訓(xùn)練一個(gè)分類器D的過(guò)程羊苟。
完整的算法如下所示:
聲明:該圖同樣來(lái)自李宏毅深度學(xué)習(xí)課程溶褪,如侵刪。
另外践险,在工程實(shí)踐中,還有一個(gè)實(shí)際的調(diào)整吹菱,就是將V蚯蚓中的log(1-D(x))改為-log(D(x))
原因是:
D的值在0到1之間巍虫,
在剛開始時(shí),由于G生成的圖像比較挫鳍刷,D可以很容易分辨出來(lái)占遥,也就是生成的值都遠(yuǎn)小于0.5,V蚯蚓關(guān)于G的參數(shù)θg的梯度等于 log(1-D(x))關(guān)于D的梯度再乘以D關(guān)于θg的梯度输瓜,但是此時(shí)瓦胎,log(1-D(x))在0到0.5之間的梯度比較小芬萍,這就會(huì)導(dǎo)致θg的更新比較慢,但是換成-log(D(x))后搔啊,它在0到0.5之間的梯度是比較大的柬祠,可以快速的更新g的參數(shù)。這符合我們的直覺(jué)负芋,就是一開始學(xué)習(xí)要快一點(diǎn)漫蛔,后面要慢一點(diǎn)。
5旧蛾,無(wú)總結(jié)莽龟,不進(jìn)步
縱觀整個(gè)GAN,本來(lái)我們要求Pg和Pd之間的相似度锨天,由于不能直接求毯盈,轉(zhuǎn)而借助一個(gè)D,maxV求出一個(gè)最佳的D后病袄,maxV就是在衡量Pg和Pd的JS 散度搂赋,然后,最小化這個(gè)散度值陪拘,更新一次Pg厂镇,有了新的Pg后,進(jìn)一步求出最佳的D左刽,然后重復(fù)上面的步驟捺信。