簡(jiǎn)介
生成對(duì)抗網(wǎng)絡(luò)(以下簡(jiǎn)稱(chēng)GAN)是通過(guò)讓兩個(gè)神經(jīng)網(wǎng)絡(luò)相互博弈的方式進(jìn)行學(xué)習(xí)沿量,可以根據(jù)原有的數(shù)據(jù)集生成以假亂真的新的數(shù)據(jù)乌助,舉個(gè)不是很恰當(dāng)?shù)睦釉搪拢?lèi)似于造假鞋,莆田藝術(shù)家通過(guò)觀(guān)察真鞋乳讥,模仿真鞋的特點(diǎn)造出假鞋并賣(mài)給消費(fèi)者,消費(fèi)者收到鞋子后將它與網(wǎng)上的真鞋信息進(jìn)行對(duì)比找瑕疵廓俭,并給出反饋云石,比如標(biāo)不正,氣墊彈性不好研乒,莆田藝術(shù)家根據(jù)消費(fèi)者給出的反饋積極地改進(jìn)工藝汹忠,經(jīng)過(guò)不懈努力后最終造出了可以忽悠消費(fèi)者的假鞋。
在上述情景中雹熬,莆田藝術(shù)家相當(dāng)于生成器宽菜,消費(fèi)者相當(dāng)于辨別器,在造假的過(guò)程中竿报,生成器和判別器一直處于對(duì)抗?fàn)顟B(tài)铅乡。
我們把上述情景抽象為神經(jīng)網(wǎng)絡(luò)。首先仰楚,通過(guò)對(duì)生成器輸入一個(gè)分布的數(shù)據(jù)隆判,生成器通過(guò)神經(jīng)網(wǎng)絡(luò)模仿生成出一個(gè)輸出(假鞋),將假鞋與真鞋的信息共同輸入到判別器中僧界。然后侨嘀,判別器通過(guò)神經(jīng)網(wǎng)絡(luò)學(xué)著分辨兩者的差異,做一個(gè)分類(lèi)判斷出這雙鞋是真鞋還是假鞋捂襟。
這樣咬腕,生成器不斷訓(xùn)練為了以假亂真,判別器不斷訓(xùn)練為了區(qū)分二者葬荷。最終涨共,生成器真能完全模擬出與真實(shí)的數(shù)據(jù)一模一樣的輸出纽帖,判別器已經(jīng)無(wú)力判斷【俜矗基于伊恩·古德費(fèi)洛最早對(duì) GAN 的定義懊直,GAN 實(shí)際上是在完成這樣一個(gè)優(yōu)化任務(wù):
式中, 表示生成器火鼻; 表示判別器室囊; 是定義的價(jià)值函數(shù),代表判別器的判別性能魁索,該數(shù)值越大性能越好融撞; 表示真實(shí)的數(shù)據(jù)分布; 表示生成器的輸入數(shù)據(jù)分布粗蔚; 表示期望尝偎。
第一項(xiàng) 是依據(jù)真實(shí)數(shù)據(jù)的對(duì)數(shù)函數(shù)損失而構(gòu)建的。具體可以理解為鹏控,最理想的情況是致扯,判別器 能夠?qū)谡鎸?shí)數(shù)據(jù)的分布數(shù)據(jù)給出 1
的判斷。所以牧挣,通過(guò)優(yōu)化 最大化這一項(xiàng)可以使 急前。其中, 服從 分布瀑构。
第二項(xiàng)裆针,,是相對(duì)生成器的生成數(shù)據(jù)而言的寺晌。我們希望世吨,當(dāng)喂給判別器的數(shù)據(jù)是生成器的生成數(shù)據(jù)時(shí),判別器能輸出 0
呻征。由于 的輸出是耘婚,輸入數(shù)據(jù)是真實(shí)數(shù)據(jù)的概率,那么 是陆赋,輸入數(shù)據(jù)是生成器生成數(shù)據(jù)的概率沐祷,通過(guò)優(yōu)化 最大化這一項(xiàng),則可以使 攒岛。其中赖临, 服從 ,也就是生成器的生成數(shù)據(jù)分布灾锯。
生成器與判別器是對(duì)抗的關(guān)系兢榨,價(jià)值函數(shù)代表了判別器的判別性能。那么,通過(guò)優(yōu)化 能夠在第二項(xiàng) 上迷惑判別器吵聪,讓判別器對(duì)于 這個(gè)輸入凌那,盡可能地得到 。本質(zhì)上吟逝,生成器就是在最小化這一項(xiàng)帽蝶,也就是在最小化價(jià)值函數(shù)。
散度
為了界定兩個(gè)數(shù)據(jù)分布澎办,也就是真實(shí)數(shù)據(jù)和生成器生成數(shù)據(jù)之間的差異嘲碱,需要引入 散度。
散度具有非負(fù)性局蚀。
當(dāng)且僅當(dāng) , 在離散型變量下是相同的分布時(shí)恕稠,即 琅绅,。
散度衡量了兩個(gè)分布差異的程度鹅巍,經(jīng)常被視為兩種分布間的距離千扶。
要注意的是,骆捧,即 散度沒(méi)有對(duì)稱(chēng)性澎羞。
最優(yōu)判別器
將價(jià)值函數(shù)里的生成器固定不動(dòng),將期望寫(xiě)成積分的形式有:
整個(gè)式子中敛苇,只有一個(gè)變量 妆绞。次數(shù),對(duì)被積函數(shù)枫攀,令 括饶,,来涨,图焰, 均為常數(shù)。那么蹦掐,被積函數(shù)變?yōu)椋?br>
為了找到最優(yōu)值 技羔,需要對(duì)上式求一階導(dǎo)數(shù)。而且卧抗,在 的情況下有:
驗(yàn)證 的二階導(dǎo)數(shù) ,則 這個(gè)點(diǎn)為極大值颗味,這個(gè)事實(shí)給出了最優(yōu)判別器的存在可能性超陆。
盡管在實(shí)踐中我們并不知道 ,也就是真實(shí)的數(shù)據(jù)的分布。但我們?cè)诶蒙疃葘W(xué)習(xí)訓(xùn)練判別器時(shí)时呀,可以讓 向這個(gè)目標(biāo)逐漸逼近张漂。
最優(yōu)生成器
若最優(yōu)的判別器為:
我們將其代入 ,此時(shí)價(jià)值函數(shù)里只有 這一個(gè)變量:
此時(shí)谨娜,通過(guò)變換航攒,我們可以得到下面的式子:
這個(gè)變換比較復(fù)雜,大家可以檢驗(yàn)步與步之間的恒等性判斷趴梢。根據(jù)對(duì)數(shù)的一些基本變換漠畜,可以得到:
最終得到:
因?yàn)? 散度的非負(fù)性,那么就可以知道 就是 的最小值坞靶,而且最小值是在當(dāng)且僅當(dāng) 時(shí)取得憔狞。這其實(shí)就是真實(shí)數(shù)據(jù)分布等于生成器的生成數(shù)據(jù)分布,可以從數(shù)學(xué)理論上證明了它的存在性和唯一性彰阴。
GAN的實(shí)現(xiàn)過(guò)程
生成器的輸入:即上面的 瘾敢,我們當(dāng)然不能讓這個(gè)分布任意化,一般會(huì)設(shè)為常見(jiàn)的分布類(lèi)型尿这,如高斯分布簇抵、均勻分布等等,然后生成器基于這個(gè)分布產(chǎn)生的數(shù)據(jù)生成自己的偽造數(shù)據(jù)來(lái)迷惑判別器射众。
期望如何模擬:實(shí)踐中碟摆,我們是沒(méi)有辦法利用積分求數(shù)學(xué)期望的,所以一般只能從無(wú)窮的真實(shí)數(shù)據(jù)和無(wú)窮的生成器中采樣以逼近真實(shí)的數(shù)學(xué)期望叨橱。
近似價(jià)值函數(shù):若給定生成器 典蜕,并希望計(jì)算 以求得判別器 。那么雏逾,首先需要從真實(shí)的數(shù)據(jù)分布 中采樣 個(gè)樣本 {}嘉裤。并從生成器的輸入,即 中采樣 個(gè)樣本 {}栖博。因此屑宠,最大化價(jià)值函數(shù) 就可以使用以下表達(dá)式近似替代:
可以把 GAN 的訓(xùn)練過(guò)程總結(jié)為:
- 從真實(shí)數(shù)據(jù) 采樣 個(gè)樣本 {};
- 從生成器的輸入仇让,即噪聲數(shù)據(jù) 采樣 個(gè)樣本 {}典奉;
- 將噪聲樣本 {} 投入到生成器中生成{};
- 通過(guò)梯度上升的方法丧叽,極大化價(jià)值函數(shù)卫玖,更新判別器的參數(shù);
- 從生成器的輸入踊淳,即噪聲數(shù)據(jù) 另外采樣 個(gè)樣本{};
- 將噪聲樣本 {} 投入到生成器中生成 {};
- 通過(guò)梯度下降的方法假瞬,極小化價(jià)值函數(shù)陕靠,更新生成器的參數(shù)。
利用PyTorch搭建GAN生成手寫(xiě)識(shí)別數(shù)據(jù)
安裝GPU版本PyTorch
-
打開(kāi)終端脱茉,在conda 配置中添加清華源
-
編輯~/.condarc剪芥,將- defaults整行刪除
- 安裝PyTouch GPU版本
使用conda安裝,不用自己額外配置依賴(lài)包和版本兼容問(wèn)題琴许,conda會(huì)自動(dòng)配置好税肪,而且可以直接在jupyter中調(diào)用,非常方便榜田。
一般需要等待很長(zhǎng)時(shí)間益兄,而且會(huì)經(jīng)常中斷,中斷直接再重復(fù)運(yùn)行安裝命令即可箭券,會(huì)繼續(xù)安裝之前沒(méi)裝上的
得益于國(guó)內(nèi)無(wú)與倫比的網(wǎng)絡(luò)環(huán)境净捅,100Mb的寬帶完全失靈,下載了大概一個(gè)小時(shí)辩块,中途中斷了三四次灸叼,終于裝好了!!我感覺(jué)天快亮了... ...
訓(xùn)練GAN
為了方便可視化庆捺,直接用jupyter notebook
-
首先,導(dǎo)入需要用的模塊
-
下載并解壓mnist數(shù)據(jù)集
transform
函數(shù)允許我們把導(dǎo)入的數(shù)據(jù)集按照一定規(guī)則改變結(jié)構(gòu)屁魏,我們?cè)谶@里引入了Normalize
將會(huì)把Tensor
正則化滔以。即:Normalized_image=(image-mean)/std
。這樣做的目的是便于后續(xù)的訓(xùn)練氓拼。
-
接下來(lái)你画,搭建深度學(xué)習(xí)模型,用于構(gòu)建判別器和生成器桃漾。這里通過(guò)引入
nn.Module
基類(lèi)的方法來(lái)搭建
判別器構(gòu)建過(guò)程坏匪,遵照 PyTorch 的 Sequential 網(wǎng)絡(luò)搭建法。我們用4
層網(wǎng)絡(luò)結(jié)構(gòu)撬统,并把每層都使用全連接配上LeakyReLU
激活再帶上dropout
防止過(guò)擬合适滓。最后一層,用sigmoid
保證輸出值是一個(gè)0
到1
之間的概率值恋追。設(shè)計(jì)前饋過(guò)程函數(shù)時(shí)凭迹,注意把每個(gè)樣本大小 的輸入矩陣先轉(zhuǎn)換為784
的向量用于全連接。
-
接下來(lái)構(gòu)建生成器苦囱。本模型中的設(shè)定生成器的每個(gè)輸入樣本是大小為
100
的向量嗅绸,通過(guò)全連接層配上LeakyReLU
激活搭建,最后一層用tanh
激活撕彤,且保證每個(gè)樣本輸出是一個(gè)784
的向量鱼鸠。
-
接下來(lái)實(shí)例化生成器與判別器,設(shè)定學(xué)習(xí)率和損失函數(shù)。價(jià)值函數(shù)按照定義是:
PyTorch 中蚀狰,BCELoss 表示二項(xiàng) Cross Entropy愉昆,它的展開(kāi)形式是:
其中y
是label
,x
是輸出造锅。那么撼唾,對(duì)于0
和1
這兩種label
而言,當(dāng) 哥蔚,上式第一項(xiàng)不存在倒谷,就剩下 的第二項(xiàng)。當(dāng) 糙箍,上式第二項(xiàng)不存在渤愁,就剩下 的第一項(xiàng)。那么 BCELoss 的結(jié)構(gòu)就與損失函數(shù) 相同深夯,只不過(guò)我們定義的損失函數(shù)有對(duì)真實(shí)數(shù)據(jù)與對(duì)生成器生成的數(shù)據(jù)兩種情況的輸出抖格。
-
接下來(lái),就可以定義如何訓(xùn)練判別器了咕晋。值得注意的是雹拄,這里需要設(shè)置
zero_grad()
來(lái)消除之前的梯度,以免造成梯度疊加掌呜。此外滓玖,我們通過(guò)將真實(shí)數(shù)據(jù)的損失和偽造數(shù)據(jù)的損失兩部分相加,作為最終的損失函數(shù)质蕉。然后势篡,通過(guò)后向傳播,用之前的判定器優(yōu)化器優(yōu)化模暗,通過(guò)降低 BCELoss 來(lái)增大價(jià)值函數(shù)的值禁悠。
-
同樣,接下來(lái)需要定義生成器的訓(xùn)練方法兑宇。注意碍侦,這里的
real_labels
在之后將設(shè)為1
。因?yàn)閷?duì)于所有的生成器輸出顾孽,我們希望它向真實(shí)的數(shù)據(jù)分布學(xué)習(xí)祝钢,那么BCELoss
此時(shí)為 。最終若厚,我們希望判別器的輸出 接近于1
拦英,即判別器判斷該數(shù)據(jù)為真實(shí)數(shù)據(jù)的概率越大。所以测秸,這里依舊是在減少BCELoss
疤估,則直接調(diào)用criterion
就可以設(shè)定好生成器的損失函數(shù)灾常。
100
大小的向量,這里就將生成器的輸入產(chǎn)生一個(gè)100
大小铃拇,且服從標(biāo)準(zhǔn)正態(tài)分布的向量钞瀑。
-
一切準(zhǔn)備就緒,開(kāi)始 GAN 的訓(xùn)練慷荔。
以下是剛開(kāi)始產(chǎn)生的圖片
GAN的改進(jìn)
相比起卷積神經(jīng)網(wǎng)絡(luò)之于計(jì)算機(jī)視覺(jué)贷岸,循環(huán)神經(jīng)網(wǎng)絡(luò)之于自然語(yǔ)言處理,GAN 尚且沒(méi)有一個(gè)特別適合的應(yīng)用場(chǎng)景磷雇。主要原因是 GAN 目前還存在諸多問(wèn)題偿警。例如:
- 不收斂問(wèn)題:GAN 是兩個(gè)神經(jīng)網(wǎng)絡(luò)之間的博弈。試想唯笙,如果判別器提前學(xué)到了非常強(qiáng)的螟蒸,那么生成器很容易出現(xiàn)梯度消失而無(wú)法繼續(xù)學(xué)習(xí)。所有 GAN 的收斂性一直是個(gè)問(wèn)題崩掘,這樣也導(dǎo)致 GAN 在實(shí)際搭建過(guò)程中對(duì)各種超參數(shù)都非常敏感七嫌,需要精心設(shè)計(jì)才能完成一次訓(xùn)練任務(wù);
- 崩潰問(wèn)題:GAN 模型被定義為一個(gè)極小極大問(wèn)題苞慢,可以說(shuō)抄瑟,GAN 沒(méi)有一個(gè)清晰的目標(biāo)函數(shù)。這樣會(huì)非常容易導(dǎo)致枉疼,生成器在學(xué)習(xí)的過(guò)程中開(kāi)始退化,總是生成相同的樣本點(diǎn)鞋拟,而這也進(jìn)一步導(dǎo)致判別器總是被喂給相同的樣本點(diǎn)而無(wú)法繼續(xù)學(xué)習(xí)骂维,整個(gè)模型崩潰;
- 模型過(guò)于自由: 理論上贺纲,我們希望 GAN 能夠模擬出任意的真實(shí)數(shù)據(jù)分布航闺,但事實(shí)上,由于我們沒(méi)有對(duì)模型進(jìn)行事先建模猴誊,再加上「真實(shí)分布與生成分布的樣本空間并不完全重合」是一個(gè)極大概率事件潦刃。那么,對(duì)于較大的圖片懈叹,如果像素一旦過(guò)多乖杠,GAN 就會(huì)變得越來(lái)越不可控,訓(xùn)練難度非常大澄成。