cDCGAN生成MNIST圖片(條件深度卷積生成對抗網(wǎng)絡(luò))

原文歡迎關(guān)注http://blackblog.tech/2018/01/25/cDCGAN生成MNIST圖片/
歡迎關(guān)注我的個人博客 http://blackblog.tech

前幾日和媳,學(xué)校期末作業(yè)要求我們使用機器學(xué)習(xí)的方法解決一個實際問題击喂,思考了很久,嘗試做了很多選題弦悉,最終決定做一個cDCGAN瞳收,即條件深度卷積生成對抗網(wǎng)絡(luò)涮阔。
為什么做這個選題呢叮盘?
生成對抗網(wǎng)絡(luò)這幾年實在是火爆,圖片上色霹俺,視頻去馬賽克柔吼,包括英偉達最近展出的白馬變棕馬,白天變黑夜丙唧,都是使用生成對抗網(wǎng)絡(luò)實現(xiàn)的愈魏。
2014年"Generative Adversarial Nets"這篇論文中所提到的生成對抗網(wǎng)絡(luò)是一個無監(jiān)督的生成對抗網(wǎng)絡(luò),且沒有使用卷積與反卷積操作想际。
今天我們以MNIST手寫集為數(shù)據(jù)集培漏,使用tensorflow實現(xiàn)cDCGAN(條件深度卷積生成對抗網(wǎng)絡(luò))

算法描述

生成對抗網(wǎng)絡(luò)(Generative Adversarial Nets)啟發(fā)自博弈論中的兩人零和博弈,GAN模型中的兩位博弈方分別有生成網(wǎng)絡(luò)(Generator)與判別網(wǎng)絡(luò)(Discriminator)充當(dāng)胡本。當(dāng)生成網(wǎng)絡(luò)G捕捉到樣本數(shù)據(jù)分布牌柄,用服從某一分布的噪聲z生成一個類似真實訓(xùn)練數(shù)據(jù)的樣本,與真實樣本越接近越好侧甫;判別網(wǎng)絡(luò)D一般是一個二分類模型珊佣,在本文中D是一個多分類器,用于估計一個樣本來自于真實數(shù)據(jù)的概率闺骚,如果樣本來自于真實數(shù)據(jù)彩扔,則D輸出大概率,否則輸出小概率僻爽。本文中虫碉,判別網(wǎng)絡(luò)需要在此基礎(chǔ)上實現(xiàn)分類功能。

在訓(xùn)練的過程中胸梆,需要固定一方敦捧,更新另一方的網(wǎng)絡(luò)狀態(tài),如此交替進行碰镜。在整個訓(xùn)練的過程中兢卵,雙方都極力優(yōu)化自己的網(wǎng)絡(luò),從而形成競爭對抗绪颖,知道雙方達到一個動態(tài)的平衡秽荤。此時生成網(wǎng)絡(luò)訓(xùn)練出來的數(shù)據(jù)與真實數(shù)據(jù)的分布幾乎相同,判別網(wǎng)絡(luò)也無法再判斷出真?zhèn)巍?br> 本文中生成對抗網(wǎng)絡(luò)主要分為兩部分柠横,生成網(wǎng)絡(luò)(Generator)與判別網(wǎng)絡(luò)(Discriminator)窃款。向生成網(wǎng)絡(luò)內(nèi)輸入噪聲,通過多次反卷積的方式得到一個28x28x1的圖像作為X_fake牍氛,此時將真實的圖像X_real與生成器生成的X_fake放入判別網(wǎng)絡(luò)晨继,判別網(wǎng)絡(luò)使用多次卷積與Sigmoid函數(shù)并通過交叉熵函數(shù)計算出判別網(wǎng)絡(luò)的損失函數(shù)D_loss,通過判別網(wǎng)絡(luò)的損失函數(shù)D_loss計算得到生成網(wǎng)絡(luò)損失函數(shù)G_loss搬俊。使用G_loss與D_loss對生成網(wǎng)絡(luò)與判別網(wǎng)絡(luò)進行參數(shù)調(diào)整紊扬。

算法流程

1.輸入噪聲z
2.通過生成網(wǎng)絡(luò)G得到X_fake=G(z)
3.從數(shù)據(jù)集中獲取真實數(shù)據(jù)X_real
4.通過判別網(wǎng)絡(luò)D計算D(real logits)=D(X_real)
5.通過判別網(wǎng)絡(luò)D計算D(fake logits)=D(X_fake)
6.使用交叉熵函數(shù)做損失函數(shù)根據(jù)D(real logits)計算D(loss real)
7.使用交叉熵函數(shù)做損失函數(shù)根據(jù)D(fake logits)計算D(loss fake)
8.計算判別網(wǎng)絡(luò)損失函數(shù)D_loss=D(loss real)+ D_(loss fake)
9.使用交叉熵函數(shù)做損失函數(shù)計算生成網(wǎng)絡(luò)損失函數(shù)G_loss
10.使用D_loss對判別網(wǎng)絡(luò)進行參數(shù)調(diào)整蜒茄,使用G_loss對生成網(wǎng)絡(luò)參數(shù)進行調(diào)整

網(wǎng)絡(luò)結(jié)構(gòu)

生成網(wǎng)絡(luò)

[圖片上傳失敗...(image-54149c-1530238358410)]

判別網(wǎng)絡(luò)

[圖片上傳失敗...(image-6bc57d-1530238358410)]

數(shù)據(jù)集

MNIST.....
就不多說啥了

訓(xùn)練環(huán)境

系統(tǒng):Windows 10
框架:tensorflow 1.2
CPU:Intel core i5-4210H
GPU:Nvidia GTX 960M 4G(買不起顯卡........)

上代碼!

一些常量的定義餐屎,包括學(xué)校率檀葛,batch_size,保存的路徑等等

import os, time, random,itertools
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import cv2
# 保存圖片
dirpath = 'GAN/'
model = 'GAN_MINIST'
if not os.path.isdir(dirpath):
    os.mkdir(dirpath)
if not os.path.isdir(dirpath + 'FakeImg'):
    os.mkdir(dirpath + 'FakeImg')
# 初始化
IMAGE_SIZE = 28
onehot = np.eye(10)
noise_ = np.random.normal(0, 1, (10, 1, 1, 100))
fixed_noise_ = noise_
fixed_label_ = np.zeros((10, 1))
#用于最后顯示十組圖像
for i in range(9):
    fixed_noise_ = np.concatenate([fixed_noise_, noise_], 0)
    temp = np.ones((10, 1)) + I
    fixed_label_ = np.concatenate([fixed_label_, temp], 0)
fixed_label_ = onehot[fixed_label_.astype(np.int32)].reshape((100, 1, 1, 10))
batch_size = 100
#一共迭代20次
step = 30
#設(shè)置一個全局的計數(shù)器
global_step = tf.Variable(0, trainable=False)
#設(shè)置學(xué)習(xí)率
lr = tf.train.exponential_decay(0.0002, global_step, 500, 0.95, staircase=True)
#加載數(shù)據(jù)集Batch大小:100
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=[])

leaky_relu的定義

def leaky_relu(X, leak=0.2):
    f1 = 0.5 * (1 + leak)
    f2 = 0.5 * (1 - leak)
    return f1 * X + f2 * tf.abs(X)

生成網(wǎng)絡(luò)的定義:

def Generator(x, labels, Training=True, reuse=False):
    with tf.variable_scope('Generator', reuse=reuse):
        #初始化參數(shù)
        W = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
        b = tf.constant_initializer(0.0)
        #把數(shù)據(jù)和標(biāo)簽進行連接
        concat = tf.concat([x, labels], 3)
        #第一次反卷積,卷積核大小為7*7啤挎,輸出維度256
        out_1 = tf.layers.conv2d_transpose(concat, 256, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=W, bias_initializer=b)
        out_1 = tf.layers.batch_normalization(out_1, training=Training)#batch norm
        out_1 = leaky_relu(out_1, 0.2)
         #第二次反卷機驻谆,卷積核大小為5*5,輸出維度128
        out_2 = tf.layers.conv2d_transpose(out_1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
        out_2 = tf.layers.batch_normalization(out_2, training=Training)#batch norm
        out_2 = leaky_relu(out_2, 0.2)
         #第三次反卷機庆聘,卷積核大小5*5,輸出維度1
        out_3 = tf.layers.conv2d_transpose(out_2, 1, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
        out_3 = tf.nn.tanh(out_3)
        return out_3

判別網(wǎng)絡(luò)的定義

def Discriminator(x, real, Training=True, reuse=False):
    with tf.variable_scope('Discriminator', reuse=reuse):
        #初始化參數(shù)
        W = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
        b = tf.constant_initializer(0.0)
        #把數(shù)據(jù)和標(biāo)簽進行連接
        concat = tf.concat([x, real], 3)
        #第一次卷積 卷積核為5*5 輸出維度為128
        out_1 = tf.layers.conv2d(concat, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
        out_1 = leaky_relu(out_1, 0.2)
        # 第二次卷積 卷積核為5*5 輸出維度256
        out_2 = tf.layers.conv2d(out_1, 256, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
        out_2 = tf.layers.batch_normalization(out_2, training=Training)#batch norm
        out_2 = leaky_relu(out_2, 0.2)
        #第三次卷積勺卢,卷積和為7*7伙判,輸出維度為1
        out_3 = tf.layers.conv2d(out_2, 1, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=W)
        logits = tf.nn.sigmoid(out_3)
        return logits, out_3

輸出圖片

def show_result(num_epoch, show = False, save = False, path):
    test_images = sess.run(G_noise, {noise: fixed_noise_, labels: fixed_label_, Training: False})
    size_figure_grid = 10
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
    for k in range(10*10):
        i = k // 10
        j = k % 10
        ax[i, j].cla()
        ax[i, j].imshow(np.reshape(test_images[k], (IMAGE_SIZE, IMAGE_SIZE)), cmap='gray')
    label = 'Step {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()

placeholder

x = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1))
noise = tf.placeholder(tf.float32, shape=(None, 1, 1, 100))
labels = tf.placeholder(tf.float32, shape=(None, 1, 1, 10))
real = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 10))
Training = tf.placeholder(dtype=tf.bool)

調(diào)整參數(shù)

# 運行生成網(wǎng)絡(luò)哦
G_noise = Generator(noise, labels, Training)
# 運行判別網(wǎng)絡(luò)
D_real, D_real_logits = Discriminator(x, real, Training)
D_fake, D_fake_logits = Discriminator(G_noise, real, Training, reuse=True)
# 計算每個網(wǎng)絡(luò)的損失函數(shù)
#算判別器真值的損失函數(shù)
Dis_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1, 1, 1])))
#算判別器噪聲生成圖片的損失函數(shù)
Dis_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1, 1, 1])))
#損失函數(shù)求和
Dis_loss = Dis_loss_real + Dis_loss_fake
#計算生成器的損失函數(shù)
Gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1, 1, 1])))
# 提取每個網(wǎng)絡(luò)的變量
tf_vars = tf.trainable_variables()
Dis_vars = [var for var in tf_vars if var.name.startswith('Discriminator')]
Gen_vars = [var for var in tf_vars if var.name.startswith('Generator')]
# 調(diào)整參數(shù) 設(shè)計是用來控制計算流圖的,給圖中的某些計算指定順序
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    optim = tf.train.AdamOptimizer(lr, beta1=0.5)#尋找全局最優(yōu)點的優(yōu)化算法黑忱,引入了二次方梯度校正 衰減率0.5
    D_optim = optim.minimize(Dis_loss, global_step=global_step, var_list=Dis_vars)#優(yōu)化更新訓(xùn)練的模型參數(shù)宴抚,也可以為全局步驟(global step)計數(shù)
    G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(Gen_loss, var_list=Gen_vars)#尋找全局最優(yōu)點的優(yōu)化算法,引入了二次方梯度校正 衰減率0.5

運行

# 開啟一個session甫煞,
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
#對MNIST做一下處理
train_set = (mnist.train.images - 0.5) / 0.5
train_label = mnist.train.labels

for i in range(step):
    Gen_losses = []
    Dis_losses = []
    i_start_time = time.time()
    index = random.sample(range(0, train_set.shape[0]), train_set.shape[0])
    new_set = train_set[index]
    new_label = train_label[index]
    for j in range(new_set.shape[0] // batch_size):
        #對判別器進行更新
        x_ = new_set[j*batch_size:(j+1)*batch_size]
        label_ = new_label[j*batch_size:(j+1)*batch_size].reshape([batch_size, 1, 1, 10])
        real_ = label_ * np.ones([batch_size, IMAGE_SIZE, IMAGE_SIZE, 10])
        noise_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
        loss_d_, _ = sess.run([Dis_loss, D_optim], {x: x_, noise: noise_, real: real_, labels: label_, Training: True})
        #對生成器進行更新
        noise_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
        y_ = np.random.randint(0, 9, (batch_size, 1))
        label_ = onehot[y_.astype(np.int32)].reshape([batch_size, 1, 1, 10])
        real_ = label_ * np.ones([batch_size, IMAGE_SIZE, IMAGE_SIZE, 10])
        loss_g_, _ = sess.run([Gen_loss, G_optim], {noise: noise_, x: x_, real: real_, labels: label_, Training: True})
        #計算訓(xùn)練過程中的損失函數(shù)
        errD_fake = Dis_loss_fake.eval({noise: noise_, labels: label_, real: real_, Training: False})
        errD_real = Dis_loss_real.eval({x: x_, labels: label_, real: real_, Training: False})
        errG = Gen_loss.eval({noise: noise_, labels: label_, real: real_, Training: False})
        Dis_losses.append(errD_fake + errD_real)
        Gen_losses.append(errG)
        if(j%10==0):
            pic = dirpath + 'FakeImg/' + model + str(i *new_set.shape[0] // batch_size + j+1) + '_' +str(i + 1) + '.png'
            show_result((i + 1), save=True, path=pic)
    print('判別器損失函數(shù): %.6f, 生成器損失函數(shù): %.6f' % np.mean(Dis_losses), np.mean(Gen_losses))
    pic = dirpath + 'FakeImg/' + model + str(i + 1) + '.png'
    show_result((i + 1), save=True, path=pic)
sess.close()

生成結(jié)果

迭代了30次
[圖片上傳失敗...(image-854000-1530238441721)]

總體的效果還是可以的菇曲,除了9有點看不清之外,0-8的輪廓還是很清晰的抚吠。
GAN的用途非常廣泛常潮,過幾天在寫一個生成臉部圖片的網(wǎng)絡(luò)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末楷力,一起剝皮案震驚了整個濱河市喊式,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌萧朝,老刑警劉巖岔留,帶你破解...
    沈念sama閱讀 217,185評論 6 503
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異检柬,居然都是意外死亡献联,警方通過查閱死者的電腦和手機,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,652評論 3 393
  • 文/潘曉璐 我一進店門何址,熙熙樓的掌柜王于貴愁眉苦臉地迎上來里逆,“玉大人,你說我怎么就攤上這事头朱≡吮” “怎么了?”我有些...
    開封第一講書人閱讀 163,524評論 0 353
  • 文/不壞的土叔 我叫張陵项钮,是天一觀的道長班眯。 經(jīng)常有香客問我希停,道長,這世上最難降的妖魔是什么署隘? 我笑而不...
    開封第一講書人閱讀 58,339評論 1 293
  • 正文 為了忘掉前任宠能,我火速辦了婚禮,結(jié)果婚禮上磁餐,老公的妹妹穿的比我還像新娘违崇。我一直安慰自己,他們只是感情好诊霹,可當(dāng)我...
    茶點故事閱讀 67,387評論 6 391
  • 文/花漫 我一把揭開白布羞延。 她就那樣靜靜地躺著,像睡著了一般脾还。 火紅的嫁衣襯著肌膚如雪伴箩。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,287評論 1 301
  • 那天鄙漏,我揣著相機與錄音嗤谚,去河邊找鬼。 笑死怔蚌,一個胖子當(dāng)著我的面吹牛巩步,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播桦踊,決...
    沈念sama閱讀 40,130評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼椅野,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了钞钙?” 一聲冷哼從身側(cè)響起鳄橘,我...
    開封第一講書人閱讀 38,985評論 0 275
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎芒炼,沒想到半個月后瘫怜,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,420評論 1 313
  • 正文 獨居荒郊野嶺守林人離奇死亡本刽,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,617評論 3 334
  • 正文 我和宋清朗相戀三年鲸湃,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片子寓。...
    茶點故事閱讀 39,779評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡暗挑,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出斜友,到底是詐尸還是另有隱情炸裆,我是刑警寧澤,帶...
    沈念sama閱讀 35,477評論 5 345
  • 正文 年R本政府宣布鲜屏,位于F島的核電站烹看,受9級特大地震影響国拇,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜惯殊,卻給世界環(huán)境...
    茶點故事閱讀 41,088評論 3 328
  • 文/蒙蒙 一酱吝、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧土思,春花似錦务热、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,716評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至址愿,卻和暖如春该镣,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背响谓。 一陣腳步聲響...
    開封第一講書人閱讀 32,857評論 1 269
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留省艳,地道東北人娘纷。 一個月前我還...
    沈念sama閱讀 47,876評論 2 370
  • 正文 我出身青樓,卻偏偏與公主長得像跋炕,于是被迫代替她去往敵國和親赖晶。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,700評論 2 354

推薦閱讀更多精彩內(nèi)容