前面幾章學習了使用JAX進行計算機視覺、自然語言處理等方面深度學習任務吱型,可以看到基于JAX的深度學習框架能夠較好地完成這些常見任務。本章講學習一種較為特殊的網絡,生成對抗網絡(Generative Adverbial Network痴施,GAN)。
生成對抗網絡究流,是一種包含兩個網絡的神經網絡結辣吃,將一個網絡與另一個網絡相互對立(因此稱為“對抗”)。
從目前對GAN的研究和應用來看芬探,GAN的潛力巨大神得,因為它能學習和模仿任何數據分布,因此偷仿,GAB能被教導在任何領域創(chuàng)造類似于真實世界的東西哩簿,比如圖像、音樂酝静、演講节榜、散文等。在某種意義上别智,GAN可以被視為一個機器人藝術家全跨,它而“創(chuàng)作”出令人印象深刻,甚至打動人類的作品亿遂。
GAN工作原理
為了理解GAN浓若,需要知道GAN是如何工作的渺杉。實際上GAN的組成和工作原理非常簡單,
其中挪钓,
- 生成器是越。生成器學習如何生成看似合理的數據。對判別器來說碌上,這些生成的實例會變成負面訓練樣本倚评。
- 判別器。判別器學習如何通過真實數據的學習來辨別出生成器生成的假數據馏予。判別器將懲罰生成器生成的“不可理(假)”的數據結果天梧。
GAN訓練的整個步驟如下,
- 當訓練開始霞丧,生成器生成一個些很明顯的假數據呢岗,判別器能快速地學習如何識別出是不是假數據,
- 隨著訓練穩(wěn)步推進后豫,生成器更接近于能生成蒙騙判別器的輸出數據。
- 最終挫酿,如果生成器訓練得當,在識別真實和虛假方面愕难,判別器變得差強人意早龟,而且將開始把假數據分類為真實數據,識別的準確率降低猫缭。
GAN是一種生成式的對抗網絡。具體來說饵骨,就是通過對抗的方式去學習數據分布的生成式模型翘悉。所謂對抗茫打,指的是生成網絡和判別網絡的互相對抗居触。生成網絡盡可能生成逼真的樣本,判別網絡這盡可能地去判別該樣本是真實的樣本還是生成的假樣本老赤。
生成對抗網絡整個架構如下所示轮洋,
生成器和判別器兩者都是神經網絡。生成器的輸出直接連接到判別器的輸入抬旺。通過反向傳播弊予,判別器的分類結果給提供了一個信號讓生成器更新其權重。
生成器和判別器共同構成GAN
生成器(Generator)與判別器(Discriminator)共同構成一個GAN开财。再介紹GAN之前先對生成器和辨別器的作用做一個詳細的解釋汉柒。
判別器
對于判別器來說误褪,給它一幅畫,判別器中的判別算法能夠判別這幅畫是不是由真人完成的碾褂。畫的真假是給予判別器的標簽一致兽间,而這幅畫本身的向量特征就組成了輸入的特征向量。
把上述語句用數學形式表示出來正塌,標簽被定義成為y嘀略,特征向量被定義成x,那么判別器的判定公式為乓诽,
也就是在輸入的x特征向量的基礎上定義出y的概率帜羊。在這個判別器的例子中,輸入向量也就是畫的特征被定義成x鸠天,而判別器對畫的判定則是y讼育,即判別器對這幅畫判定真?zhèn)蔚母怕省R虼肆竿穑袆e算法將特征映射為概率窥淆,判別器只關心其中的特征是夠滿足概率生成的條件。
生成器
與判別器的做法正好相反巍杈,它不關心向量是什么形式和內容忧饭,它只關心標簽信息,嘗試由給定的標簽內容去生成特征筷畦。同樣以畫為例词裤,生成器需要考慮大的是,假定這個畫是由真實畫家完成的鳖宾,那么這個畫包含哪些畫家的特征信息吼砂,這些信息又是什么樣的,怎么將其展示出來讓“別人(判別器)”認為這幅畫是畫家本人的真跡鼎文。這和人類思考的過程相類似渔肩。判別器關心的是由x‘判斷出y,而生成器關心的是如何生成一個x去滿足y的判定拇惋。用公式表示如下周偎,
兩者區(qū)別總結如下,
- 判別器撑帖,學習不同類別和標簽之間的區(qū)分界限蓉坎。
- 生成器,學習標簽中的某一類的概率分布并進行建模胡嘿。
GAN是如何工作的
判別器如何工作
判別器在生成對抗網絡中蛉艾,簡單來說是一個分類器。該分類器嘗試從生成器生成的假數據中識別真實數據。它可以使用任何適用于數據分類的網絡架構勿侯。判別器在訓練中使用誤差反向傳播機制來計算損失和更新權重參數拓瞪,
從上圖可以看到,判別器的訓練數據有兩處來源助琐,
- 真實數據吴藻。真實數據實例,比如人的照片弓柱。在訓練中沟堡,判別器把這些實例用作正面樣本。
- 虛假數據矢空。生成器生成的實例航罗。在訓練中,判別器把這些實例用作負面樣本屁药。
上圖中兩個Sample的框就是這兩種輸入到判別器的樣本粥血。注意,在判別器訓練時酿箭,生成器不會訓練复亏,即在生成器為判別器生成示例數據時,生成器的權重保持恒定缭嫡。
在訓練判別器時缔御,判別器連接到兩個損失函數。在訓練時妇蛀,判別器忽略生成器的損失而只使用判別器損失耕突。在訓練過程中,
- 判別器對真實數據和來自生成器生成的假數據進行分類评架。
- 判別器的損失函數將懲罰由判別器產生的誤判眷茁,比如把真實實例判定成假,或者把假的實例判定為真纵诞。
- 判別器通過對來自于判別器網損失函數計算的損失進行反向傳播上祈。如上圖。
下面介紹為什么生成器的損失函數直接連接到判別器浙芙。
生成器如何工作
生成對抗網絡里的生成器登刺,通過接受來自于判別器的反饋來學習如何創(chuàng)建假數據。生成器學習如何讓(欺騙)判別器把它的輸出歸類為真實數據茁裙。
相對于判別器的訓練塘砸,生成器的訓練要求生成器與判別器有更加緊密的集成节仿。生成器訓練包含晤锥,
隨機輸入。神經網絡需要某種形式的輸入。通常矾瘾,為了達到某種目的而輸入數據女轿,比如一個輸入的實例用來進行分類任務或者預測。但當希望輸出一整個全新的數據實例壕翩,用什么樣輸入數據呢蛉迹?
最常見基礎形式里,GAN使用隨機噪音作為它的輸入放妈。然后北救,生成器將把隨機噪音轉換成有意義的輸出。通過引入噪音芜抒,可以從不同分布形式的不同空間采樣珍策,讓GAN生成一個寬域的數據,
實驗結果表明宅倒,不同噪音的分布不會產生太大影響攘宙。因此,可以選擇相對較易的采樣來源拐迁,比如蹭劈,均勻分布。方便起見线召,噪音采樣空間的維度一般小于輸出空間的維度铺韧。
注意,有些GAN變種不使用隨機輸入來形成輸出缓淹。
生成器網絡祟蚀,負責把隨機輸入轉換成數據實例。
判別器網絡割卖,負責把上一步生成的數據歸類前酿。
判別器輸出。
生成器的損失函數鹏溯,負責懲罰企圖蒙騙判別器失敗的情況(即生成器生成的假數據罢维,被識別器成功識破)。
使用判別器訓練生成器
要訓練神經網絡丙挽,通過修改網絡的權重來減少誤差或者輸出的損失肺孵。在GAN里卻不同,生成器不直接連接到損失函數來試圖影響損失颜阐,而是把生成的數據輸出到判別器平窘,而判別器會制造影響誤差損失的輸出。當生成器生成的數據被判別器成功識別成仿冒時凳怨,生成器損失函數會懲罰生成器瑰艘。
另外是鬼,反向傳播里也包含網絡的額外處理。反向傳播通過計算對輸出的影響——更改后的權重在多大程度上影響輸紫新,來調整每個權重以使其在正確的方向上均蜜。但,生成器權重的影響取決于直接輸出到判別器的權重的影響芒率。因此囤耳,反向傳播始于輸出且穿過判別器回流到生成器。
在生成器訓練時偶芍,不希望判別器更改充择,就像嘗試擊中一個移動目標,會讓一個本身就麻煩的問題變得更加困難匪蟀。所以聪铺,在訓練生成器時使用如下流程,
- 隨機噪音采樣作為輸入萄窜。
- 生成器從采樣的隨機噪音采樣里生成輸出铃剔。
- 讓判斷器判斷上述輸出是“真”或“假”,以此作為生成器的輸出查刻。
- 從判別器的分類輸出計算誤差損失键兜。
- 穿過判別器和生成器的反向傳播,從而獲得梯度穗泵。
- 使用梯度來更新生成器的權重普气。
這個流程是生成器訓練的一個迭代。下面會玩轉整個生成器和判別器佃延。
訓練GAN
因為GAN包含兩個單獨經訓練的網絡现诀,GAN的訓練算法必須解決兩個難題,
- GAN必須能勝任來個不同的訓練(生成器和判別器)履肃。
- GAN的趨同難以識別仔沿。
交替訓練
生成器和判別器有不同的訓練流程,那么如何才能作為一個整體來訓練GAN呢尺棋?GAN的訓練有交替階段封锉,
- 判別器訓練一個或者多個迭代。
- 生成器訓練一個或者多個迭代膘螟。
- 不斷重復1和2步來訓練生成器和判別器成福。
在判別器訓練的階段,保持生成器不變荆残。因為判別器訓練會嘗試從仿冒數據里分辨出真實數據奴艾,判別器必須學習如何識別生成器的缺陷。這就是經過完整訓練的生成器和只能生成隨機輸出的未訓練生成器的不同之處内斯。
類似地蕴潦,在生成器訓練的階段像啼,保持判別器不變。否則品擎,生成器像嘗試擊中移動目標一樣,可能永遠無法收斂备徐。
這種往復訓練使得GAN能夠處理另外一些棘手的數據生成問題萄传。開始于相對較簡單的分類任務問題,從而獲得一個解決生成難題的立足點蜜猾。相反地秀菱,如果不能訓練一個分類器來識別真實數據和生成數據的區(qū)別,甚至無法識別與隨機初始化輸出的區(qū)別蹭睡,那么GAN的訓練根本無法開始衍菱。
收斂
隨著訓練進行,生成器不斷改善肩豁,相對地脊串,由于不能再輕易地識別出真實數據和假冒數據的,判別器的表現(xiàn)越來越差清钥。當生成器完美地成功生成時琼锋,判別器的正確率只有50%∷钫眩基本上和拋一枚硬幣來預測正反一樣的概率一樣缕坎。
這種進度展現(xiàn)了作為整體GAN的一個問題,判別器的反饋隨著時間的推移越來越不具有意義篡悟。過了這個節(jié)點之后谜叹,即判別器完全給出了隨機反饋,如果繼續(xù)訓練GAN搬葬,那么生成器將使用判別器給出的無效反饋進行訓練荷腊,那么生成器的質量可能會崩塌。
對于GAN來說急凰,收斂往往是一個閃現(xiàn)的點停局,而不是牢固的、穩(wěn)態(tài)的香府。
如何理解生成對抗網絡
簡單來來說董栽,GAN的功罪哦原理就是使用生成器去生成新的惡具有一定特征的向量內容,并且將生成的向量內容輸出入到判別器中去對去進行驗證企孩,評估這些向量內容為真或者假的概率锭碳。
比如假鈔制造,以及假鈔的識別勿璃,如下圖擒抛,
另外推汽,信用結構,比如銀行的業(yè)務中歧沪,手寫字作為交易的依據是最常見的一種存根方式歹撒,而往往有人就是通過模仿別人的手寫數字進行詐騙,特別是在銀行領域诊胞,毛領支票的事件層出不窮暖夭,
在上面兩種場景中,生成器的作用就是根據標簽的類別進行特征生成撵孤,最終生成具有真實特征迈着,比如紙幣或手寫特征的一系列圖片,即向量數據邪码,而判別器的目標就是當其被展示一個紙幣或者手寫字是能識別出真或者假裕菠。
在這個過程中,GAN所采取的步驟如下闭专,
- 生成器使用隨機數生成一幅圖奴潘。
- 這幅圖和真實數據集的圖片流一起被送到判別器。
- 判別器接收真是的和仿冒的圖片影钉,然后返回概率萤彩。
可以把GAN想象成貓鼠游戲里的偽造者和警察的角色,偽造者不斷學習假冒票據斧拍,警察在學習如何檢測它們雀扶。雙方都是動態(tài)的,警察也在訓練肆汹,并且雙方在不斷升級中學習對方的方法愚墓。
需要強調的是,在這個過程中生成器和判別器是一個循環(huán)過程昂勉,隨著生成器和判別器能力的提升浪册,其對應的生成和判別能力也越來越強。這樣實際上就構成了一個反饋鏈:
- 判別器和圖片標簽構成一個反饋岗照。
- 生成器和判別器構成一個反饋村象。
結論
本章介紹了生成對抗網絡的原理和工作機制,特別是生成器和判別器的原理及訓練過程攒至。