生成對(duì)抗網(wǎng)絡(luò)(GAN)是一種強(qiáng)大的生成模型载城,但是自從2014年Ian Goodfellow提出以來,GAN就存在訓(xùn)練不穩(wěn)定的問題戚宦。最近提出的 Wasserstein GAN(WGAN)在訓(xùn)練穩(wěn)定性上有極大的進(jìn)步个曙,但是在某些設(shè)定下仍存在生成低質(zhì)量的樣本,或者不能收斂等問題受楼。
近日,蒙特利爾大學(xué)的研究者們?cè)赪GAN的訓(xùn)練上又有了新的進(jìn)展呼寸,他們將論文《Improved Training of Wasserstein GANs》發(fā)布在了arXiv上艳汽。研究者們發(fā)現(xiàn)失敗的案例通常是由在WGAN中使用權(quán)重剪枝來對(duì)critic實(shí)施Lipschitz約束導(dǎo)致的。在本片論文中对雪,研究者們提出了一種替代權(quán)重剪枝實(shí)施Lipschitz約束的方法:懲罰critic對(duì)輸入的梯度河狐。該方法收斂速度更快,并能夠生成比權(quán)重剪枝的WGAN更高質(zhì)量的樣本瑟捣。
以下為雷鋒網(wǎng)AI科技評(píng)論據(jù)論文內(nèi)容進(jìn)行的部分編譯馋艺。
論文摘要
生成對(duì)抗網(wǎng)絡(luò)(GAN)將生成問題當(dāng)作兩個(gè)對(duì)抗網(wǎng)絡(luò)的博弈:生成網(wǎng)絡(luò)從給定噪聲中產(chǎn)生合成數(shù)據(jù),判別網(wǎng)絡(luò)分辨生成器的的輸出和真實(shí)數(shù)據(jù)迈套。GAN可以生成視覺上吸引人的圖片捐祠,但是網(wǎng)絡(luò)通常很難訓(xùn)練。前段時(shí)間桑李,Arjovsky等研究者對(duì)GAN值函數(shù)的收斂性進(jìn)行了深入的分析踱蛀,并提出了Wasserstein GAN(WGAN),利用Wasserstein距離產(chǎn)生一個(gè)比Jensen-Shannon發(fā)散值函數(shù)有更好的理論上的性質(zhì)的值函數(shù)贵白。但是仍然沒能完全解決GAN訓(xùn)練穩(wěn)定性的問題率拒。
雷鋒網(wǎng)(公眾號(hào):雷鋒網(wǎng))了解到,在該論文中禁荒,蒙特利爾大學(xué)的研究者對(duì)WGAN進(jìn)行改進(jìn)猬膨,提出了一種替代WGAN判別器中權(quán)重剪枝的方法,下面是他們所做的工作:
通過小數(shù)據(jù)集上的實(shí)驗(yàn)呛伴,概述了判別器中的權(quán)重剪枝是如何導(dǎo)致影響穩(wěn)定性和性能的病態(tài)行為的勃痴。
提出具有梯度懲罰的WGAN(WGAN with gradient penalty),從而避免同樣的問題磷蜀。
展示該方法相比標(biāo)準(zhǔn)WGAN擁有更快的收斂速度召耘,并能生成更高質(zhì)量的樣本。
展示該方法如何提供穩(wěn)定的GAN訓(xùn)練:幾乎不需要超參數(shù)調(diào)參褐隆,成功訓(xùn)練多種針對(duì)圖片生成和語言模型的GAN架構(gòu)
WGAN的critic函數(shù)對(duì)輸入的梯度相比于GAN的更好污它,因此對(duì)生成器的優(yōu)化更簡(jiǎn)單。另外,WGAN的值函數(shù)是與生成樣本的質(zhì)量相關(guān)的衫贬,這個(gè)性質(zhì)是GAN所沒有的德澈。WGAN的一個(gè)問題是如何高效地在critic上應(yīng)用Lipschitz約束,Arjovsky提出了權(quán)重剪枝的方法固惯。但權(quán)重剪枝會(huì)導(dǎo)致最優(yōu)化困難梆造。在權(quán)重剪枝約束下,大多數(shù)神經(jīng)網(wǎng)絡(luò)架構(gòu)只有在學(xué)習(xí)極其簡(jiǎn)單地函數(shù)時(shí)才能達(dá)到k地最大梯度范數(shù)葬毫。因此镇辉,通過權(quán)重剪枝來實(shí)現(xiàn)k-Lipschitz約束將會(huì)導(dǎo)致critic偏向更簡(jiǎn)單的函數(shù)。如下圖所示贴捡,在小型數(shù)據(jù)集上忽肛,權(quán)重剪枝不能捕捉到數(shù)據(jù)分布的高階矩。
由于在WGAN中使用權(quán)重剪枝可能會(huì)導(dǎo)致不良結(jié)果烂斋,研究者考慮在訓(xùn)練目標(biāo)上使用Lipschitz約束的一種替代方法:一個(gè)可微的函數(shù)是1-Lipschitz屹逛,當(dāng)且僅當(dāng)它的梯度具有小于或等于1的范數(shù)時(shí)。因此汛骂,可以直接約束critic函數(shù)對(duì)其輸入的梯度范數(shù)罕模。新的critic函數(shù)為:
實(shí)驗(yàn)結(jié)果 圖&表
研究者們?cè)贑IFAR-10數(shù)據(jù)集上將梯度懲罰的WGAN與權(quán)重剪枝的WGAN的訓(xùn)練進(jìn)行了對(duì)比。其中橙色曲線的梯度懲罰WGAN使用了與權(quán)重剪枝WGAN相同的優(yōu)化器(RMSProp)和相同的學(xué)習(xí)率帘瞭。綠色曲線是使用了Adam優(yōu)化器和更高學(xué)習(xí)率的梯度懲罰WGAN淑掌。可以看到图张,即使使用了同樣的優(yōu)化器锋拖,該論文中的方法也能更快的收斂并得到更高的最終分?jǐn)?shù)。使用Adam優(yōu)化器能進(jìn)一步提高性能兽埃。
為了展示該方法訓(xùn)練過程中的穩(wěn)定性,研究者在LSUN臥室訓(xùn)練集上訓(xùn)練了多種不同的GAN架構(gòu)适袜,除了DCGAN外,研究者還選擇了另外六種較難訓(xùn)練的架構(gòu)苦酱,如下圖所示:
對(duì)于每種架構(gòu)售貌,研究者都使用了四種不同的GAN過程:梯度懲罰的WGAN,權(quán)重剪枝的WGAN疫萤,DCGAN颂跨,以及最小二乘GAN。對(duì)于每種方法扯饶,都使用了推薦的優(yōu)化器超參數(shù)默認(rèn)設(shè)置:
WGAN with gradient penalty: Adam (α = .0001, β1 = .5, β2 = .9)
WGAN with weight clipping: RMSProp (α = .00005)
DCGAN: Adam (α = .0002, β1 = .5)
LSGAN: RMSProp (α = .0001) [chosen by search over α = .001, .0002, .0001]
上圖顯示的樣本都是經(jīng)過200k次迭代的結(jié)果恒削。目前為止池颈,梯度懲罰的WGAN是唯一一種使用同一種默認(rèn)超參數(shù),并在每個(gè)架構(gòu)下都成功訓(xùn)練的方法钓丰。而所有其他方法躯砰,都在一些架構(gòu)下不穩(wěn)定。
使用GAN構(gòu)建語言模型是一項(xiàng)富有挑戰(zhàn)的任務(wù)携丁,很大程度上是因?yàn)樯善髦须x散的輸入輸出序列很難進(jìn)行反向傳播琢歇。先前的GAN語言模型通常憑借預(yù)訓(xùn)練或者與監(jiān)督最大似然方法聯(lián)合訓(xùn)練。相比之下梦鉴,使用該論文的方法李茫,不需采用復(fù)雜的通過離散變量反向傳播的方法,也不需要最大似然訓(xùn)練或fine-tune結(jié)構(gòu)尚揣。該方法在Google Billion Word數(shù)據(jù)集上訓(xùn)練了一個(gè)字符級(jí)的GAN語言模型涌矢。生成器是一個(gè)簡(jiǎn)單的CNN架構(gòu),通過1D卷積將latent vector轉(zhuǎn)換為32個(gè)one-hot字符向量的序列快骗。
下圖展示了模型的一個(gè)例子。目前為止塔次,這是第一個(gè)完全使用對(duì)抗方法進(jìn)行訓(xùn)練方篮,而沒有使用監(jiān)督的最大似然損失的生成語言模型。其中有一些拼寫上的錯(cuò)誤励负,這可能是由于模型是每個(gè)字符獨(dú)立輸出的藕溅。
該文提供了一種訓(xùn)練GAN的穩(wěn)定的算法,能夠更好的探索哪種架構(gòu)能夠得到最好的生成模型性能继榆。該方法也打開了使用大規(guī)模圖像或語言數(shù)據(jù)集訓(xùn)練以得到更強(qiáng)的模型性能的大門巾表。
本論文在github上開源了代碼:github
本論文同時(shí)也提供了詳細(xì)的數(shù)學(xué)證明,以及更多的示例略吨,進(jìn)一步了解請(qǐng)閱讀原論文:Improved Training