1. InfoGAN簡介:
? ? ? ?普通的GAN存在無約束、不可控违孝、噪聲信號z很難解釋等問題,2016年發(fā)表在NIPS頂會上的文章InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets泳赋,提出了InfoGAN的生成對抗網(wǎng)絡(luò)雌桑。InfoGAN 主要特點是對GAN進行了一些改動,成功地讓網(wǎng)絡(luò)學(xué)到了可解釋的特征祖今,網(wǎng)絡(luò)訓(xùn)練完成之后校坑,我們可以通過設(shè)定輸入生成器的隱含編碼來控制生成數(shù)據(jù)的特征。
? ? ? ? 作者將輸入生成器的隨機噪聲分成了兩部分:一部分是隨機噪聲Z千诬, 另一部分是由若干隱變量拼接而成的latent code c耍目。其中,c會有先驗的概率分布徐绑,可以離散也可以連續(xù)邪驮,用來代表生成數(shù)據(jù)的不同特征。例如:對于MNIST數(shù)據(jù)集傲茄,c包含離散部分和連續(xù)部分耕捞,離散部分取值為0~9的離散隨機變量(表示數(shù)字),連續(xù)部分有兩個連續(xù)型隨機變量(分別表示傾斜度和粗細度)烫幕。
? ? ? ? 為了讓隱變量c能夠與生成數(shù)據(jù)的特征產(chǎn)出關(guān)聯(lián)俺抽,作者引入了互信息來對c進行約束,因為c對生成數(shù)據(jù)G(z, c)具有可解釋性较曼,那么c和G(z, c)應(yīng)該具有較高的相關(guān)性磷斧,即它們之間的互信息比較大〗萦蹋互信息是兩個隨機變量之間依賴程度的度量弛饭,互信息越大就說明生成網(wǎng)絡(luò)在根據(jù)c的信息生成數(shù)據(jù)時,隱編碼c的信息損失越低萍歉,即生成數(shù)據(jù)保留的c的信息越多侣颂。因此,我們希望c和G(z, c)之間的互信息I(c; G(z, c))越大越好枪孩,故模型的目標函數(shù)變?yōu)椋?/p>
? ? ? ? 但是由于在c與G(z, c)的互信息的計算中憔晒,真實的P(c|x)難以獲得,因此在具體的優(yōu)化過程中蔑舞,作者采用了變分推斷的思想拒担,引入了變分分布Q(c|x)來逼近P(c|x),它是基于最優(yōu)互信息下界的輪流迭代實現(xiàn)最終的求解攻询,于是InfoGAN的目標函數(shù)變?yōu)椋?/p>
2. InfoGAN的基本結(jié)構(gòu)為:
? ? ? ? 其中从撼,真實數(shù)據(jù)Real_data只是用來跟生成的Fake_data混合在一起進行真假判斷,并根據(jù)判斷的結(jié)果更新生成器和判別器钧栖,從而使生成的數(shù)據(jù)與真實數(shù)據(jù)接近低零。生成數(shù)據(jù)既要參與真假判斷婆翔,還需要和隱變量C_vector求互信息,并根據(jù)互信息更新生成器和判別器掏婶,從而使得生成圖像中保留了更多隱變量C_vector的信息啃奴。
? ? ? ? 因此可以對InfoGAN的基本結(jié)構(gòu)進行如下的拆分,其中判別器D和Q共用所有卷積層气堕,只是最后的全連接層不同。從另一個角度來看畔咧,G-Q聯(lián)合網(wǎng)絡(luò)相當于是一個自編網(wǎng)絡(luò)茎芭,G相當于一個編碼器,而Q相當于一個解碼器誓沸,生成數(shù)據(jù)Fake_data相當于對輸入隱變量C_vector的編碼梅桩。
生成器G的輸入為:(batch_size, noise_dim + discrete_dim + continuous_dim)县钥,其中noise_dim為輸入噪聲的維度注簿,discrete_dim為離散隱變量的維度当悔,continuous_dim為連續(xù)隱變量的維度物臂。生成器G的輸出為(batch_size, channel, img_cols, img_rows)凄鼻。
判別器D的輸入為:(batch_size, channel, img_cols, img_rows)拥褂,判別器D的輸出為:(batch_size, 1)紊册。
判別器Q的輸入為:(batch_size, channel, img_cols, img_rows)放前,Q的輸出為:(batch_size, discrete_dim + continuous_dim)
3. InfoGAN的優(yōu)化目標函數(shù)為:
? ? ? ? InfoGAN的目標函數(shù)變?yōu)椋? ? ?
? ? ? ? 對于判別器D而言干奢,優(yōu)化目標函數(shù)為:? ??
D_real, _, _ = Discriminator(real_imgs)? ? ? ? #? real_imgs為用于訓(xùn)練的真實圖像
gen_imgs = Generator(noise, c_discrete, c_continuous)? ? ? ? #? c_discrete為輸入的離散型隱變量痊焊, c_continuous為輸入的連續(xù)型隱變量??
D_fake, _, _ = Discriminator(gen_imgs)?
D_real_loss = torch.nn.BCELoss(D_real, y_real)? ? ? ? ? #? y_real 真實圖像的標簽,都為1
D_fake_loss = torch.nn.BCELoss(D_fake, y_fake)? ? ? ? ?# y_fake為生成圖像的標簽忿峻,都為0
D_loss = D_real_loss + D_fake_loss
? ? ? ? 對于生成器G而言薄啥,優(yōu)化目標函數(shù)為:?
gen_imgs = Generator(noise, c_discrete, c_continuous)? ? ? ? ? ? ? ? ? ?#? c_discrete為輸入的離散隱變量,c_continuous為輸入的連續(xù)隱變量
D_fake, D_continuous, D_discrete = Discriminator(gen_imgs)?
G_loss = torch.nn.BCELoss( D_fake, y_real)? ? ? ? ? ? ? ??#? y_real 真實圖像的標簽逛尚,都為1
? ? ? ? ?對于G-Q聯(lián)合網(wǎng)絡(luò)而言垄惧,它的優(yōu)化目標函數(shù)為:? , 其中? ? ? ? ? ? ??因此绰寞,
discrete_loss = torch.nn.CELoss(D_discrete, c_discrete)
continuous_loss = torch.nn.MSELoss(D_continuous, c_continuous)
info_loss = discrete_loss + continuous_loss
info_loss.backward()
info_optimizer.step()? ? ? ?# 其中到逊,info_optimizer = optim.Adam(itertools.chain(Generator.parameters(), Discriminator.parameters()), lr = learning_rate, betas=(beta1, beta2))
? ? ? ?簡而言之,InfoGAN中單獨判別器D的優(yōu)化目標函數(shù)只有對抗損失滤钱,單獨生成器G的優(yōu)化目標函數(shù)也只有對抗損失蕾管,生成器G和輔助判別器Q聯(lián)合網(wǎng)絡(luò)的優(yōu)化目標函數(shù)是info損失,包含離散損失和連續(xù)損失兩個部分菩暗。其中掰曾,判別器D和輔助判別器Q共用卷積層,只是最后的全連接層不同停团。
參考鏈接:http://aistudio.baidu.com/aistudio/projectdetail/29156?中山大學(xué)黃濤對論文InfoGAN:Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets的復(fù)現(xiàn)旷坦。