原文歡迎關(guān)注http://blackblog.tech/2018/01/25/cDCGAN生成MNIST圖片/
歡迎關(guān)注我的個人博客 http://blackblog.tech
前幾日和媳,學(xué)校期末作業(yè)要求我們使用機器學(xué)習(xí)的方法解決一個實際問題击喂,思考了很久,嘗試做了很多選題弦悉,最終決定做一個cDCGAN瞳收,即條件深度卷積生成對抗網(wǎng)絡(luò)涮阔。
為什么做這個選題呢叮盘?
生成對抗網(wǎng)絡(luò)這幾年實在是火爆,圖片上色霹俺,視頻去馬賽克柔吼,包括英偉達最近展出的白馬變棕馬,白天變黑夜丙唧,都是使用生成對抗網(wǎng)絡(luò)實現(xiàn)的愈魏。
2014年"Generative Adversarial Nets"這篇論文中所提到的生成對抗網(wǎng)絡(luò)是一個無監(jiān)督的生成對抗網(wǎng)絡(luò),且沒有使用卷積與反卷積操作想际。
今天我們以MNIST手寫集為數(shù)據(jù)集培漏,使用tensorflow實現(xiàn)cDCGAN(條件深度卷積生成對抗網(wǎng)絡(luò))
算法描述
生成對抗網(wǎng)絡(luò)(Generative Adversarial Nets)啟發(fā)自博弈論中的兩人零和博弈,GAN模型中的兩位博弈方分別有生成網(wǎng)絡(luò)(Generator)與判別網(wǎng)絡(luò)(Discriminator)充當(dāng)胡本。當(dāng)生成網(wǎng)絡(luò)G捕捉到樣本數(shù)據(jù)分布牌柄,用服從某一分布的噪聲z生成一個類似真實訓(xùn)練數(shù)據(jù)的樣本,與真實樣本越接近越好侧甫;判別網(wǎng)絡(luò)D一般是一個二分類模型珊佣,在本文中D是一個多分類器,用于估計一個樣本來自于真實數(shù)據(jù)的概率闺骚,如果樣本來自于真實數(shù)據(jù)彩扔,則D輸出大概率,否則輸出小概率僻爽。本文中虫碉,判別網(wǎng)絡(luò)需要在此基礎(chǔ)上實現(xiàn)分類功能。
在訓(xùn)練的過程中胸梆,需要固定一方敦捧,更新另一方的網(wǎng)絡(luò)狀態(tài),如此交替進行碰镜。在整個訓(xùn)練的過程中兢卵,雙方都極力優(yōu)化自己的網(wǎng)絡(luò),從而形成競爭對抗绪颖,知道雙方達到一個動態(tài)的平衡秽荤。此時生成網(wǎng)絡(luò)訓(xùn)練出來的數(shù)據(jù)與真實數(shù)據(jù)的分布幾乎相同,判別網(wǎng)絡(luò)也無法再判斷出真?zhèn)巍?br>
本文中生成對抗網(wǎng)絡(luò)主要分為兩部分柠横,生成網(wǎng)絡(luò)(Generator)與判別網(wǎng)絡(luò)(Discriminator)窃款。向生成網(wǎng)絡(luò)內(nèi)輸入噪聲,通過多次反卷積的方式得到一個28x28x1的圖像作為X_fake牍氛,此時將真實的圖像X_real與生成器生成的X_fake放入判別網(wǎng)絡(luò)晨继,判別網(wǎng)絡(luò)使用多次卷積與Sigmoid函數(shù)并通過交叉熵函數(shù)計算出判別網(wǎng)絡(luò)的損失函數(shù)D_loss,通過判別網(wǎng)絡(luò)的損失函數(shù)D_loss計算得到生成網(wǎng)絡(luò)損失函數(shù)G_loss搬俊。使用G_loss與D_loss對生成網(wǎng)絡(luò)與判別網(wǎng)絡(luò)進行參數(shù)調(diào)整紊扬。
算法流程
1.輸入噪聲z
2.通過生成網(wǎng)絡(luò)G得到X_fake=G(z)
3.從數(shù)據(jù)集中獲取真實數(shù)據(jù)X_real
4.通過判別網(wǎng)絡(luò)D計算D(real logits)=D(X_real)
5.通過判別網(wǎng)絡(luò)D計算D(fake logits)=D(X_fake)
6.使用交叉熵函數(shù)做損失函數(shù)根據(jù)D(real logits)計算D(loss real)
7.使用交叉熵函數(shù)做損失函數(shù)根據(jù)D(fake logits)計算D(loss fake)
8.計算判別網(wǎng)絡(luò)損失函數(shù)D_loss=D(loss real)+ D_(loss fake)
9.使用交叉熵函數(shù)做損失函數(shù)計算生成網(wǎng)絡(luò)損失函數(shù)G_loss
10.使用D_loss對判別網(wǎng)絡(luò)進行參數(shù)調(diào)整蜒茄,使用G_loss對生成網(wǎng)絡(luò)參數(shù)進行調(diào)整
網(wǎng)絡(luò)結(jié)構(gòu)
生成網(wǎng)絡(luò)
[圖片上傳失敗...(image-54149c-1530238358410)]
判別網(wǎng)絡(luò)
[圖片上傳失敗...(image-6bc57d-1530238358410)]
數(shù)據(jù)集
MNIST.....
就不多說啥了
訓(xùn)練環(huán)境
系統(tǒng):Windows 10
框架:tensorflow 1.2
CPU:Intel core i5-4210H
GPU:Nvidia GTX 960M 4G(買不起顯卡........)
上代碼!
一些常量的定義餐屎,包括學(xué)校率檀葛,batch_size,保存的路徑等等
import os, time, random,itertools
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import cv2
# 保存圖片
dirpath = 'GAN/'
model = 'GAN_MINIST'
if not os.path.isdir(dirpath):
os.mkdir(dirpath)
if not os.path.isdir(dirpath + 'FakeImg'):
os.mkdir(dirpath + 'FakeImg')
# 初始化
IMAGE_SIZE = 28
onehot = np.eye(10)
noise_ = np.random.normal(0, 1, (10, 1, 1, 100))
fixed_noise_ = noise_
fixed_label_ = np.zeros((10, 1))
#用于最后顯示十組圖像
for i in range(9):
fixed_noise_ = np.concatenate([fixed_noise_, noise_], 0)
temp = np.ones((10, 1)) + I
fixed_label_ = np.concatenate([fixed_label_, temp], 0)
fixed_label_ = onehot[fixed_label_.astype(np.int32)].reshape((100, 1, 1, 10))
batch_size = 100
#一共迭代20次
step = 30
#設(shè)置一個全局的計數(shù)器
global_step = tf.Variable(0, trainable=False)
#設(shè)置學(xué)習(xí)率
lr = tf.train.exponential_decay(0.0002, global_step, 500, 0.95, staircase=True)
#加載數(shù)據(jù)集Batch大小:100
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=[])
leaky_relu的定義
def leaky_relu(X, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * X + f2 * tf.abs(X)
生成網(wǎng)絡(luò)的定義:
def Generator(x, labels, Training=True, reuse=False):
with tf.variable_scope('Generator', reuse=reuse):
#初始化參數(shù)
W = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
b = tf.constant_initializer(0.0)
#把數(shù)據(jù)和標(biāo)簽進行連接
concat = tf.concat([x, labels], 3)
#第一次反卷積,卷積核大小為7*7啤挎,輸出維度256
out_1 = tf.layers.conv2d_transpose(concat, 256, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=W, bias_initializer=b)
out_1 = tf.layers.batch_normalization(out_1, training=Training)#batch norm
out_1 = leaky_relu(out_1, 0.2)
#第二次反卷機驻谆,卷積核大小為5*5,輸出維度128
out_2 = tf.layers.conv2d_transpose(out_1, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
out_2 = tf.layers.batch_normalization(out_2, training=Training)#batch norm
out_2 = leaky_relu(out_2, 0.2)
#第三次反卷機庆聘,卷積核大小5*5,輸出維度1
out_3 = tf.layers.conv2d_transpose(out_2, 1, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
out_3 = tf.nn.tanh(out_3)
return out_3
判別網(wǎng)絡(luò)的定義
def Discriminator(x, real, Training=True, reuse=False):
with tf.variable_scope('Discriminator', reuse=reuse):
#初始化參數(shù)
W = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
b = tf.constant_initializer(0.0)
#把數(shù)據(jù)和標(biāo)簽進行連接
concat = tf.concat([x, real], 3)
#第一次卷積 卷積核為5*5 輸出維度為128
out_1 = tf.layers.conv2d(concat, 128, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
out_1 = leaky_relu(out_1, 0.2)
# 第二次卷積 卷積核為5*5 輸出維度256
out_2 = tf.layers.conv2d(out_1, 256, [5, 5], strides=(2, 2), padding='same', kernel_initializer=W, bias_initializer=b)
out_2 = tf.layers.batch_normalization(out_2, training=Training)#batch norm
out_2 = leaky_relu(out_2, 0.2)
#第三次卷積勺卢,卷積和為7*7伙判,輸出維度為1
out_3 = tf.layers.conv2d(out_2, 1, [7, 7], strides=(1, 1), padding='valid', kernel_initializer=W)
logits = tf.nn.sigmoid(out_3)
return logits, out_3
輸出圖片
def show_result(num_epoch, show = False, save = False, path):
test_images = sess.run(G_noise, {noise: fixed_noise_, labels: fixed_label_, Training: False})
size_figure_grid = 10
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(10*10):
i = k // 10
j = k % 10
ax[i, j].cla()
ax[i, j].imshow(np.reshape(test_images[k], (IMAGE_SIZE, IMAGE_SIZE)), cmap='gray')
label = 'Step {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
if save:
plt.savefig(path)
if show:
plt.show()
else:
plt.close()
placeholder
x = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 1))
noise = tf.placeholder(tf.float32, shape=(None, 1, 1, 100))
labels = tf.placeholder(tf.float32, shape=(None, 1, 1, 10))
real = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE, 10))
Training = tf.placeholder(dtype=tf.bool)
調(diào)整參數(shù)
# 運行生成網(wǎng)絡(luò)哦
G_noise = Generator(noise, labels, Training)
# 運行判別網(wǎng)絡(luò)
D_real, D_real_logits = Discriminator(x, real, Training)
D_fake, D_fake_logits = Discriminator(G_noise, real, Training, reuse=True)
# 計算每個網(wǎng)絡(luò)的損失函數(shù)
#算判別器真值的損失函數(shù)
Dis_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1, 1, 1])))
#算判別器噪聲生成圖片的損失函數(shù)
Dis_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1, 1, 1])))
#損失函數(shù)求和
Dis_loss = Dis_loss_real + Dis_loss_fake
#計算生成器的損失函數(shù)
Gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1, 1, 1])))
# 提取每個網(wǎng)絡(luò)的變量
tf_vars = tf.trainable_variables()
Dis_vars = [var for var in tf_vars if var.name.startswith('Discriminator')]
Gen_vars = [var for var in tf_vars if var.name.startswith('Generator')]
# 調(diào)整參數(shù) 設(shè)計是用來控制計算流圖的,給圖中的某些計算指定順序
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
optim = tf.train.AdamOptimizer(lr, beta1=0.5)#尋找全局最優(yōu)點的優(yōu)化算法黑忱,引入了二次方梯度校正 衰減率0.5
D_optim = optim.minimize(Dis_loss, global_step=global_step, var_list=Dis_vars)#優(yōu)化更新訓(xùn)練的模型參數(shù)宴抚,也可以為全局步驟(global step)計數(shù)
G_optim = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(Gen_loss, var_list=Gen_vars)#尋找全局最優(yōu)點的優(yōu)化算法,引入了二次方梯度校正 衰減率0.5
運行
# 開啟一個session甫煞,
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
#對MNIST做一下處理
train_set = (mnist.train.images - 0.5) / 0.5
train_label = mnist.train.labels
for i in range(step):
Gen_losses = []
Dis_losses = []
i_start_time = time.time()
index = random.sample(range(0, train_set.shape[0]), train_set.shape[0])
new_set = train_set[index]
new_label = train_label[index]
for j in range(new_set.shape[0] // batch_size):
#對判別器進行更新
x_ = new_set[j*batch_size:(j+1)*batch_size]
label_ = new_label[j*batch_size:(j+1)*batch_size].reshape([batch_size, 1, 1, 10])
real_ = label_ * np.ones([batch_size, IMAGE_SIZE, IMAGE_SIZE, 10])
noise_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
loss_d_, _ = sess.run([Dis_loss, D_optim], {x: x_, noise: noise_, real: real_, labels: label_, Training: True})
#對生成器進行更新
noise_ = np.random.normal(0, 1, (batch_size, 1, 1, 100))
y_ = np.random.randint(0, 9, (batch_size, 1))
label_ = onehot[y_.astype(np.int32)].reshape([batch_size, 1, 1, 10])
real_ = label_ * np.ones([batch_size, IMAGE_SIZE, IMAGE_SIZE, 10])
loss_g_, _ = sess.run([Gen_loss, G_optim], {noise: noise_, x: x_, real: real_, labels: label_, Training: True})
#計算訓(xùn)練過程中的損失函數(shù)
errD_fake = Dis_loss_fake.eval({noise: noise_, labels: label_, real: real_, Training: False})
errD_real = Dis_loss_real.eval({x: x_, labels: label_, real: real_, Training: False})
errG = Gen_loss.eval({noise: noise_, labels: label_, real: real_, Training: False})
Dis_losses.append(errD_fake + errD_real)
Gen_losses.append(errG)
if(j%10==0):
pic = dirpath + 'FakeImg/' + model + str(i *new_set.shape[0] // batch_size + j+1) + '_' +str(i + 1) + '.png'
show_result((i + 1), save=True, path=pic)
print('判別器損失函數(shù): %.6f, 生成器損失函數(shù): %.6f' % np.mean(Dis_losses), np.mean(Gen_losses))
pic = dirpath + 'FakeImg/' + model + str(i + 1) + '.png'
show_result((i + 1), save=True, path=pic)
sess.close()
生成結(jié)果
迭代了30次
[圖片上傳失敗...(image-854000-1530238441721)]
總體的效果還是可以的菇曲,除了9有點看不清之外,0-8的輪廓還是很清晰的抚吠。
GAN的用途非常廣泛常潮,過幾天在寫一個生成臉部圖片的網(wǎng)絡(luò)。