實現(xiàn)黑白圖片自動變成彩色圖片
如果你有一幅黑白圖片,你該如何上色讓他變成彩色的呢?通常做法可能是使用PS工具來進(jìn)行上色滔金。那么屎蜓,有沒有什么辦法進(jìn)行自動上色呢痘昌?自動將黑白圖片變成彩色圖片?答案是有的,使用深度學(xué)習(xí)中的Pix2Pix網(wǎng)絡(luò)就可以實現(xiàn)這一功能辆苔。
如圖所示笔诵,我們可以將黑白動漫圖片,通過網(wǎng)絡(luò)學(xué)習(xí)姑子,自動變成彩色乎婿。對這個Pix2Pix網(wǎng)絡(luò)是如何實現(xiàn)的,想要進(jìn)一步了解網(wǎng)絡(luò)和代碼的話街佑,可以點擊這個
課程鏈接
下面谢翎,對這個網(wǎng)絡(luò)進(jìn)行一點簡要介紹。
Pix2Pix網(wǎng)絡(luò)介紹
pix2pix算是cGAN的一種沐旨,但是和cGAN又略有不同森逮,而且,在pix2pix這篇論文中磁携,首次提出了PatchGAN的概念褒侧,初次接觸到的人可能會略有疑惑。這篇文章谊迄,我們就一起來探討一下闷供,pix2pix中的判別器是如何設(shè)計的。
cGAN
提到pix2pix就一定要提一下统诺,他的思想源泉歪脏,cGAN。最初我們所熟知的GAN的概念粮呢,當(dāng)屬造假鈔和驗假鈔的對抗過程(誕生了DCGAN)婿失,造假鈔造出來的假鈔越來越像真鈔,驗假鈔的越來越能夠識別假鈔啄寡。我們從這個具體故事里面抽象出來豪硅,其實就是說,生成器生成的圖片夠真挺物,就可以騙過判別器懒浮。至于這個生成器生成的圖片真的是我們想要的?就不一定了姻乓。
[圖片上傳失敗...(image-256cfe-1576999041338)]
另外還有一個問題嵌溢,比如上面這幅圖。如果我有一堆火車圖片蹋岩。有正面的也有側(cè)面的赖草,我們都知道這是火車。但是生成對抗網(wǎng)絡(luò)其實并不理解剪个。如果用最基本的GAN(比如DCGAN)來做的話秧骑,很有可能最后就會得到一個normal的圖片,就是正面和側(cè)面火車平均之后的一個圖片。就會導(dǎo)致訓(xùn)練之后的圖片結(jié)果很模糊乎折。
cGAN就是來解決這個問題的绒疗。c表示conditional,是控制骂澄。我想讓生成器生成小狗的圖片吓蘑,他就不能生成火車的圖片。此時我們的D和G不再是單獨的一個輸入坟冲,而是兩種輸入磨镶。
[圖片上傳失敗...(image-927613-1576999041338)]
在生成器部分,我們不僅輸入normal distribution健提,還輸入一個條件c(比如cat或者train)琳猫。我們在判別器部分,也輸入兩個私痹,一個是條件c脐嫂,另外一個是x(生成的數(shù)據(jù)或者真實的數(shù)據(jù))。這里判別器的目的不僅僅要看生成的x數(shù)據(jù)是否和真實數(shù)據(jù)分布接近紊遵。還要看和條件c是否一致账千。對于判別器而言,生成的圖片不好癞蚕,還有生成的圖片和c不匹配蕊爵,都要給它低分。
pix2pix的判別器
在pix2pix中我們的判別器構(gòu)造和cGAN思想基本一致桦山,但稍有不同。
[圖片上傳失敗...(image-c797a8-1576999041338)]
這里醋旦,我們的判別器輸入兩張圖像恒水,一張是G的input圖像,一張是G的output圖像饲齐。也就是說钉凌,對于判別器而言,不只是輸出高質(zhì)量的圖像就可以騙過判別器捂人,必須要兩張圖像有對應(yīng)關(guān)系才可以御雕。
pix2pix的判別器訓(xùn)練代碼
下面,我們從代碼詳細(xì)的看一下滥搭,pix2pix是如何對判別器進(jìn)行計算的酸纲。
real_a, real_b = batch[0].to(device), batch[1].to(device)
fake_b = net_g(real_a)
optimizer_d.zero_grad()
# 判別器對虛假數(shù)據(jù)進(jìn)行訓(xùn)練
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)
# 判別器對真實數(shù)據(jù)進(jìn)行訓(xùn)練
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)
# 判別器損失
loss_d = (loss_d_fake + loss_d_real) * 0.5
loss_d.backward()
optimizer_d.step()
從代碼中我們可以看到,對判別器而言瑟匆,輸入數(shù)據(jù)需要通過cat來連接之后一起輸入闽坡。real_a和fake_b的結(jié)合數(shù)據(jù)為假。real_a和real_b結(jié)合的數(shù)據(jù)為真。關(guān)于代碼中為什么D有detach而G沒有detach可以看我寫的[2]疾嗅。
我們來比較一下DCGAN是怎么做的外厂,下面是DCGAN的代碼:
# 訓(xùn)練判別器
optimizer_d.zero_grad()
## 盡可能把真圖片判別為正
output = netd(real_img)
error_d_real = criterion(output, true_labels)
error_d_real.backward()
## 盡可能把假圖片判斷為錯誤
noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
# 使用detach來關(guān)閉G求梯度,加速訓(xùn)練
fake_img = netg(noises).detach()
output = netd(fake_img)
error_d_fake = criterion(output, fake_labels)
error_d_fake.backward()
optimizer_d.step()
error_d = error_d_fake + error_d_real
errord_meter.add(error_d.item())
DCGAN和cGAN不太一樣的地方就是輸入數(shù)據(jù)不需要concatenate代承,也就是沒有條件c的意思汁蝶。pix2pix中判別器有兩個輸入是要求,兩個圖片必須匹配才算是正確的论悴。
如果對optimzer掖棉,loss等流程不太清楚,可以看參考[3]
PatchGAN
pix2pix判別器另外一個設(shè)計點意荤,就在PatchGAN了啊片。我們先來看一下PatchGAN的網(wǎng)絡(luò)結(jié)構(gòu)。
[圖片上傳失敗...(image-3bbd42-1576999041338)]
[圖片上傳失敗...(image-8b4bf8-1576999041338)]
下面是對應(yīng)代碼部分:
class NLayerDiscriminator(nn.Module):
"""
定義PatchGAN判別器
"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
"""
構(gòu)建PatchGAN判別器
參數(shù):
input_nc --輸入圖片通道數(shù)
ndf --最后一個卷積層過濾器的數(shù)量
n_layers --判別器卷積層的數(shù)量
norm_layer --標(biāo)準(zhǔn)化層
use_sigmoid --是否使用sigmoid函數(shù)
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4 # kernel size
padw = 1 # padding
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
# 逐漸增加過濾器的數(shù)量
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
從網(wǎng)絡(luò)結(jié)構(gòu)中可以看到玖像,并且結(jié)合之前torch.cat我們可以看到紫谷,輸入的shape是6*256*256,然后輸出的shape是1*30*30捐寥。
論文中稱PatchGAN是一種馬爾科夫判別器笤昨。關(guān)于PatchGAN的理解可以看[6],之前我們說了PatchGAN輸出的是一個1*30*30的矩陣握恳。這和我們普通的GAN里面輸出一個預(yù)測值完全不同瞒窒。一個矩陣怎么做預(yù)測呢?我們的做法是把預(yù)測值也擴(kuò)展成一個1*30*30的矩陣乡洼。之后對二者使用最小二乘損失崇裁。這相當(dāng)于對1*30*30的矩陣的每一個點都對應(yīng)一個label。
通過對圖像進(jìn)行卷積操作束昵,后面的輸出矩陣拔稳,對前面部分有了更大的感受野(如果不明白感受野,可以看一下這里)锹雏。那么巴比,最后輸出的30*30的每一個點,相當(dāng)于最初輸入圖像的一個Patch礁遵,所以命名為PatchGAN轻绞。根據(jù)論文中描述的,這個Patch大小為70佣耐。
這個70是如何計算出來的呢政勃?
感受野計算公式我參考的是[7],下面的表格是PatchGAN網(wǎng)絡(luò)感受野的計算晰赞,可以看到30*30的矩陣稼病,每一個pixel對應(yīng)的感受野的確是70*70选侨。
[圖片上傳失敗...(image-c67138-1576999041338)]
Layer | Input Size | Kernel Size | Stride | Output Size | Receptive Field |
---|---|---|---|---|---|
Conv1 | 256*256 | 4*4 | 2 | 128*128 | 4 |
Conv2 | 128*128 | 4*4 | 2 | 64*64 | 10 |
Conv3 | 64*64 | 4*4 | 2 | 32*32 | 22 |
Conv4 | 32*32 | 4*4 | 1 | 31*31 | 46 |
Conv5 | 31*31 | 4*4 | 1 | 30*30 | 70 |
另外,可以點擊這個網(wǎng)站:Fomoro AI然走,可以自動幫你分析計算感受野援制。
這樣,使用PatchGAN處理之后芍瑞,pix2pix就將圖像切割成30*30份晨仑,每一份對應(yīng)一個70*70的patch,我們想要每個patch的結(jié)果都為真拆檬。通過聚焦于一個patch的局部位置洪己,可以更好地提高整體識別和判斷效果。
參考
[1]李宏毅生成對抗網(wǎng)絡(luò)2018
[2]訓(xùn)練生成對抗網(wǎng)絡(luò)的過程中竟贯,訓(xùn)練gan的地方為什么這里沒有detach,怎么保證訓(xùn)練生成器的時候不會改變判別器
[3Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
[4]pix2pix主要代碼學(xué)習(xí)
[5][GAN筆記] pix2pix
[6]關(guān)于PatchGAN的理解
[7]關(guān)于感受野的理解與計算