前 言
這是關(guān)于使用tensorflow來實(shí)現(xiàn)Goodfellow的生成對抗網(wǎng)絡(luò)論文的教程夫壁。對抗網(wǎng)絡(luò)是一個可以使用大約80行的python代碼就可以實(shí)現(xiàn)的一個有趣的小深度學(xué)習(xí)練習(xí),這將使你進(jìn)入深度學(xué)習(xí)的一個活躍領(lǐng)域:生成式模型。
對抗網(wǎng)絡(luò)論文地址:https://arxiv.org/abs/1406.2661
Github上的源碼地址:https://github.com/ericjang/genadv_tutorial/blob/master/genadv1.ipynb
章節(jié)目錄
情景:假幣
背景:判別模型vs生成模型
生成對抗網(wǎng)絡(luò)
實(shí)現(xiàn)
流形對齊
預(yù)處理判別模型
其他的棘手問題建議
結(jié)果
附錄
情景:假幣
為了更好地解釋這篇論文的動機(jī)丽惭,這里提供一個假設(shè)場景: Danielle是一個銀行的出納員陨享,她的工作職責(zé)之一就是辨別真幣與假幣叮盘。George是一個制造假幣的騙子厢汹,因?yàn)槊赓M(fèi)的錢相當(dāng)激進(jìn)。
讓我們簡化一下:假定貨幣的唯一顯著特征就是印在每個鈔票上的唯一編號X塌西。這些編碼是一個概率分布的隨機(jī)抽樣他挎,其中密度函數(shù)pdata只有國家財(cái)政部知道(這意味著Danielle與George都不知道)。方便起見捡需,這個教程使用pdata同時指代這個分布與它的概率密度函數(shù)(盡管從本質(zhì)上說概率分布與它的密度函數(shù)并不相同)办桨。
George的目標(biāo)是從pdata生成樣例x′,所以他制造的假幣與真幣難以區(qū)分站辉。你可能會問:George事先并不知道pdata呢撞,他怎么能從pdata采樣损姜?
我們可以在不知道真實(shí)的潛在生成過程的情況下制造出計(jì)算不可區(qū)分的樣例[1]。這個潛在的生成過程是財(cái)政部所使用的生成樣例X的方法-也許是從pdata
抽樣的一些有效算法殊霞,這些算法依賴于概率密度分布的解析式摧阅。
我們可以將這種算法看做“自然(函數(shù))基”,財(cái)政部將使用這種直接的方法來印制我們假設(shè)的鈔票绷蹲。然而棒卷,一個(連續(xù))函數(shù)可以用一系列不同的基函數(shù)來表征;George 可以使用“神經(jīng)網(wǎng)絡(luò)基”祝钢,“傅里葉基”或者其它的能用來構(gòu)建近似器的基來表示相同的抽樣算法比规。從局外人的角度來看,這些抽樣器是計(jì)算上不可區(qū)分的拦英,然而 George的模型并沒有將pdata的自然抽樣基或者解析式泄露給他蜒什。
背景:判別模型vs生成模型
我們使用X和Y代表“觀測”和“目標(biāo)”隨機(jī)變量。X和Y的聯(lián)合分布為P(X,Y)疤估,我們可以將其看做兩變量(可能相關(guān))的概率密度函數(shù)灾常。
一個判別式模型可以用來評估條件概率P(Y|X)。例如铃拇,給定一個代表像素點(diǎn)的向量x钞瀑,那么Y=6的概率是多少?(6代表是虎斑貓的類別標(biāo)簽)锚贱。 MNIST LeNet, AlexNet和其他的分類器都是判別式模型的實(shí)例。
另外一方面关串,一個生成式模型可以用來估計(jì)聯(lián)合分布P(X,Y)拧廊。這意味著我們可以選取(X,Y)值對,然后使用舍取抽樣法來從P(X,Y)來獲得樣例x,y晋修。使用正確的生成模型的另外一種方式吧碾,我們可以將一些分布在[0,1]上的隨機(jī)值轉(zhuǎn)化為一個兔子圖。這會很有趣墓卦。
當(dāng)然倦春,生成模型比判別模型更難構(gòu)建,這兩者都是統(tǒng)計(jì)學(xué)與機(jī)器學(xué)習(xí)研究的熱點(diǎn)領(lǐng)域落剪。
生成對抗網(wǎng)絡(luò)
條件隨機(jī)場(Conditional Random Field睁本,簡稱CRF)是一種判別式無向圖模型。生成式模型是直接對聯(lián)合分布進(jìn)行建模忠怖,而判別式模型則是對條件分布進(jìn)行建模呢堰。前面介紹的隱馬爾可夫模型和馬爾可夫隨機(jī)場都是生成式模型,而條件隨機(jī)場是判別式模型凡泣。
Goodfellow的論文提出了一個優(yōu)雅的方式來將神經(jīng)網(wǎng)絡(luò)訓(xùn)練成一個可以表示任何(連續(xù))概率密度函數(shù)的生成模型枉疼。我們構(gòu)建兩個神經(jīng)網(wǎng)絡(luò)皮假,分別是D(Danielle)和G(George),然后使用它們來玩一個對抗式的貓捉老鼠的游戲:G是一個生成器骂维,它嘗試著從pdata生成偽樣例惹资;而D是一個決策器,它試著不會被騙航闺。我們同時訓(xùn)練它們褪测,所以它們將在相互抗?fàn)幹谢ハ嗟玫教岣摺.?dāng)收斂時来颤,我們希望G能夠?qū)W會完全從pdata抽樣汰扭,此時D(x)=0.5(對于真假二分類來講,這個概率等于瞎猜)福铅。
對抗網(wǎng)絡(luò)已經(jīng)成功地用來憑空合成下列類型的圖片:
- 貓
-
教堂
在這個教程里萝毛,我們不會做任何很神奇的東西,但是希望你將會對對抗網(wǎng)絡(luò)有一個更基本的了解滑黔。
實(shí)現(xiàn)
我們將訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)用來從簡單的一維正態(tài)分布N(?1,1)中抽樣
這里D,G都是小的3層感知機(jī)笆包,每層總共有稀薄的11個隱含單元。G的輸入是一個噪音分布z~uniform(0,1)中的單個樣例略荡。我們想使用G來將點(diǎn)z1,z2,...zM映射為x1,x2,...xM庵佣,這樣映射的點(diǎn)xi=G(zi)在pdata(X)密集的地方會密集聚集。因此汛兜,在G中輸入z將生成偽數(shù)據(jù)x′巴粪。
同時,判別器D粥谬,以x為輸入肛根,然后輸出該輸入屬于pdata的可能性。令D1和D2為D的副本(它們共享參數(shù)漏策,那么D1(x)=D2(x))派哲。D1的輸入是從合法的數(shù)據(jù)分布x~pdata中得到的單個樣例,所以當(dāng)優(yōu)化判別器時我們想使D1(x)最大化掺喻。D2以x′(G生成的偽數(shù)據(jù))為輸入芭届,所以當(dāng)優(yōu)化D時,我們想使D2(x)最小化感耙。D的損失函數(shù)為: log(D1(x))+log(1?D2(G(z)))
里是Python代碼:
batch = tf.Variable(0)
obj_d = tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d = tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_d,global_step=batch,var_list=theta_d)
我們之所以要不厭其煩地指定D的兩個副本D1和D2褂乍,是因?yàn)樵趖ensorflow中,我們需要D的一個副本以x為輸入即硼,而另外一個副本以G(z)為輸入树叽;計(jì)算圖的相同部分不能被重用于不同的輸入。
當(dāng)優(yōu)化G時谦絮,我們想使D2(X′)最大化(成功騙過D)题诵。G的損失函數(shù)為: log(D2(G(z)))
batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_g,global_step=batch,var_list=theta_g)
在優(yōu)化時我們不是僅在某一刻輸入一個值對(x,z)洁仗,而是同時計(jì)算M個不同的值對(x,z)的損失梯度,然后用其平均值來更新梯度性锭。從一個小批量樣本中估計(jì)的隨機(jī)梯度與整個訓(xùn)練樣本的真實(shí)梯度非常接近赠潦。
訓(xùn)練的循環(huán)過程是非常簡單的:
# Algorithm 1, GoodFellow et al. 2014for i in range(TRAIN_ITERS):
x= np.random.normal(mu,sigma,M) # sample minibatch from p_data
z= np.random.random(M) # sample minibatch from noise prior
sess.run(opt_d, {x_node: x, z_node: z}) # update discriminator D
z= np.random.random(M) # sample noise prior
sess.run(opt_g, {z_node: z}) # update generator G
流形對齊
簡單地用上面的方法并不能得到好結(jié)果,因?yàn)槊看蔚形覀兪仟?dú)立地從pdata和uniform(0,1)中抽樣草冈。這并不能使得Z范圍中的鄰近點(diǎn)能夠映射到X范圍中的鄰近點(diǎn)她奥;在某一小批量訓(xùn)練中,我們可能在訓(xùn)練G中發(fā)生下面的映射:0.501→?1.1怎棱,0.502→0.01 和0.503→?1.11哩俭。映射線相互交叉很多,這將使轉(zhuǎn)化非常不平穩(wěn)拳恋。更糟糕的是凡资,接下來的小批量訓(xùn)練中,可能發(fā)生不同的映射:0.5015→1.1谬运,0.5025→?1.1 和0.504→1.01隙赁。這表明G進(jìn)行了一個與前面的小批量訓(xùn)練中完全不同的映射,因此優(yōu)化器不會得到收斂梆暖。
為了解決這個問題伞访,我們想最小化從Z到X的映射線的總長,因?yàn)檫@將使轉(zhuǎn)換盡可能的平順轰驳,而且更加容易學(xué)習(xí)厚掷。 另外一種說法是中將Z轉(zhuǎn)化到X的“向量叢”在小批量訓(xùn)練中要相互關(guān)聯(lián)。
首先级解,我們將Z的區(qū)域拉伸到與X區(qū)域的大小相同冒黑。以?1為中心點(diǎn)的正態(tài)分布其主要概率分布在[?5,5]范圍內(nèi),所以我們應(yīng)該從uniform[?5,5]來抽樣Z蠕趁。這樣處理后G模型就不需要學(xué)習(xí)如何將[0,1]區(qū)域拉伸10倍薛闪。G模型需要學(xué)習(xí)的越少辛馆,越好俺陋。接下來,我們將通過由低到高排序的方式使每個小批量中的Z與X對齊昙篙。
這里我們不是采用 np.random.random.sort()的方法來抽樣Z腊状,而是采用分層抽樣的方式-我們在抽樣范圍內(nèi)產(chǎn)生M個等距點(diǎn),然后隨機(jī)擾動它們苔可。這樣處理得到的樣本不僅保證其大小順序缴挖,而且可以增加在整個訓(xùn)練空間的代表性。我們接著匹配之前的分層焚辅,即排序的Z樣本對其排序的X樣本映屋。
當(dāng)然苟鸯,對于高維問題,由于在二維或者更高維空間里面對點(diǎn)排序并無意義棚点,所以對其輸入空間Z與目標(biāo)空間X并不容易早处。然而,最小化Z與X流形之間的轉(zhuǎn)化距離仍然有意義[2]瘫析。
修改的算法如下:
for i in range(TRAIN_ITERS):
x= np.random.normal(mu,sigma,M).sort()
z= np.linspace(-5.,5.,M)+np.random.random(M)*.01 # stratified
sess.run(opt_d, {x_node: x, z_node: z})
z= np.linspace(-5.,5.,M)+np.random.random(M)*.01
sess.run(opt_g, {z_node: z})
這是使這個例子有效的很關(guān)鍵一步:當(dāng)使用隨機(jī)噪音作為輸入時砌梆,未能正確地對齊轉(zhuǎn)化映射線將會產(chǎn)生一系列其它問題,如過大的梯度很早地關(guān)閉ReLU神經(jīng)元贬循,目標(biāo)函數(shù)停滯咸包,或者性能不能隨著批量大小縮放。
預(yù)處理判別模型
在原始的算法中杖虾,GAN是每次通過梯度下降訓(xùn)練D模型k步烂瘫,然后訓(xùn)練G一步。但是這里發(fā)現(xiàn)在訓(xùn)練對抗網(wǎng)絡(luò)之前亏掀,先對D預(yù)訓(xùn)練很多步更有用忱反,這里使用二次代價函數(shù)對D進(jìn)行預(yù)訓(xùn)練使其適應(yīng)pdata。這個代價函數(shù)相比對數(shù)似然代價函數(shù)更容易優(yōu)化(后者還要處理來自G的生成樣本)滤愕。很顯然pdata就是其自身分布的最優(yōu)可能性決定邊界温算。
這里是初始的決定邊界:
預(yù)訓(xùn)練之后:
已經(jīng)非常接近了,竊喜间影!
其他的棘手問題建議
模型過大容易導(dǎo)致過擬合注竿,但是在這個例子中,網(wǎng)絡(luò)過大在極小極大目標(biāo)下甚至不會收斂-神經(jīng)元在很大的梯度下很快達(dá)到飽和魂贬。從淺層的小網(wǎng)絡(luò)始巩割,除非你覺得有必要再去增加額外的神經(jīng)元或者隱含層。
剛開始我使用的是ReLU神經(jīng)元付燥,但是這種神經(jīng)元一直處于飽和狀態(tài)(也許由于流形對齊問題)宣谈。Tanh激活函數(shù)好像更有效。
我必須要調(diào)整學(xué)習(xí)速率才能得到很好的結(jié)果键科。
結(jié)果
下面是訓(xùn)練之前的pdata闻丑,預(yù)訓(xùn)練后的D的決定邊界以及生成分布pg:
這是代價函數(shù)在訓(xùn)練迭代過程中的變化曲線:
訓(xùn)練之后,pg接近pdata勋颖,判別器也基對所有X一視同仁(D=0.5):
這事就完成了訓(xùn)練過程嗦嗡。G已經(jīng)學(xué)會如何從pdata中近似抽樣,以至于D已經(jīng)無法偽數(shù)據(jù)中分離出真數(shù)據(jù)饭玲。
附言
這里是一個關(guān)于計(jì)算不可分性的更生動例子:假設(shè)我們在訓(xùn)練一個超級大的神經(jīng)網(wǎng)絡(luò)來從貓臉分布中抽樣侥祭。真實(shí)貓臉的隱含(生成)數(shù)據(jù)分布包含:1)一只正在出生的貓,2)某人最終拍下了這個貓的照片。顯然矮冬,我們的神經(jīng)網(wǎng)絡(luò)并不是要學(xué)習(xí)這個特殊的生成過程谈宛,因?yàn)檫@個過程并沒有涉及真實(shí)的貓。然而胎署,如果我們的網(wǎng)絡(luò)能夠產(chǎn)生無法與真實(shí)的貓圖片相區(qū)分的圖片(在多項(xiàng)式時間計(jì)算資源內(nèi))那么從某種意義上說這些照片與正常的貓照片一樣合法入挣。在圖靈測試,密碼學(xué)與假勞力士的背景下硝拧,這值得深思径筏。
可以從過擬合的角度看待過量的映射線交叉,學(xué)習(xí)到的預(yù)測或者判別函數(shù)已經(jīng)被樣本數(shù)據(jù)以一種“矛盾”的方式扭曲了(例如障陶,一張貓的照片被分類為狗)滋恬。正則化方法可以間接地防止過多的“映射交叉”,但是沒有顯式地使用排序算法來確保學(xué)習(xí)到的從Z空間到X空間的映射轉(zhuǎn)化是連續(xù)或者對齊的抱究。這種排序機(jī)制也許對于提升訓(xùn)練速度非常有效恢氯。。鼓寺。