保留初心丰包,砥礪前行
SeqGAN這篇paper從大半年之前就開始看,斷斷續(xù)續(xù)看到現(xiàn)在壤巷,接下來的工作或許會(huì)與GAN + RL有關(guān)邑彪,因此又把它翻出來,又一次仔細(xì)拜讀了一番胧华。接下來就記錄下我的一點(diǎn)理解寄症。
1. 背景
GAN在之前發(fā)的文章里已經(jīng)說過了宙彪,不了解的同學(xué)點(diǎn)我,雖然現(xiàn)在GAN的變種越來越多有巧,用途廣泛释漆,但是它們的對(duì)抗思想都是沒有變化的。簡(jiǎn)單來說篮迎,就是在生成的過程中加入一個(gè)可以鑒別真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的鑒別器男图,使生成器G和鑒別器D相互對(duì)抗,D的作用是努力地分辨真實(shí)數(shù)據(jù)和生成數(shù)據(jù)甜橱,G的作用是努力改進(jìn)自己從而生成可以迷惑D的數(shù)據(jù)享言。當(dāng)D無法再分別出真假數(shù)據(jù),則認(rèn)為此時(shí)的G已經(jīng)達(dá)到了一個(gè)很優(yōu)的效果渗鬼。
它的諸多優(yōu)點(diǎn)是它如今可以這么火爆的原因:
- 可以生成更好的樣本
- 模型只用到了反向傳播,而不需要馬爾科夫鏈
- 訓(xùn)練時(shí)不需要對(duì)隱變量做推斷
- G的參數(shù)更新不是直接來自數(shù)據(jù)樣本,而是使用來自D的反向傳播
- 理論上,只要是可微分函數(shù)都可以用于構(gòu)建D和G,因?yàn)槟軌蚺c深度神經(jīng)網(wǎng)絡(luò)結(jié)合做深度生成式模型
它的最后一條優(yōu)點(diǎn)也恰恰就是它的局限览露,之前我發(fā)過的文章中也有涉及到,點(diǎn)點(diǎn)點(diǎn)點(diǎn)點(diǎn)我譬胎,在NLP中差牛,數(shù)據(jù)不像圖片處理時(shí)是連續(xù)的,可以微分堰乔,我們?cè)趦?yōu)化生成器的過程中不能找到“中國 + 0.1”這樣的東西代表什么偏化,因此對(duì)于離散的數(shù)據(jù),普通的GAN是無法work的镐侯。
2. 大體思路
這位還在讀本科的作者想到了使用RL來解決這個(gè)問題侦讨。
如上圖(左)所示,仍然是對(duì)抗的思想苟翻,真實(shí)數(shù)據(jù)加上G的生成數(shù)據(jù)來訓(xùn)練D韵卤。但是從前邊背景章節(jié)所述的內(nèi)容中,我們可以知道G的離散輸出崇猫,讓D很難回傳一個(gè)梯度用來更新G沈条,因此需要做一些改變,看上圖(右)诅炉,paper中將policy network當(dāng)做G蜡歹,已經(jīng)存在的紅色圓點(diǎn)稱為現(xiàn)在的狀態(tài)(state),要生成的下一個(gè)紅色圓點(diǎn)稱作動(dòng)作(action)涕烧,因?yàn)镈需要對(duì)一個(gè)完整的序列評(píng)分月而,所以就是用MCTS(蒙特卡洛樹搜索)將每一個(gè)動(dòng)作的各種可能性補(bǔ)全,D對(duì)這些完整的序列產(chǎn)生reward议纯,回傳給G父款,通過增強(qiáng)學(xué)習(xí)更新G。這樣就是用Reinforcement learning的方式,訓(xùn)練出一個(gè)可以產(chǎn)生下一個(gè)最優(yōu)的action的生成網(wǎng)絡(luò)铛漓。
3. 主要內(nèi)容
不論怎么對(duì)抗溯香,目的都是為了更好的生成鲫构,因此我們可以把生成作為切入點(diǎn)浓恶。生成器G的目標(biāo)是生成sequence來最大化reward的期望。
在這里把這個(gè)reward的期望叫做J(θ)结笨。就是在s0和θ的條件下包晰,產(chǎn)生某個(gè)完全的sequence的reward的期望。其中Gθ()部分可以輕易地看出就是Generator Model炕吸。而QDφGθ()(我在這里叫它Q值)在文中被叫做一個(gè)sequence的action-value function 伐憾。因此,我們可以這樣理解這個(gè)式子:G生成某一個(gè)y1的概率乘以這個(gè)y1的Q值赫模,這樣求出所有y1的概率乘Q值树肃,再求和,則得到了這個(gè)J(θ)瀑罗,也就是我們生成模型想要最大化的函數(shù)胸嘴。
所以問題來了,這個(gè)Q值怎么求斩祭?
paper中使用的是REINFORCE algorithm 并且就把這個(gè)Q值看作是鑒別器D的返回值劣像。
因?yàn)椴煌暾能壽E產(chǎn)生的reward沒有實(shí)際意義,因此在原有y_1到y(tǒng)_t-1的情況下摧玫,產(chǎn)生的y_t的Q值并不能在y_t產(chǎn)生后直接計(jì)算耳奕,除非y_t就是整個(gè)序列的最后一個(gè)。paper中想了一個(gè)辦法诬像,使用蒙特卡洛搜索(就我所知“蒙特卡洛”這四個(gè)字可以等同于“隨意”)將y_t后的內(nèi)容進(jìn)行補(bǔ)全屋群。既然是隨意補(bǔ)全就說明會(huì)產(chǎn)生多種情況,paper中將同一個(gè)y_t后使用蒙特卡洛搜索補(bǔ)全的所有可能的sequence全都計(jì)算reward坏挠,然后求平均谓晌。如下圖所示。
就這樣癞揉,我們生成了一些逼真的sequence纸肉。我們就要用如下方式訓(xùn)練D。
這個(gè)式子很容易理解喊熟,最大化D判斷真實(shí)數(shù)據(jù)為真加上D判斷生成數(shù)據(jù)為假柏肪,也就是最小化它們的相反數(shù)。
D訓(xùn)練了一輪或者多輪(因?yàn)镚AN的訓(xùn)練一直是個(gè)難題芥牌,找好G和D的訓(xùn)練輪數(shù)比例是關(guān)鍵)之后烦味,就得到了一個(gè)更優(yōu)秀的D,此時(shí)要用D去更新G。G的更新可以看做是梯度下降谬俄。
其中柏靶,
αh代表學(xué)習(xí)率。
以上就是大概的seqGAN的原理溃论。
4. 算法
首先隨機(jī)初始化G網(wǎng)絡(luò)和D網(wǎng)絡(luò)參數(shù)屎蜓。
通過MLE預(yù)訓(xùn)練G網(wǎng)絡(luò),目的是提高G網(wǎng)絡(luò)的搜索效率钥勋。
使用預(yù)訓(xùn)練的G生成一些數(shù)據(jù)炬转,用來通過最小化交叉熵來預(yù)訓(xùn)練D。
開始生成sequence算灸,并使用方程(4)計(jì)算reward(這個(gè)reward來自于G生成的sequence與D產(chǎn)生的Q值)扼劈。
使用方程(8)更新G的參數(shù)。
更優(yōu)的G生成更好的sequence菲驴,和真實(shí)數(shù)據(jù)一起通過方程(5)訓(xùn)練D荐吵。
以上1,2赊瞬,3循環(huán)訓(xùn)練直到收斂先煎。
5. 實(shí)驗(yàn)
論文的實(shí)驗(yàn)部分就不是本文的重點(diǎn)了,有興趣的話看一下paper就可以了森逮。
后邊說的比較敷衍了榨婆,那...就這樣吧。
參考資料:SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient | 百度&google