論文鏈接:http://openaccess.thecvf.com/content_CVPR_2019/papers/Huang_Generative_Dual_Adversarial_Network_for_Generalized_Zero-Shot_Learning_CVPR_2019_paper.pdf
創(chuàng)新:論文提出了一個(gè)新穎的模型鼻百,該模型為三種不同的方法提供了統(tǒng)一的框架:視覺(jué)→語(yǔ)義映射,語(yǔ)義→視覺(jué)映射和深度度量學(xué)習(xí)件余。
模型包含一個(gè)生成器網(wǎng)絡(luò)良姆,它能夠生成以類嵌入為條件的圖像特征肠虽;一個(gè)回歸器網(wǎng)絡(luò),它獲取圖像特征并輸出其類嵌入(即語(yǔ)義特征)玛追;以及一個(gè)鑒別器網(wǎng)絡(luò)税课,其將圖像特征和語(yǔ)義特征作為輸入,并輸出一個(gè)分?jǐn)?shù)痊剖,以表明它們彼此的匹配程度韩玩。生成器和回歸器通過(guò)循環(huán)一致性損失相互學(xué)習(xí),而它們兩者也通過(guò)雙重對(duì)抗性損失與鑒別器交互陆馁。三個(gè)網(wǎng)絡(luò)均使用前饋網(wǎng)絡(luò)模型找颓。
背景:如圖1(a)所示,大多數(shù)現(xiàn)有方法將視覺(jué)特征投射到類屬性所跨越的語(yǔ)義空間叮贩。然而击狮,使用語(yǔ)義空間作為共享的潛在空間將遭受hubness problem的困擾佛析,即將高維視覺(jué)特征投影到低維空間將大大減少特征的差異,結(jié)果可能聚集成一個(gè)中心彪蓬。為了減輕這個(gè)問(wèn)題寸莫,一些方法提議將語(yǔ)義特征投影到視覺(jué)空間中,如圖1(b)的左側(cè)所示寞焙。但是储狭,使用確定性方法將類別的語(yǔ)義嵌入映射到視覺(jué)空間仍然是有問(wèn)題的,因?yàn)橐粋€(gè)類別標(biāo)簽具有許多對(duì)應(yīng)的視覺(jué)特征捣郊。相反辽狈,一些最新的著作建議使用生成方法,這些方法可以生成基于語(yǔ)義特征向量的各種視覺(jué)特征呛牲,如圖1(b)的右側(cè)所示刮萌。盡管它們有效,但由于缺乏學(xué)習(xí)視覺(jué)空間和語(yǔ)義空間之間的雙向映射的能力或缺乏對(duì)抗性損失(作為評(píng)估特征相似性的更靈活指標(biāo))的效果受到限制娘扩。代替人工選擇一個(gè)公共的潛在空間着茸,RelationNet 提議學(xué)習(xí)一個(gè)深度度量網(wǎng)絡(luò),該網(wǎng)絡(luò)以一對(duì)視覺(jué)和語(yǔ)義特征作為輸入并輸出它們的相似性琐旁,如圖1(c)所示涮阔。但是,RelationNet 無(wú)法學(xué)習(xí)圖像和類的潛在特征灰殴,也不支持半監(jiān)督學(xué)習(xí)敬特。
三部分網(wǎng)絡(luò)詳細(xì)介紹:
CVAE-G生成器:參考CVAE-GAN
Regressor回歸器:自己有一個(gè) supervised loss,
除此之外牺陶,Regressor與CVAE交互時(shí)會(huì)出現(xiàn)以下循環(huán)一致性損失:
Discriminator判別器:
D有四種輸入伟阔,①(v, s) :真實(shí)特征和其對(duì)應(yīng)的真實(shí)語(yǔ)義? ②(G(s, z), s) :生成特征和真實(shí)語(yǔ)義? ③(v,R(v)):真實(shí)特征和生成語(yǔ)義? ④ (v, s?):是一個(gè)隨機(jī)的語(yǔ)義向量,且掰伸,皱炉,即真實(shí)特征和不對(duì)應(yīng)的真實(shí)語(yǔ)義
式中幾項(xiàng)與上面輸入輸出對(duì)應(yīng),由此狮鸭,CVAE和Regressor的對(duì)抗損失還可以定義為:
訓(xùn)練:
鑒別器是與其他兩個(gè)網(wǎng)絡(luò)分開(kāi)訓(xùn)練的合搅,首先使用公式2進(jìn)行CVAE的預(yù)訓(xùn)練,然后使用公式5和公式8以對(duì)抗的方式訓(xùn)練整個(gè)模型怕篷。
實(shí)驗(yàn):
訓(xùn)練完模型后历筝,為了預(yù)測(cè)未見(jiàn)類別的標(biāo)簽,為每個(gè)未見(jiàn)類別首先生成新樣本廊谓,然后將這些合成樣本與訓(xùn)練數(shù)據(jù)中的其他樣本合并梳猪,之后可以訓(xùn)練任何新的類別基于此新數(shù)據(jù)集,其中包含可見(jiàn)和不可見(jiàn)類的樣本。 為了與其他基準(zhǔn)進(jìn)行公平比較春弥,僅應(yīng)用一個(gè)簡(jiǎn)單的1-NN分類器進(jìn)行測(cè)試呛哟,該分類器用于大多數(shù)基準(zhǔn)。
將GDAN模型在SUN 匿沛,CUB扫责,aPY和AWA2數(shù)據(jù)集上與幾個(gè)baseline方法進(jìn)行比較。采用了廣泛使用的平均每類的top-1準(zhǔn)確性來(lái)評(píng)估每個(gè)模型的性能逃呼,定義如下:
在廣義零鏡頭學(xué)習(xí)設(shè)置中鳖孤,在測(cè)試階段,我們使用可見(jiàn)和不可見(jiàn)類的圖像抡笼,并且標(biāo)簽空間也是可見(jiàn)和不可見(jiàn)類的組合苏揣。 我們希望可見(jiàn)和不可見(jiàn)類的準(zhǔn)確性都盡可能高,因此我們需要一個(gè)可以反映模型整體性能的指標(biāo)推姻。使用諧波平均數(shù)平匈,令A(yù)ccYs和AccYu分別表示可見(jiàn)和不可見(jiàn)類別的圖像的準(zhǔn)確性,因此可見(jiàn)和不可見(jiàn)準(zhǔn)確性的諧波平均值H定義為:
我們將模型的CVAE藏古,回歸器和鑒別器實(shí)現(xiàn)為前饋神經(jīng)網(wǎng)絡(luò)增炭。 CVAE的編碼器具有兩個(gè)分別為1200和600個(gè)單位的隱藏層,而CVAE的生成器和鑒別器是由800個(gè)隱藏單位的一個(gè)隱藏層實(shí)現(xiàn)的拧晕。 回歸器只有600個(gè)單位的隱藏層隙姿。 對(duì)于所有數(shù)據(jù)集,噪聲向量z的維數(shù)均設(shè)置為100厂捞。 我們使用λ1=λ2=λ3= 0.1孟辑,發(fā)現(xiàn)它們通常工作良好。 我們選擇Adam作為優(yōu)化器蔫敲,動(dòng)量設(shè)為(0.9,0.999)。判別器的學(xué)習(xí)率設(shè)為0.00001炭玫,而CVAE和回歸器的學(xué)習(xí)率設(shè)為0.0001奈嘿。 diter和giter設(shè)置為1,這意味著我們模型中的所有模塊都以相同的批次數(shù)量進(jìn)行訓(xùn)練吞加。 我們對(duì)每個(gè)數(shù)據(jù)集訓(xùn)練500個(gè)epoch裙犹,每10個(gè)epoch保存一次模型檢查點(diǎn),然后對(duì)驗(yàn)證集進(jìn)行評(píng)估衔憨,以找到最佳的測(cè)試集叶圃。
結(jié)果:
雙重學(xué)習(xí)
可以同時(shí)訓(xùn)練一個(gè)主要任務(wù)和一個(gè)雙重任務(wù),其中雙重任務(wù)是主要任務(wù)的逆任務(wù)践图。 本文的工作與CycleGAN和DualGAN有關(guān)掺冠,因?yàn)閺乃鼈兡抢锝鑱?lái)了循環(huán)一致性損失。 但是码党,這兩個(gè)模型需要兩個(gè)生成網(wǎng)絡(luò)德崭,這使得它們不能直接應(yīng)用于廣義零鏡頭學(xué)習(xí)斥黑,因?yàn)槊總€(gè)類都有固定的語(yǔ)義表示,并且生成網(wǎng)絡(luò)不適合視覺(jué)→語(yǔ)義映射眉厨,因?yàn)樗赡軙?huì)產(chǎn)生很大的差異 語(yǔ)義特征锌奴。 因此,需要一種新穎的架構(gòu)將循環(huán)一致性納入零擊學(xué)習(xí)中憾股。
CVAE-GAN
本文在CVAE-GAN基礎(chǔ)上添加了回歸模型鹿蜀,輸入生成器的z并非隨機(jī)噪聲而是通過(guò)編碼器編碼的模擬所求類別的特征分布的向量。
Kullback-Leibler Divergence (KL 散度)
DKL(p||q)表示的就是概率 q與概率 p之間的差異服球,很顯然茴恰,散度越小,說(shuō)明 概率 q 與概率 p之間越接近有咨,那么估計(jì)的概率分布于真實(shí)的概率分布也就越接近琐簇。
代碼:http://www.github.com/stevehuanghe/GDAN
### 環(huán)境需要
- Python 3.6, PyTroch 0.4
- sklearn, scipy, numpy, tqdm
### 運(yùn)行步驟
1. 按照要求修改配置文件
2. 預(yù)訓(xùn)練 CVAE,checkpoints存在配置文件里指定的"vae_dir":
運(yùn)行python pretrain_gdan.py --config configs/cub.yml
3. 選擇想用的CVAE checkpoint去初始化GDAN模型 并修改 yaml文件中的"vae_ckpt"變量。
運(yùn)行python train_gdan.py --config configs/cub.yml訓(xùn)練GDAN
4.用驗(yàn)證數(shù)據(jù)決定使用哪個(gè)GDAN? checkpoint(在配置文件yaml中的指定路徑為"ckpt_dir" )
運(yùn)行python valtest_gdan.py --config configs/cub.yml測(cè)試和訓(xùn)練
訓(xùn)練過(guò)程座享,先預(yù)訓(xùn)練CVAE婉商,然后每個(gè)epoch先用G和D對(duì)抗訓(xùn)練D,再用G和R生成的假數(shù)據(jù)通過(guò)D訓(xùn)練G和R