在開始之前請(qǐng)先了解GAN的原理聋迎,有很多博主講的都很好,在這里我就不再過多講解,視頻推薦臺(tái)大李宏毅老師的課程救恨。
GAN共包含兩個(gè)主要結(jié)構(gòu)generator和discriminator缀匕。generator負(fù)責(zé)生成假的數(shù)據(jù)來“欺騙”discriminator纳决,discriminator負(fù)責(zé)判斷輸入的數(shù)據(jù)是否為generator生成的,二者互相迭代乡小,最終實(shí)現(xiàn)generator生成能以假亂真的數(shù)據(jù)阔加。以下以Mnist數(shù)據(jù)集為例,使用GAN來產(chǎn)生手寫數(shù)字满钟。
構(gòu)建網(wǎng)絡(luò)模型
1.generator
神經(jīng)網(wǎng)絡(luò)模型有輸出就有輸入胜榔,我們要想得到假的生成數(shù)據(jù),就要給模型一個(gè)輸入湃番,這里采用形狀為[100,]的向量作為輸入夭织,輸出是形狀為[28,28,1]的矩陣。
def build_generator(self):
# input shape = [100,]
# output shape = [np.prod(self.img_shape)]
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# image_shape = [28,28,1]
model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()計(jì)算形狀乘積
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
2.discriminator
判別器的輸入為生成的假的圖片吠撮,形狀為[28,28,1]摔癣,輸出為判別器給出的validity,區(qū)間為[0,1],數(shù)越大表面判別器任務(wù)輸入是真實(shí)數(shù)據(jù)的可能性越大择浊,反之則認(rèn)為輸入數(shù)據(jù)是真實(shí)數(shù)據(jù)的可能性越小铜犬。
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
3.構(gòu)建完整模型
optimizer = Adam(0.0002, 0.5)
# 構(gòu)建和編譯判別器
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# 構(gòu)建生成器
self.generator = self.build_generator()
# 輸入噪聲給生成器,并產(chǎn)生假的圖片
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# 凍結(jié)判別器
self.discriminator.trainable = False
# 將假的圖片輸入給判別器
validity = self.discriminator(img)
# 將生成器和判別器合二為一
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
4.訓(xùn)練策略
先訓(xùn)練判別器厂庇,將真實(shí)圖片和生成器生成的假的圖片(真實(shí)圖片標(biāo)簽為1焕梅,生成圖片標(biāo)簽為0)分別輸入到generator中,計(jì)算兩個(gè)數(shù)據(jù)集損失的平均值担孔,這一步的目的是使判別去學(xué)會(huì)區(qū)分哪一種是真實(shí)圖片江锨,哪一種是生成的圖片
然后訓(xùn)練生成器,但實(shí)際上訓(xùn)練的是剛剛構(gòu)建的完整的模型combined糕篇,但是由于將discriminator凍結(jié)了啄育,所以只有g(shù)enerator參與訓(xùn)練。然后將預(yù)測(cè)結(jié)果與1對(duì)比拌消,如果越接近1說明生成器已經(jīng)生成了能欺騙discriminator的圖片挑豌,通過優(yōu)化loss使generator產(chǎn)生的圖片越接近真實(shí)圖片
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data() # 分別是訓(xùn)練集數(shù)據(jù),訓(xùn)練集標(biāo)簽墩崩,測(cè)試集數(shù)據(jù)氓英,測(cè)試集標(biāo)簽 (tuple格式)
# X_train.shape = (60000, 28, 28)
# Rescale -1 to 1 歸一化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3) # 增加一維 ---> (60000,28鹦筹,28铝阐,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) #產(chǎn)生0到60000,batchsize個(gè)隨機(jī)整數(shù)
imgs = X_train[idx] # 隨機(jī)取出batchsize個(gè)圖片
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #隨機(jī)產(chǎn)生輸入铐拐,輸入形狀(batch_size, 100)
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid) # 輸入的是真實(shí)圖片徘键,valid都是1
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 輸入的都是產(chǎn)生的圖片,fake都是0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
5.GAN網(wǎng)絡(luò)結(jié)構(gòu)
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_1 (Flatten) (None, 784) 0
_________________________________________________________________
dense_1 (Dense) (None, 512) 401920
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 256) 131328
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 256) 0
_________________________________________________________________
dense_3 (Dense) (None, 1) 257
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_4 (Dense) (None, 256) 25856
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 256) 0
_________________________________________________________________
batch_normalization_1 (Batch (None, 256) 1024
_________________________________________________________________
dense_5 (Dense) (None, 512) 131584
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 512) 0
_________________________________________________________________
batch_normalization_2 (Batch (None, 512) 2048
_________________________________________________________________
dense_6 (Dense) (None, 1024) 525312
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 1024) 0
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024) 4096
_________________________________________________________________
dense_7 (Dense) (None, 784) 803600
_________________________________________________________________
reshape_1 (Reshape) (None, 28, 28, 1) 0
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
完整代碼
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
# 構(gòu)建生成器
def build_generator(self):
# input shape = [100,]
# output shape = [np.prod(self.img_shape)]
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
# image_shape = [28,28,1]
model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()計(jì)算形狀乘積
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
# 構(gòu)建判別器
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data() # 分別是訓(xùn)練集數(shù)據(jù)遍蟋,訓(xùn)練集標(biāo)簽啊鸭,測(cè)試集數(shù)據(jù),測(cè)試集標(biāo)簽 (tuple格式)
# X_train.shape = (60000, 28, 28)
# Rescale -1 to 1 歸一化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3) # 增加一維 ---> (60000匿值,28赠制,28,1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size) #產(chǎn)生0到60000挟憔,batchsize個(gè)隨機(jī)整數(shù)
imgs = X_train[idx] # 隨機(jī)取出batchsize個(gè)圖片
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #隨機(jī)產(chǎn)生輸入钟些,輸入形狀(batch_size, 100)
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid) # 輸入的是真實(shí)圖片,valid都是1
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 輸入的都是產(chǎn)生的圖片绊谭,fake都是0
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
if not os.path.exists("./images"):
os.makedirs("./images")
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200) #sample_interval => 采樣間隔
下面為分別訓(xùn)練第0,10000,20000和29800個(gè)epoch時(shí)generator產(chǎn)生的圖像: