生成對(duì)抗模型(GAN)簡(jiǎn)介可以參考:http://www.reibang.com/p/34d9d0755f51
這里介紹的WGAN汉嗽,將損失函數(shù)進(jìn)行了正則化
這里作者提出了一種新的損失函數(shù)定義模式绅络,對(duì)于普通GAN的損失函數(shù)定義:
由生成器 G 產(chǎn)生的 fake data,x 代表 real data炊甲,那么對(duì)于WGAN-GP 它的損失函數(shù)為:
其中 由生成器 G 產(chǎn)生的 fake data锰什,x 代表 real data迈螟, 在本研究中代表:
這里只重點(diǎn)講講目標(biāo)函數(shù)約束的代碼部分:
ITERS = 100000
CRITIC_ITERS = 5
# 訓(xùn)練模型
for iteration in range(ITERS):
for p in netD.parameters():
p.requires_grad = True
data = inf_train_gen()
real_data = torch.FloatTensor(data)
real_data_v = autograd.Variable(real_data)
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
fake_output=fake.data.cpu().numpy()
# 訓(xùn)練判別器 netD
for iter_d in range(CRITIC_ITERS):
# 梯度清零
netD.zero_grad()
D_real, hidden_output_real_1, hidden_output_real_2, hidden_output_real_3 = netD(real_data_v)
# 高維張量取平均值,變成一個(gè)標(biāo)量
D_real = D_real.mean()
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise, volatile=True)
fake = autograd.Variable(netG(noisev, real_data_v).data)
inputv = fake
D_fake, hidden_output_fake_1, hidden_output_fake_2, hidden_output_fake_3 = netD(inputv)
# 高維張量取平均值绵疲,變成一個(gè)標(biāo)量
D_fake = D_fake.mean()
# 計(jì)算正則項(xiàng)
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
# WGAN-GP 損失函數(shù)
D_cost = D_fake - D_real + gradient_penalty
Wasserstein_D = D_real - D_fake
# 反向傳播損失函數(shù)
D_cost.backward()
# 迭代更新
optimizerD.step()
# 訓(xùn)練生成器 netG
for p in netD.parameters():
p.requires_grad = False
netG.zero_grad()
real_data = torch.Tensor(data)
real_data_v = autograd.Variable(real_data)
noise = torch.randn(BATCH_SIZE, 258)
noisev = autograd.Variable(noise)
fake = netG(noisev, real_data_v)
G, hidden_output_ignore_1, hidden_output_ignore_2, hidden_output_ignore_3 = netD(fake)
G = G.mean()
G_cost = -G
# 反向傳播損失函數(shù)
G_cost.backward()
# 迭代更新
optimizerG.step()
計(jì)算gradient_penalty的代碼為:
def calc_gradient_penalty(netD, real_data, fake_data):
alpha = torch.rand(BATCH_SIZE, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda() if use_cuda else alpha
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates, hidden_output_1, hidden_output_2, hidden_output_3 = netD(interpolates)
# 求梯度
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda() if use_cuda else torch.ones(
disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
# 正則項(xiàng)哲鸳,二階范數(shù)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty