DCGAN學(xué)習(xí)和編程

DCGAN

網(wǎng)絡(luò)結(jié)構(gòu)

DCGAN其主要貢獻(xiàn)在于把原始GAN中的全連接層替換為了卷積層。具體如下:

  • 首先是全卷積網(wǎng)絡(luò),這使用了跨步卷積代替了確定性的空間池化功能(例如最大池化等操作),從而能讓網(wǎng)絡(luò)能夠?qū)W習(xí)自身的空間下采樣召廷。
  • 其次是再卷積層特征上消除全連接層的趨勢唁奢,全局池化就是一個最好的例子。
  • 第三是采用了BatchNormalization喉恋,這個通過將輸入歸一化從而穩(wěn)定了訓(xùn)練的過程,并有助于在梯度在更深的模型中進(jìn)行流動母廷,BN并不用于生成器輸出層和鑒別器輸入層轻黑。
  • 使用了ReLU激活函數(shù),并發(fā)現(xiàn)使用LeaklyReLU函數(shù)可以讓正常工作琴昆,特別是對于更高分辨率的建模氓鄙。

DCGAN的生成器結(jié)構(gòu)可以用如下的圖來表示:


DCGAN生成64*64圖像的生成器結(jié)構(gòu)

DCGAN的判別器和生成器的結(jié)構(gòu)基本相反,其主要是通過進(jìn)行卷積降維從而把輸入的圖像生成為一個標(biāo)量业舍,從而使用Sigmoid激活層確認(rèn)其概率抖拦。

一些的DCGAN結(jié)構(gòu)指南

  • 用跨步卷積(針對鑒別器)和分?jǐn)?shù)跨步卷積(針對生成器)替換掉所有的池化層升酣。
  • 在生成器和鑒別器中都使用BN,并且需要注意的是不對生成器的最后一層和鑒別器的輸入層使用BN态罪。
  • 刪除掉全連接的隱藏層從而實現(xiàn)更深層次的體系結(jié)構(gòu)噩茄。
  • 在生成器中全都使用ReLU激活函數(shù),并在最后一層使用Tanh激活函數(shù)
  • 在鑒別其中复颈,對所有層使用LeakyReLU激活函數(shù)绩聘。

訓(xùn)練的一些細(xì)節(jié):

  • 使用了batch_size=128
  • 所有權(quán)重都服從0中心方差為0.02的正態(tài)分布。
  • 在LeakyReLU的泄露斜率值都為0.2
  • 使用Adam的優(yōu)化器耗啦,lr=0.0002君纫,\beta 1=0.5(作者發(fā)現(xiàn)0.9會有不穩(wěn)定的情況發(fā)生)

代碼實現(xiàn):

# 使用pytorch在ununtu20上使用的代碼
# gpu:Nvidia RTX2070s 8g顯存

import os,math,torch,torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torch.nn as nn
import random
from torch.utils.data import Dataset
random.seed(666)
torch.manual_seed(666)
from torch.autograd import Variable
import torch.nn.functional as F
os.makedirs('myImages', exist_ok=True)
#下面是一些初始化數(shù)據(jù)的定義
n_epochs=2
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=8
latent_dim=100
img_size=64
channels=3
sample_interval=400
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# dataset = torchvision.datasets.MNIST(root='../../data/mnist',download=True,
#                             transform=transforms.Compose([transforms.Resize(size=img_size),
#                                                          transforms.ToTensor(),
#                                                          transforms.Normalize([0.5]*3,[0.5]*3)]
#                                                           )
#                             )
# dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=False,num_workers=n_cpu)
import PIL.Image as Image
class CeleBaDataset(torch.utils.data.Dataset):
    def __init__(self,img_root:str,transform=None):
        super(CeleBaDataset,self).__init__()
        temp_list=list()
        for s in os.listdir(path=img_root):
            if s.find('.png'):
                temp_list.append(os.path.join(img_root,s))
        self.datalist = temp_list
        self.transform = transform
    def __len__(self):
        return len(self.datalist)
    def __getitem__(self,idx):
        image = Image.open(self.datalist[idx])
        if self.transform:
            image = self.transform(image)
        return image
dataloader = DataLoader(dataset=CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/'
                                              ,transform=transforms.Compose([transforms.Resize(size=img_size),
                                                         transforms.Resize(64),
                                                         transforms.CenterCrop(64),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize([0.5]*3,[0.5]*3),
                                                        ])
                                              ),
                        batch_size=batch_size,
                        num_workers=n_cpu,
                        shuffle=False,
                        pin_memory=True)

#%%

def weight_init(modules:torch.nn.Module):
    for m in modules.modules():
        if isinstance(m,nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data,0,0.02)
        elif isinstance(m,nn.BatchNorm2d):
            nn.init.normal_(m.weight.data,0,0.02)
def weight_init_apply(m:object):
    if m.__class__.__name__.find('Conv'):
        nn.init.normal_(m.weight.data,0,0.02)
    elif m.__class__.__name__.find('BatchNorm'):
        nn.init.normal_(m.weight.data,0,0.02)

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        in_channels=[latent_dim,512,256,128,64]
        out_channels=[512,256,128,64,3]
        paddings=[0,1,1,1,1]
        strides=[1,2,2,2,2]
        layers=[]
        for i in range(5):
            layers.append(nn.BatchNorm2d(num_features=in_channels[i]))
            layers.append(nn.ConvTranspose2d(in_channels=in_channels[i],
                                             out_channels=out_channels[i],
                                             kernel_size=4,
                                             stride=strides[i],
                                             padding=paddings[i]))
            if i != 4:
                layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
            else:
                layers.append(nn.Tanh())
        self.G=nn.Sequential(*layers)

    def forward(self,x):
        return self.G(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        layers=[]
        def block(in_channels,out_channels,stride=2,padding=1,if_bn=True,if_relu=True):
            if if_bn:
                layers.append(nn.BatchNorm2d(in_channels))
            layers.append(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=4,stride=stride,padding=padding))
            if if_relu:
                layers.append(nn.LeakyReLU(negative_slope=0.2,inplace=True))
            else:
                layers.append(nn.Sigmoid())
        block(3,64,stride=2,padding=1,if_bn=False)  # 此時64*32*32
        block(64,128,2,1)                           # 此時128*16*16
        block(128,256,2,1)                          # 此時256*8*8
        block(256,512,2,1)                          # 此時512*4*4
        block(512,1,1,0,if_relu=False)              # 此時1*1*1
        self.D=nn.Sequential(*layers)

    def forward(self,x):
        return self.D(x)

#%%

generator = Generator()
weight_init(generator)
discriminator=Discriminator()
weight_init(discriminator)
loss_fn = torch.nn.BCELoss()
generator.to(device)
discriminator.to(device)
loss_fn.to(device)

opm_G = torch.optim.Adam(generator.parameters(),lr=lr,betas=(b1,b2))
opm_D = torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(b1,b2))

#%%
data = CeleBaDataset(img_root='/home/hx/Desktop/WorkDisk/DadaSets/CelebA/Img/img_align_celeba_png.7z/img_align_celeba_png/')
data.__getitem__(10000)

#%%

for epoch in range(20):
    for i,img in enumerate(dataloader):
        img = img.to(device)
        real = torch.ones((img.shape[0],1),device=device)
        fake = torch.zeros((img.shape[0],1),device=device)
        z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
        opm_D.zero_grad()
        real_loss = loss_fn(discriminator(img).view(img.shape[0],-1),real)
        fake_loss = loss_fn(discriminator(generator(z).detach())view(img.shape[0],-1),fake)
        d_loss = (real_loss+fake_loss)/2
        d_loss.backward()
        opm_D.step()
        print('Dloss:',d_loss)

        opm_G.zero_grad()
        z = torch.randn(size=(img.shape[0],latent_dim,1,1),device=device)
        g_loss = loss_fn(discriminator(generator(z)).view(img.shape[0],-1),fake)
        g_loss.backward()
        opm_G.step()
        print('Gloss:',g_loss)
    print('epoch:{}Dloss:{}Gloss:{}',epoch,d_loss,g_loss)

最后的圖像生成效果。芹彬。蓄髓。待跑完

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市舒帮,隨后出現(xiàn)的幾起案子会喝,更是在濱河造成了極大的恐慌,老刑警劉巖玩郊,帶你破解...
    沈念sama閱讀 218,204評論 6 506
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件肢执,死亡現(xiàn)場離奇詭異,居然都是意外死亡译红,警方通過查閱死者的電腦和手機预茄,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 93,091評論 3 395
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來侦厚,“玉大人耻陕,你說我怎么就攤上這事∨俾伲” “怎么了诗宣?”我有些...
    開封第一講書人閱讀 164,548評論 0 354
  • 文/不壞的土叔 我叫張陵,是天一觀的道長想诅。 經(jīng)常有香客問我召庞,道長,這世上最難降的妖魔是什么来破? 我笑而不...
    開封第一講書人閱讀 58,657評論 1 293
  • 正文 為了忘掉前任篮灼,我火速辦了婚禮,結(jié)果婚禮上徘禁,老公的妹妹穿的比我還像新娘诅诱。我一直安慰自己,他們只是感情好晌坤,可當(dāng)我...
    茶點故事閱讀 67,689評論 6 392
  • 文/花漫 我一把揭開白布逢艘。 她就那樣靜靜地躺著,像睡著了一般骤菠。 火紅的嫁衣襯著肌膚如雪它改。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 51,554評論 1 305
  • 那天商乎,我揣著相機與錄音央拖,去河邊找鬼。 笑死鹉戚,一個胖子當(dāng)著我的面吹牛鲜戒,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播抹凳,決...
    沈念sama閱讀 40,302評論 3 418
  • 文/蒼蘭香墨 我猛地睜開眼遏餐,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了赢底?” 一聲冷哼從身側(cè)響起失都,我...
    開封第一講書人閱讀 39,216評論 0 276
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎幸冻,沒想到半個月后粹庞,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 45,661評論 1 314
  • 正文 獨居荒郊野嶺守林人離奇死亡洽损,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點故事閱讀 37,851評論 3 336
  • 正文 我和宋清朗相戀三年庞溜,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片碑定。...
    茶點故事閱讀 39,977評論 1 348
  • 序言:一個原本活蹦亂跳的男人離奇死亡流码,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出延刘,到底是詐尸還是另有隱情旅掂,我是刑警寧澤,帶...
    沈念sama閱讀 35,697評論 5 347
  • 正文 年R本政府宣布访娶,位于F島的核電站商虐,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏崖疤。R本人自食惡果不足惜秘车,卻給世界環(huán)境...
    茶點故事閱讀 41,306評論 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望劫哼。 院中可真熱鬧叮趴,春花似錦、人聲如沸权烧。這莊子的主人今日做“春日...
    開封第一講書人閱讀 31,898評論 0 22
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至妻率,卻和暖如春乱顾,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背宫静。 一陣腳步聲響...
    開封第一講書人閱讀 33,019評論 1 270
  • 我被黑心中介騙來泰國打工走净, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人孤里。 一個月前我還...
    沈念sama閱讀 48,138評論 3 370
  • 正文 我出身青樓伏伯,卻偏偏與公主長得像,于是被迫代替她去往敵國和親捌袜。 傳聞我的和親對象是個殘疾皇子说搅,可洞房花燭夜當(dāng)晚...
    茶點故事閱讀 44,927評論 2 355

推薦閱讀更多精彩內(nèi)容

  • 16.批量歸一化和殘差網(wǎng)絡(luò) 批量歸一化(BatchNormalization) BN是由Google于2015年提...
    0d3382cf56eb閱讀 823評論 0 0
  • GAN 由Goodfellow等人于2014年引入的生成對抗網(wǎng)絡(luò)(GAN)是用于學(xué)習(xí)圖像潛在空間的VAE的替代方案...
    七八音閱讀 7,774評論 1 3
  • 想從Tensorflow循環(huán)生成對抗網(wǎng)絡(luò)開始。但是發(fā)現(xiàn)從最難的內(nèi)容入手還是虏等?太復(fù)雜了所以搜索了一下他的始祖也就是深...
    Feather輕飛閱讀 5,043評論 1 4
  • (轉(zhuǎn))生成對抗網(wǎng)絡(luò)(GANs)最新家譜:為你揭秘GANs的前世今生 生成對抗網(wǎng)絡(luò)(GAN)一...
    Eric_py閱讀 4,297評論 0 4
  • 推薦指數(shù): 6.0 書籍主旨關(guān)鍵詞:特權(quán)蜓堕、焦點、注意力博其、語言聯(lián)想套才、情景聯(lián)想 觀點: 1.統(tǒng)計學(xué)現(xiàn)在叫數(shù)據(jù)分析,社會...
    Jenaral閱讀 5,721評論 0 5