utility.py

import os
import math
import time
import datetime
from functools import reduce

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import numpy as np
import scipy.misc as misc
from skimage.restoration import denoise_bilateral

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs

class timer():
    def __init__(self):
        self.acc = 0
        self.tic()
        #print ("2-1-1-checkpoint")

    def tic(self):
        self.t0 = time.time()
        #print ("2-1-2-checkpoint")

    def toc(self):
        return time.time() - self.t0
        #print ("2-1-3-checkpoint")

    def hold(self):
        self.acc += self.toc()
        #print ("2-1-4-checkpoint")

    def release(self):
        ret = self.acc
        self.acc = 0
        #print ("2-1-5-checkpoint")

        return ret

    def reset(self):
        self.acc = 0
        #print ("2-1-6-checkpoint")

class checkpoint():
    def __init__(self, args):
        self.args = args
        self.ok = True
        self.log = torch.Tensor()
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')

        if args.load == '.':
            if args.save == '.': args.save = now
            self.dir = '../experiment/' + args.save
        else:
            self.dir = '../experiment/' + args.load
            if not os.path.exists(self.dir):
                args.load = '.'
            else:
                self.log = torch.load(self.dir + '/psnr_log.pt')
                print('Continue from epoch {}...'.format(len(self.log)))

        if args.reset:
            os.system('rm -rf ' + self.dir)
            args.load = '.'

        def _make_dir(path):
            if not os.path.exists(path): os.makedirs(path)

        _make_dir(self.dir)
        _make_dir(self.dir + '/model')
        _make_dir(self.dir + '/results')
        _make_dir(self.dir + '/residuals')
        _make_dir(self.dir + '/branches')

        open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
        self.log_file = open(self.dir + '/log.txt', open_type)
        with open(self.dir + '/config.txt', open_type) as f:
            f.write(now + '\n\n')
            for arg in vars(args):
                f.write('{}: {}\n'.format(arg, getattr(args, arg)))
            f.write('\n')
        #print ("2-2-1-checkpoint")

    def save(self, trainer, epoch, is_best=False):
        trainer.model.save(self.dir, epoch, is_best=is_best)
        trainer.loss.save(self.dir)
        trainer.loss.plot_loss(self.dir, epoch)

        self.plot_psnr(epoch)
        torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
        torch.save(
            trainer.optimizer.state_dict(),
            os.path.join(self.dir, 'optimizer.pt')
        )
        #print ("2-2-2-checkpoint")

    def add_log(self, log):
        self.log = torch.cat([self.log, log])
        #print ("2-2-3-checkpoint")

    def write_log(self, log, refresh=False):
        #print(log)
        self.log_file.write(log + '\n')
        if refresh:
            self.log_file.close()
            self.log_file = open(self.dir + '/log.txt', 'a')
        #print ("2-2-4-checkpoint")

    def done(self):
        self.log_file.close()
        #print ("2-2-5-checkpoint")

    def plot_psnr(self, epoch):
        axis = np.linspace(1, epoch, epoch)
        label = 'SR on {}'.format(self.args.data_test)
        fig = plt.figure()
        plt.title(label)
        for idx_scale, scale in enumerate(self.args.scale):
            plt.plot(
                axis,
                self.log[:, idx_scale].numpy(),
                label='Scale {}'.format(scale)
            )
        plt.legend()
        plt.xlabel('Epochs')
        plt.ylabel('PSNR')
        plt.grid(True)
        plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
        plt.close(fig)
        #print ("2-2-5-checkpoint")

    def save_results(self, filename, save_list, scale):
        filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
        postfix = ('SR', 'LR', 'HR')
        for v, p in zip(save_list, postfix):
            normalized = v[0].data.mul(255 / self.args.rgb_range)
            ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
            
            if ndarr.shape[-1] == 1: 
                    ndarr = ndarr[:,:,0] 
                    
            misc.imsave('{}{}.png'.format(filename, p), ndarr)
        #print ("2-2-6-checkpoint")

    def save_residuals(self, filename, save_list, scale): 
        filename = '{}/residuals/{}_x{}'.format(self.dir, filename, scale)
        sr, hr = save_list[0], save_list[-1]

        def _prepare(x):
            normalized = x[0].data.mul(1. / self.args.rgb_range)
            out = normalized.permute(1,2,0).cpu().numpy()
            
            if out.shape[-1] == 1: 
                out = out[:,:,0]

            return out 

        ndarr_sr, ndarr_hr = _prepare(sr), _prepare(hr)
        out = np.abs(ndarr_hr - ndarr_sr)
        misc.imsave('{}.png'.format(filename), out)
        #print ("2-2-7-checkpoint")

    def save_branches(self, filename, save_list, scale): 
        filename = '{}/branches/{}_x{}'.format(self.dir, filename, scale)
        
        def _prepare(x, residual):
            normalized = x[0].data.mul(1. / self.args.rgb_range)
            if not residual: 
                out = normalized.permute(1,2,0).cpu().numpy()
            else: 
                out = np.abs(normalized.permute(1,2,0).cpu().numpy())

            if out.shape[-1] == 1: 
                out = out[:,:,0]
            return out 

        for i, branch_output in enumerate(save_list): 
            ndarr = _prepare(branch_output, not (i==0))
            misc.imsave('{}{}.png'.format(filename, '_branch{}'.format(i)), ndarr)
        #print ("2-2-8-checkpoint")
        return 

def get_bilateral(tensor, rgb_range): 
    tensor = tensor.numpy().transpose(0,2,3,1) / rgb_range
    out = np.zeros_like(tensor)

    for i, t in enumerate(tensor): 
        out[i] = denoise_bilateral(t)

    #print ("2-3-checkpoint")
    return torch.Tensor(out.transpose(0,3,1,2)) * rgb_range

def quantize(img, rgb_range):
    pixel_range = 255 / rgb_range
    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
    #print ("2-4-checkpoint")

def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
    diff = (sr - hr).data.div(rgb_range)
    if benchmark:
        shave = scale
        if diff.size(1) > 1:
            convert = diff.new(1, 3, 1, 1)
            convert[0, 0, 0, 0] = 65.738
            convert[0, 1, 0, 0] = 129.057
            convert[0, 2, 0, 0] = 25.064
            diff.mul_(convert).div_(256)
            diff = diff.sum(dim=1, keepdim=True)
    else:
        shave = scale + 6

    valid = diff[:, :, shave:-shave, shave:-shave]
    mse = valid.pow(2).mean()
    #print ("2-5-checkpoint")
    return -10 * math.log10(mse)

def make_optimizer(args, my_model):
    trainable = filter(lambda x: x.requires_grad, my_model.parameters())

    if args.optimizer == 'SGD':
        optimizer_function = optim.SGD
        kwargs = {'momentum': args.momentum}
    elif args.optimizer == 'ADAM':
        optimizer_function = optim.Adam
        kwargs = {
            'betas': (args.beta1, args.beta2),
            'eps': args.epsilon
        }
    elif args.optimizer == 'RMSprop':
        optimizer_function = optim.RMSprop
        kwargs = {'eps': args.epsilon}

    kwargs['lr'] = args.lr
    kwargs['weight_decay'] = args.weight_decay
    #print ("2-6-checkpoint")
    return optimizer_function(trainable, **kwargs)

def make_scheduler(args, my_optimizer):
    if args.decay_type == 'step':
        scheduler = lrs.StepLR(
            my_optimizer,
            step_size=args.lr_decay,
            gamma=args.gamma
        )
    elif args.decay_type.find('step') >= 0:
        milestones = args.decay_type.split('_')
        milestones.pop(0)
        milestones = list(map(lambda x: int(x), milestones))
        scheduler = lrs.MultiStepLR(
            my_optimizer,
            milestones=milestones,
            gamma=args.gamma
        )
    #print ("2-7-checkpoint")
    return scheduler
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末烟瞧,一起剝皮案震驚了整個(gè)濱河市缆八,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌椭豫,老刑警劉巖豁生,帶你破解...
    沈念sama閱讀 211,561評(píng)論 6 492
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件碧囊,死亡現(xiàn)場(chǎng)離奇詭異帚桩,居然都是意外死亡特漩,警方通過查閱死者的電腦和手機(jī)吧雹,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,218評(píng)論 3 385
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來(lái)涂身,“玉大人雄卷,你說(shuō)我怎么就攤上這事「蚴郏” “怎么了丁鹉?”我有些...
    開封第一講書人閱讀 157,162評(píng)論 0 348
  • 文/不壞的土叔 我叫張陵妒潭,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我揣钦,道長(zhǎng)雳灾,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 56,470評(píng)論 1 283
  • 正文 為了忘掉前任冯凹,我火速辦了婚禮佑女,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘谈竿。我一直安慰自己团驱,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,550評(píng)論 6 385
  • 文/花漫 我一把揭開白布空凸。 她就那樣靜靜地躺著嚎花,像睡著了一般。 火紅的嫁衣襯著肌膚如雪呀洲。 梳的紋絲不亂的頭發(fā)上紊选,一...
    開封第一講書人閱讀 49,806評(píng)論 1 290
  • 那天,我揣著相機(jī)與錄音道逗,去河邊找鬼兵罢。 笑死,一個(gè)胖子當(dāng)著我的面吹牛滓窍,可吹牛的內(nèi)容都是我干的卖词。 我是一名探鬼主播,決...
    沈念sama閱讀 38,951評(píng)論 3 407
  • 文/蒼蘭香墨 我猛地睜開眼吏夯,長(zhǎng)吁一口氣:“原來(lái)是場(chǎng)噩夢(mèng)啊……” “哼此蜈!你這毒婦竟也來(lái)了?” 一聲冷哼從身側(cè)響起噪生,我...
    開封第一講書人閱讀 37,712評(píng)論 0 266
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤裆赵,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后跺嗽,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體战授,經(jīng)...
    沈念sama閱讀 44,166評(píng)論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,510評(píng)論 2 327
  • 正文 我和宋清朗相戀三年桨嫁,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了植兰。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 38,643評(píng)論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡瞧甩,死狀恐怖钉跷,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情肚逸,我是刑警寧澤爷辙,帶...
    沈念sama閱讀 34,306評(píng)論 4 330
  • 正文 年R本政府宣布彬坏,位于F島的核電站,受9級(jí)特大地震影響膝晾,放射性物質(zhì)發(fā)生泄漏栓始。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,930評(píng)論 3 313
  • 文/蒙蒙 一血当、第九天 我趴在偏房一處隱蔽的房頂上張望幻赚。 院中可真熱鬧,春花似錦臊旭、人聲如沸落恼。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,745評(píng)論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)佳谦。三九已至,卻和暖如春滋戳,著一層夾襖步出監(jiān)牢的瞬間钻蔑,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 31,983評(píng)論 1 266
  • 我被黑心中介騙來(lái)泰國(guó)打工奸鸯, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留咪笑,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 46,351評(píng)論 2 360
  • 正文 我出身青樓娄涩,卻偏偏與公主長(zhǎng)得像窗怒,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子钝满,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,509評(píng)論 2 348