[Machine Learning From Scratch]-unsupervised_learning-generative_adversarial_network

from __future__ import print_function, division
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar

from sklearn.datasets import fetch_mldata

from mlfromscratch.deep_learning.optimizers import Adam
from mlfromscratch.deep_learning.loss_functions import CrossEntropy
from mlfromscratch.deep_learning.layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from mlfromscratch.deep_learning import NeuralNetwork


class GAN():
    """A Generative Adversarial Network with deep fully-connected neural nets as
    Generator and Discriminator.
    Training Data: MNIST Handwritten Digits (28x28 images)
    """
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.img_dim = self.img_rows * self.img_cols
        self.latent_dim = 100

        optimizer = Adam(learning_rate=0.0002, b1=0.5)
        loss_function = CrossEntropy

        # Build the discriminator
        self.discriminator = self.build_discriminator(optimizer, loss_function)

        # Build the generator
        self.generator = self.build_generator(optimizer, loss_function)

        # Build the combined model
        self.combined = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        self.combined.layers.extend(self.generator.layers)
        self.combined.layers.extend(self.discriminator.layers)

        print ()
        self.generator.summary(name="Generator")
        self.discriminator.summary(name="Discriminator")

    def build_generator(self, optimizer, loss_function):
        
        model = NeuralNetwork(optimizer=optimizer, loss=loss_function)

        model.add(Dense(256, input_shape=(self.latent_dim,)))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.img_dim))
        model.add(Activation('tanh'))

        return model

    def build_discriminator(self, optimizer, loss_function):
        
        model = NeuralNetwork(optimizer=optimizer, loss=loss_function)

        model.add(Dense(512, input_shape=(self.img_dim,)))
        model.add(Activation('leaky_relu'))
        model.add(Dropout(0.5))
        model.add(Dense(256))
        model.add(Activation('leaky_relu'))
        model.add(Dropout(0.5))
        model.add(Dense(2))
        model.add(Activation('softmax'))

        return model

    def train(self, n_epochs, batch_size=128, save_interval=50):

        mnist = fetch_mldata('MNIST original')

        X = mnist.data
        y = mnist.target

        # Rescale [-1, 1]
        X = (X.astype(np.float32) - 127.5) / 127.5

        half_batch = int(batch_size / 2)

        for epoch in range(n_epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            self.discriminator.set_trainable(True)

            # Select a random half batch of images
            idx = np.random.randint(0, X.shape[0], half_batch)
            imgs = X[idx]

            # Sample noise to use as generator input
            noise = np.random.normal(0, 1, (half_batch, self.latent_dim))

            # Generate a half batch of images
            gen_imgs = self.generator.predict(noise)

            # Valid = [1, 0], Fake = [0, 1]
            valid = np.concatenate((np.ones((half_batch, 1)), np.zeros((half_batch, 1))), axis=1)
            fake = np.concatenate((np.zeros((half_batch, 1)), np.ones((half_batch, 1))), axis=1)

            # Train the discriminator
            d_loss_real, d_acc_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * (d_loss_real + d_loss_fake)
            d_acc = 0.5 * (d_acc_real + d_acc_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            # We only want to train the generator for the combined model
            self.discriminator.set_trainable(False)

            # Sample noise and use as generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # The generator wants the discriminator to label the generated samples as valid
            valid = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))), axis=1)

            # Train the generator
            g_loss, g_acc = self.combined.train_on_batch(noise, valid)

            # Display the progress
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, d_loss, 100*d_acc, g_loss, 100*g_acc))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5 # Grid size
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        # Generate images and reshape to image shape
        gen_imgs = self.generator.predict(noise).reshape((-1, self.img_rows, self.img_cols))

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        plt.suptitle("Generative Adversarial Network")
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(n_epochs=200000, batch_size=64, save_interval=400)
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末誉简,一起剝皮案震驚了整個(gè)濱河市州藕,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖锈津,帶你破解...
    沈念sama閱讀 222,865評(píng)論 6 518
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件数初,死亡現(xiàn)場(chǎng)離奇詭異叹括,居然都是意外死亡枯芬,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 95,296評(píng)論 3 399
  • 文/潘曉璐 我一進(jìn)店門(mén)碰辅,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)懂昂,“玉大人,你說(shuō)我怎么就攤上這事没宾×璞颍” “怎么了?”我有些...
    開(kāi)封第一講書(shū)人閱讀 169,631評(píng)論 0 364
  • 文/不壞的土叔 我叫張陵循衰,是天一觀的道長(zhǎng)铲敛。 經(jīng)常有香客問(wèn)我,道長(zhǎng)羹蚣,這世上最難降的妖魔是什么原探? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 60,199評(píng)論 1 300
  • 正文 為了忘掉前任,我火速辦了婚禮顽素,結(jié)果婚禮上咽弦,老公的妹妹穿的比我還像新娘。我一直安慰自己胁出,他們只是感情好型型,可當(dāng)我...
    茶點(diǎn)故事閱讀 69,196評(píng)論 6 398
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著全蝶,像睡著了一般闹蒜。 火紅的嫁衣襯著肌膚如雪寺枉。 梳的紋絲不亂的頭發(fā)上,一...
    開(kāi)封第一講書(shū)人閱讀 52,793評(píng)論 1 314
  • 那天绷落,我揣著相機(jī)與錄音姥闪,去河邊找鬼。 笑死砌烁,一個(gè)胖子當(dāng)著我的面吹牛筐喳,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播函喉,決...
    沈念sama閱讀 41,221評(píng)論 3 423
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼避归,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來(lái)了管呵?” 一聲冷哼從身側(cè)響起梳毙,我...
    開(kāi)封第一講書(shū)人閱讀 40,174評(píng)論 0 277
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎捐下,沒(méi)想到半個(gè)月后账锹,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 46,699評(píng)論 1 320
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡蔑担,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 38,770評(píng)論 3 343
  • 正文 我和宋清朗相戀三年牌废,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了咽白。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片啤握。...
    茶點(diǎn)故事閱讀 40,918評(píng)論 1 353
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖晶框,靈堂內(nèi)的尸體忽然破棺而出排抬,到底是詐尸還是另有隱情,我是刑警寧澤授段,帶...
    沈念sama閱讀 36,573評(píng)論 5 351
  • 正文 年R本政府宣布蹲蒲,位于F島的核電站,受9級(jí)特大地震影響侵贵,放射性物質(zhì)發(fā)生泄漏届搁。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 42,255評(píng)論 3 336
  • 文/蒙蒙 一窍育、第九天 我趴在偏房一處隱蔽的房頂上張望卡睦。 院中可真熱鬧,春花似錦漱抓、人聲如沸表锻。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 32,749評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)瞬逊。三九已至显歧,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間确镊,已是汗流浹背士骤。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 33,862評(píng)論 1 274
  • 我被黑心中介騙來(lái)泰國(guó)打工林束, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留速勇,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 49,364評(píng)論 3 379
  • 正文 我出身青樓娜亿,卻偏偏與公主長(zhǎng)得像束铭,于是被迫代替她去往敵國(guó)和親廓块。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,926評(píng)論 2 361

推薦閱讀更多精彩內(nèi)容