WGAN-GP 原理和代碼分析

生成對(duì)抗模型(GAN)簡(jiǎn)介可以參考:http://www.reibang.com/p/34d9d0755f51
這里介紹的WGAN汉嗽,將損失函數(shù)進(jìn)行了正則化

文章鏈接:《Improving protein function prediction with synthetic feature samples created by generative adversarial networks

這里作者提出了一種新的損失函數(shù)定義模式绅络,對(duì)于普通GAN的損失函數(shù)定義:


\widetilde{x} 由生成器 G 產(chǎn)生的 fake data,x 代表 real data炊甲,那么對(duì)于WGAN-GP 它的損失函數(shù)為:


其中 \widetilde{x} 由生成器 G 產(chǎn)生的 fake data锰什,x 代表 real data迈螟,\widehat{x} 在本研究中代表:
α 代表隨機(jī)的參數(shù)弄息,λ 這一項(xiàng)代表正則項(xiàng)作為梯度約束

代碼部分:https://github.com/psipred/FFPredGAN/blob/master/src/Generating_Synthetic_Positive_Samples_FFPred-GAN.py

這里只重點(diǎn)講講目標(biāo)函數(shù)約束的代碼部分:

ITERS = 100000 
CRITIC_ITERS = 5

# 訓(xùn)練模型
for iteration in range(ITERS):
    for p in netD.parameters():  
        p.requires_grad = True  

    data = inf_train_gen()
    real_data = torch.FloatTensor(data)
    real_data_v = autograd.Variable(real_data)
    
    noise = torch.randn(BATCH_SIZE, 258)
    noisev = autograd.Variable(noise, volatile=True)  
    fake = autograd.Variable(netG(noisev, real_data_v).data)

    fake_output=fake.data.cpu().numpy()
    
    # 訓(xùn)練判別器 netD
    for iter_d in range(CRITIC_ITERS):
        # 梯度清零
        netD.zero_grad()

        D_real, hidden_output_real_1, hidden_output_real_2, hidden_output_real_3 = netD(real_data_v)

        # 高維張量取平均值,變成一個(gè)標(biāo)量
        D_real = D_real.mean()

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise, volatile=True)  
        fake = autograd.Variable(netG(noisev, real_data_v).data)
        
        inputv = fake
        D_fake, hidden_output_fake_1, hidden_output_fake_2, hidden_output_fake_3 = netD(inputv)
       
        # 高維張量取平均值绵疲,變成一個(gè)標(biāo)量
        D_fake = D_fake.mean()
        
        # 計(jì)算正則項(xiàng)
        gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
        
        # WGAN-GP 損失函數(shù)
        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake

        # 反向傳播損失函數(shù)
        D_cost.backward()
        # 迭代更新
        optimizerD.step()

    # 訓(xùn)練生成器 netG
    for p in netD.parameters():
            p.requires_grad = False

        netG.zero_grad()
        real_data = torch.Tensor(data)
        real_data_v = autograd.Variable(real_data)

        noise = torch.randn(BATCH_SIZE, 258)
        noisev = autograd.Variable(noise)
        fake = netG(noisev, real_data_v)
        G, hidden_output_ignore_1, hidden_output_ignore_2, hidden_output_ignore_3 = netD(fake)

        G = G.mean()
        G_cost = -G
        # 反向傳播損失函數(shù)
        G_cost.backward()
        # 迭代更新
        optimizerG.step()

計(jì)算gradient_penalty的代碼為:

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda() if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates, hidden_output_1, hidden_output_2, hidden_output_3 = netD(interpolates) 
    
    # 求梯度
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    
    # 正則項(xiàng)哲鸳,二階范數(shù)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市盔憨,隨后出現(xiàn)的幾起案子徙菠,更是在濱河造成了極大的恐慌,老刑警劉巖郁岩,帶你破解...
    沈念sama閱讀 206,013評(píng)論 6 481
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件婿奔,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡问慎,警方通過查閱死者的電腦和手機(jī)萍摊,發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 88,205評(píng)論 2 382
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來如叼,“玉大人冰木,你說我怎么就攤上這事∞闭” “怎么了片酝?”我有些...
    開封第一講書人閱讀 152,370評(píng)論 0 342
  • 文/不壞的土叔 我叫張陵囚衔,是天一觀的道長(zhǎng)挖腰。 經(jīng)常有香客問我,道長(zhǎng)练湿,這世上最難降的妖魔是什么猴仑? 我笑而不...
    開封第一講書人閱讀 55,168評(píng)論 1 278
  • 正文 為了忘掉前任,我火速辦了婚禮肥哎,結(jié)果婚禮上辽俗,老公的妹妹穿的比我還像新娘。我一直安慰自己篡诽,他們只是感情好崖飘,可當(dāng)我...
    茶點(diǎn)故事閱讀 64,153評(píng)論 5 371
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著杈女,像睡著了一般朱浴。 火紅的嫁衣襯著肌膚如雪吊圾。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 48,954評(píng)論 1 283
  • 那天翰蠢,我揣著相機(jī)與錄音项乒,去河邊找鬼。 笑死梁沧,一個(gè)胖子當(dāng)著我的面吹牛檀何,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播廷支,決...
    沈念sama閱讀 38,271評(píng)論 3 399
  • 文/蒼蘭香墨 我猛地睜開眼频鉴,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了恋拍?” 一聲冷哼從身側(cè)響起砚殿,我...
    開封第一講書人閱讀 36,916評(píng)論 0 259
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎芝囤,沒想到半個(gè)月后似炎,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 43,382評(píng)論 1 300
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡悯姊,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 35,877評(píng)論 2 323
  • 正文 我和宋清朗相戀三年羡藐,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片悯许。...
    茶點(diǎn)故事閱讀 37,989評(píng)論 1 333
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡仆嗦,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出先壕,到底是詐尸還是另有隱情瘩扼,我是刑警寧澤,帶...
    沈念sama閱讀 33,624評(píng)論 4 322
  • 正文 年R本政府宣布垃僚,位于F島的核電站集绰,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏谆棺。R本人自食惡果不足惜栽燕,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,209評(píng)論 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望改淑。 院中可真熱鬧碍岔,春花似錦、人聲如沸朵夏。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,199評(píng)論 0 19
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)仰猖。三九已至捏肢,卻和暖如春掠河,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背猛计。 一陣腳步聲響...
    開封第一講書人閱讀 31,418評(píng)論 1 260
  • 我被黑心中介騙來泰國(guó)打工唠摹, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人奉瘤。 一個(gè)月前我還...
    沈念sama閱讀 45,401評(píng)論 2 352
  • 正文 我出身青樓勾拉,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親盗温。 傳聞我的和親對(duì)象是個(gè)殘疾皇子藕赞,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 42,700評(píng)論 2 345

推薦閱讀更多精彩內(nèi)容