利用 TensorFlow 和 MNIST 數(shù)據(jù)集演示 GAN 的構(gòu)建

自打關(guān)注深度學(xué)習(xí)這個領(lǐng)域就不時的看到和 Generative Adversarial Network, GAN 相關(guān)的東西匾灶,也一直非常好奇這個被 LeCun 稱為深度學(xué)習(xí)近年來最大的突破的東西到底是什么樣子的。正好在 Udacity 的課堂里遇到了租漂,在完成了通過 GAN 來完成人臉生成的項目后阶女,在這里做一個總結(jié),加深一下對于 GAN 這個網(wǎng)絡(luò)的理解哩治。為了便于本地試驗秃踩,這里展示的是利用 MNIST 數(shù)據(jù)集來訓(xùn)練一個簡單的 GAN 來生成手寫數(shù)字的過程。注意文中代碼和示例圖片來自 Udacity 深度學(xué)習(xí)納米學(xué)位課程业筏,版權(quán)歸 Udacity 所有憔杨。

深度神經(jīng)網(wǎng)絡(luò)最令人詬病一點(diǎn)就在于其決策過程的不可解釋性,你無從知道網(wǎng)絡(luò)中的單元提取了哪些特征來完成了一項分類或識別任務(wù)蒜胖。比如在圖片識別任務(wù)中消别,即便你可以提取隱藏層的 feature map 來可視化出來相應(yīng)層的情況抛蚤,其圖像在人類看來是抽象而詭異甚至有些驚悚的。這一點(diǎn)其實在我看來是十分正常的寻狂,也不應(yīng)該像很多媒體的解讀方式那樣過分的夸大岁经,事實上,人腦的加工過程有誰可以可視化出來呢蛇券?只不過我們對于人類行為的可預(yù)測性是有把握的缀壤,所以不像對于新生技術(shù)那樣容易催生恐懼。

而 GAN 最為聰明之處在于既然人類無法理解網(wǎng)絡(luò)內(nèi)部的生成過程怀读,索性不用人腦和人類對于圖像的理解方式去理解中間過程诉位,而是用另一個類似結(jié)構(gòu)的神經(jīng)網(wǎng)絡(luò),二者的相互理解過程也就是對抗 Adversarial 的過程菜枷。其實現(xiàn)的大致思路是:

  • 作為生成器的一個典型代表苍糠,GAN 的一個典型應(yīng)用是通過模型來生成類似已有數(shù)據(jù)集的圖片來實現(xiàn)數(shù)據(jù)擴(kuò)增,因此可以首先建立一個通過多層神經(jīng)網(wǎng)絡(luò)實現(xiàn)的生成器啤誊,其主要作用是通過對于符合一定分布規(guī)律的原始數(shù)據(jù)進(jìn)行處理岳瞭,進(jìn)而得到一個符合另一特定分布情況的結(jié)果圖像。這里要求這個網(wǎng)絡(luò)至少包含一個隱藏層蚊锹,否則網(wǎng)絡(luò)就不具有足夠的學(xué)習(xí)和泛化能力瞳筏,這個網(wǎng)絡(luò)在 GAN 中被稱為生成器 Generator。例如在下面的示例圖片中牡昆,生成器的輸入是符合某個分布特征的隨機(jī)數(shù)字:在后續(xù)的代碼示例中采用的是 (-1, 1) 之間的均勻分布

  • 在獲得了生成器之后姚炕,還要建立一個類似結(jié)構(gòu)的可以完成圖像識別任務(wù)的分類器,其特殊之處在于這個網(wǎng)絡(luò)的輸出層只對輸入是來自原始數(shù)據(jù)集還是由生成器網(wǎng)絡(luò)生成的結(jié)果做一個真假判斷丢烘,這個網(wǎng)絡(luò)在 GAN 中稱為識別器 Discriminator

High level overview of GAN with MNIST

在看到代碼之前我一直以為 GAN 的實現(xiàn)會比較復(fù)雜柱宦,但真正看到代碼之后就像看到 E = mc2 一樣,發(fā)現(xiàn)其是如此的簡潔播瞳,優(yōu)雅掸刊,直觀,不得不佩服 Ian Goodfellow 強(qiáng)大的思路赢乓。閑話到此為止忧侧,網(wǎng)絡(luò)架構(gòu)和實現(xiàn)代碼如下:

Network Architecture
%matplotlib inline
import pickle as pkl
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data')

# define the model input for both Generator and Discirminator
def model_inputs(real_dim, z_dim):
    inputs_real = tf.placeholder(tf.float32, (None, real_dim), name='input_real') 
    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')
    
    return inputs_real, inputs_z

# define the Generator
def generator(z, out_dim, n_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('generator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(z, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        # Logits and tanh output
        logits = tf.layers.dense(h1, out_dim, activation=None)
        out = tf.tanh(logits)
        
        return out

# define the Discriminator
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
    with tf.variable_scope('discriminator', reuse=reuse):
        # Hidden layer
        h1 = tf.layers.dense(x, n_units, activation=None)
        # Leaky ReLU
        h1 = tf.maximum(alpha * h1, h1)
        
        logits = tf.layers.dense(h1, 1, activation=None)
        out = tf.sigmoid(logits)
        
        return out, logits

這里之所以要定義這個 variable_scope 是由于在后續(xù)的訓(xùn)練中,需要分別更新生成器和判別器的參數(shù)牌芋,為了提取參數(shù)而特別設(shè)置的蚓炬。另外值得注意的是,激活函數(shù)需要采用 Leaky ReLU 來保證梯度可以從判別器傳回到生成器躺屁。

# build the network
tf.reset_default_graph()
# Create our input placeholders
input_real, input_z = model_inputs(input_size, z_size)

# Build the model
g_model = generator(input_z, input_size, n_units=g_hidden_size, alpha=alpha)
# g_model is the generator output

d_model_real, d_logits_real = discriminator(input_real, n_units=d_hidden_size, alpha=alpha)
d_model_fake, d_logits_fake = discriminator(g_model, reuse=True, n_units=d_hidden_size, alpha=alpha)

# Calculate losses
d_loss_real = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                          labels=tf.ones_like(d_logits_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(
                  tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                          labels=tf.zeros_like(d_logits_real)))
d_loss = d_loss_real + d_loss_fake

g_loss = tf.reduce_mean(
             tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                     labels=tf.ones_like(d_logits_fake)))

在這里新引入的一個操作是 label smoothing试吁,其目的在于適度的放低要求以促進(jìn)收斂。而針對損失函數(shù)這部分楼咳,由于希望判別器將真實數(shù)據(jù)識別為 1熄捍, 而將生成器生成的數(shù)據(jù)識別為 0,因此需要分別計算這兩部分的損失函數(shù)母怜。

# Optimizers
learning_rate = 0.002

# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith('generator')]
d_vars = [var for var in t_vars if var.name.startswith('discriminator')]

d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)

這一段代碼非常重要余耽,正式因為選擇了間歇性的訓(xùn)練才使得網(wǎng)絡(luò)的對抗得以實現(xiàn)。

# Size of input image to discriminator
input_size = 784
# Size of latent vector to generator
z_size = 100
# Sizes of hidden layers in generator and discriminator
g_hidden_size = 128
d_hidden_size = 128
# Leak factor for leaky ReLU
alpha = 0.01
# Smoothing 
smooth = 0.1

下面代碼部分為比較常見的訓(xùn)練代碼結(jié)構(gòu):

batch_size = 100
epochs = 100
samples = []
losses = []
# Only save generator variables
saver = tf.train.Saver(var_list=g_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for ii in range(mnist.train.num_examples//batch_size):
            batch = mnist.train.next_batch(batch_size)
            
            # Get images, reshape and rescale to pass to D
            batch_images = batch[0].reshape((batch_size, 784))
            batch_images = batch_images*2 - 1
            
            # Sample random noise for G
            batch_z = np.random.uniform(-1, 1, size=(batch_size, z_size))
            
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={input_real: batch_images, input_z: batch_z})
            _ = sess.run(g_train_opt, feed_dict={input_z: batch_z})
        
        # At the end of each epoch, get the losses and print them out
        train_loss_d = sess.run(d_loss, {input_z: batch_z, input_real: batch_images})
        train_loss_g = g_loss.eval({input_z: batch_z})
            
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}...".format(train_loss_d),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # Save losses to view after training
        losses.append((train_loss_d, train_loss_g))
        
        # Sample from generator as we're training for viewing afterwards
        sample_z = np.random.uniform(-1, 1, size=(16, z_size))
        gen_samples = sess.run(
                       generator(input_z, input_size, n_units=g_hidden_size, reuse=True, alpha=alpha),
                       feed_dict={input_z: sample_z})
        samples.append(gen_samples)
        saver.save(sess, './checkpoints/generator.ckpt')

# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

為了監(jiān)控訓(xùn)練苹熏,可以提取訓(xùn)練過程中的參數(shù)來識別訓(xùn)練結(jié)果碟贾。實際上在學(xué)習(xí)過程中可以發(fā)現(xiàn) GAN 的訓(xùn)練對于超參數(shù)的選擇十分敏感,并且在后續(xù)的 DCGAN 學(xué)習(xí)中轨域,作者們甚至通過調(diào)整 Adam 中的指數(shù)加權(quán)平均參數(shù) beta1 來實現(xiàn)較好的訓(xùn)練效果袱耽。Ian Goodfellow 在 Andrew Ng 的訪談里也提到自己現(xiàn)在 40% 的時間話在研究如何 Stablize GAN,當(dāng)時沒理解是什么意思干发,直到自己訓(xùn)練了 DCGAN 之后才知道原來 GAN 的訓(xùn)練對于超參數(shù)是如此的敏感朱巨。

def view_samples(epoch, samples):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch]):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes

rows, cols = 10, 6
fig, axes = plt.subplots(figsize=(7,12), nrows=rows, ncols=cols, sharex=True, sharey=True)

for sample, ax_row in zip(samples[::int(len(samples)/rows)], axes):
    for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
        ax.imshow(img.reshape((28,28)), cmap='Greys_r')
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
Generated result as the training goes

參考閱讀

  1. Tips and tricks to make GANs work

  2. Generative Adversarial Networks for beginners

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市枉长,隨后出現(xiàn)的幾起案子冀续,更是在濱河造成了極大的恐慌,老刑警劉巖必峰,帶你破解...
    沈念sama閱讀 219,110評論 6 508
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件洪唐,死亡現(xiàn)場離奇詭異,居然都是意外死亡吼蚁,警方通過查閱死者的電腦和手機(jī)凭需,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,443評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來肝匆,“玉大人粒蜈,你說我怎么就攤上這事∈趸#” “怎么了薪伏?”我有些...
    開封第一講書人閱讀 165,474評論 0 356
  • 文/不壞的土叔 我叫張陵,是天一觀的道長粗仓。 經(jīng)常有香客問我嫁怀,道長,這世上最難降的妖魔是什么借浊? 我笑而不...
    開封第一講書人閱讀 58,881評論 1 295
  • 正文 為了忘掉前任塘淑,我火速辦了婚禮,結(jié)果婚禮上蚂斤,老公的妹妹穿的比我還像新娘存捺。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,902評論 6 392
  • 文/花漫 我一把揭開白布捌治。 她就那樣靜靜地躺著岗钩,像睡著了一般。 火紅的嫁衣襯著肌膚如雪肖油。 梳的紋絲不亂的頭發(fā)上兼吓,一...
    開封第一講書人閱讀 51,698評論 1 305
  • 那天,我揣著相機(jī)與錄音森枪,去河邊找鬼视搏。 笑死,一個胖子當(dāng)著我的面吹牛县袱,可吹牛的內(nèi)容都是我干的浑娜。 我是一名探鬼主播,決...
    沈念sama閱讀 40,418評論 3 419
  • 文/蒼蘭香墨 我猛地睜開眼式散,長吁一口氣:“原來是場噩夢啊……” “哼筋遭!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起杂数,我...
    開封第一講書人閱讀 39,332評論 0 276
  • 序言:老撾萬榮一對情侶失蹤宛畦,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后揍移,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體次和,經(jīng)...
    沈念sama閱讀 45,796評論 1 316
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,968評論 3 337
  • 正文 我和宋清朗相戀三年那伐,在試婚紗的時候發(fā)現(xiàn)自己被綠了踏施。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 40,110評論 1 351
  • 序言:一個原本活蹦亂跳的男人離奇死亡罕邀,死狀恐怖畅形,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情诉探,我是刑警寧澤日熬,帶...
    沈念sama閱讀 35,792評論 5 346
  • 正文 年R本政府宣布,位于F島的核電站肾胯,受9級特大地震影響竖席,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜敬肚,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,455評論 3 331
  • 文/蒙蒙 一毕荐、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧艳馒,春花似錦憎亚、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 32,003評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽蝶锋。三九已至,卻和暖如春斋日,著一層夾襖步出監(jiān)牢的瞬間牲览,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 33,130評論 1 272
  • 我被黑心中介騙來泰國打工恶守, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人贡必。 一個月前我還...
    沈念sama閱讀 48,348評論 3 373
  • 正文 我出身青樓兔港,卻偏偏與公主長得像,于是被迫代替她去往敵國和親仔拟。 傳聞我的和親對象是個殘疾皇子衫樊,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 45,047評論 2 355

推薦閱讀更多精彩內(nèi)容