訓(xùn)練WGAN的時(shí)候,有幾個(gè)方面可以調(diào)參:
? a. 調(diào)節(jié)Generator loss中GAN loss的權(quán)重拗引。 G loss和Gan loss在一個(gè)尺度上或者G loss比Gan loss大一個(gè)尺度。但是千萬(wàn)不能讓Gan loss占主導(dǎo)地位, 這樣整個(gè)網(wǎng)絡(luò)權(quán)重會(huì)被帶偏幌衣。
? b. 調(diào)節(jié)Generator和Discrimnator的訓(xùn)練次數(shù)比矾削。一般來(lái)說(shuō),Discrimnator要訓(xùn)練的比Genenrator多豁护。比如訓(xùn)練五次Discrimnator哼凯,再訓(xùn)練一次Genenrator(WGAN論文 是這么干的)。
? c. 調(diào)節(jié)learning rate择镇,這個(gè)學(xué)習(xí)速率不能過(guò)大挡逼。一般要比Genenrator的速率小一點(diǎn)。
? d. Optimizer的選擇不能用基于動(dòng)量法的腻豌,如Adam和momentum家坎。可使用RMSProp或者SGD吝梅。
? e. Discrimnator的結(jié)構(gòu)可以改變虱疏。如果用WGAN,判別器的最后一層需要去掉sigmoid苏携。但是用原始的GAN做瞪,需要用sigmoid,因?yàn)槠鋖oss function里面需要取log,所以值必須在[0,1]装蓬。這里用的是鄧煒的critic模型當(dāng)作判別器著拭。之前twitter的論文里面的判別器即使去掉了sigmoid也不好訓(xùn)練。
? f. Generator loss的誤差曲線走向牍帚。因?yàn)镚enerator的loss定義為:
?? G_loss = -tf.reduce_mean(D_fake)
? ? Generator_loss = gen_loss + lamda*G_loss
其中g(shù)en_loss為Generator的loss儡遮,G_loss為Discrimnator的loss,目標(biāo)是使Generator_loss不斷變小暗赶。所以理想的Generator loss的誤差曲線應(yīng)該是不斷往0靠的下降的拋物線鄙币。
? g. Discrimnator loss的誤差曲線走向。因?yàn)镈iscrimnator的loss定義為:
? ? ? D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
這個(gè)是一個(gè)和Generator抗衡的loss蹂随。目標(biāo)就是使判別器分不清哪個(gè)是生成器的輸出哪個(gè)是真實(shí)的label十嘿。所以理想的Discrimnator loss的誤差曲線應(yīng)該是最終在0附近振蕩,即傻傻分不清岳锁。換言之绩衷,就是判別器有50%的概率判斷你是真的,50%概率判斷你是假的激率。
? h. 之前的想法是就算判別器不訓(xùn)練唇聘,那么它判斷這個(gè)圖片是真是假的概率都是50%,那D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)不就已經(jīng)在0附近了嗎柱搜?
其實(shí)不是這樣的。如果是wgan的話剥险,判別器的輸出是一個(gè)負(fù)無(wú)窮到正無(wú)窮的數(shù)值聪蘸,那么要讓它對(duì)兩個(gè)不同的輸入產(chǎn)生相似的輸出是很難的。同理表制,對(duì)于gan的話健爬,判別器的輸出是介于[0,1]之間的,產(chǎn)生兩個(gè)相似的輸出也是很困難的么介。如果判別器的輸出是0或者1的話娜遵,那就是上面說(shuō)的情況。所以壤短,網(wǎng)絡(luò)要經(jīng)過(guò)學(xué)習(xí)设拟,使得 輸出盡可能相似,那就達(dá)到了傻傻分不清的狀態(tài)了久脯。