變分自編碼,英文是Variational AutoEncoder肴甸,簡(jiǎn)稱VAE寂殉。它是包含隱變量的一種模型
變分自編碼器與對(duì)抗生成網(wǎng)絡(luò)類(lèi)似塑径,均是為了解決數(shù)據(jù)生成問(wèn)題而生的。在自編碼器結(jié)構(gòu)中荣赶,通常需要一個(gè)輸入數(shù)據(jù)捺球,而且所生成的數(shù)據(jù)與輸入數(shù)據(jù)是相同的。但是通常希望生成的數(shù)據(jù)具有一定程度的不同柳沙,這需要輸入隨機(jī)向量并且模型能夠?qū)W習(xí)生成圖像的風(fēng)格化特點(diǎn),因此在后續(xù)研究中以隨機(jī)化向量作為輸入生成特定樣本的對(duì)抗生成網(wǎng)絡(luò)結(jié)構(gòu)便產(chǎn)生了。變分自編碼器同樣的以特定分布的隨機(jī)樣本作為輸入澳泵,并且可以生成相應(yīng)的圖像,從此方面來(lái)看其與對(duì)抗生成網(wǎng)絡(luò)目標(biāo)是相似的兼呵。但是變分自編碼器不需要判別器兔辅,而是使用編碼器來(lái)估計(jì)特定分布』魑梗總體結(jié)構(gòu)來(lái)看與自編碼器結(jié)構(gòu)類(lèi)似维苔,但是中間傳遞向量為特定分布的隨機(jī)向量,這里需要特別區(qū)分:編碼器懂昂、解碼器介时、生成器和判別器
一. VAE原理
先假設(shè)一個(gè)隱變量Z的分布,構(gòu)建一個(gè)從Z到目標(biāo)數(shù)據(jù)X的模型凌彬,即構(gòu)建沸柔,使得學(xué)出來(lái)的目標(biāo)數(shù)據(jù)與真實(shí)數(shù)據(jù)的概率分布相近
VAE的結(jié)構(gòu)圖如下:
VAE對(duì)每一個(gè)樣本匹配一個(gè)高斯分布,隱變量就是從高斯分布中采樣得到的铲敛。對(duì)個(gè)樣本來(lái)說(shuō)褐澎,每個(gè)樣本的高斯分布假設(shè)為,問(wèn)題就在于如何擬合這些分布伐蒋。VAE構(gòu)建兩個(gè)神經(jīng)網(wǎng)絡(luò)來(lái)進(jìn)行擬合均值與方差工三。即,擬合的原因是這樣無(wú)需加激活函數(shù)
此外先鱼,VAE讓每個(gè)高斯分布盡可能地趨于標(biāo)準(zhǔn)高斯分布俭正。這擬合過(guò)程中的誤差損失則是采用KL散度作為計(jì)算,下面做詳細(xì)推導(dǎo):
VAE與同為生成模型的GMM(高斯混合模型)也有很相似型型,實(shí)際上VAE可看成是GMM的一個(gè)distributed representation
的版本段审。GMM是有限個(gè)高斯分布的隱變量的混合,而VAE可看成是無(wú)窮個(gè)隱變量的混合,VAE中的可以是高斯也可以是非高斯的
原始樣本數(shù)據(jù)的概率分布:
假設(shè)服從標(biāo)準(zhǔn)高斯分布寺枉,先驗(yàn)分布是高斯的抑淫,即。是兩個(gè)函數(shù)始苇, 分別是對(duì)應(yīng)的高斯分布的均值和方差,則就是在積分域上所有高斯分布的累加:
由于是已知的筐喳,未知催式,所以求解問(wèn)題實(shí)際上就是求這兩個(gè)函數(shù)。最開(kāi)始的目標(biāo)是求解避归,且希望越大越好荣月,這等價(jià)于求解關(guān)于最大對(duì)數(shù)似然:
而可變換為:
到這里我們發(fā)現(xiàn),第二項(xiàng)其實(shí)就是和的KL散度梳毙,即哺窄,因?yàn)镵L散度是大于等于0的,所以上式進(jìn)一步可寫(xiě)成:
這樣就找到了一個(gè)下界(lower bound)账锹,也就是式子的右項(xiàng)萌业,即:
原式也可表示成:
為了讓越大,目的就是要最大化它的這個(gè)下界
推到這里奸柬,可能會(huì)有個(gè)疑問(wèn):為什么要引入生年,這里的可以是任何分布?
實(shí)際上,因?yàn)楹篁?yàn)分布很難求(intractable)廓奕,所以才用來(lái)逼近這個(gè)后驗(yàn)分布抱婉。在優(yōu)化的過(guò)程中發(fā)現(xiàn),首先跟是完全沒(méi)有關(guān)系的懂从,只跟有關(guān)授段,調(diào)節(jié)是不會(huì)影響似然也就是的。所以番甩,當(dāng)固定住時(shí)侵贵,調(diào)節(jié)最大化下界,KL則越小缘薛。當(dāng)與不斷逼近后驗(yàn)分布時(shí)窍育,KL散度趨于為0,就和等價(jià)宴胧。所以最大化就等價(jià)于最大化
回顧:
顯然漱抓,最大化就是等價(jià)于最小化和最大化。
第一項(xiàng)恕齐,最小化KL散度:前面已假設(shè)了是服從標(biāo)準(zhǔn)高斯分布的乞娄,且是服從高斯分布,于是代入計(jì)算可得:
對(duì)上式中的積分進(jìn)一步求解,實(shí)際就是概率密度仪或,而概率密度函數(shù)的積分就是1确镊,所以積分第一項(xiàng)等于;而又因?yàn)楦咚狗植嫉亩A矩就是范删,正好對(duì)應(yīng)積分第二項(xiàng)蕾域。又根據(jù)方差的定義可知,所以積分第三項(xiàng)為-1
最終化簡(jiǎn)得到的結(jié)果如下:
第二項(xiàng)到旦,最大化期望旨巷。也就是表明在給定(編碼器輸出)的情況下(解碼器輸出)的值盡可能高
- 第一步,利用encoder的神經(jīng)網(wǎng)絡(luò)計(jì)算出均值與方差添忘,從中采樣得到采呐,這一過(guò)程就對(duì)應(yīng)式子中的
- 第二步,利用decoder的計(jì)算的均值方差昔汉,讓均值(或也考慮方差)越接近,則產(chǎn)生的幾率越大靶病,對(duì)應(yīng)于式子中的最大化這一部分
重參數(shù)技巧:
最后模型在實(shí)現(xiàn)的時(shí)候,有一個(gè)重參數(shù)技巧沪停,就是想從高斯分布中采樣時(shí),其實(shí)是相當(dāng)于從中采樣一個(gè),然后再來(lái)計(jì)算 躬它。這么做的原因是祖娘,采樣這個(gè)操作是不可導(dǎo)的失尖,而采樣的結(jié)果是可導(dǎo)的,這樣做個(gè)參數(shù)變換渐苏,這個(gè)就可以參與梯度下降掀潮,模型就可以訓(xùn)練了
class VAE(nn.Module):
"""Implementation of VAE(Variational Auto-Encoder)"""
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 200)
self.fc2_mu = nn.Linear(200, 10)
self.fc2_log_std = nn.Linear(200, 10)
self.fc3 = nn.Linear(10, 200)
self.fc4 = nn.Linear(200, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z):
h3 = F.relu(self.fc3(z))
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x):
mu, log_std = self.encode(x)
z = self.reparametrize(mu, log_std)
recon = self.decode(z)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss
二. CVAE原理
在條件變分自編碼器(CVAE)中,模型的輸出就不是了琼富,而是對(duì)應(yīng)于輸入的任務(wù)相關(guān)數(shù)據(jù)仪吧,不過(guò)套路和VAE是一樣的,這次的最大似然估計(jì)變成了鞠眉,即::
則ELBO(Empirical Lower Bound)
為薯鼠,進(jìn)一步:
網(wǎng)絡(luò)結(jié)構(gòu)包含三個(gè)部分:
- 先驗(yàn)網(wǎng)絡(luò),如下圖(b)所示
- Recognition網(wǎng)絡(luò)械蹋, 如下圖(c)所示D
- ecoder網(wǎng)絡(luò)出皇,如下圖(b)所示
class CVAE(nn.Module):
"""Implementation of CVAE(Conditional Variational Auto-Encoder)"""
def __init__(self, feature_size, class_size, latent_size):
super(CVAE, self).__init__()
self.fc1 = nn.Linear(feature_size + class_size, 200)
self.fc2_mu = nn.Linear(200, latent_size)
self.fc2_log_std = nn.Linear(200, latent_size)
self.fc3 = nn.Linear(latent_size + class_size, 200)
self.fc4 = nn.Linear(200, feature_size)
def encode(self, x, y):
h1 = F.relu(self.fc1(torch.cat([x, y], dim=1))) # concat features and labels
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z, y):
h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) # concat latents and labels
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x, y):
mu, log_std = self.encode(x, y)
z = self.reparametrize(mu, log_std)
recon = self.decode(z, y)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss