from __future__ import print_function, division
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar
from sklearn.datasets import fetch_mldata
from mlfromscratch.deep_learning.optimizers import Adam
from mlfromscratch.deep_learning.loss_functions import CrossEntropy
from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from mlfromscratch.deep_learning import NeuralNetwork
class GAN():
"""A Generative Adversarial Network with deep fully-connected neural nets as
Generator and Discriminator.
Training Data: MNIST Handwritten Digits (28x28 images)
"""
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.img_dim = self.img_rows * self.img_cols
self.latent_dim = 100
optimizer = Adam(learning_rate=0.0002, b1=0.5)
loss_function = CrossEntropy
# Build the discriminator
self.discriminator = self.build_discriminator(optimizer, loss_function)
# Build the generator
self.generator = self.build_generator(optimizer, loss_function)
# Build the combined model
self.combined = NeuralNetwork(optimizer=optimizer, loss=loss_function)
self.combined.layers.extend(self.generator.layers)
self.combined.layers.extend(self.discriminator.layers)
print ()
self.generator.summary(name="Generator")
self.discriminator.summary(name="Discriminator")
def build_generator(self, optimizer, loss_function):
model = NeuralNetwork(optimizer=optimizer, loss=loss_function)
model.add(Dense(256, input_shape=(self.latent_dim,)))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(Activation('leaky_relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(self.img_dim))
model.add(Activation('tanh'))
return model
def build_discriminator(self, optimizer, loss_function):
model = NeuralNetwork(optimizer=optimizer, loss=loss_function)
model.add(Dense(512, input_shape=(self.img_dim,)))
model.add(Activation('leaky_relu'))
model.add(Dropout(0.5))
model.add(Dense(256))
model.add(Activation('leaky_relu'))
model.add(Dropout(0.5))
model.add(Dense(2))
model.add(Activation('softmax'))
return model
def train(self, n_epochs, batch_size=128, save_interval=50):
mnist = fetch_mldata('MNIST original')
X = mnist.data
y = mnist.target
# Rescale [-1, 1]
X = (X.astype(np.float32) - 127.5) / 127.5
half_batch = int(batch_size / 2)
for epoch in range(n_epochs):
# ---------------------
# Train Discriminator
# ---------------------
self.discriminator.set_trainable(True)
# Select a random half batch of images
idx = np.random.randint(0, X.shape[0], half_batch)
imgs = X[idx]
# Sample noise to use as generator input
noise = np.random.normal(0, 1, (half_batch, self.latent_dim))
# Generate a half batch of images
gen_imgs = self.generator.predict(noise)
# Valid = [1, 0], Fake = [0, 1]
valid = np.concatenate((np.ones((half_batch, 1)), np.zeros((half_batch, 1))), axis=1)
fake = np.concatenate((np.zeros((half_batch, 1)), np.ones((half_batch, 1))), axis=1)
# Train the discriminator
d_loss_real, d_acc_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
# ---------------------
# Train Generator
# ---------------------
# We only want to train the generator for the combined model
self.discriminator.set_trainable(False)
# Sample noise and use as generator input
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# The generator wants the discriminator to label the generated samples as valid
valid = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))), axis=1)
# Train the generator
g_loss, g_acc = self.combined.train_on_batch(noise, valid)
# Display the progress
print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, d_loss, 100*d_acc, g_loss, 100*g_acc))
# If at save interval => save generated image samples
if epoch % save_interval == 0:
self.save_imgs(epoch)
def save_imgs(self, epoch):
r, c = 5, 5 # Grid size
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
# Generate images and reshape to image shape
gen_imgs = self.generator.predict(noise).reshape((-1, self.img_rows, self.img_cols))
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
plt.suptitle("Generative Adversarial Network")
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("mnist_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(n_epochs=200000, batch_size=64, save_interval=400)
[Machine Learning From Scratch]-unsupervised_learning-generative_adversarial_network
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
- 文/潘曉璐 我一進(jìn)店門(mén)碰辅,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)懂昂,“玉大人,你說(shuō)我怎么就攤上這事没宾×璞颍” “怎么了?”我有些...
- 文/不壞的土叔 我叫張陵循衰,是天一觀的道長(zhǎng)铲敛。 經(jīng)常有香客問(wèn)我,道長(zhǎng)羹蚣,這世上最難降的妖魔是什么原探? 我笑而不...
- 正文 為了忘掉前任,我火速辦了婚禮顽素,結(jié)果婚禮上咽弦,老公的妹妹穿的比我還像新娘。我一直安慰自己胁出,他們只是感情好型型,可當(dāng)我...
- 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著全蝶,像睡著了一般闹蒜。 火紅的嫁衣襯著肌膚如雪寺枉。 梳的紋絲不亂的頭發(fā)上,一...
- 那天绷落,我揣著相機(jī)與錄音姥闪,去河邊找鬼。 笑死砌烁,一個(gè)胖子當(dāng)著我的面吹牛筐喳,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播函喉,決...
- 文/蒼蘭香墨 我猛地睜開(kāi)眼避归,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了管呵?” 一聲冷哼從身側(cè)響起梳毙,我...
- 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎捐下,沒(méi)想到半個(gè)月后账锹,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
- 正文 獨(dú)居荒郊野嶺守林人離奇死亡蔑担,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
- 正文 我和宋清朗相戀三年牌废,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了咽白。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片啤握。...
- 正文 年R本政府宣布蹲蒲,位于F島的核電站,受9級(jí)特大地震影響侵贵,放射性物質(zhì)發(fā)生泄漏届搁。R本人自食惡果不足惜,卻給世界環(huán)境...
- 文/蒙蒙 一窍育、第九天 我趴在偏房一處隱蔽的房頂上張望卡睦。 院中可真熱鬧,春花似錦漱抓、人聲如沸表锻。這莊子的主人今日做“春日...
- 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)瞬逊。三九已至显歧,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間确镊,已是汗流浹背士骤。 一陣腳步聲響...
- 正文 我出身青樓娜亿,卻偏偏與公主長(zhǎng)得像束铭,于是被迫代替她去往敵國(guó)和親廓块。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
推薦閱讀更多精彩內(nèi)容
- Unsupervised Representation Learning with Deep Convolutio...