本文是2017年發(fā)表于AAAI的論文《SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient》的閱讀筆記舱权。
想要解決的問題(動(dòng)機(jī))
GAN模型的兩個(gè)缺點(diǎn):
- GAN是設(shè)計(jì)用來生成實(shí)值連續(xù)數(shù)據(jù)的槽驶,在直接生成離散序列方面有些困難,表現(xiàn)不好;
- GAN只能給生成的完整序列打分鸯旁,對(duì)于生成的部分序列,平衡現(xiàn)在的好壞和將來作為整個(gè)序列的得分是很重的磺陡。
原因
:我們知道沼本,GAN中生成網(wǎng)絡(luò)參數(shù)的優(yōu)化是需要依賴于判別網(wǎng)絡(luò)的,如果輸出是離散序列的話薪捍,就很難將梯度更新從判別網(wǎng)絡(luò)傳給生成網(wǎng)絡(luò)笼痹。
怎么理解這個(gè)離散序列?即輸出不是一個(gè)值域里所有的值(例如
)酪穿,而是可枚舉的數(shù)值(例如
)凳干。
由于生成模型的核心是擬合數(shù)據(jù)的密度分布,而只有連續(xù)的數(shù)據(jù)才有密度分布被济,離散的數(shù)據(jù)是沒有密度分布的救赐。
解決的方法
- 問題:梯度更新無法從判別網(wǎng)絡(luò)傳給生成網(wǎng)絡(luò)
- 解決方法:將序列生成過程看成序列決策過程,引入強(qiáng)化學(xué)習(xí)只磷,將生成網(wǎng)絡(luò)看成強(qiáng)化學(xué)習(xí)機(jī)经磅,所生成的部分離散序列看成是當(dāng)前的狀態(tài)(state),動(dòng)作(action)就是決定下一個(gè)字符是什么钮追。
- 這樣一來预厌,生成網(wǎng)絡(luò)只需要接受由判別網(wǎng)絡(luò)對(duì)生成的輸出序列打分的分?jǐn)?shù)作為獎(jiǎng)勵(lì)(reward)來進(jìn)行參數(shù)調(diào)整,而不需要通過梯度計(jì)算畏陕。
- 問題:GAN只能給生成的完整序列打分
- 解決方法:使用Monte Carlo Search(蒙特卡洛搜索)方法來從當(dāng)前的序列狀態(tài)搜索得到完整序列配乓,然后由判別網(wǎng)絡(luò)打分,得到的分值作為獎(jiǎng)勵(lì)(reward)
上述兩種方案的結(jié)合,就是所謂的SeqGAN
模型架構(gòu)犹芹。
模型
目標(biāo)函數(shù)
- 生成網(wǎng)絡(luò)的目標(biāo)函數(shù)
由于生成網(wǎng)絡(luò)現(xiàn)在使用的是強(qiáng)化學(xué)習(xí)來實(shí)現(xiàn)的崎页,而強(qiáng)化學(xué)習(xí)的目標(biāo)就是使得獎(jiǎng)勵(lì)(reward)盡可能大。于是腰埂,生成網(wǎng)絡(luò)的目標(biāo)函數(shù)如下:
這里飒焦,action-value函數(shù)(即上式中的Q函數(shù))也至關(guān)重要。我們知道屿笼,reward是將完整序列輸入到判別網(wǎng)絡(luò)D中所得到的牺荠,因此:
這里只是討論了完整序列的情況,還有非完整序列的情況驴一,此時(shí)就需要用到Monte Carlo Search(蒙特卡洛搜索)來從當(dāng)前非完整序列狀態(tài)搜索得到完整的序列狀態(tài)休雌,其搜索過程定義如下(一共搜索了N次,得到了N個(gè)完整的序列):
中間狀態(tài)序列的reward是這N個(gè)搜索得到的完整序列的reward均值肝断,從而得到完整的action-value函數(shù)定義:
- 判別網(wǎng)絡(luò)的目標(biāo)函數(shù)
判別網(wǎng)絡(luò)的目標(biāo)不變(還是GAN中的那一套)杈曲,還是最小化交叉熵:
這里的目標(biāo)函數(shù)并不像GAN中的一樣可以結(jié)合成為一個(gè)整體的目標(biāo)函數(shù),因?yàn)閮蓚€(gè)目標(biāo)函數(shù)的梯度是不互通的
算法
圖中的公式(8)如下:
由于作者發(fā)現(xiàn)來自預(yù)訓(xùn)練的判別器的監(jiān)督信號(hào)能夠很有效地調(diào)整生成網(wǎng)絡(luò)胸懈,在實(shí)現(xiàn)過程中担扑,先預(yù)訓(xùn)練了判別網(wǎng)絡(luò);而且為了使得生成的負(fù)樣本更可靠趣钱,在預(yù)訓(xùn)練判別器之前涌献,使用極大似然估計(jì)(MLE)先預(yù)訓(xùn)練了生成器。
評(píng)價(jià)生成模型
當(dāng)訓(xùn)練得到最終的生成模型首有,我們就需要評(píng)價(jià)其性能燕垃。
在本文中,采用了一個(gè)能夠捕獲到字符依賴關(guān)系的語(yǔ)言模型作為真實(shí)數(shù)據(jù)的模型(稱為oracle)绞灼,使用該模型生成SeqGAN的訓(xùn)練數(shù)據(jù)利术,并且能夠用該模型評(píng)價(jià)SeqGAN中生成模型的性能。
評(píng)價(jià)采用的指標(biāo)是負(fù)對(duì)數(shù)似然:
其中 和
分別表示生成模型和oracle模型低矮。
上式其實(shí)是一個(gè)熵印叁,當(dāng)生成模型生成一個(gè)序列
的概率是
時(shí),我們計(jì)算一下oracle模型生成該序列的概率是
军掂,然后計(jì)算一下熵
轮蜕,所有可能生成的序列的熵之和為
,即為
蝗锥,當(dāng)熵越小跃洛,說明p和q的散度越小,p和q的分布就越相似终议。
越小就說明生成模型性能越高汇竭,越接近數(shù)據(jù)的真實(shí)分布葱蝗。
討論
本文還對(duì)算法中一些超參(g-step,d-step细燎,k等)的設(shè)置两曼。
實(shí)驗(yàn)結(jié)果如下:
從這個(gè)實(shí)驗(yàn)中可以得出:
- 設(shè)置g-step比d-step和k值大很多(花更多時(shí)間用來訓(xùn)練生成網(wǎng)絡(luò))時(shí),能夠快速收斂玻驻,但是它會(huì)導(dǎo)致判別網(wǎng)絡(luò)沒訓(xùn)練好悼凑,從而會(huì)誤導(dǎo)生成網(wǎng)絡(luò)進(jìn)行更新璧瞬,收斂性會(huì)不穩(wěn)定;(圖3.a)
- 增大d-step和k的設(shè)置(d-step相較于g-step比值有提升)時(shí)渔欢,收斂性會(huì)穩(wěn)定很多;(圖3.b)
- 將d-step
k 設(shè)置的略大于g-step時(shí)档冬,意味著我們一直保證判別網(wǎng)絡(luò)的準(zhǔn)確性(在判別網(wǎng)絡(luò)被混淆之前就使用更真實(shí)的負(fù)例來訓(xùn)練它)膘茎,此時(shí)SeqGAN會(huì)比較穩(wěn)定;(圖3.c)
- 將d-step
k進(jìn)一步增大酷誓,發(fā)現(xiàn)其收斂的更快;(圖3.d)
上面的實(shí)驗(yàn)證明了:在訓(xùn)練過程中將判別網(wǎng)絡(luò)的性能訓(xùn)練得略高于生成網(wǎng)絡(luò)是有利于模型的收斂态坦。