生成對(duì)抗網(wǎng)絡(luò)原理與實(shí)戰(zhàn)

原創(chuàng):李孟啟

1. 前言

在生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network,簡(jiǎn)稱 GAN)發(fā)明之前峭跳,變分自編碼器被認(rèn)為是理論完備,實(shí)現(xiàn)簡(jiǎn)單,使用神經(jīng)網(wǎng)絡(luò)訓(xùn)練起來很穩(wěn)定误证,生成的圖片逼近度也較高,但是人眼還是可以很輕易地分辨出真實(shí)圖片與機(jī)器生成的圖片修壕。

2014 年愈捅,Université de Montréal 大學(xué) Yoshua Bengio(2019 年圖靈獎(jiǎng)獲得者)的學(xué)生 Ian Goodfellow 提出了生成對(duì)抗網(wǎng)絡(luò) GAN,從而開辟了深度學(xué)習(xí)最炙手可熱的研究方向之一慈鸠。從 2014 年到 2019 年蓝谨,GAN 的研究穩(wěn)步推進(jìn),研究捷報(bào)頻傳青团,最新的 GAN 算法在圖片生成上的效果甚至達(dá)到了肉眼難辨的程度譬巫,著實(shí)令人振奮。由于 GAN 的發(fā)明督笆,Ian Goodfellow 榮獲 GAN 之父稱號(hào)芦昔,并獲得 2017 年麻省理工科技評(píng)論頒發(fā)的 35 Innovators Under 35 獎(jiǎng)項(xiàng)。圖 1 展示了從 2014 年到 2018 年胖腾,GAN 模型取得了圖書生成的效果烟零,可以看到不管是圖片大小瘪松,還是圖片逼真度,都有了巨大的提升锨阿。

圖1 GAN模型2014~2018年的圖片生成效果

2. 博弈學(xué)實(shí)例

接下來宵睦,我們將從生活中博弈學(xué)習(xí)的實(shí)例出發(fā),一步步引出 GAN 算法的設(shè)計(jì)思想和模型結(jié)構(gòu)墅诡。我們用一個(gè)漫畫家的成長(zhǎng)軌跡來形象介紹生成對(duì)抗網(wǎng)絡(luò)的思想壳嚎。考慮一對(duì)雙胞胎兄弟末早,分別稱為老二 G 和老大 D烟馅,G 學(xué)習(xí)如何繪制漫畫,D 學(xué)習(xí)如何鑒賞畫作然磷。還在娃娃時(shí)代的兩兄弟郑趁,尚且只學(xué)會(huì)了如何使用畫筆和紙張,G 繪制了一張不明所以的畫作姿搜,如圖2(a)所示寡润,由于此時(shí) D 鑒別能力不高,覺得 G 的作品還行舅柜,但是人物主體不夠鮮明梭纹。在 D 的指引和鼓勵(lì)下,G 開始嘗試學(xué)習(xí)如何繪制主體輪廓和使用簡(jiǎn)單的色彩搭配致份。一年后变抽,G 提升了繪畫的基本功,D 也通過分析名作和初學(xué)者 G 的作品氮块,初步掌握了鑒別作品的能力绍载。此時(shí) D 覺得 G 的作品人物主體有了,如圖 2(b)滔蝉,但是色彩的運(yùn)用還不夠成熟逛钻。數(shù)年后,G 的繪畫基本功已經(jīng)很扎實(shí)了锰提,可以輕松繪制出主體鮮明、顏色搭配合適和逼真度較高的畫作芳悲,如圖 2(c)立肘,但是 D 同樣通過觀察 G 和其它名作的差別,提升了畫作鑒別能力名扛,覺得 G 的畫作技藝已經(jīng)趨于成熟谅年,但是對(duì)生活的觀察尚且不夠,作品沒有傳達(dá)神情且部分細(xì)節(jié)不夠完美肮韧。又過了數(shù)年融蹂,G 的繪畫功力達(dá)到了爐火純青的地步旺订,繪制的作品細(xì)節(jié)完美、風(fēng)格迥異超燃、惟妙惟肖区拳,宛如大師級(jí)水準(zhǔn),如圖 2(d)意乓,即便此時(shí)的D 鑒別功力也相當(dāng)出色樱调,亦很難將 G 和其他大師級(jí)的作品區(qū)分開來。

上述畫家的成長(zhǎng)歷程其實(shí)是一個(gè)生活中普遍存在的學(xué)習(xí)過程届良,通過雙方的博弈學(xué)習(xí)笆凌,相互提高,最終達(dá)到一個(gè)平衡點(diǎn)士葫。GAN 網(wǎng)絡(luò)借鑒了博弈學(xué)習(xí)的思想乞而,分別設(shè)立了兩個(gè)子網(wǎng)絡(luò):負(fù)責(zé)生成樣本的生成器 G 和負(fù)責(zé)鑒別真?zhèn)蔚蔫b別器 D。類比到畫家的例子慢显,生成器 G就是老二爪模,鑒別器 D 就是老大。鑒別器 D 通過觀察真實(shí)的樣本和生成器 G 產(chǎn)生的樣本之間的區(qū)別鳍怨,學(xué)會(huì)如何鑒別真假呻右,其中真實(shí)的樣本為真,生成器 G 產(chǎn)生的樣本為假鞋喇。而生成器 G 同樣也在學(xué)習(xí)声滥,它希望產(chǎn)生的樣本能夠獲得鑒別器 D 的認(rèn)可,即在鑒別器 D 中鑒別為真侦香,因此生成器 G 通過優(yōu)化自身的參數(shù)落塑,嘗試使得自己產(chǎn)生的樣本在鑒別器 D 中判別為真。生成器 G 和鑒別器 D 相互博弈罐韩,共同提升憾赁,直至達(dá)到平衡點(diǎn)。此時(shí)生成器 G 生成的樣本非常逼真散吵,使得鑒別器 D 真假難分龙考。

圖2 畫家的成長(zhǎng)軌跡示意圖

在原始的 GAN 論文中,Ian Goodfellow 使用了另一個(gè)形象的比喻來介紹 GAN 模型:生成器網(wǎng)絡(luò) G 的功能就是產(chǎn)生一系列非常逼真的假鈔試圖欺騙鑒別器 D矾睦,而鑒別器 D 通過學(xué)習(xí)真鈔和生成器 G 生成的假鈔來掌握鈔票的鑒別方法晦款。這兩個(gè)網(wǎng)絡(luò)在相互博弈的過程中間同步提升,直到生成器 G 產(chǎn)生的假鈔非常的逼真枚冗,連鑒別器 D 都真假難辨缓溅。

這種博弈學(xué)習(xí)的思想使得 GAN 的網(wǎng)絡(luò)結(jié)構(gòu)和訓(xùn)練過程與之前的網(wǎng)絡(luò)模型略有不同,下面我們來詳細(xì)介紹 GAN 的網(wǎng)絡(luò)結(jié)構(gòu)和算法原理赁温。

3. GAN原理

一個(gè)典型的生成對(duì)抗網(wǎng)絡(luò)模型大概如圖3所示坛怪。

圖3 對(duì)抗生成網(wǎng)絡(luò)模型

我們先來理解下GAN的兩個(gè)模型要做什么淤齐。首先判別模型(鑒別器),就是圖3中右半部分的網(wǎng)絡(luò)袜匿,直觀來看就是一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)更啄,輸入就是一副圖像,輸出就是一個(gè)概率值(其實(shí)是個(gè)二分類問題)沉帮,用于判斷真假使用(概率值大于0.5那就是真锈死,小于0.5那就是假),真假也不過是人們定義的概率而已穆壕。其次是生成模型待牵,生成模型要做什么呢,同樣也可以看成是一個(gè)神經(jīng)網(wǎng)絡(luò)模型喇勋,輸入是一組隨機(jī)數(shù)Z缨该,輸出是一個(gè)圖像,不再是一個(gè)數(shù)值而已川背。從圖3中可以看到贰拿,會(huì)存在兩個(gè)數(shù)據(jù)集,一個(gè)是真實(shí)數(shù)據(jù)集熄云,另一個(gè)是假的數(shù)據(jù)集膨更,那這個(gè)數(shù)據(jù)集就是有生成網(wǎng)絡(luò)造出來的數(shù)據(jù)集。根據(jù)圖3我們?cè)賮砝斫庖幌翯AN的目標(biāo)是要做什么:

判別網(wǎng)絡(luò)的目的:就是能判別出來輸入的一張圖它是來自真實(shí)樣本集還是假樣本集缴允。假如輸入的是真樣本荚守,網(wǎng)絡(luò)輸出就接近1,輸入的是假樣本练般,網(wǎng)絡(luò)輸出接近0矗漾,那么很完美,達(dá)到了很好判別的目的薄料。

生成網(wǎng)絡(luò)的目的:生成網(wǎng)絡(luò)是造樣本的敞贡,它的目的就是使得自己造樣本的能力盡可能強(qiáng),強(qiáng)到什么程度呢摄职,你判別網(wǎng)絡(luò)沒法判斷我是真樣本還是假樣本誊役。因此辨別網(wǎng)絡(luò)的作用就是對(duì)噪音生成的數(shù)據(jù)辨別他為假的,對(duì)真實(shí)的數(shù)據(jù)辨別他為真的谷市。而生成網(wǎng)絡(luò)的損失函數(shù)就是使得對(duì)于噪音數(shù)據(jù)势木,經(jīng)過辨別網(wǎng)絡(luò)之后的辨別結(jié)果是真的,這樣就能達(dá)到生成真實(shí)圖像的目的歌懒。這里會(huì)感覺比較饒,這也是生成對(duì)抗網(wǎng)絡(luò)的難點(diǎn)所在溯壶,理解了這點(diǎn)及皂,整個(gè)生成對(duì)抗網(wǎng)絡(luò)模型也就理解了甫男。

4. DCGAN實(shí)戰(zhàn)

這里我們拿DCGAN來舉例子,DCGAN是GAN的一個(gè)變體验烧,DCGAN就是將CNN和原始的GAN結(jié)合到一起板驳,生成網(wǎng)絡(luò)和鑒別網(wǎng)絡(luò)都運(yùn)用到了深度卷積神經(jīng)網(wǎng)絡(luò)。DCGAN提高了基礎(chǔ)GAN的穩(wěn)定性和生成結(jié)果質(zhì)量碍拆。

該項(xiàng)目使用的是mnist手寫字?jǐn)?shù)據(jù)集若治,深度學(xué)習(xí)框架為tensorflow。你也可以直接跳過下面代碼直接git clone本項(xiàng)目感混,項(xiàng)目的github鏈接https://github.com/limengqigithub/DCGAN-mnist-master.git端幼。

4.1 DCGAN模型代碼

import tensorflow as tf
from tensorflow import keras

# 生成網(wǎng)絡(luò)
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        self.n_f = 512
        self.n_k = 4

        # input z vector is [None, 100]
        self.dense1 = keras.layers.Dense(3 * 3 * self.n_f)
        self.conv2 = keras.layers.Conv2DTranspose(self.n_f // 2, 3, 2, 'valid')
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2DTranspose(self.n_f // 4, self.n_k, 2, 'same')
        self.bn3 = keras.layers.BatchNormalization()
        self.conv4 = keras.layers.Conv2DTranspose(1, self.n_k, 2, 'same')
        return

    def call(self, inputs, training=None):
        # [b, 100] => [b, 3, 3, 512]
        x = tf.nn.leaky_relu(tf.reshape(self.dense1(inputs), shape=[-1, 3, 3, self.n_f]))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = tf.tanh(self.conv4(x))
        return x

# 判別網(wǎng)絡(luò)
class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.n_f = 64
        self.n_k = 4

        # input image is [-1, 28, 28, 1]
        self.conv1 = keras.layers.Conv2D(self.n_f, self.n_k, 2, 'same')
        self.conv2 = keras.layers.Conv2D(self.n_f * 2, self.n_k, 2, 'same')
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2D(self.n_f * 4, self.n_k, 2, 'same')
        self.bn3 = keras.layers.BatchNormalization()
        self.flatten4 = keras.layers.Flatten()
        self.dense4 = keras.layers.Dense(1)
        return

    def call(self, inputs, training=None):
        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = self.dense4(self.flatten4(x))
        return x

4.2 損失函數(shù)實(shí)現(xiàn)

# shorten sigmoid cross entropy loss calculation
def celoss_ones(logits, smooth=0.0):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.ones_like(logits) * (1.0 - smooth)))


def celoss_zeros(logits, smooth=0.0):

    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.zeros_like(logits) * (1.0 - smooth)))


def d_loss_fn(generator, discriminator, input_noise, real_image, is_trainig):
    # 生成模型根據(jù)噪聲輸入生成圖片,把根據(jù)噪聲生成的圖片與真實(shí)的圖片輸進(jìn)糾錯(cuò)模型中弧满,然后做交叉熵的loss計(jì)算
    fake_image = generator(input_noise, is_trainig)
    d_real_logits = discriminator(real_image, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)

    d_loss_real = celoss_ones(d_real_logits, smooth=0.1)
    d_loss_fake = celoss_zeros(d_fake_logits, smooth=0.0)
    loss = d_loss_real + d_loss_fake
    return loss


def g_loss_fn(generator, discriminator, input_noise, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)
    loss = celoss_ones(d_fake_logits, smooth=0.1)
    return loss

4.3 保存生成網(wǎng)絡(luò)的生成結(jié)果

def save_result(val_out, val_block_size, image_fn, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b + 1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image, mode=color_mode).save(image_fn)

4.4 主函數(shù)部分

def main():
    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')

    # hyper parameters  超參數(shù)
    z_dim = 100
    epochs = 3000000
    # epochs = 30
    batch_size = 128
    learning_rate = 0.0002
    is_training = True

    # for validation purpose
    assets_dir = './images'
    if not os.path.isdir(assets_dir):
        os.makedirs(assets_dir)
    val_block_size = 10
    val_size = val_block_size * val_block_size

    # load mnist data
    # x_train shape (60000, 28, 28) numpy.ndarray
    # x_test shape (10000, 28, 28) numpy.ndarray
    (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
    x_train = x_train.astype(np.float32) / 255.  # 歸一到(0,1)區(qū)間內(nèi)
    db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size * 4).batch(batch_size).repeat()
    db_iter = iter(db)
    inputs_shape = [-1, 28, 28, 1]

    # create generator & discriminator
    generator = Generator()
    generator.build(input_shape=(batch_size, z_dim))
    generator.summary()
    discriminator = Discriminator()
    discriminator.build(input_shape=(batch_size, 28, 28, 1))
    discriminator.summary()
    # prepare optimizer
    d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epochs):

        # no need labels
        batch_x = next(db_iter)

        # rescale images to -1 ~ 1
        batch_x = tf.reshape(batch_x, shape=inputs_shape)
        # -1 - 1
        batch_x = batch_x * 2.0 - 1.0

        # Sample random noise for G
        batch_z = tf.random.uniform(shape=[batch_size, z_dim], minval=-1., maxval=1.)

        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))

            # validation results at every epoch
            val_z = np.random.uniform(-1, 1, size=(val_size, z_dim))
            fake_image = generator(val_z, training=False)
            image_fn = os.path.join('images', 'gan-val-{:03d}.png'.format(epoch + 1))
            save_result(fake_image.numpy(), val_block_size, image_fn, color_mode='L')

4.5 生成圖片結(jié)果展示

圖4

參考文獻(xiàn):

1.《TensorFlow深度學(xué)習(xí)》——深入理解人工智能算法設(shè)計(jì)

2.https://my.oschina.net/u/778683/blog/3100336

3.https://github.com/limengqigithub/DCGAN-mnist-master.git

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末婆跑,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子庭呜,更是在濱河造成了極大的恐慌滑进,老刑警劉巖,帶你破解...
    沈念sama閱讀 218,525評(píng)論 6 507
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件募谎,死亡現(xiàn)場(chǎng)離奇詭異扶关,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī)数冬,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,203評(píng)論 3 395
  • 文/潘曉璐 我一進(jìn)店門节槐,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人吉执,你說我怎么就攤上這事疯淫。” “怎么了戳玫?”我有些...
    開封第一講書人閱讀 164,862評(píng)論 0 354
  • 文/不壞的土叔 我叫張陵熙掺,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我咕宿,道長(zhǎng)币绩,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 58,728評(píng)論 1 294
  • 正文 為了忘掉前任府阀,我火速辦了婚禮缆镣,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘试浙。我一直安慰自己董瞻,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,743評(píng)論 6 392
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著钠糊,像睡著了一般挟秤。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上抄伍,一...
    開封第一講書人閱讀 51,590評(píng)論 1 305
  • 那天艘刚,我揣著相機(jī)與錄音,去河邊找鬼截珍。 笑死攀甚,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的岗喉。 我是一名探鬼主播秋度,決...
    沈念sama閱讀 40,330評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼沈堡!你這毒婦竟也來了静陈?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 39,244評(píng)論 0 276
  • 序言:老撾萬榮一對(duì)情侶失蹤诞丽,失蹤者是張志新(化名)和其女友劉穎鲸拥,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體僧免,經(jīng)...
    沈念sama閱讀 45,693評(píng)論 1 314
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡刑赶,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,885評(píng)論 3 336
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了懂衩。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片撞叨。...
    茶點(diǎn)故事閱讀 40,001評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖浊洞,靈堂內(nèi)的尸體忽然破棺而出牵敷,到底是詐尸還是另有隱情,我是刑警寧澤法希,帶...
    沈念sama閱讀 35,723評(píng)論 5 346
  • 正文 年R本政府宣布枷餐,位于F島的核電站,受9級(jí)特大地震影響苫亦,放射性物質(zhì)發(fā)生泄漏毛肋。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,343評(píng)論 3 330
  • 文/蒙蒙 一屋剑、第九天 我趴在偏房一處隱蔽的房頂上張望润匙。 院中可真熱鬧,春花似錦唉匾、人聲如沸孕讳。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,919評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽厂财。三九已至油啤,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間蟀苛,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,042評(píng)論 1 270
  • 我被黑心中介騙來泰國(guó)打工逮诲, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留帜平,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 48,191評(píng)論 3 370
  • 正文 我出身青樓梅鹦,卻偏偏與公主長(zhǎng)得像裆甩,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子齐唆,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,955評(píng)論 2 355

推薦閱讀更多精彩內(nèi)容