pix2pix圖像超分辨率

目標(biāo)

訓(xùn)練一個(gè)pix2pix模型谜慌,把模糊的圖片換成清晰的圖片

準(zhǔn)備

租了一個(gè)極客云上面的GPU服務(wù)器,GTX1070莺奔,速度還行

數(shù)據(jù)集準(zhǔn)備

數(shù)據(jù)集是“圖片對(duì)”的形式欣范,一個(gè)圖片包含兩張圖片,一張是清晰的圖片,一張是模糊的圖片熙卡。

我的原數(shù)據(jù)集

我用的kaggle上面的花花數(shù)據(jù):https://www.kaggle.com/alxmamaev/flowers-recognition

去除錯(cuò)誤數(shù)據(jù)

import tensorflow as tf
from glob import glob
import os
import argparse
import logging
from PIL import Image
import traceback

def glob_all(dir_path):
    pic_list = glob(os.path.join(dir_path, '*.jpg'))
    inside = os.listdir(dir_path)
    for dir_name in inside:
        if os.path.isdir(os.path.join(dir_path, dir_name)):
            pic_list.extend(glob_all(os.path.join(dir_path, dir_name)))
    return pic_list

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-p', '--dir-path', default='data/')
    return parser.parse_args()

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    all_pic_list = glob_all(args.dir_path)
    for i, img_path in enumerate(all_pic_list):
        try:
            sess = tf.Session()
            with open(img_path, 'rb') as f:
                img_byte = f.read()
                img = tf.image.decode_jpeg(img_byte)
                data = sess.run(img)
                if data.shape[2] != 3:
                    print(data.shape)
                    raise Exception
            tf.reset_default_graph()
            img = Image.open(img_path)
        except Exception:
            logging.warning('%s has broken. Delete it.' % img_path)
            os.remove(img_path)
        if (i + 1) % 1000 == 0:
            logging.info('Processing %d / %d.' % (i + 1, len(all_pic_list)))

運(yùn)行指令類(lèi)似如下:

python delete_broken_img.py -p 文件目錄

圖像裁剪到統(tǒng)一大小

主要兩個(gè)文件:
process.py:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import argparse
import os
import tempfile
import subprocess
import tensorflow as tf
import numpy as np
import tfimage as im
import threading
import time
import multiprocessing

edge_pool = None


parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True, help="path to folder containing images")
parser.add_argument("--output_dir", required=True, help="output path")
parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges", "blur"])
parser.add_argument("--workers", type=int, default=1, help="number of workers")
# resize
parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
# combine
parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
a = parser.parse_args()


def resize(src):
    height, width, _ = src.shape
    dst = src
    if height != width:
        if a.pad:
            size = max(height, width)
            # pad to correct ratio
            oh = (size - height) // 2
            ow = (size - width) // 2
            dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
        else:
            # crop to correct ratio
            size = min(height, width)
            oh = (height - size) // 2
            ow = (width - size) // 2
            dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)

    assert(dst.shape[0] == dst.shape[1])

    size, _, _ = dst.shape
    if size > a.size:
        dst = im.downscale(images=dst, size=[a.size, a.size])
    elif size < a.size:
        dst = im.upscale(images=dst, size=[a.size, a.size])
    return dst


def blank(src):
    height, width, _ = src.shape
    if height != width:
        raise Exception("non-square image")

    image_size = width
    size = int(image_size * 0.3)
    offset = int(image_size / 2 - size / 2)

    dst = src
    dst[offset:offset + size, offset:offset + size, :] = np.ones([size, size, 3])
    return dst


def combine(src, src_path):
    if a.b_dir is None:
        raise Exception("missing b_dir")

    # find corresponding file in b_dir, could have a different extension
    basename, _ = os.path.splitext(os.path.basename(src_path))
    for ext in [".png", ".jpg"]:
        sibling_path = os.path.join(a.b_dir, basename + ext)
        if os.path.exists(sibling_path):
            sibling = im.load(sibling_path)
            break
    else:
        raise Exception("could not find sibling image for " + src_path)

    # make sure that dimensions are correct
    height, width, _ = src.shape
    if height != sibling.shape[0] or width != sibling.shape[1]:
        raise Exception("differing sizes")

    # convert both images to RGB if necessary
    if src.shape[2] == 1:
        src = im.grayscale_to_rgb(images=src)

    if sibling.shape[2] == 1:
        sibling = im.grayscale_to_rgb(images=sibling)

    # remove alpha channel
    if src.shape[2] == 4:
        src = src[:, :, :3]

    if sibling.shape[2] == 4:
        sibling = sibling[:, :, :3]

    return np.concatenate([src, sibling], axis=1)


def grayscale(src):
    return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src))


def blur(src, scale=4):
    height, width, _ = src.shape
    height_down = height // scale
    width_down = width // scale
    dst = im.downscale(images=src, size=[height_down, width_down])
    dst = im.upscale(images=dst, size=[height, width])
    return dst

net = None


def run_caffe(src):
    # lazy load caffe and create net
    global net
    if net is None:
        # don't require caffe unless we are doing edge detection
        os.environ["GLOG_minloglevel"] = "2"  # disable logging from caffe
        import caffe
        # using this requires using the docker image or assembling a bunch of dependencies
        # and then changing these hardcoded paths
        net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST)

    net.blobs["data"].reshape(1, *src.shape)
    net.blobs["data"].data[...] = src
    net.forward()
    return net.blobs["sigmoid-fuse"].data[0][0, :, :]


def edges(src):
    # based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py
    # and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m
    import scipy.io
    src = src * 255
    border = 128  # put a padding around images since edge detection seems to detect edge of image
    src = src[:, :, :3]  # remove alpha channel if present
    src = np.pad(src, ((border, border), (border, border), (0, 0)), "reflect")
    src = src[:, :, ::-1]
    src -= np.array((104.00698793, 116.66876762, 122.67891434))
    src = src.transpose((2, 0, 1))

    # [height, width, channels] => [batch, channel, height, width]
    fuse = edge_pool.apply(run_caffe, [src])
    fuse = fuse[border:-border, border:-border]

    with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file:
        scipy.io.savemat(mat_file.name, {"input": fuse})

        octave_code = r"""
E = 1-load(input_path).input;
E = imresize(E, [image_width,image_width]);
E = 1 - E;
E = single(E);
[Ox, Oy] = gradient(convTri(E, 4), 1);
[Oxx, ~] = gradient(Ox, 1);
[Oxy, Oyy] = gradient(Oy, 1);
O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi);
E = edgesNmsMex(E, O, 1, 5, 1.01, 1);
E = double(E >= max(eps, threshold));
E = bwmorph(E, 'thin', inf);
E = bwareaopen(E, small_edge);
E = 1 - E;
E = uint8(E * 255);
imwrite(E, output_path);
"""

        config = dict(
            input_path="'%s'" % mat_file.name,
            output_path="'%s'" % png_file.name,
            image_width=256,
            threshold=25.0 / 255.0,
            small_edge=5,
        )

        args = ["octave"]
        for k, v in config.items():
            args.extend(["--eval", "%s=%s;" % (k, v)])

        args.extend(["--eval", octave_code])
        try:
            subprocess.check_output(args, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as e:
            print("octave failed")
            print("returncode:", e.returncode)
            print("output:", e.output)
            raise
        return im.load(png_file.name)


def process(src_path, dst_path):
    src = im.load(src_path)

    if a.operation == "grayscale":
        dst = grayscale(src)
    elif a.operation == "resize":
        dst = resize(src)
    elif a.operation == "blank":
        dst = blank(src)
    elif a.operation == "combine":
        dst = combine(src, src_path)
    elif a.operation == "edges":
        dst = edges(src)
    elif a.operation == "blur":
        dst = blur(src)
    else:
        raise Exception("invalid operation")

    im.save(dst, dst_path)


complete_lock = threading.Lock()
start = None
num_complete = 0
total = 0


def complete():
    global num_complete, rate, last_complete

    with complete_lock:
        num_complete += 1
        now = time.time()
        elapsed = now - start
        rate = num_complete / elapsed
        if rate > 0:
            remaining = (total - num_complete) / rate
        else:
            remaining = 0

        print("%d/%d complete  %0.2f images/sec  %dm%ds elapsed  %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60))

        last_complete = now


def main():
    if not os.path.exists(a.output_dir):
        os.makedirs(a.output_dir)

    src_paths = []
    dst_paths = []

    skipped = 0
    for src_path in im.find(a.input_dir):
        name, _ = os.path.splitext(os.path.basename(src_path))
        dst_path = os.path.join(a.output_dir, name + ".png")
        if os.path.exists(dst_path):
            skipped += 1
        else:
            src_paths.append(src_path)
            dst_paths.append(dst_path)

    print("skipping %d files that already exist" % skipped)

    global total
    total = len(src_paths)

    print("processing %d files" % total)

    global start
    start = time.time()

    if a.operation == "edges":
        # use a multiprocessing pool for this operation so it can use multiple CPUs
        # create the pool before we launch processing threads
        global edge_pool
        edge_pool = multiprocessing.Pool(a.workers)

    if a.workers == 1:
        with tf.Session() as sess:
            for src_path, dst_path in zip(src_paths, dst_paths):
                process(src_path, dst_path)
                complete()
    else:
        queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
        dequeue_op = queue.dequeue()

        def worker(coord):
            with sess.as_default():
                while not coord.should_stop():
                    try:
                        src_path, dst_path = sess.run(dequeue_op)
                    except tf.errors.OutOfRangeError:
                        coord.request_stop()
                        break

                    process(src_path, dst_path)
                    complete()

        # init epoch counter for the queue
        local_init_op = tf.local_variables_initializer()
        with tf.Session() as sess:
            sess.run(local_init_op)

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            for i in range(a.workers):
                t = threading.Thread(target=worker, args=(coord,))
                t.start()
                threads.append(t)

            try:
                coord.join(threads)
            except KeyboardInterrupt:
                coord.request_stop()
                coord.join(threads)

main()

tfimage.py:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import os


def create_op(func, **placeholders):
    op = func(**placeholders)

    def f(**kwargs):
        feed_dict = {}
        for argname, argvalue in kwargs.items():
            placeholder = placeholders[argname]
            feed_dict[placeholder] = argvalue
        return tf.get_default_session().run(op, feed_dict=feed_dict)

    return f

downscale = create_op(
    func=tf.image.resize_images,
    images=tf.placeholder(tf.float32, [None, None, None]),
    size=tf.placeholder(tf.int32, [2]),
    method=tf.image.ResizeMethod.AREA,
)

upscale = create_op(
    func=tf.image.resize_images,
    images=tf.placeholder(tf.float32, [None, None, None]),
    size=tf.placeholder(tf.int32, [2]),
    method=tf.image.ResizeMethod.BICUBIC,
)

decode_jpeg = create_op(
    func=tf.image.decode_jpeg,
    contents=tf.placeholder(tf.string),
)

decode_png = create_op(
    func=tf.image.decode_png,
    contents=tf.placeholder(tf.string),
)

rgb_to_grayscale = create_op(
    func=tf.image.rgb_to_grayscale,
    images=tf.placeholder(tf.float32),
)

grayscale_to_rgb = create_op(
    func=tf.image.grayscale_to_rgb,
    images=tf.placeholder(tf.float32),
)

encode_jpeg = create_op(
    func=tf.image.encode_jpeg,
    image=tf.placeholder(tf.uint8),
)

encode_png = create_op(
    func=tf.image.encode_png,
    image=tf.placeholder(tf.uint8),
)

crop = create_op(
    func=tf.image.crop_to_bounding_box,
    image=tf.placeholder(tf.float32),
    offset_height=tf.placeholder(tf.int32, []),
    offset_width=tf.placeholder(tf.int32, []),
    target_height=tf.placeholder(tf.int32, []),
    target_width=tf.placeholder(tf.int32, []),
)

pad = create_op(
    func=tf.image.pad_to_bounding_box,
    image=tf.placeholder(tf.float32),
    offset_height=tf.placeholder(tf.int32, []),
    offset_width=tf.placeholder(tf.int32, []),
    target_height=tf.placeholder(tf.int32, []),
    target_width=tf.placeholder(tf.int32, []),
)

to_uint8 = create_op(
    func=tf.image.convert_image_dtype,
    image=tf.placeholder(tf.float32),
    dtype=tf.uint8,
    saturate=True,
)

to_float32 = create_op(
    func=tf.image.convert_image_dtype,
    image=tf.placeholder(tf.uint8),
    dtype=tf.float32,
)


def load(path):
    with open(path, "rb") as f:
        contents = f.read()
        
    _, ext = os.path.splitext(path.lower())

    if ext == ".jpg":
        image = decode_jpeg(contents=contents)
    elif ext == ".png":
        image = decode_png(contents=contents)
    else:
        raise Exception("invalid image suffix")

    return to_float32(image=image)


def find(d):
    result = []
    for filename in os.listdir(d):
        _, ext = os.path.splitext(filename.lower())
        if ext == ".jpg" or ext == ".png":
            result.append(os.path.join(d, filename))
    result.sort()
    return result


def save(image, path, replace=False):
    _, ext = os.path.splitext(path.lower())
    image = to_uint8(image=image)
    if ext == ".jpg":
        encoded = encode_jpeg(image=image)
    elif ext == ".png":
        encoded = encode_png(image=image)
    else:
        raise Exception("invalid image suffix")

    dirname = os.path.dirname(path)
    if dirname != "" and not os.path.exists(dirname):
        os.makedirs(dirname)

    if os.path.exists(path):
        if replace:
            os.remove(path)
        else:
            raise Exception("file already exists at " + path)

    with open(path, "wb") as f:
        f.write(encoded)

運(yùn)行指令類(lèi)似下面:

python process.py --input_dir 上步處理完的文件目錄 --operation resize --output_dir 自己定義一個(gè)輸出文件夾

制作對(duì)應(yīng)要求的圖片對(duì)

代碼在:https://github.com/hzy46/Deep-Learning-21-Examples/blob/master/chapter_10/
第十章的代碼杖刷,對(duì)應(yīng)的處理代碼在chapter10/pix2pix-tensorflow/tools下,需要的兩個(gè)處理腳本與上面的兩個(gè)腳本同名驳癌。下載下來(lái)放到對(duì)應(yīng)文件夾就好。

  • 模糊處理命令
python process.py --operation blur --input_dir resize后的文件目錄 --output_dir 自定義一個(gè)輸出文件夾
  • 原始圖片和模糊圖片合并在一起命令
python process.py --input_dir resize后的文件夾 --b_dir 上面模糊操作后的輸出文件夾 --operation combine --output_dir 自定義輸出文件夾
  • 分為訓(xùn)練集和測(cè)試集(split.py同樣在之前的GitHub地址)
python split.py --dir 上面的合并后輸出文件夾

最后的生成結(jié)果類(lèi)似下面

image.png

image.png

訓(xùn)練模型(pix2pix.py還是在上面GitHub地址)

python pix2pix.py --mode train --output_dir 自定義模型輸出模型路徑 --max_epochs 20 --input_dir 上面輸出的訓(xùn)練文件夾 --which_direction BtoA

模型迭代

由于我把終端給關(guān)了役听,就不截圖了颓鲜,最后云端打包下來(lái)是這樣的形式:


image.png

測(cè)試模型

python pix2pix.py --mode test --output_dir 自定義輸出文件夾 --input_dir 之前生成的驗(yàn)證數(shù)據(jù)集目錄 --checkpoint 之前自定義的模型輸出文件夾

結(jié)果展示

image.png

左邊是模糊的,中間是模型生成的典予,右邊是原圖甜滨,效果不是特別好,之前看了看我的模型瘤袖,后面收斂的不是很好衣摩,但差不多就這意思了,另外我的數(shù)據(jù)集也不太好捂敌,僅作借鑒

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末艾扮,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子占婉,更是在濱河造成了極大的恐慌泡嘴,老刑警劉巖,帶你破解...
    沈念sama閱讀 216,692評(píng)論 6 501
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件逆济,死亡現(xiàn)場(chǎng)離奇詭異酌予,居然都是意外死亡,警方通過(guò)查閱死者的電腦和手機(jī)奖慌,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 92,482評(píng)論 3 392
  • 文/潘曉璐 我一進(jìn)店門(mén)抛虫,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái),“玉大人简僧,你說(shuō)我怎么就攤上這事建椰。” “怎么了涎劈?”我有些...
    開(kāi)封第一講書(shū)人閱讀 162,995評(píng)論 0 353
  • 文/不壞的土叔 我叫張陵广凸,是天一觀的道長(zhǎng)。 經(jīng)常有香客問(wèn)我蛛枚,道長(zhǎng)谅海,這世上最難降的妖魔是什么? 我笑而不...
    開(kāi)封第一講書(shū)人閱讀 58,223評(píng)論 1 292
  • 正文 為了忘掉前任蹦浦,我火速辦了婚禮扭吁,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘炮捧。我一直安慰自己表鳍,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 67,245評(píng)論 6 388
  • 文/花漫 我一把揭開(kāi)白布看疙。 她就那樣靜靜地躺著枫吧,像睡著了一般浦旱。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上九杂,一...
    開(kāi)封第一講書(shū)人閱讀 51,208評(píng)論 1 299
  • 那天颁湖,我揣著相機(jī)與錄音,去河邊找鬼例隆。 笑死甥捺,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的镀层。 我是一名探鬼主播镰禾,決...
    沈念sama閱讀 40,091評(píng)論 3 418
  • 文/蒼蘭香墨 我猛地睜開(kāi)眼,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼唱逢!你這毒婦竟也來(lái)了吴侦?” 一聲冷哼從身側(cè)響起,我...
    開(kāi)封第一講書(shū)人閱讀 38,929評(píng)論 0 274
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤惶我,失蹤者是張志新(化名)和其女友劉穎妈倔,沒(méi)想到半個(gè)月后,有當(dāng)?shù)厝嗽跇?shù)林里發(fā)現(xiàn)了一具尸體绸贡,經(jīng)...
    沈念sama閱讀 45,346評(píng)論 1 311
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡盯蝴,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 37,570評(píng)論 2 333
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了听怕。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片捧挺。...
    茶點(diǎn)故事閱讀 39,739評(píng)論 1 348
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖尿瞭,靈堂內(nèi)的尸體忽然破棺而出闽烙,到底是詐尸還是另有隱情,我是刑警寧澤声搁,帶...
    沈念sama閱讀 35,437評(píng)論 5 344
  • 正文 年R本政府宣布黑竞,位于F島的核電站,受9級(jí)特大地震影響疏旨,放射性物質(zhì)發(fā)生泄漏很魂。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 41,037評(píng)論 3 326
  • 文/蒙蒙 一檐涝、第九天 我趴在偏房一處隱蔽的房頂上張望遏匆。 院中可真熱鬧法挨,春花似錦、人聲如沸幅聘。這莊子的主人今日做“春日...
    開(kāi)封第一講書(shū)人閱讀 31,677評(píng)論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)帝蒿。三九已至荐糜,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間葛超,已是汗流浹背狞尔。 一陣腳步聲響...
    開(kāi)封第一講書(shū)人閱讀 32,833評(píng)論 1 269
  • 我被黑心中介騙來(lái)泰國(guó)打工, 沒(méi)想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留巩掺,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 47,760評(píng)論 2 369
  • 正文 我出身青樓页畦,卻偏偏與公主長(zhǎng)得像胖替,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子豫缨,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 44,647評(píng)論 2 354

推薦閱讀更多精彩內(nèi)容

  • 外面酷熱的陽(yáng)光炙烤著大地独令,樹(shù)上的知了不停歇地唱著歌,吵的我心神不寧好芭,可是這也怪不得知了燃箭,誰(shuí)叫現(xiàn)在是三伏天呢? 已經(jīng)...
    冰潔媽媽閱讀 245評(píng)論 0 1
  • 留白舍败,是指書(shū)畫(huà)藝術(shù)創(chuàng)作中為使整個(gè)作品畫(huà)面招狸、章法更為協(xié)調(diào)精美而有意留下相應(yīng)的空白,給欣賞者留有想像的空間邻薯。這在南宋馬...
    蒲公英的樣子閱讀 1,437評(píng)論 0 5
  • 不知什么時(shí)候開(kāi)始壹罚,這個(gè)世界開(kāi)始對(duì)中年人充滿惡意。油膩寿羞、猥瑣猖凛、厚臉皮,保溫杯配枸杞…… 年輕人和中年人最大的區(qū)別是什...
    劍齒虎的牙閱讀 1,304評(píng)論 2 0