DCGAN
網(wǎng)絡(luò)結(jié)構(gòu)
DCGAN其主要貢獻(xiàn)在于把原始GAN中的全連接層替換為了卷積層。具體如下:
- 首先是全卷積網(wǎng)絡(luò),這使用了跨步卷積代替了確定性的空間池化功能(例如最大池化等操作),從而能讓網(wǎng)絡(luò)能夠?qū)W習(xí)自身的空間下采樣召廷。
- 其次是再卷積層特征上消除全連接層的趨勢唁奢,全局池化就是一個最好的例子。
- 第三是采用了BatchNormalization喉恋,這個通過將輸入歸一化從而穩(wěn)定了訓(xùn)練的過程,并有助于在梯度在更深的模型中進(jìn)行流動母廷,BN并不用于生成器輸出層和鑒別器輸入層轻黑。
- 使用了ReLU激活函數(shù),并發(fā)現(xiàn)使用LeaklyReLU函數(shù)可以讓正常工作琴昆,特別是對于更高分辨率的建模氓鄙。
DCGAN的生成器結(jié)構(gòu)可以用如下的圖來表示:
DCGAN生成64*64圖像的生成器結(jié)構(gòu)
DCGAN的判別器和生成器的結(jié)構(gòu)基本相反,其主要是通過進(jìn)行卷積降維從而把輸入的圖像生成為一個標(biāo)量业舍,從而使用Sigmoid激活層確認(rèn)其概率抖拦。
一些的DCGAN結(jié)構(gòu)指南
- 用跨步卷積(針對鑒別器)和分?jǐn)?shù)跨步卷積(針對生成器)替換掉所有的池化層升酣。
- 在生成器和鑒別器中都使用BN,并且需要注意的是不對生成器的最后一層和鑒別器的輸入層使用BN态罪。
- 刪除掉全連接的隱藏層從而實現(xiàn)更深層次的體系結(jié)構(gòu)噩茄。
- 在生成器中全都使用ReLU激活函數(shù),并在最后一層使用Tanh激活函數(shù)
- 在鑒別其中复颈,對所有層使用LeakyReLU激活函數(shù)绩聘。
訓(xùn)練的一些細(xì)節(jié):
- 使用了batch_size=128
- 所有權(quán)重都服從0中心方差為0.02的正態(tài)分布。
- 在LeakyReLU的泄露斜率值都為0.2
- 使用Adam的優(yōu)化器耗啦,lr=0.0002君纫,
(作者發(fā)現(xiàn)0.9會有不穩(wěn)定的情況發(fā)生)
代碼實現(xiàn):
# 使用pytorch在ununtu20上使用的代碼
# gpu:Nvidia RTX2070s 8g顯存
import os,math,torch,torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torch.nn as nn
import random
from torch.utils.data import Dataset
random.seed(666)
torch.manual_seed(666)
from torch.autograd import Variable
import torch.nn.functional as F
os.makedirs('myImages', exist_ok=True)
#下面是一些初始化數(shù)據(jù)的定義
n_epochs=2
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=8
latent_dim=100
img_size=64
channels=3
sample_interval=400
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dataset = torchvision.datasets.MNIST(root='../../data/mnist',download=True,
# transform=transforms.Compose([transforms.Resize(size=img_size),
# transforms.ToTensor(),
# transforms.Normalize([0.5]*3,[0.5]*3)]
# )
# )
# dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=n_cpu)
import PIL.Image as Image
class CeleBaDataset(torch.utils.data.Dataset):
def __init__(self,img_root:str,transform=None):
super(CeleBaDataset,self).__init__()
temp_list=list()
for s in os.listdir(path=img_root):
if s.find('.png'):
temp_list.append(os.path.join(img_root,s))
self.datalist = temp_list
self.transform = transform
def __len__(self):
return len(self.datalist)
def __getitem__(self,idx):
image = Image.open(self.datalist[idx])
if self.transform:
image = self.transform(image)
return image
dataloader = DataLoader(dataset=CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/'
,transform=transforms.Compose([transforms.Resize(size=img_size),
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.5]*3,[0.5]*3),
])
),
batch_size=batch_size,
num_workers=n_cpu,
shuffle=False,
pin_memory=True)
#%%
def weight_init(modules:torch.nn.Module):
for m in modules.modules():
if isinstance(m,nn.ConvTranspose2d):
nn.init.normal_(m.weight.data,0,0.02)
elif isinstance(m,nn.BatchNorm2d):
nn.init.normal_(m.weight.data,0,0.02)
def weight_init_apply(m:object):
if m.__class__.__name__.find('Conv'):
nn.init.normal_(m.weight.data,0,0.02)
elif m.__class__.__name__.find('BatchNorm'):
nn.init.normal_(m.weight.data,0,0.02)
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
in_channels=[latent_dim,512,256,128,64]
out_channels=[512,256,128,64,3]
paddings=[0,1,1,1,1]
strides=[1,2,2,2,2]
layers=[]
for i in range(5):
layers.append(nn.BatchNorm2d(num_features=in_channels[i]))
layers.append(nn.ConvTranspose2d(in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=4,
stride=strides[i],
padding=paddings[i]))
if i != 4:
layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
else:
layers.append(nn.Tanh())
self.G=nn.Sequential(*layers)
def forward(self,x):
return self.G(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
layers=[]
def block(in_channels,out_channels,stride=2,padding=1,if_bn=True,if_relu=True):
if if_bn:
layers.append(nn.BatchNorm2d(in_channels))
layers.append(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=4,stride=stride,padding=padding))
if if_relu:
layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
else:
layers.append(nn.Sigmoid())
block(3,64,stride=2,padding=1,if_bn=False) # 此時64*32*32
block(64,128,2,1) # 此時128*16*16
block(128,256,2,1) # 此時256*8*8
block(256,512,2,1) # 此時512*4*4
block(512,1,1,0,if_relu=False) # 此時1*1*1
self.D=nn.Sequential(*layers)
def forward(self,x):
return self.D(x)
#%%
generator = Generator()
weight_init(generator)
discriminator=Discriminator()
weight_init(discriminator)
loss_fn = torch.nn.BCELoss()
generator.to(device)
discriminator.to(device)
loss_fn.to(device)
opm_G = torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
opm_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))
#%%
data = CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/')
data.__getitem__(10000)
#%%
for epoch in range(20):
for i,img in enumerate(dataloader):
img = img.to(device)
real = torch.ones((img.shape[0],1),device=device)
fake = torch.zeros((img.shape[0],1),device=device)
z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
opm_D.zero_grad()
real_loss = loss_fn(discriminator(img).view(img.shape[0],-1),real)
fake_loss = loss_fn(discriminator(generator(z).detach())view(img.shape[0],-1),fake)
d_loss = (real_loss+fake_loss)/2
d_loss.backward()
opm_D.step()
print('Dloss:',d_loss)
opm_G.zero_grad()
z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
g_loss = loss_fn(discriminator(generator(z)).view(img.shape[0],-1),fake)
g_loss.backward()
opm_G.step()
print('Gloss:',g_loss)
print('epoch:{}Dloss:{}Gloss:{}',epoch,d_loss,g_loss)