《Globally and locally consistent image completion》論文復(fù)現(xiàn)

以下代碼是論文《Globally and locally consistent image completion》的代碼實(shí)現(xiàn)膜楷,論文地址:http://xueshu.baidu.com/usercenter/paper/show?paperid=ea74830570062151f14abfb1fe89bb33&site=xueshu_se&hitarticle=1
論文速讀可參考我的另一篇文章:http://www.reibang.com/p/12da271c8bf8
使用的數(shù)據(jù)集是CelebA人臉數(shù)據(jù)集肆糕,數(shù)據(jù)集下載地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
框架:tensorflow 1.11.0
圖像大忻氚巍:128 * 128 | 原文是 256 * 256
原文作者使用4個(gè)K80 GPU穷娱,訓(xùn)練了2個(gè)月才訓(xùn)練完成谷市,我這里暫時(shí)沒(méi)有什么硬件資源亡嫌,根本跑不動(dòng)嚎于,所以下面的代碼雖然可以跑,但一些參數(shù)只是根據(jù)經(jīng)驗(yàn)設(shè)的挟冠,并沒(méi)有驗(yàn)證其效果于购,請(qǐng)謹(jǐn)慎食用。

因?yàn)橐呀?jīng)寫過(guò)論文的研究思路知染,下面不再對(duì)代碼實(shí)現(xiàn)思路進(jìn)行講解肋僧,如果有不懂的請(qǐng)參考上面提到的論文和相應(yīng)講解文章,代碼的盡可能的詳細(xì)的注釋了控淡,相信大家看起來(lái)難度不大嫌吠。

make_data.py

"""
說(shuō)明:數(shù)據(jù)預(yù)處理,將圖片讀取到npy文件中掺炭,這樣就可避免每次都去讀一個(gè)一個(gè)的圖片數(shù)據(jù)居兆,可以加快讀取數(shù)據(jù)的速度
npy文件——Numpy專用的二進(jìn)制格式
數(shù)據(jù)集地址:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
import glob
import cv2
import numpy as np


image_size = 128
train_ratio = 0.8

# 得到所有圖片的路徑,數(shù)據(jù)文件不要放在工程目錄下竹伸,否則編輯工程時(shí)可能會(huì)比較卡
paths = glob.glob(r'D:\bigdata\img_align_celeba/*.jpg')
x = []  # 圖片數(shù)據(jù)列表
# 讀取圖像泥栖,為了訓(xùn)練快點(diǎn),只取1000個(gè)圖像進(jìn)程處理
for img_path in paths[:1000]:
    '''
    如果cv2沒(méi)有提示勋篓,卸載重裝
    卸載:pip uninstall opencv-python
    不使用緩存重裝:pip --no-cache-dir install opencv-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
    安裝擴(kuò)展:pip --no-cache-dir install opencv-contrib-python -i http://pypi.douban.com/simple --trusted-host pypi.douban.com
    '''
    img = cv2.imread(img_path)  # 得到每幅圖片的矩陣表示吧享,shape:(218, 178, 3)
    # 對(duì)圖像進(jìn)行縮放--插值法
    img = cv2.resize(img, (image_size, image_size))
    # 色彩空間的轉(zhuǎn)化,以便生成mask圖等操作
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    x.append(img)

x = np.array(x, dtype=np.uint8)  # 規(guī)定數(shù)據(jù)量類型np.uint8是為了節(jié)省存儲(chǔ)空間
# 打亂圖片排序
np.random.shuffle(x)

p = int(train_ratio * len(x))
x_train = x[:p]
x_test = x[p:]

np.save(r'D:\demos\image processing\demo\data_my\x_train.npy', x_train)
np.save(r'D:\demos\image processing\demo\data_my\x_test.npy', x_test)

network_build.py

"""
定義訓(xùn)練網(wǎng)絡(luò)所需的各部分
"""
import tensorflow as tf


class Network:
    def __init__(self, x, mask, local_x, global_completion, local_completion, is_training, batch_size):
        """
        :param x: 輸入
        :param mask: 需要填補(bǔ)修復(fù)的圖像表示譬嚣,缺失區(qū)域的值為1钢颂,其他部分為0,整體大小和完整圖像相同
        :param local_x: 從原圖中摳出來(lái)的那部分
        :param global_completion: 經(jīng)過(guò)補(bǔ)全網(wǎng)絡(luò)后的缺失部分之外的部分圖像
        :param local_completion: 經(jīng)過(guò)補(bǔ)全網(wǎng)絡(luò)之后的補(bǔ)全的部分
        :param is_training:
        :param batch_size:
        """
        self.batch_size = batch_size
        # x * (1 - mask)可以實(shí)現(xiàn)將輸入圖像“挖洞”拜银,洞的地方值全為0,所以generator輸入的是一張帶洞的圖像
        self.imitation = self.generator(x * (1 - mask), is_training)
        # 補(bǔ)全圖像應(yīng)該是:填補(bǔ)的地方是網(wǎng)絡(luò)生成的殊鞭,但其他地方應(yīng)該是原圖數(shù)值
        self.completion = self.imitation * mask + x * (1 - mask)
        # 輸入真實(shí)圖像數(shù)據(jù),需要用到的變量尼桶,自己去創(chuàng)建操灿、更新
        self.real = self.discriminator(x, local_x, reuse=False)
        # 輸入補(bǔ)全網(wǎng)絡(luò)生成的圖像數(shù)據(jù),判別網(wǎng)絡(luò)用到的變量應(yīng)當(dāng)是訓(xùn)練真實(shí)圖像數(shù)據(jù)時(shí)創(chuàng)建的相同變量泵督,所以reuse=True
        self.fake = self.discriminator(global_completion, local_completion, reuse=True)
        self.g_loss = self.calc_g_loss(x, self.completion)
        self.d_loss = self.calc_d_loss(self.real, self.fake)
        """
        tf.get_collection(key,scope=None)
        用來(lái)獲取一個(gè)名稱是‘key’的集合中的所有元素趾盐,返回的是一個(gè)列表,列表的順序是按照變量放入集合中的先后;   scope參數(shù)可選,
        表示的是名稱空間(名稱域)救鲤,如果指定久窟,就返回名稱域中所有放入‘key’的變量的列表,不指定則返回所有變量本缠。
        tf.Optimizer默認(rèn)只優(yōu)化tf.GraphKeys.TRAINABLE_VARIABLES中的變量斥扛。
        """
        self.g_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        self.d_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')

    def conv_layer(self, x, filter_shape, stride):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        # padding='SAME',輸出圖像大小是邊長(zhǎng)除以步長(zhǎng),向上取整丹锹,就是不夠則填充
        # padding='VALID',輸出圖像大小是邊長(zhǎng)減去濾波器大小加一犹赖,最后除以步長(zhǎng),向上取整卷仑,不進(jìn)行填充
        """
        tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)
        除去name參數(shù)用以指定該操作的name峻村,與方法有關(guān)的一共五個(gè)參數(shù):

        第一個(gè)參數(shù)input:指需要做卷積的輸入圖像,它要求是一個(gè)Tensor锡凝,具有[batch, in_height, in_width, in_channels]
        這樣的shape粘昨,具體含義是[訓(xùn)練時(shí)一個(gè)batch的圖片數(shù)量, 圖片高度, 圖片寬度, 圖像通道數(shù)],注意這是一個(gè)4維的Tensor窜锯,
        要求類型為float32和float64其中之一

        第二個(gè)參數(shù)filter:相當(dāng)于CNN中的卷積核仗岸,它要求是一個(gè)Tensor遇汞,具有[filter_height, filter_width, in_channels, out_channels]
        這樣的shape,具體含義是[卷積核的高度,卷積核的寬度眼刃,圖像通道數(shù)蒋院,卷積核個(gè)數(shù)]哼勇,要求類型與參數(shù)input相同炸庞,有一個(gè)地方需要注意,
        第三維in_channels翠勉,就是參數(shù)input的第四維
        
        第三個(gè)參數(shù)strides:卷積時(shí)在圖像每一維的步長(zhǎng)妖啥,這是一個(gè)一維的向量,長(zhǎng)度4
        
        第四個(gè)參數(shù)padding:string類型的量对碌,只能是"SAME","VALID"其中之一荆虱,這個(gè)值決定了不同的卷積方式
        
        第五個(gè)參數(shù):use_cudnn_on_gpu:bool類型,是否使用cudnn加速朽们,默認(rèn)為true
        
        結(jié)果返回一個(gè)Tensor怀读,這個(gè)輸出,就是我們常說(shuō)的feature map骑脱,shape仍然是[batch, height, width, channels]這種形式菜枷。
        """
        return tf.nn.conv2d(x, filters, [1, stride, stride, 1], padding='SAME')

    # 空洞卷積
    def dilated_conv_layer(self, x, filter_shape, dilation):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        """
        tf.nn.atrous_conv2d(value,filters,rate,padding,name=None)
        value: 指需要做卷積的輸入圖像,要求是一個(gè)4維Tensor惜姐,具有[batch, height, width, channels]
        filters: 相當(dāng)于CNN中的卷積核犁跪,要求是一個(gè)4維Tensor椿息,具有[filter_height, filter_width, channels, out_channels]
                  這樣的shape歹袁,具體含義是[卷積核的高度坷衍,卷積核的寬度,圖像通道數(shù)或前一次卷積核個(gè)數(shù)条舔,本次卷積核個(gè)數(shù)]
        rate: 即空洞率dilation枫耳,在卷積核中穿插補(bǔ)(rate-1)個(gè)0,rate=1時(shí)孟抗,就沒(méi)有0插入迁杨,此時(shí)這個(gè)函數(shù)就變成了普通卷積。
        """
        return tf.nn.atrous_conv2d(x, filters, dilation, padding='SAME')

    # 反向卷積
    def deconv_layer(self, x, filter_shape, output_shape, stride):
        filters = tf.get_variable(
            name='weight',
            shape=filter_shape,
            dtype=tf.float32,
            initializer=tf.contrib.layers.xavier_initializer(),
            trainable=True)
        """
        tf.conv2d_transpose(value, filter, output_shape, strides, padding="SAME", data_format="NHWC", name=None)
        第一個(gè)參數(shù)value:指需要做反卷積的輸入圖像凄硼,它要求是一個(gè)Tensor
        第二個(gè)參數(shù)filter:卷積核铅协,它要求是一個(gè)Tensor,具有[filter_height, filter_width, out_channels, in_channels]這樣的shape摊沉,
        具體含義是[卷積核的高度狐史,卷積核的寬度,卷積核個(gè)數(shù)说墨,圖像通道數(shù)或上次卷積核個(gè)數(shù)]
        第三個(gè)參數(shù)output_shape:反卷積操作輸出的shape骏全,普通卷積操作是沒(méi)有這個(gè)參數(shù)的.
        第四個(gè)參數(shù)strides:反卷積時(shí)在圖像每一維的步長(zhǎng),這是一個(gè)一維的向量尼斧,長(zhǎng)度4
        第五個(gè)參數(shù)padding:string類型的量姜贡,只能是"SAME","VALID"其中之一,這個(gè)值決定了不同的卷積方式
        第六個(gè)參數(shù)data_format:string類型的量棺棵,'NHWC'和'NCHW'其中之一楼咳,這是tensorflow新版本中新加的參數(shù),它說(shuō)明了value參數(shù)的
        數(shù)據(jù)格式烛恤。'NHWC'指tensorflow標(biāo)準(zhǔn)的數(shù)據(jù)格式[batch, height, width, in_channels]爬橡,'NCHW'指Theano的數(shù)據(jù)格式,
        [batch, in_channels,height, width]棒动,當(dāng)然默認(rèn)值是'NHWC'
        """
        return tf.nn.conv2d_transpose(x, filters, output_shape, [1, stride, stride, 1])

    def batch_normalize(self, x, is_training, decay=0.99, epsilon=0.001):
        """
        tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)是一個(gè)低級(jí)的操作函數(shù)糙申,
        調(diào)用者需要自己處理張量的平均值和方差。
        mean:樣本均值
        variance:樣本方差
        offset:樣本偏移船惨,None或一個(gè)向量柜裸,添加到歸一化中
        scale:縮放(默認(rèn)為1),None或一個(gè)向量粱锐,添加到歸一化中
        :param x:
        :param is_training:
        :param decay:
        :param epsilon:為了避免分母為0疙挺,添加的一個(gè)極小值
        :return:
        """

        def bn_train():
            # 計(jì)算輸入的均值與方差
            batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2])
            # 計(jì)算訓(xùn)練階段用于更新的均值和方差
            train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
            train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
            with tf.control_dependencies([train_mean, train_var]):
                # 在[train_mean, train_var]執(zhí)行之后,下面的才執(zhí)行
                return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon)

        def bn_inference():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon)
        """
        tf.shape(x)返回的是一個(gè)tensor怜浅。要想知道是多少铐然,必須通過(guò)sess.run()
        x.get_shape()返回的是元組,需要通過(guò)as_list()的操作轉(zhuǎn)換成list,x必須是tensor
        x:[batch, height, width, channels or kernels],則dim就是channels的值蔬崩,圖像數(shù)據(jù)的第三維
        """
        dim = x.get_shape().as_list()[-1]
        beta = tf.get_variable(
            name='beta',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.0),
            trainable=True)
        scale = tf.get_variable(
            name='scale',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.1),
            trainable=True)
        pop_mean = tf.get_variable(
            name='pop_mean',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(0.0),
            trainable=False)
        pop_var = tf.get_variable(
            name='pop_var',
            shape=[dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(1.0),
            trainable=False)
        # tf.cond()類似于問(wèn)號(hào)表達(dá)式
        return tf.cond(is_training, bn_train, bn_inference)
    
    def flatten_layer(self, x):
        """
        圖像矩陣轉(zhuǎn)換為一個(gè)向量,有batch_size個(gè)這種向量
        :param x:
        :return:
        """
        input_shape = x.get_shape().as_list()
        dim = input_shape[1] * input_shape[2] * input_shape[3]  # 一張圖片搀暑,三個(gè)維度上的總數(shù)據(jù)量
        # 不同維度進(jìn)行交換
        transposed = tf.transpose(x, (0, 3, 1, 2))
        return tf.reshape(transposed, [-1, dim])

    def full_connection_layer(self, x, out_dim):
        # in_dim其實(shí)是前一層網(wǎng)絡(luò)的輸出大小
        in_dim = x.get_shape().as_list()[-1]
        W = tf.get_variable(
            name='weight',
            shape=[in_dim, out_dim],
            dtype=tf.float32,
            initializer=tf.truncated_normal_initializer(stddev=0.1),
            trainable=True)
        b = tf.get_variable(
            name='bias',
            shape=[out_dim],
            dtype=tf.float32,
            initializer=tf.constant_initializer(0.0),
            trainable=True)
        return tf.add(tf.matmul(x, W), b)

    def generator(self, x, is_training):
        with tf.variable_scope('generator'):
            with tf.variable_scope('conv1'):
                x = self.conv_layer(x, [5, 5, 3, 64], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv2'):
                x = self.conv_layer(x, [3, 3, 64, 128], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv3'):
                x = self.conv_layer(x, [3, 3, 128, 128], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv4'):
                x = self.conv_layer(x, [3, 3, 128, 256], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv5'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv6'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated1'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated2'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 4)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated3'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 8)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('dilated4'):
                x = self.dilated_conv_layer(x, [3, 3, 256, 256], 16)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv7'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv8'):
                x = self.conv_layer(x, [3, 3, 256, 256], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('deconv1'):
                x = self.deconv_layer(x, [4, 4, 128, 256], [self.batch_size, 64, 64, 128], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv9'):
                x = self.conv_layer(x, [3, 3, 128, 128], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('deconv2'):
                x = self.deconv_layer(x, [4, 4, 64, 128], [self.batch_size, 128, 128, 64], 2)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv10'):
                x = self.conv_layer(x, [3, 3, 64, 32], 1)
                x = self.batch_normalize(x, is_training)
                x = tf.nn.relu(x)
            with tf.variable_scope('conv11'):
                x = self.conv_layer(x, [3, 3, 32, 3], 1)
                x = tf.nn.tanh(x)
        # 輸出圖像尺寸 128 * 128
        return x

    def discriminator(self, x, local_x, reuse):
        def global_discriminator(x):
            is_training = tf.constant(True)
            with tf.variable_scope('global'):
                # 因?yàn)槲覀兪褂胕mage_size = 128沥阳,原文是256,所以這里的卷積也少一層
                with tf.variable_scope('conv1'):
                    x = self.conv_layer(x, [5, 5, 3, 64], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv2'):
                    x = self.conv_layer(x, [5, 5, 64, 128], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv3'):
                    x = self.conv_layer(x, [5, 5, 128, 256], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv4'):
                    x = self.conv_layer(x, [5, 5, 256, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv5'):
                    x = self.conv_layer(x, [5, 5, 512, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('fc'):
                    x = self.flatten_layer(x)
                    x = self.full_connection_layer(x, 1024)
            return x

        def local_discriminator(x):
            is_training = tf.constant(True)
            with tf.variable_scope('local'):
                # 原文LOCAL_SIZE = 128自点,我們?nèi)?4桐罕,所以這部分網(wǎng)絡(luò)結(jié)構(gòu)也少一層卷積
                with tf.variable_scope('conv1'):
                    x = self.conv_layer(x, [5, 5, 3, 64], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv2'):
                    x = self.conv_layer(x, [5, 5, 64, 128], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv3'):
                    x = self.conv_layer(x, [5, 5, 128, 256], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('conv4'):
                    x = self.conv_layer(x, [5, 5, 256, 512], 2)
                    x = self.batch_normalize(x, is_training)
                    x = tf.nn.relu(x)
                with tf.variable_scope('fc'):
                    x = self.flatten_layer(x)
                    # 全連接層輸入512維的向量,輸出1024維的向量桂敛,并且沒(méi)有激活函數(shù)
                    x = self.full_connection_layer(x, 1024)
            return x

        with tf.variable_scope('discriminator', reuse=reuse):
            """
            reuse參數(shù):
            True: 參數(shù)空間使用reuse 模式功炮,即該空間下的所有tf.get_variable()函數(shù)將直接獲取已經(jīng)創(chuàng)建的變量,
            如果參數(shù)不存在tf.get_variable()函數(shù)將會(huì)報(bào)錯(cuò)术唬。
            AUTO_REUSE:若參數(shù)空間的參數(shù)不存在就創(chuàng)建他們薪伏,如果已經(jīng)存在就直接獲取它們。
            None 或者False 這里創(chuàng)建函數(shù)tf.get_variable()函數(shù)只能創(chuàng)建新的變量粗仓,當(dāng)同名變量已經(jīng)存在時(shí)嫁怀,函數(shù)就報(bào)錯(cuò)
            * reuse(重用)標(biāo)志是有繼承性的:如果我們打開(kāi)一個(gè)重用范圍,那么它的所有子范圍也會(huì)重用潦牛。
            """
            global_output = global_discriminator(x)
            local_output = local_discriminator(local_x)
            with tf.variable_scope('concatenation'):
                output = tf.concat((global_output, local_output), 1)
                output = self.full_connection_layer(output, 1)

        return output

    def calc_g_loss(self, x, completion):
        # 補(bǔ)全網(wǎng)絡(luò)用到的損失函數(shù)眶掌,用于比較生成網(wǎng)絡(luò)得到的圖像和原圖的差別大小
        loss = tf.nn.l2_loss(x - completion)
        return tf.reduce_mean(loss)

    def calc_d_loss(self, real, fake):
        # 判別網(wǎng)絡(luò)損失函數(shù),二分類問(wèn)題
        alpha = 4e-4  # 約0.073
        # tf.ones_like(real)創(chuàng)建一個(gè)將real設(shè)置為1的張量.
        d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real, labels=tf.ones_like(real)))
        d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake, labels=tf.zeros_like(fake)))
        return tf.add(d_loss_real, d_loss_fake) * alpha

train.py

import numpy as np
import tensorflow as tf
import os
import cv2
import tqdm
from network_build import Network

"""
如果運(yùn)行過(guò)程中出現(xiàn):
An error ocurred while starting the kernel
2019???? 20:27:22.601831: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this 
TensorFlow binary was not compiled to use: AVX2
大概意思是:你的CPU支持AVX擴(kuò)展巴碗,但是你安裝的TensorFlow版本無(wú)法編譯使用朴爬。

那為什么會(huì)出現(xiàn)這種警告呢?
由于tensorflow默認(rèn)分布是在沒(méi)有CPU擴(kuò)展的情況下構(gòu)建的橡淆,例如SSE4.1召噩,SSE4.2,AVX逸爵,AVX2具滴,F(xiàn)MA等。默認(rèn)版本(來(lái)自pip install 
tensorflow的版本)旨在與盡可能多的CPU兼容师倔。另一個(gè)觀點(diǎn)是构韵,即使使用這些擴(kuò)展名,CPU的速度也要比GPU慢很多趋艘,并且期望在GPU上執(zhí)行中型和大型機(jī)器學(xué)習(xí)培訓(xùn)疲恢。

如果你有一個(gè)GPU,你不應(yīng)該關(guān)心AVX的支持瓷胧,因?yàn)榇蠖鄶?shù)昂貴的操作將被分派到一個(gè)GPU設(shè)備上(除非明確地設(shè)置)显拳。在這種情況下,您可以簡(jiǎn)單地忽略此警告:
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

"""

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

BATCH_SIZE = 10
IMAGE_SIZE = 128
LOCAL_SIZE = 64
HOLE_MIN = 24
HOLE_MAX = 48
LEARNING_RATE = 1e-3
PRETRAIN_EPOCH = 100


def load(dir_=r'D:\demos\image processing\demo\data_my'):
    x_train = np.load(os.path.join(dir_, 'x_train.npy'))
    x_test = np.load(os.path.join(dir_, 'x_test.npy'))
    return x_train, x_test


def get_points():
    points = []
    mask = []
    for i in range(BATCH_SIZE):
        x1, y1 = np.random.randint(0, IMAGE_SIZE - LOCAL_SIZE + 1, 2)
        x2, y2 = np.array([x1, y1]) + LOCAL_SIZE
        points.append([x1, y1, x2, y2])

        w, h = np.random.randint(HOLE_MIN, HOLE_MAX + 1, 2)
        p1 = x1 + np.random.randint(0, LOCAL_SIZE - w)
        q1 = y1 + np.random.randint(0, LOCAL_SIZE - h)
        p2 = p1 + w
        q2 = q1 + h

        m = np.zeros((IMAGE_SIZE, IMAGE_SIZE, 1), dtype=np.uint8)
        m[q1:q2 + 1, p1:p2 + 1] = 1
        mask.append(m)
    # points是一個(gè)能表示大小為L(zhǎng)OCAL_SIZE * LOCAL_SIZE的區(qū)域搓萧,mask的大小是IMAGE_SIZE * IMAGE_SIZE杂数,但里面只有大小為(q2-q1)*
    # (p2-p1)的區(qū)域里面是1宛畦,其他部分全是0,并且這部分區(qū)域在points表示的區(qū)域內(nèi)部
    return np.array(points), np.array(mask)


def train():
    """
    tf.reset_default_graph函數(shù)用于清除默認(rèn)圖形堆棧并重置全局默認(rèn)圖形揍移。

    注意:默認(rèn)圖形是當(dāng)前線程的一個(gè)屬性次和。該tf.reset_default_graph函數(shù)只適用于當(dāng)前線程。當(dāng)一個(gè)tf.Session或者tf.InteractiveSession
    激活時(shí)調(diào)用這個(gè)函數(shù)會(huì)導(dǎo)致未定義的行為羊精。調(diào)用此函數(shù)后使用任何以前創(chuàng)建的tf.Operation或tf.Tensor對(duì)象將導(dǎo)致未定義的行為斯够。
    可能引發(fā)的異常:
    AssertionError:如果在嵌套圖中調(diào)用此函數(shù)則會(huì)引發(fā)此異常囚玫。
    Clears the default graph stack and resets the global default graph.
    """
    tf.reset_default_graph()
    x = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="x")
    mask = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 1], name="mask")
    local_x = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_x")
    global_completion = tf.placeholder(tf.float32, [BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3], name="global_completion")
    local_completion = tf.placeholder(tf.float32, [BATCH_SIZE, LOCAL_SIZE, LOCAL_SIZE, 3], name="local_completion")
    is_training = tf.placeholder(tf.bool, [], name="is_training")

    model = Network(x, mask, local_x, global_completion, local_completion, is_training, batch_size=BATCH_SIZE)
    # global_step在滑動(dòng)平均喧锦、優(yōu)化器、指數(shù)衰減學(xué)習(xí)率等方面都有用到抓督,這個(gè)變量的實(shí)際意義非常好理解:代表全局步數(shù)燃少,比如在多少步該進(jìn)行
    # 什么操作,現(xiàn)在神經(jīng)網(wǎng)絡(luò)訓(xùn)練到多少輪等等铃在,類似于一個(gè)鐘表阵具。global_step的初始化值是0損失函數(shù)優(yōu)化器的minimize()中g(shù)lobal_step=
    # global_steps能夠提供global_step每訓(xùn)練一個(gè)batch就加1的操作。
    global_step = tf.Variable(0, name='global_step', trainable=False)
    epoch = tf.Variable(0, name='epoch', trainable=False)

    opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    # tf.train.Optimizer.minimize:添加操作節(jié)點(diǎn)定铜,用于最小化loss阳液,并更新var_list
    # 該函數(shù)是簡(jiǎn)單的合并了compute_gradients()與apply_gradients()函數(shù)
    # 返回為一個(gè)優(yōu)化更新后的var_list,如果global_step非None揣炕,該操作還會(huì)為global_step做自增操作
    g_train_op = opt.minimize(model.g_loss, global_step=global_step, var_list=model.g_variables)
    d_train_op = opt.minimize(model.d_loss, global_step=global_step, var_list=model.d_variables)

    # 加載數(shù)據(jù)
    x_train, x_test = load()
    # 將圖像中的每個(gè)數(shù)據(jù)歸一化到 [-1, 1] 內(nèi)
    x_train = np.array([a / 127.5 - 1 for a in x_train])
    x_test = np.array([a / 127.5 - 1 for a in x_test])

    # 一個(gè)epoch需要循環(huán)多少次
    step_num = int(len(x_train) / BATCH_SIZE)

    
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)

        # 加載預(yù)訓(xùn)練好的模型帘皿,加快訓(xùn)練
        if tf.train.get_checkpoint_state('../backup'):
            saver = tf.train.Saver()
            saver.restore(sess, '../backup/latest')

        while True:
            # 每循環(huán)一次,epoch + 1
            sess.run(tf.assign(epoch, tf.add(epoch, 1)))
            print('epoch: {}'.format(sess.run(epoch)))

            # 每循環(huán)一次數(shù)據(jù)集畸陡,打亂一次數(shù)據(jù)集中的數(shù)據(jù)
            np.random.shuffle(x_train)

            # Completion
            # 先訓(xùn)練圖像補(bǔ)全網(wǎng)絡(luò) PRETRAIN_EPOCH = 100 次
            # 注意:取每一個(gè)tensor變量的值鹰溜,都要 run 一下
            if sess.run(epoch) <= PRETRAIN_EPOCH:
                g_loss_value = 0
                points_batch, mask_batch = get_points()
                # tqdm 是 Python 進(jìn)度條庫(kù),可以在 Python 長(zhǎng)循環(huán)中添加一個(gè)進(jìn)度提示信息用法:tqdm(iterator)
                for i in tqdm.tqdm(range(step_num)):
                    # 一個(gè)epoch循環(huán)step_num次丁恭,每次從訓(xùn)練集中取出一批BATCH_SIZE大小的數(shù)據(jù)
                    x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                    _, g_loss = sess.run([g_train_op, model.g_loss],
                                         feed_dict={x: x_batch, mask: mask_batch, is_training: True})
                    g_loss_value += g_loss
                print("epoch:{}".format(sess.run(epoch)))
                print("Completion loss: {}".format(g_loss_value))

                np.random.shuffle(x_test)
                # 因?yàn)樵谥谱鱩ask的時(shí)候曹动,選擇一次制作的數(shù)量是BATCH_SIZE,所以從測(cè)試集中取出BATCH_SIZE個(gè)數(shù)據(jù)進(jìn)行測(cè)試
                x_batch = x_test[:BATCH_SIZE]
                completion = sess.run(model.completion, feed_dict={x: x_batch, mask: mask_batch, is_training: False})
                print("completion[0].shape:", completion[0].shape)
                # 恢復(fù)圖像
                sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
                cv2.imwrite('./output1/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
                            cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))

                saver = tf.train.Saver()
                saver.save(sess, './backup/latest')
                if sess.run(epoch) == PRETRAIN_EPOCH:
                    saver.save(sess, './backup/pretrained')
                    # Discrimitation

            # Discrimitation
            # 如果epoch > 100,生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)一起訓(xùn)練
            else:
                g_loss_value = 0
                d_loss_value = 0
                points_batch, mask_batch = get_points()
                for i in tqdm.tqdm(range(step_num)):
                    x_batch = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]

                    # 訓(xùn)練生成網(wǎng)絡(luò)
                    _, g_loss, completion = sess.run([g_train_op, model.g_loss, model.completion],
                                                     feed_dict={x: x_batch, mask: mask_batch, is_training: True})
                    g_loss_value += g_loss

                    local_x_batch = []
                    local_completion_batch = []
                    # 得到一個(gè)BATCH_SIZE中原始圖片和生成網(wǎng)絡(luò)生成的圖片的local區(qū)域
                    for i in range(BATCH_SIZE):
                        x1, y1, x2, y2 = points_batch[i]
                        local_x_batch.append(x_batch[i][y1:y2, x1:x2, :])
                        local_completion_batch.append(completion[i][y1:y2, x1:x2, :])
                    local_x_batch = np.array(local_x_batch)
                    local_completion_batch = np.array(local_completion_batch)

                    """
                    d_train_op用到了d_loss牲览,d_loss來(lái)自于calc_d_loss墓陈,calc_d_loss有real和fake兩個(gè)參數(shù),
                    real來(lái)自于discriminator(x, local_x, reuse=False)
                    fake來(lái)自于discriminator(global_completion, local_completion, reuse=True)
                    所以feed_dict的參數(shù)包括x第献、local_x贡必、global_completion、local_completion痊硕,以及mask
                    
                    """
                    _, d_loss = sess.run(
                        [d_train_op, model.d_loss],
                        feed_dict={x: x_batch, mask: mask_batch, local_x: local_x_batch, global_completion: completion,
                                   local_completion: local_completion_batch, is_training: True})
                    d_loss_value += d_loss

                print("epoch:{}".format(sess.run(epoch)))
                print('Completion loss: {}'.format(g_loss_value))
                print('Discriminator loss: {}'.format(d_loss_value))

                np.random.shuffle(x_test)
                x_batch = x_test[:BATCH_SIZE]
                completion = sess.run(model.completion,
                                      feed_dict={x: x_batch, mask: mask_batch, is_training: False})
                sample = np.array((completion[0] + 1) * 127.5, dtype=np.uint8)
                cv2.imwrite('./output2/{}.jpg'.format("{0:06d}".format(sess.run(epoch))),
                            cv2.cvtColor(sample, cv2.COLOR_RGB2BGR))

                saver = tf.train.Saver()
                saver.save(sess, './backup/latest', write_meta_graph=False)


if __name__ == '__main__':
    train()
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末赊级,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子岔绸,更是在濱河造成了極大的恐慌理逊,老刑警劉巖橡伞,帶你破解...
    沈念sama閱讀 206,214評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異晋被,居然都是意外死亡兑徘,警方通過(guò)查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,307評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門羡洛,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)挂脑,“玉大人,你說(shuō)我怎么就攤上這事欲侮≌赶校” “怎么了?”我有些...
    開(kāi)封第一講書人閱讀 152,543評(píng)論 0 341
  • 文/不壞的土叔 我叫張陵威蕉,是天一觀的道長(zhǎng)刁俭。 經(jīng)常有香客問(wèn)我,道長(zhǎng)韧涨,這世上最難降的妖魔是什么牍戚? 我笑而不...
    開(kāi)封第一講書人閱讀 55,221評(píng)論 1 279
  • 正文 為了忘掉前任,我火速辦了婚禮虑粥,結(jié)果婚禮上如孝,老公的妹妹穿的比我還像新娘。我一直安慰自己娩贷,他們只是感情好第晰,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,224評(píng)論 5 371
  • 文/花漫 我一把揭開(kāi)白布。 她就那樣靜靜地躺著育勺,像睡著了一般但荤。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上涧至,一...
    開(kāi)封第一講書人閱讀 49,007評(píng)論 1 284
  • 那天腹躁,我揣著相機(jī)與錄音,去河邊找鬼南蓬。 笑死纺非,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的赘方。 我是一名探鬼主播烧颖,決...
    沈念sama閱讀 38,313評(píng)論 3 399
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼窄陡!你這毒婦竟也來(lái)了炕淮?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書人閱讀 36,956評(píng)論 0 259
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤跳夭,失蹤者是張志新(化名)和其女友劉穎涂圆,沒(méi)想到半個(gè)月后们镜,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,441評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡润歉,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 35,925評(píng)論 2 323
  • 正文 我和宋清朗相戀三年模狭,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片踩衩。...
    茶點(diǎn)故事閱讀 38,018評(píng)論 1 333
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡嚼鹉,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出驱富,到底是詐尸還是另有隱情锚赤,我是刑警寧澤,帶...
    沈念sama閱讀 33,685評(píng)論 4 322
  • 正文 年R本政府宣布萌朱,位于F島的核電站宴树,受9級(jí)特大地震影響策菜,放射性物質(zhì)發(fā)生泄漏晶疼。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,234評(píng)論 3 307
  • 文/蒙蒙 一又憨、第九天 我趴在偏房一處隱蔽的房頂上張望翠霍。 院中可真熱鬧,春花似錦蠢莺、人聲如沸寒匙。這莊子的主人今日做“春日...
    開(kāi)封第一講書人閱讀 30,240評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)锄弱。三九已至,卻和暖如春祸憋,著一層夾襖步出監(jiān)牢的瞬間会宪,已是汗流浹背。 一陣腳步聲響...
    開(kāi)封第一講書人閱讀 31,464評(píng)論 1 261
  • 我被黑心中介騙來(lái)泰國(guó)打工蚯窥, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留掸鹅,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 45,467評(píng)論 2 352
  • 正文 我出身青樓拦赠,卻偏偏與公主長(zhǎng)得像巍沙,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子荷鼠,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,762評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容