GAN:對抗性生成網絡丑瞧,通俗來講,即有兩個網絡一個是g(generator )網絡蜀肘,用于生成嗦篱,一個是d(discriminator)網絡,用于判斷幌缝。
GAN網絡的目的就是使其自己生成一副圖片灸促,比如說經過對一系列貓的圖片的處理,g網絡可以自己“繪制”出一張貓的圖片涵卵,且盡量真實浴栽。
d網絡則是用來進行判斷的,將一張真實的圖片和一張由g網絡生成的照片同時交給d網絡轿偎,不斷訓練d網絡典鸡,使其可以準確判斷,將d網絡生成的“假圖片”找出來坏晦。
再回到兩個網絡上萝玷,g網絡不斷改進使其可以騙過d網絡,而d網絡不斷改進使其可以更準確找到“假圖片”昆婿,這種相互促進相互對抗的關系球碉,就叫做對抗網絡。
我們可以使用tensorflow中的mnist手寫體數(shù)據(jù)來進行實現(xiàn)仓蛆。
實現(xiàn)原理如下:
將一張隨機像素的圖片經過一個全連接層后經過一個Leaky ReLU處理睁冬,之后為了避免過擬合dropout后再經過一個全連接層進行tanh激活后,生成一張“假圖片”
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
with tf.variable_scope("generator", reuse=reuse):
hidden1 = tf.layers.dense(noise_img, n_units) # 全連接層
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return logits, outputs
將待判定的圖片經過全連接層-->Leaky ReLU-->全連接層-->sigmoid激活函數(shù)處理后看疙,得到0或1的結果豆拨。
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
在實現(xiàn)時,我們可以首先把MNIST數(shù)據(jù)中的標簽為0的圖像提取出來能庆,存到列表中施禾。
i = j = 0
while i<5000:
if mnist.train.labels[j] == 0:
samples.append(mnist.train.images[j])
i += 1
j += 1
這樣就可以在訓練時只訓練標簽為0的圖像。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets("D:/python/MNIST_data/")
img = mnist.train.images[50]
def get_inputs(real_size, noise_size):
real_img = tf.placeholder(tf.float32, [None, real_size], name="real_img")
noise_img = tf.placeholder(tf.float32, [None, noise_size], name="noise_img")
return real_img, noise_img
# 生成
def get_generator(noise_img, n_units, out_dim, reuse=False, alpha=0.01):
with tf.variable_scope("generator", reuse=reuse):
hidden1 = tf.layers.dense(noise_img, n_units) # 全連接層
hidden1 = tf.maximum(alpha * hidden1, hidden1)
hidden1 = tf.layers.dropout(hidden1, rate=0.2)
logits = tf.layers.dense(hidden1, out_dim)
outputs = tf.tanh(logits)
return logits, outputs
# 判別
def get_discriminator(img, n_units, reuse=False, alpha=0.01):
with tf.variable_scope("discriminator", reuse=reuse):
hidden1 = tf.layers.dense(img, n_units)
hidden1 = tf.maximum(alpha * hidden1, hidden1)
logits = tf.layers.dense(hidden1, 1)
outputs = tf.sigmoid(logits)
return logits, outputs
img_size = mnist.train.images[0].shape[0]
noise_size = 100
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
tf.reset_default_graph()
real_img, noise_img = get_inputs(img_size, noise_size)
g_logits, g_outputs = get_generator(noise_img, g_units, img_size)
d_logits_real, d_outputs_real = get_discriminator(real_img, d_units)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, d_units, reuse=True)
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_real, labels=tf.ones_like(d_logits_real)
) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake)
))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)
) * (1 - smooth))
train_vars = tf.trainable_variables()
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
epochs = 5000
samples = []
n_sample = 10
losses = []
i = j = 0
while i<5000:
if mnist.train.labels[j] == 0:
samples.append(mnist.train.images[j])
i += 1
j += 1
print(len(samples))
size = samples[0].size
with tf.Session() as sess:
tf.global_variables_initializer().run()
for e in range(epochs):
batch_images = samples[e] * 2 -1
batch_noise = np.random.uniform(-1, 1, size=noise_size)
_ = sess.run(d_train_opt, feed_dict={real_img:[batch_images], noise_img:[batch_noise]})
_ = sess.run(g_train_opt, feed_dict={noise_img:[batch_noise]})
sample_noise = np.random.uniform(-1, 1, size=noise_size)
g_logit, g_output = sess.run(get_generator(noise_img, g_units, img_size,
reuse=True), feed_dict={
noise_img:[sample_noise]
})
print(g_logit.size)
g_output = (g_output+1)/2
plt.imshow(g_output.reshape([28, 28]), cmap='Greys_r')
plt.show()
運行結果:
可以看出搁胆,在經過了5000次的迭代后弥搞,g網絡生成的圖片已經可以大致呈現(xiàn)出一個0的形狀邮绿。