前一陣子铆遭,偶然看到一個(gè)換臉的視頻,覺(jué)得實(shí)在是很神奇沿猜,于是饒有興致的去了解一下?lián)Q臉?biāo)惴度佟T瓉?lái)背后有一個(gè)極為有意思的算法思想——對(duì)抗生成。今天筆者斗膽來(lái)介紹一下在學(xué)術(shù)界大名鼎鼎的GAN(Generative Adversarial Networks ),此網(wǎng)絡(luò)結(jié)構(gòu)由Ian J. Goodfellow大神在2014年提出啼肩,一經(jīng)推出橄妆,就引爆了學(xué)術(shù)界。
隨后各種各樣的GAN算法以指數(shù)級(jí)增長(zhǎng)的方式涌現(xiàn)出來(lái)祈坠,比如WGAN(Wasserstein GAN)呼畸,CGAN(condition gan),SRGAN(super resolution gan)等颁虐。據(jù)說(shuō)后來(lái)提出的GAN在取名字簡(jiǎn)稱的時(shí)候——XXGAN蛮原,其中GAN的前面的XX,26個(gè)英文字母兩兩排列組合都快不夠用了另绩,這足以見(jiàn)得這個(gè)算法最近幾年的熱度儒陨。而GAN也有很多應(yīng)用場(chǎng)景:
- 高清圖片生成。
- 消除馬賽克笋籽。
- 側(cè)臉轉(zhuǎn)正等等蹦漠。
由于筆者只在稍微了解過(guò)圖像領(lǐng)域的GAN算法,所以只能說(shuō)出以上具體的應(yīng)用場(chǎng)景车海。不過(guò)據(jù)了解在自然語(yǔ)言處理領(lǐng)域GAN也可用了訓(xùn)練聊天機(jī)器人(chatbot)笛园。總之筆者感覺(jué)GAN這個(gè)算法如果用對(duì)了地方侍芝,還是能夠發(fā)揮出它的潛力的研铆。
GAN算法簡(jiǎn)介
GAN的結(jié)構(gòu)
首先我們簡(jiǎn)單了解一下最原始的GAN網(wǎng)絡(luò)結(jié)構(gòu),如下圖,主要只看分為黃色長(zhǎng)方形和粉紅色長(zhǎng)方形州叠,這兩部分為Network的部分:
- 一個(gè)生成器(粉紅的generator)棵红,
- 一個(gè)判別器(黃色的discriminator),
接下來(lái)注意了咧栗,我們仔細(xì)研究下這兩個(gè)網(wǎng)絡(luò)的輸入和輸出逆甜,同時(shí)了解一下這兩個(gè)網(wǎng)絡(luò)的關(guān)系:
- 生成器的輸入是隨機(jī)生成的噪聲向量,輸出是一張圖片(2維或者3維矩陣)
- 判別器的輸入是真實(shí)的圖片(2維或者3維矩陣)和生成器生成的圖片(2維或者3維矩陣)致板,輸出是0或者1交煞。
按照農(nóng)場(chǎng)文的普遍的講法,整個(gè)GAN做的事情就是類似于假畫(huà)師和鑒畫(huà)師之間的博弈斟或,是不是現(xiàn)在有點(diǎn)對(duì)抗(Adversarial )的意思了素征,其整個(gè)過(guò)程分為以下兩部分: - 生成器的訓(xùn)練(假畫(huà)師提高自己畫(huà)假畫(huà)的水平):生成的圖片能夠以假亂真欺騙判別器。
- 判別器的訓(xùn)練(鑒畫(huà)師提高自己鑒別假畫(huà)的水平):能夠鑒別出生成器生成的假圖片。
最后結(jié)果可想而之稚茅,這兩方在互相博弈之間,都得到了極大的提升平斩。判別器鑒別能力越來(lái)越強(qiáng)亚享,而生成器生成的圖片越來(lái)越像真的。最終我們拿到訓(xùn)練好的生成器绘面,隨機(jī)輸入一個(gè)噪聲向量給它欺税,它也能輸出一張以假亂真的圖片。
GAN的原理
筆者在這里不想講太多的原理部分揭璃,大家感興趣的可以去訪問(wèn)我的參考文獻(xiàn)部分晚凿,其中臺(tái)灣大學(xué)的李宏毅老師視頻和蘇劍林大神的博客中,將GAN講得很通俗易懂瘦馍。在數(shù)學(xué)理論方面筆者只強(qiáng)調(diào)一句話:GAN的訓(xùn)練目的是希望生成器生成的數(shù)據(jù)分布和真實(shí)數(shù)據(jù)里的分布越像越好歼秽,如下圖所示。
這里特別推薦一下李宏毅老師的課程情组,不需要你懂太高深的數(shù)學(xué)燥筷,也可以了解GAN和WGAN原理部分的精髓。
DCGAN的實(shí)戰(zhàn)部分
DCGAN
實(shí)戰(zhàn)部分筆者采用的是DCGAN院崇,這個(gè)DCGAN架構(gòu)規(guī)定一些搭建GAN網(wǎng)絡(luò)時(shí)的規(guī)則:
- 在生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)上必須使用批處理規(guī)范化肆氓。
- 對(duì)于更深的架構(gòu)移除全連接隱藏層。
- 在生成網(wǎng)絡(luò)的所有層上必須使用ReLU激活函數(shù)底瓣,除了輸出層使用Tanh激活函數(shù)谢揪。
- 在判別網(wǎng)絡(luò)的所有層上必須使用LeakyReLU激活函數(shù)。
這些規(guī)定主要是為了使優(yōu)化效果變得更好捐凭,沒(méi)有特別好的數(shù)學(xué)解釋拨扶。既然DCGAN做圖像生成效果好,那我們用起來(lái)吧茁肠。
載入數(shù)據(jù)
數(shù)據(jù)載入部分其實(shí)很簡(jiǎn)單屈雄,就是一堆開(kāi)通人物人臉的圖片独令,沒(méi)有l(wèi)abel缎谷。從下方代碼中可以看到,筆者本次實(shí)驗(yàn)圖片是528張shape為(96, 96, 3)的圖片屑埋,所以生成器需要生成的圖片矩陣維度就必須是(96奶赔,96惋嚎,3)。
import numpy as np
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import keras.backend as K
import matplotlib.pyplot as plt
import sys
from PIL import Image
import os
pic_list = os.listdir('./anime-faces/1boy')
pic_arr_list = []
###read anime-faces data from folder
for i in range(len(pic_list)):
t = Image.open("./anime-faces/1boy/{}".format(pic_list[I]))
t = np.array(t)
t = t/127.5 - 1
pic_arr_list.append(t)
### convert the picture data to array
train_data = np.array(pic_arr_list)
train_data.shape#(528, 96, 96, 3)
定義generator
注意DCGAN中生成網(wǎng)絡(luò)的所有層使用ReLU的激活函數(shù)站刑,除了輸出層使用Tanh激活函數(shù)另伍。而且每層不要忘了加BN層(作用是批處理規(guī)范化)。同時(shí)定義好輸入隨機(jī)噪聲的維度,這里筆者定義的是100維摆尝。輸出維度就是生成圖片的維度(96堕汞,96,3)讯检。
def build_generator():
model = Sequential()
model.add(Dense(128 * 24 * 24, activation="relu", input_dim=100))
model.add(Reshape((24, 24, 128)))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation("relu"))
model.add(Conv2D(3, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(100,))
img = model(noise)
return Model(noise, img)
從上方的代碼和下方的生成器的結(jié)構(gòu)可視化中可以看出人灼,整個(gè)網(wǎng)絡(luò)的結(jié)構(gòu),以及輸入維度和輸出維度段磨。
定義discriminator
在DCGAN中定義discriminator時(shí)债蜜,判別網(wǎng)絡(luò)的所有層使用LeakyReLU的激活函數(shù)精耐。而且每層也需要加BN層進(jìn)行批處理規(guī)范化處理向胡。而判別器的輸入時(shí)一張(96小槐,96,3)的矩陣骡显,輸出則是0或者1遭殉。
def build_discriminator():
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=(96,96,3), padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
依然從上方代碼和下方模型可視化輸出中可以清晰看到蛔糯,判別器的網(wǎng)絡(luò)結(jié)構(gòu)淮逻,以及輸入筛严,輸出檬输。總之判別器就是為了判斷輸入圖片是真實(shí)圖片還是生成圖片昼浦。
聯(lián)系生成器和判別器
這部分就是定義GAN的最關(guān)鍵部分关噪,我們需要讓生成器和判別器聯(lián)系起來(lái)鸟蟹。下面部分代碼有兩點(diǎn)需注意:
- 將生成器生成的圖片輸入給判別器建钥,
- 此時(shí)判別器不做訓(xùn)練镐依,只訓(xùn)練生成器灼卢。
到這一步有的同學(xué)就會(huì)問(wèn)了鞋真,為啥不訓(xùn)練判別器呢崇堰?別急,對(duì)抗的過(guò)程(判別器訓(xùn)練一步檩互,生成器訓(xùn)練一步)在GAN的訓(xùn)練中才會(huì)體現(xiàn)出來(lái)特幔。
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
generator = build_generator()
z = Input(shape=(100,))
#feed the random noise to the generator
img = generator(z)
# For the combined model we will only train the generator
discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
valid_g = discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, valid_g)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)
combined.summary()
訓(xùn)練GAN
接下來(lái)接可以開(kāi)始DCGAN的訓(xùn)練了,這里代碼的含義是先訓(xùn)練一步判別器闸昨,在訓(xùn)練一步生成器蚯斯,二者互相博弈薄风,互相進(jìn)步。
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random half of images
idx = np.random.randint(0, train_data.shape[0], batch_size)
imgs = train_data[idx]
# Sample noise and generate a batch of new images
noise = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(noise)
# Train the discriminator (real classified as ones and generated as zeros)
d_loss_real = discriminator.train_on_batch(imgs, valid_d)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
# Train the generator (wants discriminator to mistake images as real)
g_loss = combined.train_on_batch(noise, valid_d)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
r, c = 5, 5
noise_save = np.random.normal(0, 1, (r * c, 100))
gen_imgs = generator.predict(noise_save)
# gen_imgs = 0.5* gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], )
axs[i,j].axis('off')
cnt += 1
fig.savefig("/mnt/disk2/data/wp/test/images/boy_%d.png" % epoch)
plt.close()
從下圖模型輸出的損失函數(shù)中我們可以看到判別器和生成器博弈的過(guò)程拍嵌,判別器的D loss先下降后上升(當(dāng)然筆者的實(shí)驗(yàn)不是很明顯)遭赂,代表著判別器本先變強(qiáng)導(dǎo)致D loss下降,之后生成器開(kāi)始發(fā)力横辆,生成質(zhì)量更好的圖片撇他,使得D loss上升。判別器的D loss應(yīng)該是一個(gè)跌宕起伏的曲線狈蚤。這里筆者的D loss 很小困肩,說(shuō)明判別器太強(qiáng)大了,其實(shí)在GAN訓(xùn)練的過(guò)程中脆侮,任意一方太強(qiáng)锌畸,都會(huì)導(dǎo)致模型訓(xùn)練效果不好,比較GAN是個(gè)相互進(jìn)步他嚷,相互促進(jìn)的過(guò)程蹋绽,任何一方太強(qiáng)都會(huì)導(dǎo)致大家無(wú)法進(jìn)步芭毙。
筆者在這里輸出了模型跑了500個(gè)epoch和5000個(gè)epoch之后生成器生成的圖像效果對(duì)比筋蓖。
從兩張圖的生成效果上來(lái)說(shuō),5000個(gè)epoch時(shí)退敦,生成器生成的圖片質(zhì)量更好一些粘咖,已經(jīng)能夠可看出卡通人物臉清晰的輪廓曲線了。
使用generator生成圖片
pic = generator.predict(np.random.normal(0, 1, (1, 100)))
plt.imshow(np.squeeze(pic[0]))
最終筆者在模型訓(xùn)練好之后侈百,運(yùn)行上方代碼瓮下,給生成器隨機(jī)輸入一個(gè)100維的向量,生成下方那個(gè)綠頭發(fā)的卡通人臉钝域》砘担看起來(lái)效果還不錯(cuò)耶。是不是很神奇
結(jié)語(yǔ)
GAN確實(shí)是個(gè)很有趣的結(jié)構(gòu)例证,對(duì)抗生成的思想很像我們?nèi)祟惿鐣?huì)中的棋逢對(duì)手的情況路呜。在足球界梅西和C羅,正是因?yàn)閷?duì)方的存在而促使對(duì)方努力织咧,互相進(jìn)步胀葱,形成絕代雙驕的局面,在金庸的武俠世界老頑童周伯通發(fā)明來(lái)雙手互搏來(lái)提升功力笙蒙,而GAN正是使用這種對(duì)抗的方式學(xué)習(xí)進(jìn)步抵屿。在人類世界獨(dú)孤求敗有時(shí)候也是很悲哀的一種情景,這也暗合了在train GAN時(shí)一定要保持判別器和生成器實(shí)力相當(dāng)捅位,不然你trian出來(lái)的GAN肯定很糟轧葛。
參考
https://spaces.ac.cn/archives/6240
http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS18.html
https://arxiv.org/pdf/1511.06434.pdf
https://arxiv.org/abs/1406.2661
https://github.com/eriklindernoren/Keras-GAN