GAN的簡介
近年來弯汰,GAN(生成對抗式網(wǎng)絡(luò))成功地應(yīng)用于圖像生成让蕾、圖像編輯和和表達學(xué)習(xí)等方面睛约。最小化對抗損失使得生成的圖像看起來真實叫惊。GAN的基本原理為:
- 生成器G是生成圖片的網(wǎng)絡(luò)帝洪,接收一個隨機的噪聲z似舵,生成圖片G(z)。其目標是盡量生成真實的圖片去欺騙判別網(wǎng)絡(luò)D葱峡。
-
判別器D是判別一張圖片是否為真實砚哗。輸入一張圖片x,輸出D(x)為x為真實圖片的概率砰奕。其目的是盡量把生成器生成的圖片和真實的圖片區(qū)別出來蛛芥。
在理想情況下,生成器可以生成足以以假亂真的圖片军援。而判別器難以辨別生成器生成的圖片是否為真仅淑。
GAN的損失函數(shù)為:
CycleGAN原理
圖像與圖像之間的變換
在傳統(tǒng)的CNN方法中,圖像與圖像之間的變換是通過CNN來學(xué)習(xí)轉(zhuǎn)移參數(shù)胸哥。
而本文的cycleGAN算法可以直接從一個圖像生成另一個圖像來實現(xiàn)圖像之間的變換涯竟。
CycleGAN
目的:學(xué)習(xí)域X與域Y之間的映射關(guān)系。在CycleGAN模型中包括兩個映射:X->Y, Y->X。如下圖所示庐船。
在該網(wǎng)絡(luò)中银酬,存在兩個域之間分別轉(zhuǎn)換的生成器,以及每個生成器對應(yīng)的判別器筐钟。目標函數(shù)中包括兩項:
- 對抗損失:使用控制生成的圖像為目標域的圖像捡硅。
-
cycle loss:為了防止兩個生成器之間是相互矛盾的。
在本項目中用來實現(xiàn)男女性別兩個域之間的轉(zhuǎn)換盗棵。
代碼解析
### generator
conv(7, 7, 32)
conv(3, 3, 64)
conv(3, 3, 128)
res_block * 6
deconv(3, 3, 64)
deconv(3, 3, 32)
conv(7, 7, 3)
### discriminator
conv(3, 3, 64)
conv(3, 3, 128)
conv(3, 3, 256)
conv(3, 3, 512)
conv(4, 4, 512)
### resnet_block
def build_resnet_block(inputres, dim, name="resnet", padding="REFLECT"):
with tf.variable_scope(name):
out_res = tf.pad(inputres, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c1")
out_res = tf.pad(out_res, [[0, 0], [1, 1], [1, 1], [0, 0]], padding)
out_res = layers.general_conv2d(out_res, dim, 3, 3, 1, 1, 0.02, "VALID", "c2", do_relu=False)
return tf.nn.relu(out_res + inputres)
### generator
def build_generator_resnet_9blocks_tf(inputgen, name="generator", skip=False):
with tf.variable_scope(name):
f = 7
ks = 3
padding = "REFLECT"
pad_input = tf.pad(inputgen, [[0, 0], [ks, ks], [ ks, ks], [0, 0]], padding)
o_c1 = layers.general_conv2d(pad_input, ngf, f, f, 1, 1, 0.02, name="c1")
o_c2 = layers.general_conv2d(o_c1, ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c2")
o_c3 = layers.general_conv2d(o_c2, ngf * 4, ks, ks, 2, 2, 0.02, "SAME", "c3")
o_r1 = build_resnet_block(o_c3, ngf * 4, "r1", padding)
o_r2 = build_resnet_block(o_r1, ngf * 4, "r2", padding)
o_r3 = build_resnet_block(o_r2, ngf * 4, "r3", padding)
o_r4 = build_resnet_block(o_r3, ngf * 4, "r4", padding)
o_r5 = build_resnet_block(o_r4, ngf * 4, "r5", padding)
o_r6 = build_resnet_block(o_r5, ngf * 4, "r6", padding)
o_r7 = build_resnet_block(o_r6, ngf * 4, "r7", padding)
o_r8 = build_resnet_block(o_r7, ngf * 4, "r8", padding)
o_r9 = build_resnet_block(o_r8, ngf * 4, "r9", padding)
o_c4 = layers.general_deconv2d(o_r9, [BATCH_SIZE, 128, 128, ngf * 2], ngf * 2, ks, ks, 2, 2, 0.02, "SAME", "c4")
o_c5 = layers.general_deconv2d(o_c4, [BATCH_SIZE, 256, 256, ngf], ngf, ks, ks, 2, 2, 0.02,"SAME", "c5")
o_c6 = layers.general_conv2d(o_c5, IMG_CHANNELS, f, f, 1, 1, 0.02, "SAME", "c6",do_norm=False, do_relu=False)
if skip is True:
out_gen = tf.nn.tanh(inputgen + o_c6, "t1")
else:
out_gen = tf.nn.tanh(o_c6, "t1")
return out_gen
### discriminator
def discriminator_tf(inputdisc, name="discriminator"):
with tf.variable_scope(name):
f = 4
o_c1 = layers.general_conv2d(inputdisc, ndf, f, f, 2, 2,0.02, "SAME", "c1", do_norm=False, relufactor=0.2)
o_c2 = layers.general_conv2d(o_c1, ndf * 2, f, f, 2, 2, 0.02, "SAME", "c2", relufactor=0.2)
o_c3 = layers.general_conv2d(o_c2, ndf * 4, f, f, 2, 2, 0.02, "SAME", "c3", relufactor=0.2)
o_c4 = layers.general_conv2d(o_c3, ndf * 8, f, f, 1, 1,0.02, "SAME", "c4", relufactor=0.2)
o_c5 = layers.general_conv2d(o_c4, 1, f, f, 1, 1, 0.02, "SAME", "c5", do_norm=False, do_relu=False
)
return o_c5
### layers.py
import tensorflow as tf
def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
with tf.variable_scope(name):
if alt_relu_impl:
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
else:
return tf.maximum(x, leak * x)
def instance_norm(x):
with tf.variable_scope("instance_norm"):
epsilon = 1e-5
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
scale = tf.get_variable('scale', [x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(mean=1.0, stddev=0.02
))
offset = tf.get_variable('offset', [x.get_shape()[-1]], initializer=tf.constant_initializer(0.0)
)
out = scale * tf.div(x - mean, tf.sqrt(var + epsilon)) + offset
return out
def general_conv2d(inputconv, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1, stddev=0.02,
padding="VALID", name="conv2d", do_norm=True, do_relu=True,
relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d( inputconv, o_d, f_w, s_w, padding, activation_fn=None, weights_initializer=tf.truncated_normal_initializer(stddev=stddev
), biases_initializer=tf.constant_initializer(0.0))
if do_norm:
conv = instance_norm(conv)
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
def general_deconv2d(inputconv, outshape, o_d=64, f_h=7, f_w=7, s_h=1, s_w=1,
stddev=0.02, padding="VALID", name="deconv2d",
do_norm=True, do_relu=True, relufactor=0):
with tf.variable_scope(name):
conv = tf.contrib.layers.conv2d_transpose(inputconv, o_d, [f_h, f_w], [s_h, s_w], padding, activation_fn=None,weights_initializer=tf.truncated_normal_initializer(stddev=stddev), biases_initializer=tf.constant_initializer(0.0))
if do_norm:
conv = instance_norm(conv)
# conv = tf.contrib.layers.batch_norm(conv, decay=0.9,
# updates_collections=None, epsilon=1e-5, scale=True,
# scope="batch_norm")
if do_relu:
if(relufactor == 0):
conv = tf.nn.relu(conv, "relu")
else:
conv = lrelu(conv, relufactor, "lrelu")
return conv
測試的結(jié)果: