簡(jiǎn)介:
用Python框架實(shí)現(xiàn)產(chǎn)生式模型逼争,最基本要實(shí)現(xiàn)的,就是概率函數(shù)劝赔。這類(lèi)函數(shù)的實(shí)現(xiàn)誓焦,包含兩個(gè)要素
- 確定性的Python代碼
- 隨機(jī)數(shù)產(chǎn)生器
具體來(lái)說(shuō),隨機(jī)函數(shù)可以是任何具備__call__()
方法的Python對(duì)象着帽,或者Pytorch框架里的nn.Module
方法杂伟。
在本教程里,所有的隨機(jī)函數(shù)被叫做模型仍翰,表達(dá)模型的方法赫粥,和正常的Python方法沒(méi)有區(qū)別。
安裝:
pip install pyro-ppl
或者:
git clone https://github.com/uber/pyro.git
cd pyro
python setup.py install
注意:Pyro只支持python3.*,不支持Python 2.*!
最基礎(chǔ)的隨機(jī)函數(shù)模型
Pyro利用了Pytorch的distribution libraray予借。舉個(gè)例子越平,我們想采樣x
服從標(biāo)準(zhǔn)正態(tài)分布,我們這樣做:
import torch
import pyro
pyro.set_rng_seed(101)
loc = 0. # 均值為0
scale = 1. #標(biāo)準(zhǔn)差為1
normal = torch.distributions.Normal(loc, scale) #構(gòu)造一個(gè)正態(tài)分布的對(duì)象
x = normal.rsample() #從N(0,1)采樣
print("sample", x)
print("log prob", normal.log_prob(x)) # 采樣分?jǐn)?shù)
##結(jié)果:
##sample tensor(-1.3905)
##log prob tensor(-1.8857)
normal = pyro.distributions.Normal(loc, scale) #構(gòu)造一個(gè)正態(tài)分布的對(duì)象
x = normal.rsample() #從N(0,1)采樣
print("sample", x)
print("log prob", normal.log_prob(x)) # 采樣分?jǐn)?shù)
##結(jié)果:
##sample tensor(1.3834)
##log prob tensor(-1.8759)
這里torch.distributions.Normal
是Distribution
類(lèi)的子類(lèi)灵迫,它已經(jīng)實(shí)現(xiàn)了采樣和打分功能秦叛。Pyro的庫(kù)pyro.distributions
打包了torch.distributions
的方法,這樣做科研利用Pytorch的數(shù)學(xué)方法和自動(dòng)求導(dǎo)功能龟再。
一個(gè)簡(jiǎn)單的例子
假設(shè)我們手里有一批數(shù)據(jù)书闸,記錄了日常氣溫和陰晴尼变。我們希望研究氣溫和陰晴的關(guān)系利凑。于是我們首先用Pytorch構(gòu)造如下函數(shù),來(lái)描述數(shù)據(jù)的生成過(guò)程:
def weather():
cloudy_ = torch.distributions.Bernoulli(0.3).sample()
cloudy = 'cloudy' if cloudy_.item() == 1. else 'sunny'
mean_temp = {'cloudy': 55., 'sunny': 75.}[cloudy] # 注意嫌术,溫度是華氏單位
scale_temp = {'cloudy': 10., 'sunny': 15.}[cloudy]
temp = torch.distributions.Normal(mean_temp, scale_temp).rsample()
return cloudy, temp.item()
我們逐行解釋這個(gè)函數(shù)哀澈。在第二行,我們定義一個(gè)二值變量cloudy_
度气,其伯努利參數(shù)0.3
表示其值為1的概率割按。第三行cloudy
為字符串變量,取值為‘cloudy’(陰天)或者‘sunny’(晴天)磷籍。模型規(guī)定适荣,30%的可能性是陰天,70%的可能性是晴天院领。第四行規(guī)定弛矛,陰天的平均溫度為華氏55度(約12.78°C),晴天為華氏75度(約23.89°C)比然。第五行規(guī)定了標(biāo)準(zhǔn)差丈氓。
下面我們使用Pyro重寫(xiě)上面的函數(shù)。
def weather():
cloudy_ = pyro.sample('cloudy', pyro.distributions.Bernoulli(0.3))
cloudy = 'cloudy' if cloudy_.item() == 1.0 else 'sunny'
mean_temp = {'cloudy': 55.0, 'sunny': 75.0}[cloudy]
scale_temp = {'cloudy': 10.0, 'sunny': 15.0}[cloudy]
temp = pyro.sample('temp', pyro.distributions.Normal(mean_temp, scale_temp))
return cloudy, temp.item()
for _ in range(3):
print(weather())
###結(jié)果:
#('cloudy', 64.5440444946289)
#('sunny', 94.37557983398438)
#('sunny', 72.5186767578125)
從表面來(lái)看,我們僅僅利用了pyro.sample
万俗,然而并非如此湾笛。假如我們想問(wèn),得到采樣為70度闰歪,有多大概率是陰天嚎研?利用Pyro回答這個(gè)問(wèn)題的方法,為了不破壞教程的連續(xù)性库倘,我們將在下個(gè)教程中講解嘉赎。
框架的通用性:遞歸隨機(jī)函數(shù)、高階隨機(jī)函數(shù)于樟、隨機(jī)數(shù)控制流
我們定義下面的簡(jiǎn)單模型:
def ice_cream_scales():
cloudy, temp = weather()
expected_sales = 200. if cloudy == 'sunny' and temp > 80. else 50.
ice_cream = pyro.sample('ice_cream', pyro.distributions.Normal(expected_sales, 10.))
return ice_cream
到此為止都是令人滿(mǎn)意的公条。我們要問(wèn),Pyro能否涵蓋更加復(fù)雜的模型迂曲?答案是肯定的靶橱。
由于Pyro是基于Python代碼的,它可以定義任意復(fù)雜的控制流路捧,哪怕其中包含了隨機(jī)數(shù)关霸,也沒(méi)有問(wèn)題。例如杰扫,我們可以定義遞歸的隨機(jī)函數(shù)队寇,每次我們輸入給pyro.sample
確定性的采樣。再如章姓,我們定義幾何分布佳遣,即計(jì)數(shù)實(shí)驗(yàn)失敗的次數(shù),直到實(shí)驗(yàn)成功為止:
def geometric(p, t=None):
if t is None:
t = 0
x = pyro.sample('x_{}'.format(t), pyro.distributions.Bernoulli(p))
if x.item() == 1.:
return 0
else:
return 1 + geometric(p, t + 1)
print(geometric(0.5))
# 結(jié)果凡伊,注意該結(jié)果每次采樣不一定相同
# 3
注意零渐,在上面的geometric()
函數(shù)里,系統(tǒng)將動(dòng)態(tài)生成諸如x_0
系忙、x_1
這樣的變量诵盼。
我們還可以把別的隨機(jī)函數(shù)的結(jié)果作為新定義函數(shù)的變量,或者創(chuàng)造一個(gè)新函數(shù)银还。請(qǐng)看下面的例子:
def normal_product(loc, scale):
z1 = pyro.sample('z1', pyro.distributions.Normal(loc, scale))
z2 = pyro.sample('z2', pyro.distributions.Normal(loc, scale))
y = z1 * z2
return y
def make_normal_normal():
mu_latent = pyro.sample('mu_latent'', pyro.distributions.Normal(0., 1.))
fn = lambda scale: normal_product(mu_latent, scale)
return fn
在make_normal_normal()
中风宁,其中一個(gè)輸入是隨機(jī)函數(shù),該隨機(jī)函數(shù)包含3個(gè)隨機(jī)變量蛹疯。
Pyro支持Python代碼的各種形式:循環(huán)戒财,遞歸,高階函數(shù)苍苞,等等固翰。這意味著Pyro具有通用性狼纬,更由于Pyro基于Pytorch而可以靈活地使用GPU加速。我們將在后面的教程中骂际,逐漸介紹Pyro的強(qiáng)大用法疗琉。