目標(biāo)
訓(xùn)練一個(gè)pix2pix模型谜慌,把模糊的圖片換成清晰的圖片
準(zhǔn)備
租了一個(gè)極客云上面的GPU服務(wù)器,GTX1070莺奔,速度還行
數(shù)據(jù)集準(zhǔn)備
數(shù)據(jù)集是“圖片對(duì)”的形式欣范,一個(gè)圖片包含兩張圖片,一張是清晰的圖片,一張是模糊的圖片熙卡。
我的原數(shù)據(jù)集
我用的kaggle上面的花花數(shù)據(jù):https://www.kaggle.com/alxmamaev/flowers-recognition
去除錯(cuò)誤數(shù)據(jù)
import tensorflow as tf
from glob import glob
import os
import argparse
import logging
from PIL import Image
import traceback
def glob_all(dir_path):
pic_list = glob(os.path.join(dir_path, '*.jpg'))
inside = os.listdir(dir_path)
for dir_name in inside:
if os.path.isdir(os.path.join(dir_path, dir_name)):
pic_list.extend(glob_all(os.path.join(dir_path, dir_name)))
return pic_list
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--dir-path', default='data/')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
all_pic_list = glob_all(args.dir_path)
for i, img_path in enumerate(all_pic_list):
try:
sess = tf.Session()
with open(img_path, 'rb') as f:
img_byte = f.read()
img = tf.image.decode_jpeg(img_byte)
data = sess.run(img)
if data.shape[2] != 3:
print(data.shape)
raise Exception
tf.reset_default_graph()
img = Image.open(img_path)
except Exception:
logging.warning('%s has broken. Delete it.' % img_path)
os.remove(img_path)
if (i + 1) % 1000 == 0:
logging.info('Processing %d / %d.' % (i + 1, len(all_pic_list)))
運(yùn)行指令類(lèi)似如下:
python delete_broken_img.py -p 文件目錄
圖像裁剪到統(tǒng)一大小
主要兩個(gè)文件:
process.py:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tempfile
import subprocess
import tensorflow as tf
import numpy as np
import tfimage as im
import threading
import time
import multiprocessing
edge_pool = None
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True, help="path to folder containing images")
parser.add_argument("--output_dir", required=True, help="output path")
parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine", "edges", "blur"])
parser.add_argument("--workers", type=int, default=1, help="number of workers")
# resize
parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation")
parser.add_argument("--size", type=int, default=256, help="size to use for resize operation")
# combine
parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation")
a = parser.parse_args()
def resize(src):
height, width, _ = src.shape
dst = src
if height != width:
if a.pad:
size = max(height, width)
# pad to correct ratio
oh = (size - height) // 2
ow = (size - width) // 2
dst = im.pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
else:
# crop to correct ratio
size = min(height, width)
oh = (height - size) // 2
ow = (width - size) // 2
dst = im.crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size)
assert(dst.shape[0] == dst.shape[1])
size, _, _ = dst.shape
if size > a.size:
dst = im.downscale(images=dst, size=[a.size, a.size])
elif size < a.size:
dst = im.upscale(images=dst, size=[a.size, a.size])
return dst
def blank(src):
height, width, _ = src.shape
if height != width:
raise Exception("non-square image")
image_size = width
size = int(image_size * 0.3)
offset = int(image_size / 2 - size / 2)
dst = src
dst[offset:offset + size, offset:offset + size, :] = np.ones([size, size, 3])
return dst
def combine(src, src_path):
if a.b_dir is None:
raise Exception("missing b_dir")
# find corresponding file in b_dir, could have a different extension
basename, _ = os.path.splitext(os.path.basename(src_path))
for ext in [".png", ".jpg"]:
sibling_path = os.path.join(a.b_dir, basename + ext)
if os.path.exists(sibling_path):
sibling = im.load(sibling_path)
break
else:
raise Exception("could not find sibling image for " + src_path)
# make sure that dimensions are correct
height, width, _ = src.shape
if height != sibling.shape[0] or width != sibling.shape[1]:
raise Exception("differing sizes")
# convert both images to RGB if necessary
if src.shape[2] == 1:
src = im.grayscale_to_rgb(images=src)
if sibling.shape[2] == 1:
sibling = im.grayscale_to_rgb(images=sibling)
# remove alpha channel
if src.shape[2] == 4:
src = src[:, :, :3]
if sibling.shape[2] == 4:
sibling = sibling[:, :, :3]
return np.concatenate([src, sibling], axis=1)
def grayscale(src):
return im.grayscale_to_rgb(images=im.rgb_to_grayscale(images=src))
def blur(src, scale=4):
height, width, _ = src.shape
height_down = height // scale
width_down = width // scale
dst = im.downscale(images=src, size=[height_down, width_down])
dst = im.upscale(images=dst, size=[height, width])
return dst
net = None
def run_caffe(src):
# lazy load caffe and create net
global net
if net is None:
# don't require caffe unless we are doing edge detection
os.environ["GLOG_minloglevel"] = "2" # disable logging from caffe
import caffe
# using this requires using the docker image or assembling a bunch of dependencies
# and then changing these hardcoded paths
net = caffe.Net("/opt/caffe/examples/hed/deploy.prototxt", "/opt/caffe/hed_pretrained_bsds.caffemodel", caffe.TEST)
net.blobs["data"].reshape(1, *src.shape)
net.blobs["data"].data[...] = src
net.forward()
return net.blobs["sigmoid-fuse"].data[0][0, :, :]
def edges(src):
# based on https://github.com/phillipi/pix2pix/blob/master/scripts/edges/batch_hed.py
# and https://github.com/phillipi/pix2pix/blob/master/scripts/edges/PostprocessHED.m
import scipy.io
src = src * 255
border = 128 # put a padding around images since edge detection seems to detect edge of image
src = src[:, :, :3] # remove alpha channel if present
src = np.pad(src, ((border, border), (border, border), (0, 0)), "reflect")
src = src[:, :, ::-1]
src -= np.array((104.00698793, 116.66876762, 122.67891434))
src = src.transpose((2, 0, 1))
# [height, width, channels] => [batch, channel, height, width]
fuse = edge_pool.apply(run_caffe, [src])
fuse = fuse[border:-border, border:-border]
with tempfile.NamedTemporaryFile(suffix=".png") as png_file, tempfile.NamedTemporaryFile(suffix=".mat") as mat_file:
scipy.io.savemat(mat_file.name, {"input": fuse})
octave_code = r"""
E = 1-load(input_path).input;
E = imresize(E, [image_width,image_width]);
E = 1 - E;
E = single(E);
[Ox, Oy] = gradient(convTri(E, 4), 1);
[Oxx, ~] = gradient(Ox, 1);
[Oxy, Oyy] = gradient(Oy, 1);
O = mod(atan(Oyy .* sign(-Oxy) ./ (Oxx + 1e-5)), pi);
E = edgesNmsMex(E, O, 1, 5, 1.01, 1);
E = double(E >= max(eps, threshold));
E = bwmorph(E, 'thin', inf);
E = bwareaopen(E, small_edge);
E = 1 - E;
E = uint8(E * 255);
imwrite(E, output_path);
"""
config = dict(
input_path="'%s'" % mat_file.name,
output_path="'%s'" % png_file.name,
image_width=256,
threshold=25.0 / 255.0,
small_edge=5,
)
args = ["octave"]
for k, v in config.items():
args.extend(["--eval", "%s=%s;" % (k, v)])
args.extend(["--eval", octave_code])
try:
subprocess.check_output(args, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
print("octave failed")
print("returncode:", e.returncode)
print("output:", e.output)
raise
return im.load(png_file.name)
def process(src_path, dst_path):
src = im.load(src_path)
if a.operation == "grayscale":
dst = grayscale(src)
elif a.operation == "resize":
dst = resize(src)
elif a.operation == "blank":
dst = blank(src)
elif a.operation == "combine":
dst = combine(src, src_path)
elif a.operation == "edges":
dst = edges(src)
elif a.operation == "blur":
dst = blur(src)
else:
raise Exception("invalid operation")
im.save(dst, dst_path)
complete_lock = threading.Lock()
start = None
num_complete = 0
total = 0
def complete():
global num_complete, rate, last_complete
with complete_lock:
num_complete += 1
now = time.time()
elapsed = now - start
rate = num_complete / elapsed
if rate > 0:
remaining = (total - num_complete) / rate
else:
remaining = 0
print("%d/%d complete %0.2f images/sec %dm%ds elapsed %dm%ds remaining" % (num_complete, total, rate, elapsed // 60, elapsed % 60, remaining // 60, remaining % 60))
last_complete = now
def main():
if not os.path.exists(a.output_dir):
os.makedirs(a.output_dir)
src_paths = []
dst_paths = []
skipped = 0
for src_path in im.find(a.input_dir):
name, _ = os.path.splitext(os.path.basename(src_path))
dst_path = os.path.join(a.output_dir, name + ".png")
if os.path.exists(dst_path):
skipped += 1
else:
src_paths.append(src_path)
dst_paths.append(dst_path)
print("skipping %d files that already exist" % skipped)
global total
total = len(src_paths)
print("processing %d files" % total)
global start
start = time.time()
if a.operation == "edges":
# use a multiprocessing pool for this operation so it can use multiple CPUs
# create the pool before we launch processing threads
global edge_pool
edge_pool = multiprocessing.Pool(a.workers)
if a.workers == 1:
with tf.Session() as sess:
for src_path, dst_path in zip(src_paths, dst_paths):
process(src_path, dst_path)
complete()
else:
queue = tf.train.input_producer(zip(src_paths, dst_paths), shuffle=False, num_epochs=1)
dequeue_op = queue.dequeue()
def worker(coord):
with sess.as_default():
while not coord.should_stop():
try:
src_path, dst_path = sess.run(dequeue_op)
except tf.errors.OutOfRangeError:
coord.request_stop()
break
process(src_path, dst_path)
complete()
# init epoch counter for the queue
local_init_op = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run(local_init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(a.workers):
t = threading.Thread(target=worker, args=(coord,))
t.start()
threads.append(t)
try:
coord.join(threads)
except KeyboardInterrupt:
coord.request_stop()
coord.join(threads)
main()
tfimage.py:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
def create_op(func, **placeholders):
op = func(**placeholders)
def f(**kwargs):
feed_dict = {}
for argname, argvalue in kwargs.items():
placeholder = placeholders[argname]
feed_dict[placeholder] = argvalue
return tf.get_default_session().run(op, feed_dict=feed_dict)
return f
downscale = create_op(
func=tf.image.resize_images,
images=tf.placeholder(tf.float32, [None, None, None]),
size=tf.placeholder(tf.int32, [2]),
method=tf.image.ResizeMethod.AREA,
)
upscale = create_op(
func=tf.image.resize_images,
images=tf.placeholder(tf.float32, [None, None, None]),
size=tf.placeholder(tf.int32, [2]),
method=tf.image.ResizeMethod.BICUBIC,
)
decode_jpeg = create_op(
func=tf.image.decode_jpeg,
contents=tf.placeholder(tf.string),
)
decode_png = create_op(
func=tf.image.decode_png,
contents=tf.placeholder(tf.string),
)
rgb_to_grayscale = create_op(
func=tf.image.rgb_to_grayscale,
images=tf.placeholder(tf.float32),
)
grayscale_to_rgb = create_op(
func=tf.image.grayscale_to_rgb,
images=tf.placeholder(tf.float32),
)
encode_jpeg = create_op(
func=tf.image.encode_jpeg,
image=tf.placeholder(tf.uint8),
)
encode_png = create_op(
func=tf.image.encode_png,
image=tf.placeholder(tf.uint8),
)
crop = create_op(
func=tf.image.crop_to_bounding_box,
image=tf.placeholder(tf.float32),
offset_height=tf.placeholder(tf.int32, []),
offset_width=tf.placeholder(tf.int32, []),
target_height=tf.placeholder(tf.int32, []),
target_width=tf.placeholder(tf.int32, []),
)
pad = create_op(
func=tf.image.pad_to_bounding_box,
image=tf.placeholder(tf.float32),
offset_height=tf.placeholder(tf.int32, []),
offset_width=tf.placeholder(tf.int32, []),
target_height=tf.placeholder(tf.int32, []),
target_width=tf.placeholder(tf.int32, []),
)
to_uint8 = create_op(
func=tf.image.convert_image_dtype,
image=tf.placeholder(tf.float32),
dtype=tf.uint8,
saturate=True,
)
to_float32 = create_op(
func=tf.image.convert_image_dtype,
image=tf.placeholder(tf.uint8),
dtype=tf.float32,
)
def load(path):
with open(path, "rb") as f:
contents = f.read()
_, ext = os.path.splitext(path.lower())
if ext == ".jpg":
image = decode_jpeg(contents=contents)
elif ext == ".png":
image = decode_png(contents=contents)
else:
raise Exception("invalid image suffix")
return to_float32(image=image)
def find(d):
result = []
for filename in os.listdir(d):
_, ext = os.path.splitext(filename.lower())
if ext == ".jpg" or ext == ".png":
result.append(os.path.join(d, filename))
result.sort()
return result
def save(image, path, replace=False):
_, ext = os.path.splitext(path.lower())
image = to_uint8(image=image)
if ext == ".jpg":
encoded = encode_jpeg(image=image)
elif ext == ".png":
encoded = encode_png(image=image)
else:
raise Exception("invalid image suffix")
dirname = os.path.dirname(path)
if dirname != "" and not os.path.exists(dirname):
os.makedirs(dirname)
if os.path.exists(path):
if replace:
os.remove(path)
else:
raise Exception("file already exists at " + path)
with open(path, "wb") as f:
f.write(encoded)
運(yùn)行指令類(lèi)似下面:
python process.py --input_dir 上步處理完的文件目錄 --operation resize --output_dir 自己定義一個(gè)輸出文件夾
制作對(duì)應(yīng)要求的圖片對(duì)
代碼在:https://github.com/hzy46/Deep-Learning-21-Examples/blob/master/chapter_10/
第十章的代碼杖刷,對(duì)應(yīng)的處理代碼在chapter10/pix2pix-tensorflow/tools下,需要的兩個(gè)處理腳本與上面的兩個(gè)腳本同名驳癌。下載下來(lái)放到對(duì)應(yīng)文件夾就好。
- 模糊處理命令
python process.py --operation blur --input_dir resize后的文件目錄 --output_dir 自定義一個(gè)輸出文件夾
- 原始圖片和模糊圖片合并在一起命令
python process.py --input_dir resize后的文件夾 --b_dir 上面模糊操作后的輸出文件夾 --operation combine --output_dir 自定義輸出文件夾
- 分為訓(xùn)練集和測(cè)試集(split.py同樣在之前的GitHub地址)
python split.py --dir 上面的合并后輸出文件夾
最后的生成結(jié)果類(lèi)似下面
訓(xùn)練模型(pix2pix.py還是在上面GitHub地址)
python pix2pix.py --mode train --output_dir 自定義模型輸出模型路徑 --max_epochs 20 --input_dir 上面輸出的訓(xùn)練文件夾 --which_direction BtoA
模型迭代
由于我把終端給關(guān)了役听,就不截圖了颓鲜,最后云端打包下來(lái)是這樣的形式:
測(cè)試模型
python pix2pix.py --mode test --output_dir 自定義輸出文件夾 --input_dir 之前生成的驗(yàn)證數(shù)據(jù)集目錄 --checkpoint 之前自定義的模型輸出文件夾
結(jié)果展示
左邊是模糊的,中間是模型生成的典予,右邊是原圖甜滨,效果不是特別好,之前看了看我的模型瘤袖,后面收斂的不是很好衣摩,但差不多就這意思了,另外我的數(shù)據(jù)集也不太好捂敌,僅作借鑒