好久沒看的VAE又不太記得了吵取,重新梳理一下思路,在S同學(xué)的指導(dǎo)下锯厢,又有了一些新的理解皮官。之前寫過一篇關(guān)于VAE的入門教程,但是感覺還不夠簡練实辑,刪掉重新寫一個(gè)哈哈哈臣疑。這篇文章主要是從一個(gè)熟悉machine learning但是對于VAE一點(diǎn)都不懂的視角,進(jìn)行寫作的徙菠,并不涉及很多復(fù)雜的理論知識讯沈,輔助理解為主。另外,上次那篇文章是先給結(jié)論缺狠,在慢慢講細(xì)節(jié)问慎。這篇文章將會循序漸進(jìn)逐漸推導(dǎo)出VAE的各種Trick存在的必要性。
VAE 入門
我們首先要明確一點(diǎn)挤茄,VAE是一個(gè)生成式的模型如叼,什么是生成式的模型?簡單來說穷劈,就是可以用來生成數(shù)據(jù)的模型笼恰。怎么樣才能生成數(shù)據(jù)呢?就是我們是知道數(shù)據(jù)的分布的歇终?有了這個(gè)分布之后社证,我們就可以從這個(gè)分布中采樣,獲得新的數(shù)據(jù)评凝。
這個(gè)思路好像很簡單啊追葡,但是問題是這個(gè)是怎么得到的。有很多方法啊奕短,其中包含這樣兩大類:1. 基于概率的宜肉,如MCMC, Variational Inference等。以及2. 基于機(jī)器學(xué)習(xí)的翎碑。前面提到過谬返,本文主要面向的是對于機(jī)器學(xué)習(xí)比較熟悉的人,所以這里就對概率方法不多說了日杈,主要講一下機(jī)器學(xué)習(xí)的思路朱浴。
其實(shí)機(jī)器學(xué)習(xí)來解決這種問題的思路是很清晰的,大多數(shù)的機(jī)器學(xué)習(xí)問題都有這樣一個(gè)思路达椰。我們想要優(yōu)化某個(gè)目標(biāo)O翰蠢,我們先對這個(gè)問題建個(gè)模型,模型可以表示為某個(gè)數(shù)學(xué)表達(dá)式 啰劲,其中是參數(shù)梁沧。我們用數(shù)據(jù)去訓(xùn)練這個(gè)模型,然后根據(jù)目標(biāo)O蝇裤,去調(diào)整我們的參數(shù)廷支,我們希望訓(xùn)練結(jié)束的時(shí)候,能夠找到一組最優(yōu)的栓辜。對應(yīng)到我們這個(gè)生成式的問題恋拍,我們希望能夠生成一個(gè)新的數(shù)據(jù),那么我們構(gòu)造一個(gè)模型 藕甩,我們的目標(biāo)呢就是這個(gè)生成的數(shù)據(jù)越真越好施敢,就是在眾多的中,我們希望能夠找到一個(gè)最好的,能夠讓這個(gè)數(shù)據(jù)存在的概率越大越好僵娃。有了目標(biāo)概作,我們就能計(jì)算出損失函數(shù),然后就是利用梯度下降默怨,逐步調(diào)整參數(shù)讯榕,最終找到最優(yōu)。
前文這個(gè)過程好像很熟悉匙睹,但是存在幾個(gè)問題:
- 我們建模其實(shí)是根據(jù)我們的assumption來的愚屁,我們的模型結(jié)構(gòu),初始參數(shù)設(shè)置都是根據(jù)我們的assumption來的痕檬, 但是我們的assumption有可能是錯(cuò)的霎槐,而且很有可能是錯(cuò)的。因此引入過多或者過強(qiáng)的assumption都會導(dǎo)致我們的模型效果很差谆棺。
- 因?yàn)槲覀兊腶ssumption和真實(shí)數(shù)據(jù)分布存在偏差栽燕,相應(yīng)的罕袋,我們在優(yōu)化的過程中改淑,很容易陷入到局部最優(yōu)中。
- 如果我們直接采用建模的方式來解決生成式問題浴讯,那么我們通常需要構(gòu)造一個(gè)相對復(fù)雜的模型朵夏,或者說參數(shù)很多的模型,來獲得較大的Capacity榆纽。這樣就導(dǎo)致我們優(yōu)化的過程非常的耗時(shí)(Computationally Expensive)仰猖。
上述三個(gè)問題的存在,讓我們對直接建模這個(gè)思路產(chǎn)生了動搖奈籽,至少直接建模并不適用于所有的場景饥侵。所以我們在直接建模的基礎(chǔ)上做出修改,引入了隱變量的概念衣屏。我們認(rèn)為數(shù)據(jù)的生成是受到隱變量的影響的躏升。比如手寫數(shù)字生成的任務(wù),我們在生成數(shù)字的時(shí)候狼忱,會首先考慮膨疏,我們要生成的是數(shù)字幾啊,因?yàn)槲覀冎挥?0種數(shù)字可以生成钻弄,這個(gè)數(shù)字幾就是我們的隱變量佃却。有了這個(gè)隱變量,我們就不再是漫天生成數(shù)字了窘俺,我們只有10個(gè)方向去生成饲帅,這大大的縮小了我們的生成空間,降低了計(jì)算量。
引入隱變量洒闸,用數(shù)學(xué)公式可以表示為:
模型的學(xué)習(xí)過程也因?yàn)殡[變量的引入發(fā)生了改變染坯。我們最終的目標(biāo)還是要計(jì)算,我們對隱變量的概率建個(gè)模丘逸,參數(shù)是单鹿。我們要找到能讓值最大的參數(shù)。但是現(xiàn)在是要把隱變量的所有可能取值都找到深纲,然后求一個(gè)上面這樣的積分仲锄,來確定的值,從而通過比較不同的對應(yīng)的的值來確定哪個(gè)才是最合適的湃鹊。是不是覺得天衣無縫儒喊?對于手寫數(shù)字來說,我們的隱變量的取值只有10個(gè)币呵,所以這個(gè)積分就退化成了只有10項(xiàng)的求和怀愧。但是對于很多其他問題,這個(gè)隱變量的取值就有可能變得非常多余赢,又不太好做了芯义。所以我們的做法是不去計(jì)算積分了,我們做了一步近似妻柒。我們就采樣一個(gè)隱變量扛拨,我們希望挑出來的參數(shù)能夠在這一個(gè)隱變量上表現(xiàn)好就行了。What?這樣的近似是不是差的太多举塔?這樣真的呆膠布嗎绑警?答案是肯定的。我們雖然用單個(gè)變量代替了積分央渣,或者說计盒,代替了期望值,但是我們機(jī)器學(xué)習(xí)的過程是在不斷的迭代的隨機(jī)過程(Stochastic process)芽丹。簡單來說北启,就是我們會找一個(gè)又一個(gè)的sample,重復(fù)的進(jìn)行優(yōu)化志衍,理論上講依然能夠得到最優(yōu)解(可以參考機(jī)器學(xué)習(xí)的學(xué)習(xí)理論)暖庄。總而言之楼肪,就是這樣用單樣本代替期望是可行的培廓。
這個(gè)過程是不是聽起來又是很合理?但是存在一個(gè)問題春叫,我們說想要去采樣隱變量肩钠,但是從什么分布里采樣泣港?從?可以嗎价匠?可以当纱。但是我們再思考一個(gè)問題,真的所有的隱變量都是平等的嗎踩窖?回到手寫數(shù)字的例子坡氯,如果我們現(xiàn)在要生成的是數(shù)字7,那么隱變量如果是0洋腮,8這種帶圓圈的概率是不是不大箫柳。假如我們采樣到了一個(gè)隱變量代表的是數(shù)字0,那么這一次采樣是不是相當(dāng)于浪費(fèi)了啥供?你本來就不怎么能指導(dǎo)我做這一次生成呀悯恍。所以,為了減少這樣的無效采樣伙狐,從而進(jìn)一步的降低計(jì)算量涮毫,我們并不是從中采樣的,而是從中采樣的贷屎。
好了罢防,現(xiàn)在又有了一個(gè)新問題,這個(gè)我們知道嘛豫尽?答案是不知道篙梢,不知道怎么辦顷帖?不知道那就去求美旧?用什么樣的方法去求?用機(jī)器學(xué)習(xí)的去求贬墩,和前面對建模一樣榴嗅,我們在這里對建模為,然后求個(gè)陶舞,再梯度下降去優(yōu)化它嗽测。常見的做法是把建模為一個(gè)正態(tài)分布:
講到這里VAE的主體框架已經(jīng)出來了:
我們用Q采樣出來一個(gè)隱變量,然后我們根據(jù)這個(gè)隱變量肿孵,利用生成新的圖片唠粥。我們從這個(gè)過程中,計(jì)算損失值停做,通過梯度下降的方式晤愧,不斷優(yōu)化和的參數(shù),從而我們能夠生成越來越好的圖片
這個(gè)框架到目前為止已經(jīng)可以說是相對很完整了蛉腌,里面呢有兩個(gè)函數(shù)需要去優(yōu)化:和官份。這兩個(gè)函數(shù)我們都用神經(jīng)網(wǎng)絡(luò)去建模只厘,但是我們依然需要做一件事,就是去定義一個(gè)損失函數(shù)舅巷。我們這樣思考一下羔味,我們定義損失函數(shù),是為了能夠讓和更好钠右。我們首先來考慮赋元,如何讓變得更好?我們回憶一下飒房,是我們定義出來用來估計(jì)分布的们陆,最好的當(dāng)然就是能夠跟一毛一樣啦。那么我們很自然的就想要把目標(biāo)函數(shù)情屹,或者說損失值定義成這兩個(gè)分布之間的差距了坪仇。而計(jì)算分布差距,最常用的metric之一就是KL 散度垃你。所以椅文,我們想要讓這個(gè)公式最小化:
有人可能要說啦:這,怎么最小化惜颇?我們不知道這個(gè)是啥我們才估計(jì)的呀皆刺。沒錯(cuò),不過我們可以試著把這個(gè)公式變個(gè)形式凌摄,試試看:
$$
\begin{aligned}
KL(Q(z|X);||;P(z|X)) & =E(logQ(z|X)-logP(z|X))\
&=E(logQ(z|X))-E(logP(z|X))\
&=E(logQ(z|X))-E(log(\frac{P(X|z)P(z)}{P(X)}))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+E(logP(X))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+logP(X)\
&=KL(logQ(z|X)||P(z))-E(logP(X|z))+logP(X)
\end{aligned}
logP(X)-KL(Q(z|X);||;P(z|X))=E(logP(X|z))-KL(logQ(z|X)||P(z))
$$
上面這個(gè)公式的轉(zhuǎn)換過程中,沒有用到什么很特別的技巧锨亏,主要就是貝葉斯公式套進(jìn)去了一下痴怨,我就不多說了。重點(diǎn)來看一下最終的公式形式器予,我們發(fā)現(xiàn)公式左邊浪藻,恰好是我們想要優(yōu)化的目標(biāo),當(dāng)我們讓左邊最大化的時(shí)候乾翔,不僅和越來越接近爱葵,我們的也越來越大,也就是我們的函數(shù)的參數(shù)也越來越好反浓。一石二鳥萌丈!本來我們只是在考慮的問題,現(xiàn)在連帶著把的問題也解決了雷则。我們看一下右邊是我們能計(jì)算的東西嗎辆雾?答案是肯定的,第一項(xiàng)是我們建模的函數(shù)巧婶,直接可以得到結(jié)果乾颁。第二項(xiàng)中的是我們建模的函數(shù)涂乌,呢是隱變量的分布,我們可以把這個(gè)分布定義為一個(gè)標(biāo)準(zhǔn)正態(tài)分布英岭,因?yàn)閺倪@個(gè)標(biāo)準(zhǔn)正態(tài)分布湾盒,理論上講我們可以映射到任意的輸出空間上(當(dāng)然標(biāo)準(zhǔn)正態(tài)分布也只是一個(gè)選項(xiàng),很多別的分布都是可以的)诅妹。所以公式右邊的兩項(xiàng)都是可以求的罚勾,我們就可以把這個(gè)作為我們的優(yōu)化目標(biāo)。至此吭狡,我們的計(jì)算過程算是完善了尖殃,可以用下面這樣一張圖表示
我們首先有一個(gè)輸入,對應(yīng)到我們的例子里就是一張手寫數(shù)字圖片划煮。我們將這個(gè)輸入到自己定義的函數(shù)中送丰,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=Q" alt="Q" mathimg="1">是個(gè)正態(tài)分布函數(shù),所以我們的做法是用兩個(gè)神經(jīng)網(wǎng)絡(luò)(管他們叫encoder)分別去計(jì)算期望和方差 弛秋。有了這個(gè)分布之后器躏,我們就可以采樣出來一個(gè)隱變量的樣本,然后用這個(gè)樣本在通過神經(jīng)網(wǎng)絡(luò) (管他叫decoder)去生成新的數(shù)據(jù)樣本蟹略。在這個(gè)計(jì)算過程中登失,我們在encoder里計(jì)算了一個(gè)損失函數(shù),在decoder里計(jì)算了一個(gè)損失函數(shù)(在數(shù)據(jù)為正態(tài)分布的時(shí)候等價(jià)于)挖炬。
上面描述的過程更完整了揽浙,不過還有一個(gè)問題,就是中間采樣的那一步意敛。我們知道馅巷,神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)依賴于梯度下降,梯度下降就要求整個(gè)損失函數(shù)的梯度鏈條是存在的空闲,或者說參數(shù)是可導(dǎo)的令杈。我們看decoder的這個(gè)損失函數(shù)走敌,在計(jì)算碴倾,除了涉及分布以外,還依賴于隱變量掉丽,隱變量又是從分布中采樣出來的跌榔, 所以我們在對decoder的損失函數(shù)進(jìn)行梯度下降的時(shí)候,是要對的參數(shù)也梯度下降的捶障。但是因?yàn)橹虚g這一步采樣僧须,我們的梯度斷掉了。采樣還怎么知道是什么梯度项炼?所以這里用到了一個(gè)小trick: Reparamterization担平。簡單來說就是我們不再是采樣了示绊,而是看做按照分布的期望和方差,加上一些小噪音暂论,生成出來的樣本面褐。也就是說:
其中這個(gè)是一個(gè)隨機(jī)噪聲,我們可以從一個(gè)標(biāo)準(zhǔn)正態(tài)分布中采樣得到取胎。
這種做法非常 好理解對吧展哭,我們的每個(gè)樣本都可以看作是這個(gè)樣本服從分布的期望,根據(jù)方差進(jìn)行波動的結(jié)果闻蛀。通過這種變換匪傍,原來斷掉的梯度鏈恢復(fù)啦,我們的梯度下降終于能夠進(jìn)行下去了觉痛。修正后的計(jì)算過程如下圖役衡。
以上就是VAE的主要內(nèi)容,這里額外說明一點(diǎn):從decoder的損失函數(shù)看薪棒,我們希望Q在估計(jì)隱變量的分布映挂,而隱變量的分布就是一個(gè)標(biāo)準(zhǔn)正態(tài)分布,所以我們在實(shí)際生成的過程中盗尸,不需要用到encoder柑船,只需要從標(biāo)準(zhǔn)正態(tài)分布里隨便采樣一個(gè)隱變量就能進(jìn)行生成了。
與自編碼模型Autoencoder比較
很多人可能很熟悉自編碼模型泼各,自編碼模型英文叫Autoencoder鞍时。而我們這個(gè)VAE呢,叫Variational Autoencoder扣蜻,聽起來好像關(guān)系很大逆巍,但是其實(shí)關(guān)系真的不是很大。只不過我們這個(gè)VAE呢也像Autoencoder一樣有一個(gè)encoder莽使,一個(gè)decoder锐极。但是Autoencoder對于隱變量沒有什么限制,它的過程就很簡單芳肌,就是從輸入計(jì)算一個(gè)隱變量灵再,然后再把映射到一個(gè)新的,損失函數(shù)只有一個(gè)亿笤,就是比較和計(jì)算一個(gè)重構(gòu)損失翎迁。這樣做呢并沒有很好的利用隱變量。但是它最大的缺點(diǎn)是生成出來的東西是和輸入的高度相關(guān)的净薛,并不能生成出來什么很新奇的玩意汪榔,所以Autoencoder一般只能用來做做降噪什么的,并不能真正用來做生成肃拜。但是VAE就不一樣了痴腌,前面我們講過雌团,VAE在生成階段是完全拋開了encoder的,隱變量是從標(biāo)準(zhǔn)正態(tài)分布里隨隨便便采樣出來的士聪,這就擺脫了對輸入的依賴辱姨,想怎么生成就怎么生成。
代碼實(shí)現(xiàn)
"""
@Time : 26/01/2023 @Software: PyCharm @File : model.py
"""
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, hidden_dim=2):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
def forward(self, x):
"""x:[N,1,28,28]"""
x = torch.flatten(x, start_dim=1) # [N,764]
x = self.linear1(x) # [N, 512]
x = F.relu(x)
return self.linear2(x) # [N,2]
class Decoder(nn.Module):
def __init__(self, hidden_dim):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(hidden_dim, 512)
self.linear2 = nn.Linear(512, 28 * 28)
def forward(self, x):
"""x:[N,2]"""
hidden = self.linear1(x) # [N, 512]
hidden = torch.relu(hidden)
hidden = self.linear2(hidden) # [N,764]
hidden = torch.sigmoid(hidden)
return torch.reshape(hidden, (-1, 1, 28, 28))
class AutoEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(AutoEncoder, self).__init__()
self.name = "ae"
self.encoder = Encoder(hidden_dim)
self.decoder = Decoder(hidden_dim)
def forward(self, x):
return self.decoder(self.encoder(x))
class VAEEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(VAEEncoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
self.linear3 = nn.Linear(512, hidden_dim)
self.noise_dist = torch.distributions.Normal(0, 1)
self.kl = 0
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = self.linear1(x)
x = torch.relu(x)
mu = self.linear2(x)
sigma = torch.exp(self.linear3(x))
hidden = mu + self.noise_dist.sample(mu.shape) * sigma
self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()
return hidden
class VAE(nn.Module):
def __init__(self, hidden_dim=2):
super(VAE, self).__init__()
self.name = "vae"
self.encoder = VAEEncoder(hidden_dim=hidden_dim)
self.decoder = Decoder(hidden_dim)
self.kl = 0
def forward(self, x):
hidden = self.encoder(x)
self.kl = self.encoder.kl
return self.decoder(hidden)
if __name__ == '__main__':
dataset = torchvision.datasets.MNIST("data", transform=torchvision.transforms.ToTensor(), download=True)
print(dataset[0][0].shape)
核心的VAE 模型代碼實(shí)現(xiàn)我貼在這里戚嗅,其余代碼已經(jīng)上傳到Github雨涛。