以下代碼是論文《Globally and locally consistent image completion》的代碼實(shí)現(xiàn)膜楷,論文地址:http://xueshu.baidu.com/usercenter/paper/show?paperid=ea74830570062151f14abfb1fe89bb33&site=xueshu_se&hitarticle=1
論文速讀可參考我的另一篇文章:http://www.reibang.com/p/12da271c8bf8
使用的數(shù)據(jù)集是CelebA人臉數(shù)據(jù)集肆糕,數(shù)據(jù)集下載地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
框架:tensorflow 1.11.0
圖像大忻氚巍:128 * 128 | 原文是 256 * 256
原文作者使用4個(gè)K80 GPU穷娱,訓(xùn)練了2個(gè)月才訓(xùn)練完成谷市,我這里暫時(shí)沒(méi)有什么硬件資源亡嫌,根本跑不動(dòng)嚎于,所以下面的代碼雖然可以跑,但一些參數(shù)只是根據(jù)經(jīng)驗(yàn)設(shè)的挟冠,并沒(méi)有驗(yàn)證其效果于购,請(qǐng)謹(jǐn)慎食用。
因?yàn)橐呀?jīng)寫過(guò)論文的研究思路知染,下面不再對(duì)代碼實(shí)現(xiàn)思路進(jìn)行講解肋僧,如果有不懂的請(qǐng)參考上面提到的論文和相應(yīng)講解文章,代碼的盡可能的詳細(xì)的注釋了控淡,相信大家看起來(lái)難度不大嫌吠。
make_data.py
"""
說(shuō)明:數(shù)據(jù)預(yù)處理,將圖片讀取到npy文件中掺炭,這樣就可避免每次都去讀一個(gè)一個(gè)的圖片數(shù)據(jù)居兆,可以加快讀取數(shù)據(jù)的速度
npy文件——Numpy專用的二進(jìn)制格式
數(shù)據(jù)集地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
import glob
import cv2
import numpy as np
image_size = 128
train_ratio = 0.8
# 得到所有圖片的路徑,數(shù)據(jù)文件不要放在工程目錄下竹伸,否則編輯工程時(shí)可能會(huì)比較卡
paths = glob.glob(r'D:\bigdata\img_align_celeba/*.jpg')
x = [] # 圖片數(shù)據(jù)列表
# 讀取圖像泥栖,為了訓(xùn)練快點(diǎn),只取1000個(gè)圖像進(jìn)程處理
for img_path in paths[:1000]:
'''
如果cv2沒(méi)有提示勋篓,卸載重裝
卸載:pip uninstall opencv-python
不使用緩存重裝:pip --no-cache-dir install opencv-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
安裝擴(kuò)展:pip --no-cache-dir install opencv-contrib-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
'''
img = cv2.imread(img_path) # 得到每幅圖片的矩陣表示吧享,shape:(218, 178, 3)
# 對(duì)圖像進(jìn)行縮放--插值法
img = cv2.resize(img, (image_size, image_size))
# 色彩空間的轉(zhuǎn)化,以便生成mask圖等操作
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
x.append(img)
x = np.array(x, dtype=np.uint8) # 規(guī)定數(shù)據(jù)量類型np.uint8是為了節(jié)省存儲(chǔ)空間
# 打亂圖片排序
np.random.shuffle(x)
p = int(train_ratio * len(x))
x_train = x[:p]
x_test = x[p:]
np.save(r'D:\demos\image processing\demo\data_my\x_train.npy', x_train)
np.save(r'D:\demos\image processing\demo\data_my\x_test.npy', x_test)
network_build.py
"""
定義訓(xùn)練網(wǎng)絡(luò)所需的各部分
"""
import tensorflow as tf
class Network:
def __init__(self, x, mask, local_x, global_completion, local_completion, is_training, batch_size):
"""
:param x: 輸入
:param mask: 需要填補(bǔ)修復(fù)的圖像表示譬嚣,缺失區(qū)域的值為1钢颂,其他部分為0,整體大小和完整圖像相同
:param local_x: 從原圖中摳出來(lái)的那部分
:param global_completion: 經(jīng)過(guò)補(bǔ)全網(wǎng)絡(luò)后的缺失部分之外的部分圖像
:param local_completion: 經(jīng)過(guò)補(bǔ)全網(wǎng)絡(luò)之后的補(bǔ)全的部分
:param is_training:
:param batch_size:
"""
self.batch_size = batch_size
# x * (1 - mask)可以實(shí)現(xiàn)將輸入圖像“挖洞”拜银,洞的地方值全為0,所以generator輸入的是一張帶洞的圖像
self.imitation = self.generator(x * (1 - mask), is_training)
# 補(bǔ)全圖像應(yīng)該是:填補(bǔ)的地方是網(wǎng)絡(luò)生成的殊鞭,但其他地方應(yīng)該是原圖數(shù)值
self.completion = self.imitation * mask + x * (1 - mask)
# 輸入真實(shí)圖像數(shù)據(jù),需要用到的變量尼桶,自己去創(chuàng)建操灿、更新
self.real = self.discriminator(x, local_x, reuse=False)
# 輸入補(bǔ)全網(wǎng)絡(luò)生成的圖像數(shù)據(jù),判別網(wǎng)絡(luò)用到的變量應(yīng)當(dāng)是訓(xùn)練真實(shí)圖像數(shù)據(jù)時(shí)創(chuàng)建的相同變量泵督,所以reuse=True
self.fake = self.discriminator(global_completion, local_completion, reuse=True)
self.g_loss = self.calc_g_loss(x, self.completion)
self.d_loss = self.calc_d_loss(self.real, self.fake)
"""
tf.get_collection(key,scope=None)
用來(lái)獲取一個(gè)名稱是‘key’的集合中的所有元素趾盐,返回的是一個(gè)列表,列表的順序是按照變量放入集合中的先后; scope參數(shù)可選,
表示的是名稱空間(名稱域)救鲤,如果指定久窟,就返回名稱域中所有放入‘key’的變量的列表,不指定則返回所有變量本缠。
tf.Optimizer默認(rèn)只優(yōu)化tf.GraphKeys.TRAINABLE_VARIABLES中的變量斥扛。
"""
self.g_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
self.d_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
def conv_layer(self, x, filter_shape, stride):
filters = tf.get_variable(
name='weight',
shape=filter_shape,
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
# padding='SAME',輸出圖像大小是邊長(zhǎng)除以步長(zhǎng),向上取整丹锹,就是不夠則填充
# padding='VALID',輸出圖像大小是邊長(zhǎng)減去濾波器大小加一犹赖,最后除以步長(zhǎng),向上取整卷仑,不進(jìn)行填充
"""
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
除去name參數(shù)用以指定該操作的name峻村,與方法有關(guān)的一共五個(gè)參數(shù):
第一個(gè)參數(shù)input:指需要做卷積的輸入圖像,它要求是一個(gè)Tensor锡凝,具有[batch, in_height, in_width, in_channels]
這樣的shape粘昨,具體含義是[訓(xùn)練時(shí)一個(gè)batch的圖片數(shù)量, 圖片高度, 圖片寬度, 圖像通道數(shù)],注意這是一個(gè)4維的Tensor窜锯,
要求類型為float32和float64其中之一
第二個(gè)參數(shù)filter:相當(dāng)于CNN中的卷積核仗岸,它要求是一個(gè)Tensor遇汞,具有[filter_height, filter_width, in_channels, out_channels]
這樣的shape,具體含義是[卷積核的高度,卷積核的寬度眼刃,圖像通道數(shù)蒋院,卷積核個(gè)數(shù)]哼勇,要求類型與參數(shù)input相同炸庞,有一個(gè)地方需要注意,
第三維in_channels翠勉,就是參數(shù)input的第四維
第三個(gè)參數(shù)strides:卷積時(shí)在圖像每一維的步長(zhǎng)妖啥,這是一個(gè)一維的向量,長(zhǎng)度4
第四個(gè)參數(shù)padding:string類型的量对碌,只能是"SAME","VALID"其中之一荆虱,這個(gè)值決定了不同的卷積方式
第五個(gè)參數(shù):use_cudnn_on_gpu:bool類型,是否使用cudnn加速朽们,默認(rèn)為true
結(jié)果返回一個(gè)Tensor怀读,這個(gè)輸出,就是我們常說(shuō)的feature map骑脱,shape仍然是[batch, height, width, channels]這種形式菜枷。
"""
return tf.nn.conv2d(x, filters, [1, stride, stride, 1], padding='SAME')
# 空洞卷積
def dilated_conv_layer(self, x, filter_shape, dilation):
filters = tf.get_variable(
name='weight',
shape=filter_shape,
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
"""
tf.nn.atrous_conv2d(value,filters,rate,padding,name=None)
value: 指需要做卷積的輸入圖像,要求是一個(gè)4維Tensor惜姐,具有[batch, height, width, channels]
filters: 相當(dāng)于CNN中的卷積核犁跪,要求是一個(gè)4維Tensor椿息,具有[filter_height, filter_width, channels, out_channels]
這樣的shape歹袁,具體含義是[卷積核的高度坷衍,卷積核的寬度,圖像通道數(shù)或前一次卷積核個(gè)數(shù)条舔,本次卷積核個(gè)數(shù)]
rate: 即空洞率dilation枫耳,在卷積核中穿插補(bǔ)(rate-1)個(gè)0,rate=1時(shí)孟抗,就沒(méi)有0插入迁杨,此時(shí)這個(gè)函數(shù)就變成了普通卷積。
"""
return tf.nn.atrous_conv2d(x, filters, dilation, padding='SAME')
# 反向卷積
def deconv_layer(self, x, filter_shape, output_shape, stride):
filters = tf.get_variable(
name='weight',
shape=filter_shape,
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer(),
trainable=True)
"""
tf.conv2d_transpose(value, filter, output_shape, strides, padding="SAME", data_format="NHWC", name=None)
第一個(gè)參數(shù)value:指需要做反卷積的輸入圖像凄硼,它要求是一個(gè)Tensor
第二個(gè)參數(shù)filter:卷積核铅协,它要求是一個(gè)Tensor,具有[filter_height, filter_width, out_channels, in_channels]這樣的shape摊沉,
具體含義是[卷積核的高度狐史,卷積核的寬度,卷積核個(gè)數(shù)说墨,圖像通道數(shù)或上次卷積核個(gè)數(shù)]
第三個(gè)參數(shù)output_shape:反卷積操作輸出的shape骏全,普通卷積操作是沒(méi)有這個(gè)參數(shù)的.
第四個(gè)參數(shù)strides:反卷積時(shí)在圖像每一維的步長(zhǎng),這是一個(gè)一維的向量尼斧,長(zhǎng)度4
第五個(gè)參數(shù)padding:string類型的量姜贡,只能是"SAME","VALID"其中之一,這個(gè)值決定了不同的卷積方式
第六個(gè)參數(shù)data_format:string類型的量棺棵,'NHWC'和'NCHW'其中之一楼咳,這是tensorflow新版本中新加的參數(shù),它說(shuō)明了value參數(shù)的
數(shù)據(jù)格式烛恤。'NHWC'指tensorflow標(biāo)準(zhǔn)的數(shù)據(jù)格式[batch, height, width, in_channels]爬橡,'NCHW'指Theano的數(shù)據(jù)格式,
[batch, in_channels,height, width]棒动,當(dāng)然默認(rèn)值是'NHWC'
"""
return tf.nn.conv2d_transpose(x, filters, output_shape, [1, stride, stride, 1])
def batch_normalize(self, x, is_training, decay=0.99, epsilon=0.001):
"""
tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)是一個(gè)低級(jí)的操作函數(shù)糙申,
調(diào)用者需要自己處理張量的平均值和方差。
mean:樣本均值
variance:樣本方差
offset:樣本偏移船惨,None或一個(gè)向量柜裸,添加到歸一化中
scale:縮放(默認(rèn)為1),None或一個(gè)向量粱锐,添加到歸一化中
:param x:
:param is_training:
:param decay:
:param epsilon:為了避免分母為0疙挺,添加的一個(gè)極小值
:return:
"""
def bn_train():
# 計(jì)算輸入的均值與方差
batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2])
# 計(jì)算訓(xùn)練階段用于更新的均值和方差
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
# 在[train_mean, train_var]執(zhí)行之后,下面的才執(zhí)行
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)
def bn_inference():
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon)
"""
tf.shape(x)返回的是一個(gè)tensor怜浅。要想知道是多少铐然,必須通過(guò)sess.run()
x.get_shape()返回的是元組,需要通過(guò)as_list()的操作轉(zhuǎn)換成list,x必須是tensor
x:[batch, height, width, channels or kernels],則dim就是channels的值蔬崩,圖像數(shù)據(jù)的第三維
"""
dim = x.get_shape().as_list()[-1]
beta = tf.get_variable(
name='beta',
shape=[dim],
dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.0),
trainable=True)
scale = tf.get_variable(
name='scale',
shape=[dim],
dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1),
trainable=True)
pop_mean = tf.get_variable(
name='pop_mean',
shape=[dim],
dtype=tf.float32,
initializer=tf.constant_initializer(0.0),
trainable=False)
pop_var = tf.get_variable(
name='pop_var',
shape=[dim],
dtype=tf.float32,
initializer=tf.constant_initializer(1.0),
trainable=False)
# tf.cond()類似于問(wèn)號(hào)表達(dá)式
return tf.cond(is_training, bn_train, bn_inference)
def flatten_layer(self, x):
"""
圖像矩陣轉(zhuǎn)換為一個(gè)向量,有batch_size個(gè)這種向量
:param x:
:return:
"""
input_shape = x.get_shape().as_list()
dim = input_shape[1] * input_shape[2] * input_shape[3] # 一張圖片搀暑,三個(gè)維度上的總數(shù)據(jù)量
# 不同維度進(jìn)行交換
transposed = tf.transpose(x, (0, 3, 1, 2))
return tf.reshape(transposed, [-1, dim])
def full_connection_layer(self, x, out_dim):
# in_dim其實(shí)是前一層網(wǎng)絡(luò)的輸出大小
in_dim = x.get_shape().as_list()[-1]
W = tf.get_variable(
name='weight',
shape=[in_dim, out_dim],
dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.1),
trainable=True)
b = tf.get_variable(
name='bias',
shape=[out_dim],
dtype=tf.float32,
initializer=tf.constant_initializer(0.0),
trainable=True)
return tf.add(tf.matmul(x, W), b)
def generator(self, x, is_training):
with tf.variable_scope('generator'):
with tf.variable_scope('conv1'):
x = self.conv_layer(x, [5, 5, 3, 64], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv2'):
x = self.conv_layer(x, [3, 3, 64, 128], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv3'):
x = self.conv_layer(x, [3, 3, 128, 128], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv4'):
x = self.conv_layer(x, [3, 3, 128, 256], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv5'):
x = self.conv_layer(x, [3, 3, 256, 256], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv6'):
x = self.conv_layer(x, [3, 3, 256, 256], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('dilated1'):
x = self.dilated_conv_layer(x, [3, 3, 256, 256], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('dilated2'):
x = self.dilated_conv_layer(x, [3, 3, 256, 256], 4)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('dilated3'):
x = self.dilated_conv_layer(x, [3, 3, 256, 256], 8)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('dilated4'):
x = self.dilated_conv_layer(x, [3, 3, 256, 256], 16)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv7'):
x = self.conv_layer(x, [3, 3, 256, 256], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv8'):
x = self.conv_layer(x, [3, 3, 256, 256], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('deconv1'):
x = self.deconv_layer(x, [4, 4, 128, 256], [self.batch_size, 64, 64, 128], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv9'):
x = self.conv_layer(x, [3, 3, 128, 128], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('deconv2'):
x = self.deconv_layer(x, [4, 4, 64, 128], [self.batch_size, 128, 128, 64], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv10'):
x = self.conv_layer(x, [3, 3, 64, 32], 1)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv11'):
x = self.conv_layer(x, [3, 3, 32, 3], 1)
x = tf.nn.tanh(x)
# 輸出圖像尺寸 128 * 128
return x
def discriminator(self, x, local_x, reuse):
def global_discriminator(x):
is_training = tf.constant(True)
with tf.variable_scope('global'):
# 因?yàn)槲覀兪褂胕mage_size = 128沥阳,原文是256,所以這里的卷積也少一層
with tf.variable_scope('conv1'):
x = self.conv_layer(x, [5, 5, 3, 64], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv2'):
x = self.conv_layer(x, [5, 5, 64, 128], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv3'):
x = self.conv_layer(x, [5, 5, 128, 256], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv4'):
x = self.conv_layer(x, [5, 5, 256, 512], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv5'):
x = self.conv_layer(x, [5, 5, 512, 512], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('fc'):
x = self.flatten_layer(x)
x = self.full_connection_layer(x, 1024)
return x
def local_discriminator(x):
is_training = tf.constant(True)
with tf.variable_scope('local'):
# 原文LOCAL_SIZE = 128自点,我們?nèi)?4桐罕,所以這部分網(wǎng)絡(luò)結(jié)構(gòu)也少一層卷積
with tf.variable_scope('conv1'):
x = self.conv_layer(x, [5, 5, 3, 64], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv2'):
x = self.conv_layer(x, [5, 5, 64, 128], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv3'):
x = self.conv_layer(x, [5, 5, 128, 256], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('conv4'):
x = self.conv_layer(x, [5, 5, 256, 512], 2)
x = self.batch_normalize(x, is_training)
x = tf.nn.relu(x)
with tf.variable_scope('fc'):
x = self.flatten_layer(x)
# 全連接層輸入512維的向量,輸出1024維的向量桂敛,并且沒(méi)有激活函數(shù)
x = self.full_connection_layer(x, 1024)
return x
with tf.variable_scope('discriminator', reuse=reuse):
"""
reuse參數(shù):
True: 參數(shù)空間使用reuse 模式功炮,即該空間下的所有tf.get_variable()函數(shù)將直接獲取已經(jīng)創(chuàng)建的變量,
如果參數(shù)不存在tf.get_variable()函數(shù)將會(huì)報(bào)錯(cuò)术唬。
AUTO_REUSE:若參數(shù)空間的參數(shù)不存在就創(chuàng)建他們薪伏,如果已經(jīng)存在就直接獲取它們。
None 或者False 這里創(chuàng)建函數(shù)tf.get_variable()函數(shù)只能創(chuàng)建新的變量粗仓,當(dāng)同名變量已經(jīng)存在時(shí)嫁怀,函數(shù)就報(bào)錯(cuò)
* reuse(重用)標(biāo)志是有繼承性的:如果我們打開(kāi)一個(gè)重用范圍,那么它的所有子范圍也會(huì)重用潦牛。
"""
global_output = global_discriminator(x)
local_output = local_discriminator(local_x)
with tf.variable_scope('concatenation'):
output = tf.concat((global_output, local_output), 1)
output = self.full_connection_layer(output, 1)
return output
def calc_g_loss(self, x, completion):
# 補(bǔ)全網(wǎng)絡(luò)用到的損失函數(shù)眶掌,用于比較生成網(wǎng)絡(luò)得到的圖像和原圖的差別大小
loss = tf.nn.l2_loss(x - completion)
return tf.reduce_mean(loss)
def calc_d_loss(self, real, fake):
# 判別網(wǎng)絡(luò)損失函數(shù),二分類問(wèn)題
alpha = 4e-4 # 約0.073
# tf.ones_like(real)創(chuàng)建一個(gè)將real設(shè)置為1的張量.
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real, labels=tf.ones_like(real)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake, labels=tf.zeros_like(fake)))
return tf.add(d_loss_real, d_loss_fake) * alpha
train.py
import numpy as np
import tensorflow as tf
import os
import cv2
import tqdm
from network_build import Network
"""
如果運(yùn)行過(guò)程中出現(xiàn):
An error ocurred while starting the kernel
2019???? 20:27:22.601831: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this
TensorFlow binary was not compiled to use: AVX2
大概意思是:你的CPU支持AVX擴(kuò)展巴碗,但是你安裝的TensorFlow版本無(wú)法編譯使用朴爬。
那為什么會(huì)出現(xiàn)這種警告呢?
由于tensorflow默認(rèn)分布是在沒(méi)有CPU擴(kuò)展的情況下構(gòu)建的橡淆,例如SSE4.1召噩,SSE4.2,AVX逸爵,AVX2具滴,F(xiàn)MA等。默認(rèn)版本(來(lái)自pip install
tensorflow的版本)旨在與盡可能多的CPU兼容师倔。另一個(gè)觀點(diǎn)是构韵,即使使用這些擴(kuò)展名,CPU的速度也要比GPU慢很多趋艘,并且期望在GPU上執(zhí)行中型和大型機(jī)器學(xué)習(xí)培訓(xùn)疲恢。
如果你有一個(gè)GPU,你不應(yīng)該關(guān)心AVX的支持瓷胧,因?yàn)榇蠖鄶?shù)昂貴的操作將被分派到一個(gè)GPU設(shè)備上(除非明確地設(shè)置)显拳。在這種情況下,您可以簡(jiǎn)單地忽略此警告:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
"""
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
BATCH_SIZE = 10
IMAGE_SIZE = 128
LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
LEARNING_RATE = 1e-3
PRETRAIN_EPOCH = 100
def load(dir_=r'D:\demos\image processing\demo\data_my'):
x_train = np.load(os.path.join(dir_, 'x_train.npy'))
x_test = np.load(os.path.join(dir_, 'x_test.npy'))
return x_train, x_test
def get_points():
points = []
mask = []
for i in range(BATCH_SIZE):
x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)
x2, y2 = np.array([x1, y1]) + LOCAL_SIZE
points.append([x1, y1, x2, y2])
w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)
p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)
q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)
p2 = p1 + w
q2 = q1 + h
m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
m[q1:q2 + 1, p1:p2 + 1] = 1
mask.append(m)
# points是一個(gè)能表示大小為L(zhǎng)OCAL_SIZE * LOCAL_SIZE的區(qū)域搓萧,mask的大小是IMAGE_SIZE * IMAGE_SIZE杂数,但里面只有大小為(q2-q1)*
# (p2-p1)的區(qū)域里面是1宛畦,其他部分全是0,并且這部分區(qū)域在points表示的區(qū)域內(nèi)部
return np.array(points), np.array(mask)
def train():
"""
tf.reset_default_graph函數(shù)用于清除默認(rèn)圖形堆棧并重置全局默認(rèn)圖形揍移。
注意:默認(rèn)圖形是當(dāng)前線程的一個(gè)屬性次和。該tf.reset_default_graph函數(shù)只適用于當(dāng)前線程。當(dāng)一個(gè)tf.Session或者tf.InteractiveSession
激活時(shí)調(diào)用這個(gè)函數(shù)會(huì)導(dǎo)致未定義的行為羊精。調(diào)用此函數(shù)后使用任何以前創(chuàng)建的tf.Operation或tf.Tensor對(duì)象將導(dǎo)致未定義的行為斯够。
可能引發(fā)的異常:
AssertionError:如果在嵌套圖中調(diào)用此函數(shù)則會(huì)引發(fā)此異常囚玫。
Clears the default graph stack and resets the global default graph.
"""
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="x")
mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1], name="mask")
local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_x")
global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="global_completion")
local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_completion")
is_training = tf.placeholder(tf.bool, [], name="is_training")
model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
# global_step在滑動(dòng)平均喧锦、優(yōu)化器、指數(shù)衰減學(xué)習(xí)率等方面都有用到抓督,這個(gè)變量的實(shí)際意義非常好理解:代表全局步數(shù)燃少,比如在多少步該進(jìn)行
# 什么操作,現(xiàn)在神經(jīng)網(wǎng)絡(luò)訓(xùn)練到多少輪等等铃在,類似于一個(gè)鐘表阵具。global_step的初始化值是0損失函數(shù)優(yōu)化器的minimize()中g(shù)lobal_step=
# global_steps能夠提供global_step每訓(xùn)練一個(gè)batch就加1的操作。
global_step = tf.Variable(0, name='global_step', trainable=False)
epoch = tf.Variable(0, name='epoch', trainable=False)
opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
# tf.train.Optimizer.minimize:添加操作節(jié)點(diǎn)定铜,用于最小化loss阳液,并更新var_list
# 該函數(shù)是簡(jiǎn)單的合并了compute_gradients()與apply_gradients()函數(shù)
# 返回為一個(gè)優(yōu)化更新后的var_list,如果global_step非None揣炕,該操作還會(huì)為global_step做自增操作
g_train_op = opt.minimize(model.g_loss, global_step=global_step, var_list=model.g_variables)
d_train_op = opt.minimize(model.d_loss, global_step=global_step, var_list=model.d_variables)
# 加載數(shù)據(jù)
x_train, x_test = load()
# 將圖像中的每個(gè)數(shù)據(jù)歸一化到 [-1, 1] 內(nèi)
x_train = np.array([a / 127.5 - 1 for a in x_train])
x_test = np.array([a / 127.5 - 1 for a in x_test])
# 一個(gè)epoch需要循環(huán)多少次
step_num = int(len(x_train) / BATCH_SIZE)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 加載預(yù)訓(xùn)練好的模型帘皿,加快訓(xùn)練
if tf.train.get_checkpoint_state('../backup'):
saver = tf.train.Saver()
saver.restore(sess, '../backup/latest')
while True:
# 每循環(huán)一次,epoch + 1
sess.run(tf.assign(epoch, tf.add(epoch, 1)))
print('epoch: {}'.format(sess.run(epoch)))
# 每循環(huán)一次數(shù)據(jù)集畸陡,打亂一次數(shù)據(jù)集中的數(shù)據(jù)
np.random.shuffle(x_train)
# Completion
# 先訓(xùn)練圖像補(bǔ)全網(wǎng)絡(luò) PRETRAIN_EPOCH = 100 次
# 注意:取每一個(gè)tensor變量的值鹰溜,都要 run 一下
if sess.run(epoch) <= PRETRAIN_EPOCH:
g_loss_value = 0
points_batch, mask_batch = get_points()
# tqdm 是 Python 進(jìn)度條庫(kù),可以在 Python 長(zhǎng)循環(huán)中添加一個(gè)進(jìn)度提示信息用法:tqdm(iterator)
for i in tqdm.tqdm(range(step_num)):
# 一個(gè)epoch循環(huán)step_num次丁恭,每次從訓(xùn)練集中取出一批BATCH_SIZE大小的數(shù)據(jù)
x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
_, g_loss = sess.run([g_train_op, model.g_loss],
feed_dict={x: x_batch, mask: mask_batch, is_training: True})
g_loss_value += g_loss
print("epoch:{}".format(sess.run(epoch)))
print("Completion loss: {}".format(g_loss_value))
np.random.shuffle(x_test)
# 因?yàn)樵谥谱鱩ask的時(shí)候曹动,選擇一次制作的數(shù)量是BATCH_SIZE,所以從測(cè)試集中取出BATCH_SIZE個(gè)數(shù)據(jù)進(jìn)行測(cè)試
x_batch = x_test[:BATCH_SIZE]
completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
print("completion[0].shape:", completion[0].shape)
# 恢復(fù)圖像
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
cv2.imwrite('./output1/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
saver = tf.train.Saver()
saver.save(sess, './backup/latest')
if sess.run(epoch) == PRETRAIN_EPOCH:
saver.save(sess, './backup/pretrained')
# Discrimitation
# Discrimitation
# 如果epoch > 100,生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)一起訓(xùn)練
else:
g_loss_value = 0
d_loss_value = 0
points_batch, mask_batch = get_points()
for i in tqdm.tqdm(range(step_num)):
x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
# 訓(xùn)練生成網(wǎng)絡(luò)
_, g_loss, completion = sess.run([g_train_op, model.g_loss, model.completion],
feed_dict={x: x_batch, mask: mask_batch, is_training: True})
g_loss_value += g_loss
local_x_batch = []
local_completion_batch = []
# 得到一個(gè)BATCH_SIZE中原始圖片和生成網(wǎng)絡(luò)生成的圖片的local區(qū)域
for i in range(BATCH_SIZE):
x1, y1, x2, y2 = points_batch[i]
local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
local_completion_batch.append(completion[i][y1:y2, x1:x2, :])
local_x_batch = np.array(local_x_batch)
local_completion_batch = np.array(local_completion_batch)
"""
d_train_op用到了d_loss牲览,d_loss來(lái)自于calc_d_loss墓陈,calc_d_loss有real和fake兩個(gè)參數(shù),
real來(lái)自于discriminator(x, local_x, reuse=False)
fake來(lái)自于discriminator(global_completion, local_completion, reuse=True)
所以feed_dict的參數(shù)包括x第献、local_x贡必、global_completion、local_completion痊硕,以及mask
"""
_, d_loss = sess.run(
[d_train_op, model.d_loss],
feed_dict={x: x_batch, mask: mask_batch, local_x: local_x_batch, global_completion: completion,
local_completion: local_completion_batch, is_training: True})
d_loss_value += d_loss
print("epoch:{}".format(sess.run(epoch)))
print('Completion loss: {}'.format(g_loss_value))
print('Discriminator loss: {}'.format(d_loss_value))
np.random.shuffle(x_test)
x_batch = x_test[:BATCH_SIZE]
completion = sess.run(model.completion,
feed_dict={x: x_batch, mask: mask_batch, is_training: False})
sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
cv2.imwrite('./output2/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))
saver = tf.train.Saver()
saver.save(sess, './backup/latest', write_meta_graph=False)
if __name__ == '__main__':
train()