太長不看版
- 模型在學(xué)習(xí)或調(diào)試過程中,設(shè)置
pyro.enable_validation(True)
泵三; - 張量的“廣播”耕捞,維度對齊自右向左:
torch.ones(3,4,5) + torch.ones(5)
; - 分布的尺寸
.sample().shape == batch_shape + event_shape
烫幕; - 分布的尺寸
.log_prob(x).shape == batch_shape
(沒有event_shape
)俺抽; - 使用
expand()
從Pyro中采樣一批數(shù)據(jù),或使用plate
機(jī)制自動(dòng)擴(kuò)展较曼; - 使用
my_dist.to_event(1)
聲明維度為依賴(dependent)磷斧,或說不獨(dú)立; - 使用
with pyro.plate('name', size):
聲明條件獨(dú)立; - 所有維度要么是依賴的弛饭,要么是條件獨(dú)立的冕末;
- 支持維度最左方的批處理,啟動(dòng)Pyro的并行處理侣颂;
- 使用負(fù)號(hào)指標(biāo)档桃,如
x.sum(-1)
,而不是x.sum(2)
憔晒; - 使用省略號(hào)藻肄,如
pixel = image[...,i, j]
; - 如果要枚舉
i,j
拒担,使用Vindex嘹屯,如pixel = Vindex(image)[...,i, j]
;
- 使用負(fù)號(hào)指標(biāo)档桃,如
- 在調(diào)試過程中从撼,使用Trace.format_shapes檢查維度定義州弟。
內(nèi)容列表
- 概率分布的形狀
-
plate
聲明條件獨(dú)立 - 在plate中部分采樣
- 并行地枚舉,張量的廣播
文件頭如下
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
pyro.enable_validation(True) #這句話最好加上
# 我們借助這個(gè)函數(shù)低零,檢查模型是否正確
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
概率分布的尺寸:batch_shape
和event_shape
Pytorch的張量Tensor
只有一個(gè)尺寸.shape
婆翔,但是Distributions
有兩個(gè)尺寸.batch_shape
和.event_shape
,分別表示條件獨(dú)立的隨機(jī)變量的大小和不獨(dú)立的隨機(jī)變量的大小毁兆。這兩部分構(gòu)成了一個(gè)樣本的尺寸浙滤。
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
由于計(jì)算對數(shù)似然只牽涉不獨(dú)立的變量阴挣,所以.log_prob()
方法后气堕,event_shape
就被縮并了,只剩下batch_shape
畔咧。
assert d.log_prob(x) == d.batch_shape
Distributions.sample()
方法可以輸入一個(gè)參數(shù)sample_shape
茎芭,作為獨(dú)立同分布(iid)的隨機(jī)變量,所以指定樣本大小的采樣誓沸,具有三個(gè)尺寸梅桩。
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
總結(jié)來說
| iid | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape
由上可推論,單變量隨機(jī)分布的event_shape
為0拜隧,因?yàn)槊看尾蓸又凳且粋€(gè)實(shí)數(shù)宿百,所以沒有不獨(dú)立的維度。像MultivariateNormal
多元高斯分布這樣的概率分布洪添,具有len(event_shape) == 1
垦页,因?yàn)槊總€(gè)采樣是一個(gè)向量,向量內(nèi)部是彼此依賴的(這里假定方差矩陣不是對角陣)干奢。而InverseWishart
逆威沙特分布具有len(event_shape) == 2
痊焊,等等。
關(guān)于概率分布尺寸的舉例
從單變量隨機(jī)分布開始。
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
# x是一個(gè)Pytorch張量薄啥,沒有batch_shape和event_shape
assert x.shape == ()
assert d.log_prob(x).shape == ()
通過傳入批參數(shù)辕羽,概率分布數(shù)據(jù)可以分成批推掸。
d = Bernoulli(0.5 * torch.ones(3, 4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
另一種成批的方法熟妓,是通過expand()
。不過只在參數(shù)的最左側(cè)維度獨(dú)立時(shí)才可使用豪墅。
d = Bernoulli(torch.tensor([.1, .2, .3, .4])).expand([3, 4])
# 注意expand的參數(shù)寫在一個(gè)列表中
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
多元高斯分布具有非空的event_shape
維度到逊。對于這些分布來說酌毡,.sample()
和.log_prob()
的維度是不同的。
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, ) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape
改變分布的維度獨(dú)立性
使用關(guān)鍵字.to_event(n)改變不獨(dú)立維度的情況蕾管,其中n
表示從右數(shù)第n維度開始枷踏,聲明為不獨(dú)立維度。
d = Bernoulli(0.5 * torch.ones(3, 4)).to_event(1)
assert d.batch_shape == (3, )
assert d.event_shape == (4, )
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, )
用戶必須小心地設(shè)置.to_event(n)
將batch_shape
縮減到合適的水平上掰曾,或者用pyro.plate
聲明維度的獨(dú)立性旭蠕。采樣仍舊會(huì)保留batch_shape+event_shape
的尺寸,然而log_prob(x)
只剩下batch_shape
旷坦。
聲明為不獨(dú)立掏熬,通常是安全的做法
在Pyro中,我們常常會(huì)聲明維度是不獨(dú)立的秒梅,哪怕它們實(shí)際上是獨(dú)立的旗芬。請看這個(gè)例子:
x = pyro.sample('x', dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)
上面的例子很容易就可以換成MultivariateNormal
分布。它將下面的寫法簡化了:
with pyro.plate('x_plate', 10):
x = pyro.sample('x', dist.Normal(0, 1)) #不需要expand捆蜀,系統(tǒng)自動(dòng)補(bǔ)全
assert x.shape == (10,)
實(shí)際上疮丛,這兩份代碼存在一點(diǎn)小小的差別。上面的代碼中辆它,Pyro默認(rèn)x之間是不獨(dú)立的誊薄,而下面的x則是條件獨(dú)立的。聲明為不獨(dú)立通常是安全的锰茉,這與圖論中的d-separation基于同一個(gè)原理:在不同節(jié)點(diǎn)之間多連一條邊呢蔫,即便節(jié)點(diǎn)之間不存在互相依賴關(guān)系,隨著優(yōu)化該邊的權(quán)重將越來越低飒筑,并不影響最終結(jié)果片吊;而本就存在依賴的節(jié)點(diǎn)少連了一條邊,任優(yōu)化策略多么高明协屡,都無法彌補(bǔ)這一錯(cuò)誤俏脊。這種錯(cuò)誤常見于平均場假設(shè)的模型中。不過著瓶,在實(shí)際執(zhí)行時(shí)联予,Pyro的SVI模塊在估算Normal
分布時(shí)啼县,兩份代碼的梯度估計(jì)值是一樣的。
通過plate
聲明維度為獨(dú)立
Pyro的上下文管理器pyro.plate能夠聲明特定的維度為獨(dú)立維度沸久。推斷算法可以利用這一獨(dú)立性做一些算法優(yōu)化季眷,例如構(gòu)造低方差的梯度估計(jì)器,再如求解推斷問題不在指數(shù)空間而在線性空間采樣卷胯。下面的例子中子刮,我們將聲明同一批次中的數(shù)據(jù)之間是互相獨(dú)立的。
最簡單的方法窑睁,是不聲明獨(dú)立維度挺峡,系統(tǒng)將缺省值-1——即最右邊的維度,作為獨(dú)立維度担钮。
with pyro.plate('my_plate'):
# 在該上下文中橱赠,維度-1將作為獨(dú)立維度
雖然效果是一樣的,不過我們?nèi)蕴岢脩魧懗鰜眢锝颍詭椭脩粽{(diào)試代碼:
with pyro.plate('my_plate', len(data)):
# 在該上下文中狭姨,維度-1將作為獨(dú)立維度
從Pyro 0.2版本開始,plate語句可以嵌套使用苏遥。比如聲明圖像的每個(gè)像素都是獨(dú)立的:
with pyro.plate('x_axis', 320):
# 在該上下文中饼拍,維度-1將作為獨(dú)立維度
with pyro.plate('y_axis', 200):
# 在該上下文中,維度-2和-1將作為獨(dú)立維度
我們習(xí)慣上總從右向左聲明獨(dú)立維度田炭,所以指標(biāo)是負(fù)的师抄,如-1,-2教硫,等等叨吮。
有時(shí)情況會(huì)更復(fù)雜一些,比如我們希望聲明一些噪聲依賴x
栋豫,另一些噪聲依賴y
挤安,還有一些噪聲依賴二者谚殊。這時(shí)Pyro允許用戶聲明多重獨(dú)立丧鸯,為了清楚地標(biāo)明獨(dú)立維度,必須指定dim
這一參數(shù)嫩絮,如下面的例子:
x_axis = pyro.plate('x_axis', dim = -2)
y_axis = pyro.plate('y_axis', dim = -3)
with x_axis:
# 在該上下文中丛肢,維度-2將作為獨(dú)立維度
with y_axis:
# 在該上下文中,維度-3將作為獨(dú)立維度
with x_axis, y_axis:
# 在該上下文中剿干,維度-2和-3將作為獨(dú)立維度
讓我們舉更多例子蜂怎,來展示plate
的用法。
def model1():
a = pyro.sample('a', Normal(0, 1))
b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
with pyro.plate('c_plate', 2):
c = pyro.sample('c', Normal(torch.zeros(2), 1))
with pyro.plate('d_plate', 3):
d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
assert a.shape == () # batch_shape == (), event_shape == ()
assert b.shape == (2,) # batch_shape == (), event_shape == (2,)
assert c.shape == (2,) # batch_shape == (2,), event_shape == ()
assert d.shape == (3, 4, 5) # batch_shape == (3), event_shape == (4, 5)
##
x_axis = pyro.plate('x_axis', 3, dim=-2)
y_axis = pyro.plate('y_axis', 2, dim=-3)
with x_axis:
x = pyro.sample('x', Normal(0, 1))
with y_axis:
y = pyro.sample('y', Normal(0, 1))
with x_axis, y_axis:
xy = pyro.sample('xy', Normal(0, 1))
z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1) # batch_shape == (3, 1), event_shape==()
assert y.shape == (2, 1, 1) # batch_shape == (2, 1, 1), event_shape==()
assert xy.shape == (2, 3, 1) # batch_shape == (2, 3, 1), event_shape==()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2, 3, 1), event_shape==(5,)
test_model(model1, model1, Trace_ELBO())
可視化如下:
batch dims | event dims
-----------+-----------
| a = sample("a", Normal(0, 1))
|2 b = sample("b", Normal(zeros(2), 1)
| .to_event(1))
| with plate("c", 2):
2| c = sample("c", Normal(zeros(2), 1))
| with plate("d", 3):
3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| .to_event(2))
|
| x_axis = plate("x", 3, dim=-2)
| y_axis = plate("y", 2, dim=-3)
| with x_axis:
3 1| x = sample("x", Normal(0, 1))
| with y_axis:
2 1 1| y = sample("y", Normal(0, 1))
| with x_axis, y_axis:
2 3 1| xy = sample("xy", Normal(0, 1))
2 3 1|5 z = sample("z", Normal(0, 1).expand([5])
| .to_event(1))
為了在調(diào)試代碼時(shí)方便地查看隨機(jī)變量的形狀置尔,Pyro提供了Trace.format_shapes()
方法杠步,在采樣點(diǎn)上打印分布的形狀(包含site['fn'].batch_shape
和site['fn'].event_shape
)、變量的形狀(site['value'].shape
)、如果計(jì)算對數(shù)似然概率時(shí)log_prob
的形狀(site['log_prob'].shape
)幽歼。
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob() # 可選的朵锣,這句話可以打印log_prob的形狀
print(trace.format_shapes())
打印結(jié)果:
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
log_prob |
b dist | 2
value | 2
log_prob |
c_plate dist |
value 2 |
log_prob |
c dist 2 |
value 2 |
log_prob 2 |
d_plate dist |
value 3 |
log_prob |
d dist 3 | 4 5
value 3 | 4 5
log_prob 3 |
x_axis dist |
value 3 |
log_prob |
y_axis dist |
value 2 |
log_prob |
x dist 3 1 |
value 3 1 |
log_prob 3 1 |
y dist 2 1 1 |
value 2 1 1 |
log_prob 2 1 1 |
xy dist 2 3 1 |
value 2 3 1 |
log_prob 2 3 1 |
z dist 2 3 1 | 5
value 2 3 1 | 5
log_prob 2 3 1 |
在plate
句塊中采樣部分張量
plate最重要的功能之一就是部分采樣,plate
句塊中的隨機(jī)變量都是條件獨(dú)立的甸私。如果樣本量為總樣本的一半诚些,那么樣本損失的值將被認(rèn)為是總損失的一半。
在實(shí)現(xiàn)部分時(shí)皇型,用戶需要通知Pyro采樣量和樣本總量的值诬烹,Pyro就會(huì)隨機(jī)產(chǎn)生一定量的數(shù)據(jù)指標(biāo)作為樣本。
data = torch.arange(100.)
def model2():
mean = pyro.param('mean', torch.zeros(len(data)))
with pyro.plate('data', len(data), subsample_size=10) as ind:
assert len(ind) == 10
batch = data[ind]
mean_batch = mean[ind]
# 在batch中做一些計(jì)算
x = pyro.sample('x', Normal(mean_batch, 1), obs=batch)
assert x.shape == (10,)
test_model(model2, guide=lambda: None, loss=Trace_ELBO())
廣播功能弃鸦,實(shí)現(xiàn)數(shù)據(jù)的并行枚舉
Pyro 0.2后的版本都支持離散隨機(jī)變量的并行枚舉功能绞吁。這一功能可以極大地減少計(jì)算變分推斷時(shí)梯度估計(jì)的方差,確保優(yōu)化的穩(wěn)定性唬格。
為了實(shí)現(xiàn)枚舉掀泳,Pyro需要用戶指定哪些維度是不獨(dú)立的,哪些是獨(dú)立的西轩,只有不獨(dú)立的維度才允許枚舉员舵。自然地,這一指定需要用到plate
語句藕畔,我們需要聲明最大數(shù)量的枚舉范圍马僻,這一關(guān)鍵字為max_plate_nesting
,它是SVI
類的一個(gè)參數(shù)(而且通過TraceEnum_ELBO傳入)注服。通常來說韭邓,Pyro可以自動(dòng)地指定枚舉范圍(只要運(yùn)行一次model
和guide
,系統(tǒng)將了解枚舉范圍)溶弟,不過在動(dòng)態(tài)變化的模型中女淑,用戶需要人工地指定max_plate_nesting
的數(shù)值。
為了弄清楚max_plate_nesting
的作用機(jī)制辜御,我們重新回顧model1()
鸭你,這一次我們關(guān)心三種維度的形狀:最左邊的枚舉維度,中間的批維度擒权,最右邊的不獨(dú)立維度袱巨。而max_plate_nesting
規(guī)定了中間的批維度。
max_plate_nesting = 3
|<--->|
enumeration|batch|event
-----------+-----+-----
|. . .| a = sample("a", Normal(0, 1))
|. . .|2 b = sample("b", Normal(zeros(2), 1)
| | .to_event(1))
| | with plate("c", 2):
|. . 2| c = sample("c", Normal(zeros(2), 1))
| | with plate("d", 3):
|. . 3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| | .to_event(2))
| |
| | x_axis = plate("x", 3, dim=-2)
| | y_axis = plate("y", 2, dim=-3)
| | with x_axis:
|. 3 1| x = sample("x", Normal(0, 1))
| | with y_axis:
|2 1 1| y = sample("y", Normal(0, 1))
| | with x_axis, y_axis:
|2 3 1| xy = sample("xy", Normal(0, 1))
|2 3 1|5 z = sample("z", Normal(0, 1).expand([5]))
| | .to_event(1))
上面的例子中碳抄,如果我們聲明(過度)充裕的max_plate_nesting=4
也是可以的愉老,但不能聲明例如max_plate_nesting=2
,因?yàn)?<3剖效,這時(shí)系統(tǒng)將會(huì)報(bào)錯(cuò)嫉入。
我們再舉一個(gè)例子:
@config_enumerate
#該修飾符表示枚舉類型焰盗,不能省略!咒林!
def model3():
p = pyro.param('p', torch.arange(6) / 6.)
locs = pyro.param('locs', torch.tensor([-1., 1.]))
# locs in [-1, 1]
# a in [0, 1, 2, 3, 4, 5]
a = pyro.sample('a', Categorical(torch.ones(6) / 6.))
# p[a] in [0, 1/6, 2/6, 3/6, 4/6, 5/6]
b = pyro.sample('b', Bernoulli(p[a])) # 聲明b依賴于a
# b in [0, 1]
with pyro.plate('c_plate', 4):
c = pyro.sample('c', Bernoulli(0.4))
# c in [0, 1]
with pyro.plate('d_plate', 5):
d = pyro.sample('d', Bernoulli(0.3))
# d in [0, 1]
e_loc = locs[d.long()].unsqueeze(-1)
# e_loc in [-1, 1]
e_scale = torch.arange(1., 8.)
# e_scale in [1, 2, ..., 7]
e = pyro.sample('e', Normal(e_loc, e_scale).to_event(1)) # 依賴于d
# 枚舉維度|批維度(獨(dú)立維度)|不獨(dú)立維度
assert a.shape == ( 6, 1,1 ) # 多類別分布的維度大小為6
assert b.shape == ( 2,1, 1,1 ) # 枚舉伯努利分布姨谷,非擴(kuò)增
assert c.shape == ( 2,1,1, 1,1 ) # 伯努利分布,非擴(kuò)增
assert d.shape == ( 2,1,1,1, 1,1 ) # 伯努利分布映九,非擴(kuò)增
assert e.shape == ( 2,1,1,1, 5,4, 7) # e是采樣出來的梦湘,依賴于d
#
assert e_loc.shape == ( 2,1,1,1, 1,1, 1,) # 最后的逗號(hào)可以省略
assert e_scale.shape == ( 7,) # 注意逗號(hào)不能省略!件甥!
test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
我們重新來可視化一下:
max_plate_nesting = 2
|<->|
enumeration batch event
------------|---|-----
6|1 1| a = pyro.sample("a", Categorical(torch.ones(6) / 6))
2 1|1 1| b = pyro.sample("b", Bernoulli(p[a]))
| | with pyro.plate("c_plate", 4):
2 1 1|1 1| c = pyro.sample("c", Bernoulli(0.3))
| | with pyro.plate("d_plate", 5):
2 1 1 1|1 1| d = pyro.sample("d", Bernoulli(0.4))
2 1 1 1|1 1|1 e_loc = locs[d.long()].unsqueeze(-1)
| |7 e_scale = torch.arange(1., 8.)
2 1 1 1|5 4|7 e = pyro.sample("e", Normal(e_loc, e_scale)
| | .to_event(1))
我們分析一下這些維度捌议。我們?yōu)镻yro指定了枚舉的維度max_plate_nesting
:Pyro給a
賦予枚舉維度-3,給b
賦予枚舉維度-4引有,給c
賦予枚舉維度-5瓣颅,給d
賦予枚舉維度-6。當(dāng)用戶不指定維度擴(kuò)展后的數(shù)值時(shí)譬正,新維度被默認(rèn)為1宫补,這方便計(jì)算。我們還可以觀察到曾我,log_prob
的形狀廣播的范圍是枚舉維度和獨(dú)立維度粉怕,比如trace.nodes['d']['log_prob'].shape == (2,1,1,1,5,4)
使用Pyro的自帶工具Trace.format_shapes():
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob() # 可選
print(trace.format_shapes())
結(jié)果:
Trace Shapes:
Param Sites:
p 6
locs 2
Sample Sites:
a dist |
value 6 1 1 |
log_prob 6 1 1 |
b dist 6 1 1 |
value 2 1 1 1 |
log_prob 2 6 1 1 |
c_plate dist |
value 4 |
log_prob |
c dist 4 |
value 2 1 1 1 1 |
log_prob 2 1 1 1 4 |
d_plate dist |
value 5 |
log_prob |
d dist 5 4 |
value 2 1 1 1 1 1 |
log_prob 2 1 1 1 5 4 |
e dist 2 1 1 1 5 4 | 7
value 2 1 1 1 5 4 | 7
log_prob 2 1 1 1 5 4 |
編寫并行代碼
在Pyro中,我們需要掌握兩個(gè)取巧的技術(shù)抒巢,來實(shí)現(xiàn)并行采樣:廣播 贫贝、 橢圓分片。我們通過下面的例子來分別介紹枚舉情形和非枚舉情形下的用法蛉谜。
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumeration = None # 設(shè)為True或False
def fun(observe):
p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
# 在這些樣本點(diǎn)上稚晚,分布形狀取決于Pyro是否枚舉
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y))
if enumerated:
assert x_active.shape == (2, 1, 1) # max_plate_nesting==2
assert y_active.shape == (2, 1, 1, 1)
else:
assert x_active.shape == (width, 1)
assert y_active.shape == (height, )
# 第一個(gè)trick:廣播,broadcast型诚。枚舉和非枚舉都可使用客燕。
p = 0.1 + 0.5 * x_active * y_active
if enumerated:
assert p.shape == (2, 2, 1, 1)
else:
assert p.shape == (width, height)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
# 第二個(gè)trick:橢圓分片。Pyro可以在左方任意增加維度狰贯。
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
if enumerated:
assert dense_pixels.shape == (2, 2, width, height)
else:
assert dense_pixels.shape == (width, height)
#
with x_axis, y_axis:
if observe:
pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
def guide4():
fun(observe=False)
# Test: 非枚舉
enumerated = False
test_model(model4, guide4, Trace_ELBO())
# Test: 枚舉也搓。注意目標(biāo)函數(shù)為TraceEnum_ELBO
enumerated = True
test_model(model4, config_enumerate(guide4, 'parallel'), TraceEnum_ELBO(max_plate_nesting=2))
在pyro.plate內(nèi)部實(shí)現(xiàn)自動(dòng)廣播
在以上所有model/plate的實(shí)現(xiàn)中,我們都使用了pyro.plate的自動(dòng)擴(kuò)增功能暮现,使變量滿足pyro.sample
規(guī)定的形狀还绘。這一廣播方式等價(jià)于.expand()
。
我們稍許更改上面的代碼作為例子栖袋,注意幾點(diǎn)區(qū)別:
- 我們僅考慮并行枚舉的情況,但對于串行的抚太、非枚舉的情況也適用塘幅;
- 我們將采樣函數(shù)分離出來昔案,model代碼使用常規(guī)的形式,這樣做有利于代碼的維護(hù)电媳;
-
pyro.plate
使用ELBO的num_particles參數(shù)踏揣,將上下文中最遠(yuǎn)的內(nèi)容打包。
# 規(guī)定采樣的樣本量
num_particals = 100
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x).expand([num_particals, width, 1]))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y).expand([num_particals, 1, height]))
return x_active, y_active
def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x))
with y_axis:
y_active = pyro.sample('y_acitve', Bernoulli(p_y))
return x_active, y_active
def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample('x_active', Bernoulli(p_x).expand([width, 1]))
with y_axis:
y_active = pyro.sample('y_active', Bernoulli(p_y).expand([height]))
return x_acitve, y_active
def fun(observe, sample_fn):
p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
#
with pyro.plate('num_particals', 100, dim=-3):
x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
## 并行枚舉指標(biāo)被擴(kuò)增在“num_particals”最左邊
assert x_active.shape == (2, 1, 1, 1)
assert y_active.shape == (2, 1, 1, 1, 1)
p = 0.1 + 0.5 * x_active * y_active
assert p.shape == (2, 2, 1, 1, 1)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
assert dense_pixels.shape == (2, 2, 1, width, height)
#
with x_axis, y_axis:
if observe:
pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)
def test_model_with_sample_fn(sample_fn):
def model():
fun(observe=True, sample_fn=sample_fn)
#
@config_enumerate
def guide():
fun(observe=False, sample_fn=sample_fn)
test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)
在第一個(gè)采樣函數(shù)中匾乓,我們像賬房先生那樣捞稿,仔細(xì)規(guī)定了Bernoulli
分布的的形狀。請仔細(xì)觀察num_particles
, width
和height
傳入sample_pixel_locations
函數(shù)的方式拼缝。這一方式有些笨拙娱局。
對于第二個(gè)采樣函數(shù),我們需要注意pyro.plate
的參數(shù)必須要提供咧七,這樣系統(tǒng)才能猜出批維度的形狀衰齐。
我們可以看到,對于張量操作继阻,使用pyro.plate
實(shí)現(xiàn)并行是多么容易耻涛!
pyro.plate
還具有將代碼模塊化的效果。