import os
import math
import time
import datetime
from functools import reduce
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc
from skimage.restoration import denoise_bilateral
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class timer():
def __init__(self):
self.acc = 0
self.tic()
#print ("2-1-1-checkpoint")
def tic(self):
self.t0 = time.time()
#print ("2-1-2-checkpoint")
def toc(self):
return time.time() - self.t0
#print ("2-1-3-checkpoint")
def hold(self):
self.acc += self.toc()
#print ("2-1-4-checkpoint")
def release(self):
ret = self.acc
self.acc = 0
#print ("2-1-5-checkpoint")
return ret
def reset(self):
self.acc = 0
#print ("2-1-6-checkpoint")
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if args.load == '.':
if args.save == '.': args.save = now
self.dir = '../experiment/' + args.save
else:
self.dir = '../experiment/' + args.load
if not os.path.exists(self.dir):
args.load = '.'
else:
self.log = torch.load(self.dir + '/psnr_log.pt')
print('Continue from epoch {}...'.format(len(self.log)))
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = '.'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
_make_dir(self.dir + '/model')
_make_dir(self.dir + '/results')
_make_dir(self.dir + '/residuals')
_make_dir(self.dir + '/branches')
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
self.log_file = open(self.dir + '/log.txt', open_type)
with open(self.dir + '/config.txt', open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
#print ("2-2-1-checkpoint")
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.dir, epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
torch.save(
trainer.optimizer.state_dict(),
os.path.join(self.dir, 'optimizer.pt')
)
#print ("2-2-2-checkpoint")
def add_log(self, log):
self.log = torch.cat([self.log, log])
#print ("2-2-3-checkpoint")
def write_log(self, log, refresh=False):
#print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
#print ("2-2-4-checkpoint")
def done(self):
self.log_file.close()
#print ("2-2-5-checkpoint")
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.args.data_test)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
plt.close(fig)
#print ("2-2-5-checkpoint")
def save_results(self, filename, save_list, scale):
filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].data.mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
if ndarr.shape[-1] == 1:
ndarr = ndarr[:,:,0]
misc.imsave('{}{}.png'.format(filename, p), ndarr)
#print ("2-2-6-checkpoint")
def save_residuals(self, filename, save_list, scale):
filename = '{}/residuals/{}_x{}'.format(self.dir, filename, scale)
sr, hr = save_list[0], save_list[-1]
def _prepare(x):
normalized = x[0].data.mul(1. / self.args.rgb_range)
out = normalized.permute(1,2,0).cpu().numpy()
if out.shape[-1] == 1:
out = out[:,:,0]
return out
ndarr_sr, ndarr_hr = _prepare(sr), _prepare(hr)
out = np.abs(ndarr_hr - ndarr_sr)
misc.imsave('{}.png'.format(filename), out)
#print ("2-2-7-checkpoint")
def save_branches(self, filename, save_list, scale):
filename = '{}/branches/{}_x{}'.format(self.dir, filename, scale)
def _prepare(x, residual):
normalized = x[0].data.mul(1. / self.args.rgb_range)
if not residual:
out = normalized.permute(1,2,0).cpu().numpy()
else:
out = np.abs(normalized.permute(1,2,0).cpu().numpy())
if out.shape[-1] == 1:
out = out[:,:,0]
return out
for i, branch_output in enumerate(save_list):
ndarr = _prepare(branch_output, not (i==0))
misc.imsave('{}{}.png'.format(filename, '_branch{}'.format(i)), ndarr)
#print ("2-2-8-checkpoint")
return
def get_bilateral(tensor, rgb_range):
tensor = tensor.numpy().transpose(0,2,3,1) / rgb_range
out = np.zeros_like(tensor)
for i, t in enumerate(tensor):
out[i] = denoise_bilateral(t)
#print ("2-3-checkpoint")
return torch.Tensor(out.transpose(0,3,1,2)) * rgb_range
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
#print ("2-4-checkpoint")
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
#print ("2-5-checkpoint")
return -10 * math.log10(mse)
def make_optimizer(args, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
if args.optimizer == 'SGD':
optimizer_function = optim.SGD
kwargs = {'momentum': args.momentum}
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (args.beta1, args.beta2),
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': args.epsilon}
kwargs['lr'] = args.lr
kwargs['weight_decay'] = args.weight_decay
#print ("2-6-checkpoint")
return optimizer_function(trainable, **kwargs)
def make_scheduler(args, my_optimizer):
if args.decay_type == 'step':
scheduler = lrs.StepLR(
my_optimizer,
step_size=args.lr_decay,
gamma=args.gamma
)
elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_')
milestones.pop(0)
milestones = list(map(lambda x: int(x), milestones))
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
#print ("2-7-checkpoint")
return scheduler
utility.py
最后編輯于 :
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
- 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)涂身,“玉大人雄卷,你說(shuō)我怎么就攤上這事「蚴郏” “怎么了丁鹉?”我有些...
- 文/不壞的土叔 我叫張陵妒潭,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我揣钦,道長(zhǎng)雳灾,這世上最難降的妖魔是什么? 我笑而不...
- 正文 為了忘掉前任冯凹,我火速辦了婚禮佑女,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘谈竿。我一直安慰自己团驱,他們只是感情好,可當(dāng)我...
- 文/花漫 我一把揭開白布空凸。 她就那樣靜靜地躺著嚎花,像睡著了一般。 火紅的嫁衣襯著肌膚如雪呀洲。 梳的紋絲不亂的頭發(fā)上紊选,一...
- 那天,我揣著相機(jī)與錄音道逗,去河邊找鬼兵罢。 笑死,一個(gè)胖子當(dāng)著我的面吹牛滓窍,可吹牛的內(nèi)容都是我干的卖词。 我是一名探鬼主播,決...
- 文/蒼蘭香墨 我猛地睜開眼吏夯,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼此蜈!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起噪生,我...
- 序言:老撾萬(wàn)榮一對(duì)情侶失蹤裆赵,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后跺嗽,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體战授,經(jīng)...
- 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
- 正文 我和宋清朗相戀三年桨嫁,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了植兰。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
- 正文 年R本政府宣布彬坏,位于F島的核電站,受9級(jí)特大地震影響膝晾,放射性物質(zhì)發(fā)生泄漏栓始。R本人自食惡果不足惜,卻給世界環(huán)境...
- 文/蒙蒙 一血当、第九天 我趴在偏房一處隱蔽的房頂上張望幻赚。 院中可真熱鬧,春花似錦臊旭、人聲如沸落恼。這莊子的主人今日做“春日...
- 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)佳谦。三九已至,卻和暖如春滋戳,著一層夾襖步出監(jiān)牢的瞬間钻蔑,已是汗流浹背。 一陣腳步聲響...
- 正文 我出身青樓娄涩,卻偏偏與公主長(zhǎng)得像窗怒,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子钝满,可洞房花燭夜當(dāng)晚...