使用Keras構(gòu)建GAN

在開始之前請(qǐng)先了解GAN的原理聋迎,有很多博主講的都很好,在這里我就不再過多講解,視頻推薦臺(tái)大李宏毅老師的課程救恨。

GAN共包含兩個(gè)主要結(jié)構(gòu)generator和discriminator缀匕。generator負(fù)責(zé)生成假的數(shù)據(jù)來“欺騙”discriminator纳决,discriminator負(fù)責(zé)判斷輸入的數(shù)據(jù)是否為generator生成的,二者互相迭代乡小,最終實(shí)現(xiàn)generator生成能以假亂真的數(shù)據(jù)阔加。以下以Mnist數(shù)據(jù)集為例,使用GAN來產(chǎn)生手寫數(shù)字满钟。

構(gòu)建網(wǎng)絡(luò)模型

1.generator

神經(jīng)網(wǎng)絡(luò)模型有輸出就有輸入胜榔,我們要想得到假的生成數(shù)據(jù),就要給模型一個(gè)輸入湃番,這里采用形狀為[100,]的向量作為輸入夭织,輸出是形狀為[28,28,1]的矩陣。

    def build_generator(self):
        # input shape = [100,]
        # output shape = [np.prod(self.img_shape)]
        
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        # image_shape = [28,28,1]
        model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()計(jì)算形狀乘積
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

2.discriminator

判別器的輸入為生成的假的圖片吠撮,形狀為[28,28,1]摔癣,輸出為判別器給出的validity,區(qū)間為[0,1],數(shù)越大表面判別器任務(wù)輸入是真實(shí)數(shù)據(jù)的可能性越大择浊,反之則認(rèn)為輸入數(shù)據(jù)是真實(shí)數(shù)據(jù)的可能性越小铜犬。

    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

3.構(gòu)建完整模型

        optimizer = Adam(0.0002, 0.5)

        # 構(gòu)建和編譯判別器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

        # 構(gòu)建生成器
        self.generator = self.build_generator()

        # 輸入噪聲給生成器,并產(chǎn)生假的圖片
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # 凍結(jié)判別器
        self.discriminator.trainable = False

        # 將假的圖片輸入給判別器
        validity = self.discriminator(img)

        # 將生成器和判別器合二為一
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

4.訓(xùn)練策略

  1. 先訓(xùn)練判別器厂庇,將真實(shí)圖片和生成器生成的假的圖片(真實(shí)圖片標(biāo)簽為1焕梅,生成圖片標(biāo)簽為0)分別輸入到generator中,計(jì)算兩個(gè)數(shù)據(jù)集損失的平均值担孔,這一步的目的是使判別去學(xué)會(huì)區(qū)分哪一種是真實(shí)圖片江锨,哪一種是生成的圖片

  2. 然后訓(xùn)練生成器,但實(shí)際上訓(xùn)練的是剛剛構(gòu)建的完整的模型combined糕篇,但是由于將discriminator凍結(jié)了啄育,所以只有g(shù)enerator參與訓(xùn)練。然后將預(yù)測(cè)結(jié)果與1對(duì)比拌消,如果越接近1說明生成器已經(jīng)生成了能欺騙discriminator的圖片挑豌,通過優(yōu)化loss使generator產(chǎn)生的圖片越接近真實(shí)圖片

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data() # 分別是訓(xùn)練集數(shù)據(jù),訓(xùn)練集標(biāo)簽墩崩,測(cè)試集數(shù)據(jù)氓英,測(cè)試集標(biāo)簽 (tuple格式)
        # X_train.shape = (60000, 28, 28)
        
        # Rescale -1 to 1 歸一化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3) # 增加一維 ---> (60000,28鹦筹,28铝阐,1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size) #產(chǎn)生0到60000,batchsize個(gè)隨機(jī)整數(shù)
            imgs = X_train[idx] # 隨機(jī)取出batchsize個(gè)圖片

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #隨機(jī)產(chǎn)生輸入铐拐,輸入形狀(batch_size, 100)

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 輸入的是真實(shí)圖片徘键,valid都是1
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 輸入的都是產(chǎn)生的圖片,fake都是0
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

5.GAN網(wǎng)絡(luò)結(jié)構(gòu)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_4 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_5 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_6 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_7 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584

完整代碼

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os

import numpy as np

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    # 構(gòu)建生成器
    def build_generator(self):
        # input shape = [100,]
        # output shape = [np.prod(self.img_shape)]
        
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        # image_shape = [28,28,1]
        model.add(Dense(np.prod(self.img_shape), activation='tanh')) #np.prod()計(jì)算形狀乘積
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)
    
    # 構(gòu)建判別器
    def build_discriminator(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data() # 分別是訓(xùn)練集數(shù)據(jù)遍蟋,訓(xùn)練集標(biāo)簽啊鸭,測(cè)試集數(shù)據(jù),測(cè)試集標(biāo)簽 (tuple格式)
        # X_train.shape = (60000, 28, 28)
        
        # Rescale -1 to 1 歸一化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3) # 增加一維 ---> (60000匿值,28赠制,28,1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size) #產(chǎn)生0到60000挟憔,batchsize個(gè)隨機(jī)整數(shù)
            imgs = X_train[idx] # 隨機(jī)取出batchsize個(gè)圖片

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) #隨機(jī)產(chǎn)生輸入钟些,輸入形狀(batch_size, 100)

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)  # 輸入的是真實(shí)圖片,valid都是1
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) # 輸入的都是產(chǎn)生的圖片绊谭,fake都是0
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()

if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=32, sample_interval=200) #sample_interval => 采樣間隔

下面為分別訓(xùn)練第0,10000,20000和29800個(gè)epoch時(shí)generator產(chǎn)生的圖像:


0.png
10000.png
20000.png
29800.png
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末政恍,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子达传,更是在濱河造成了極大的恐慌篙耗,老刑警劉巖迫筑,帶你破解...
    沈念sama閱讀 206,839評(píng)論 6 482
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異宗弯,居然都是意外死亡脯燃,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,543評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門蒙保,熙熙樓的掌柜王于貴愁眉苦臉地迎上來辕棚,“玉大人,你說我怎么就攤上這事邓厕∈藕浚” “怎么了?”我有些...
    開封第一講書人閱讀 153,116評(píng)論 0 344
  • 文/不壞的土叔 我叫張陵详恼,是天一觀的道長(zhǎng)补君。 經(jīng)常有香客問我,道長(zhǎng)昧互,這世上最難降的妖魔是什么挽铁? 我笑而不...
    開封第一講書人閱讀 55,371評(píng)論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮硅堆,結(jié)果婚禮上屿储,老公的妹妹穿的比我還像新娘贿讹。我一直安慰自己渐逃,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,384評(píng)論 5 374
  • 文/花漫 我一把揭開白布民褂。 她就那樣靜靜地躺著茄菊,像睡著了一般。 火紅的嫁衣襯著肌膚如雪赊堪。 梳的紋絲不亂的頭發(fā)上面殖,一...
    開封第一講書人閱讀 49,111評(píng)論 1 285
  • 那天,我揣著相機(jī)與錄音哭廉,去河邊找鬼脊僚。 笑死,一個(gè)胖子當(dāng)著我的面吹牛遵绰,可吹牛的內(nèi)容都是我干的辽幌。 我是一名探鬼主播,決...
    沈念sama閱讀 38,416評(píng)論 3 400
  • 文/蒼蘭香墨 我猛地睜開眼椿访,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼乌企!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起成玫,我...
    開封第一講書人閱讀 37,053評(píng)論 0 259
  • 序言:老撾萬榮一對(duì)情侶失蹤加酵,失蹤者是張志新(化名)和其女友劉穎拳喻,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體猪腕,經(jīng)...
    沈念sama閱讀 43,558評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡冗澈,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,007評(píng)論 2 325
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了码撰。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片渗柿。...
    茶點(diǎn)故事閱讀 38,117評(píng)論 1 334
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖脖岛,靈堂內(nèi)的尸體忽然破棺而出朵栖,到底是詐尸還是另有隱情,我是刑警寧澤柴梆,帶...
    沈念sama閱讀 33,756評(píng)論 4 324
  • 正文 年R本政府宣布陨溅,位于F島的核電站,受9級(jí)特大地震影響绍在,放射性物質(zhì)發(fā)生泄漏门扇。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,324評(píng)論 3 307
  • 文/蒙蒙 一偿渡、第九天 我趴在偏房一處隱蔽的房頂上張望臼寄。 院中可真熱鬧,春花似錦溜宽、人聲如沸吉拳。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,315評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽留攒。三九已至,卻和暖如春嫉嘀,著一層夾襖步出監(jiān)牢的瞬間炼邀,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,539評(píng)論 1 262
  • 我被黑心中介騙來泰國(guó)打工剪侮, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留拭宁,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,578評(píng)論 2 355
  • 正文 我出身青樓瓣俯,卻偏偏與公主長(zhǎng)得像杰标,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子降铸,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,877評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容