1.CGAN的簡介
為了解決帶標(biāo)簽的數(shù)據(jù)生成問題挽封,研究者們提出了條件生成對抗網(wǎng)絡(luò)(CGAN)的概念摹蘑。
CGAN的結(jié)構(gòu)如上圖所示,與GAN的主要區(qū)別是生成器和判別器的輸入數(shù)據(jù)中都加入類別標(biāo)簽向量(C_vector)荒典,生成器的優(yōu)化目標(biāo)函數(shù)基本上沒有變化驶兜。
總的來說CGAN在GAN上的改動并不大舀凛,但是普通的GAN所生成的內(nèi)容是隨機(jī)的俊扳,CGAN實(shí)現(xiàn)了根據(jù)輸入標(biāo)簽生成指定類別的內(nèi)容。
2.CGAN的實(shí)現(xiàn)
目前CGAN的實(shí)現(xiàn)由多種形式猛遍,主要的區(qū)別是C_vector的形式馋记,目前主要有以下三種形式:
第一中形式:
將輸入Generator的C_vector進(jìn)行One-hot編碼,然后與noise進(jìn)行拼接懊烤,此時C_vector為(batch_size, class_num) 梯醒,noise為(batch_size, latent_dim),將拼接之后大小為(batch_size, latent_dim+class_num)作為生成器的輸入腌紧。
將輸入Discrimintor的C_vector首先進(jìn)行One-hot編碼茸习,然后通過expand()方法進(jìn)行維度擴(kuò)展,此時的C_vector為 (batch_size, class_num, cols, rows) 壁肋, Real_data 和 Fake_data為(batch_size, channel, cols, rows)号胚,最后將轉(zhuǎn)換后的C_vector和Real_data或者Fake_data進(jìn)行拼接,將拼接之后大小為(batch_size, channel+class_num, cols, rows)的張量作為判別器的輸入浸遗。
第二種形式:
將輸入Generator的C_vector通過Embedding方法進(jìn)詞嵌入猫胁,并進(jìn)行Flatten操作,從而將C_vector轉(zhuǎn)換成為與noise大小相同的張量(batch_size, latent_dim), 然后將noise 和 C_vector 進(jìn)行mulitiply()操作(即對應(yīng)位置上的元素相乘跛锌,該運(yùn)輸不改變張量的大衅选),將最終得到的(batch_size, latent_dim)的張量作為生成器的輸入髓帽。
將輸入Discriminator的C_vector通過Embedding方法進(jìn)行詞嵌入菠赚,并進(jìn)行Flatten操作,從而將C_vector轉(zhuǎn)換為(batch_size, channel*rows*cols)郑藏,接著對Real_data和Fake_data進(jìn)行Flatten操作衡查,將其轉(zhuǎn)換為(batch_size, channel*rows*cols),然后將轉(zhuǎn)換后的C_vector和Real_data或者Fake_data進(jìn)行multiply()操作译秦,將最終得到的(batch_size, channel*rows*cols)張量作為判別器的輸入峡捡。
第三種形式:
將輸入Generator的C_vector進(jìn)行One-hot編碼,然后與noise進(jìn)行拼接筑悴,此時C_vector為(batch_size, class_num) 们拙,noise為(batch_size, latent_dim),最后將拼接后大小為(batch_size,?latent_dim+class_num)作為生成器的輸入阁吝。
將輸入Discriminator的C_vector進(jìn)行One-hot編碼砚婆,然后與經(jīng)過Flatten()處理之后的Real_data或者Fake_data進(jìn)行拼接,此時Real_data和Fake_data為(batch_size, channel*rows*cols),C_vector為(batch_size, num_class)装盯,最后將拼接之后大小為(batch_size, channel*rows*cols + num_class)的張量作為判別器的輸入坷虑。
損失函數(shù):
在具體實(shí)現(xiàn)上,CGAN的損失函數(shù)和GAN基本相同埂奈。