Pyro簡介:產(chǎn)生式模型實(shí)現(xiàn)庫(六)邪驮,Pyro的張量尺寸

太長不看版

  • 模型在學(xué)習(xí)或調(diào)試過程中,設(shè)置pyro.enable_validation(True)泵三;
  • 張量的“廣播”耕捞,維度對齊自右向左:torch.ones(3,4,5) + torch.ones(5)
  • 分布的尺寸 .sample().shape == batch_shape + event_shape烫幕;
  • 分布的尺寸 .log_prob(x).shape == batch_shape(沒有event_shape)俺抽;
  • 使用expand()從Pyro中采樣一批數(shù)據(jù),或使用plate機(jī)制自動(dòng)擴(kuò)展较曼;
  • 使用my_dist.to_event(1)聲明維度為依賴(dependent)磷斧,或說不獨(dú)立;
  • 使用with pyro.plate('name', size):聲明條件獨(dú)立;
  • 所有維度要么是依賴的弛饭,要么是條件獨(dú)立的冕末;
  • 支持維度最左方的批處理,啟動(dòng)Pyro的并行處理侣颂;
    • 使用負(fù)號(hào)指標(biāo)档桃,如x.sum(-1),而不是x.sum(2)憔晒;
    • 使用省略號(hào)藻肄,如pixel = image[...,i, j]
    • 如果要枚舉i,j拒担,使用Vindex嘹屯,如pixel = Vindex(image)[...,i, j]

內(nèi)容列表

  • 概率分布的形狀
  • plate聲明條件獨(dú)立
  • 在plate中部分采樣
  • 并行地枚舉,張量的廣播

文件頭如下

import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True) #這句話最好加上

# 我們借助這個(gè)函數(shù)低零,檢查模型是否正確
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

概率分布的尺寸:batch_shapeevent_shape

Pytorch的張量Tensor只有一個(gè)尺寸.shape婆翔,但是Distributions有兩個(gè)尺寸.batch_shape.event_shape,分別表示條件獨(dú)立的隨機(jī)變量的大小和不獨(dú)立的隨機(jī)變量的大小毁兆。這兩部分構(gòu)成了一個(gè)樣本的尺寸浙滤。

x = d.sample()
assert x.shape == d.batch_shape + d.event_shape

由于計(jì)算對數(shù)似然只牽涉不獨(dú)立的變量阴挣,所以.log_prob()方法后气堕,event_shape就被縮并了,只剩下batch_shape畔咧。

assert d.log_prob(x) == d.batch_shape

Distributions.sample()方法可以輸入一個(gè)參數(shù)sample_shape茎芭,作為獨(dú)立同分布(iid)的隨機(jī)變量,所以指定樣本大小的采樣誓沸,具有三個(gè)尺寸梅桩。

x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape

總結(jié)來說

      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape

由上可推論,單變量隨機(jī)分布的event_shape為0拜隧,因?yàn)槊看尾蓸又凳且粋€(gè)實(shí)數(shù)宿百,所以沒有不獨(dú)立的維度。像MultivariateNormal多元高斯分布這樣的概率分布洪添,具有len(event_shape) == 1垦页,因?yàn)槊總€(gè)采樣是一個(gè)向量,向量內(nèi)部是彼此依賴的(這里假定方差矩陣不是對角陣)干奢。而InverseWishart逆威沙特分布具有len(event_shape) == 2痊焊,等等。

關(guān)于概率分布尺寸的舉例

從單變量隨機(jī)分布開始。

d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
# x是一個(gè)Pytorch張量薄啥,沒有batch_shape和event_shape
assert x.shape == () 
assert d.log_prob(x).shape == ()

通過傳入批參數(shù)辕羽,概率分布數(shù)據(jù)可以分成批推掸。

d = Bernoulli(0.5 * torch.ones(3, 4))
assert d.batch_shape == (3,4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

另一種成批的方法熟妓,是通過expand()。不過只在參數(shù)的最左側(cè)維度獨(dú)立時(shí)才可使用豪墅。

d = Bernoulli(torch.tensor([.1, .2, .3, .4])).expand([3, 4])
# 注意expand的參數(shù)寫在一個(gè)列表中
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

多元高斯分布具有非空的event_shape維度到逊。對于這些分布來說酌毡,.sample().log_prob()的維度是不同的。

d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3, )
x = d.sample()
assert x.shape == (3, ) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape

改變分布的維度獨(dú)立性

使用關(guān)鍵字.to_event(n)改變不獨(dú)立維度的情況蕾管,其中n表示從數(shù)第n維度開始枷踏,聲明為不獨(dú)立維度。

d = Bernoulli(0.5 * torch.ones(3, 4)).to_event(1)
assert d.batch_shape == (3, )
assert d.event_shape == (4, )
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, )

用戶必須小心地設(shè)置.to_event(n)batch_shape縮減到合適的水平上掰曾,或者用pyro.plate聲明維度的獨(dú)立性旭蠕。采樣仍舊會(huì)保留batch_shape+event_shape的尺寸,然而log_prob(x)只剩下batch_shape旷坦。

聲明為不獨(dú)立掏熬,通常是安全的做法

在Pyro中,我們常常會(huì)聲明維度是不獨(dú)立的秒梅,哪怕它們實(shí)際上是獨(dú)立的旗芬。請看這個(gè)例子:

x = pyro.sample('x', dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)

上面的例子很容易就可以換成MultivariateNormal分布。它將下面的寫法簡化了:

with pyro.plate('x_plate', 10):
    x = pyro.sample('x', dist.Normal(0, 1)) #不需要expand捆蜀,系統(tǒng)自動(dòng)補(bǔ)全
    assert x.shape == (10,)

實(shí)際上疮丛,這兩份代碼存在一點(diǎn)小小的差別。上面的代碼中辆它,Pyro默認(rèn)x之間是不獨(dú)立的誊薄,而下面的x則是條件獨(dú)立的。聲明為不獨(dú)立通常是安全的锰茉,這與圖論中的d-separation基于同一個(gè)原理:在不同節(jié)點(diǎn)之間連一條邊呢蔫,即便節(jié)點(diǎn)之間不存在互相依賴關(guān)系,隨著優(yōu)化該邊的權(quán)重將越來越低飒筑,并不影響最終結(jié)果片吊;而本就存在依賴的節(jié)點(diǎn)連了一條邊,任優(yōu)化策略多么高明协屡,都無法彌補(bǔ)這一錯(cuò)誤俏脊。這種錯(cuò)誤常見于平均場假設(shè)的模型中。不過著瓶,在實(shí)際執(zhí)行時(shí)联予,Pyro的SVI模塊在估算Normal分布時(shí)啼县,兩份代碼的梯度估計(jì)值是一樣的。

通過plate聲明維度為獨(dú)立

Pyro的上下文管理器pyro.plate能夠聲明特定的維度為獨(dú)立維度沸久。推斷算法可以利用這一獨(dú)立性做一些算法優(yōu)化季眷,例如構(gòu)造低方差的梯度估計(jì)器,再如求解推斷問題不在指數(shù)空間而在線性空間采樣卷胯。下面的例子中子刮,我們將聲明同一批次中的數(shù)據(jù)之間是互相獨(dú)立的。
最簡單的方法窑睁,是不聲明獨(dú)立維度挺峡,系統(tǒng)將缺省值-1——即最右邊的維度,作為獨(dú)立維度担钮。

with pyro.plate('my_plate'):
    # 在該上下文中橱赠,維度-1將作為獨(dú)立維度

雖然效果是一樣的,不過我們?nèi)蕴岢脩魧懗鰜眢锝颍詭椭脩粽{(diào)試代碼:

with pyro.plate('my_plate', len(data)):
    #  在該上下文中狭姨,維度-1將作為獨(dú)立維度

從Pyro 0.2版本開始,plate語句可以嵌套使用苏遥。比如聲明圖像的每個(gè)像素都是獨(dú)立的:

with pyro.plate('x_axis', 320):
    #  在該上下文中饼拍,維度-1將作為獨(dú)立維度
    with pyro.plate('y_axis', 200):
        #  在該上下文中,維度-2和-1將作為獨(dú)立維度

我們習(xí)慣上總從右向左聲明獨(dú)立維度田炭,所以指標(biāo)是負(fù)的师抄,如-1,-2教硫,等等叨吮。
有時(shí)情況會(huì)更復(fù)雜一些,比如我們希望聲明一些噪聲依賴x栋豫,另一些噪聲依賴y挤安,還有一些噪聲依賴二者谚殊。這時(shí)Pyro允許用戶聲明多重獨(dú)立丧鸯,為了清楚地標(biāo)明獨(dú)立維度,必須指定dim這一參數(shù)嫩絮,如下面的例子:

x_axis = pyro.plate('x_axis', dim = -2)
y_axis = pyro.plate('y_axis', dim = -3)
with x_axis:
    #  在該上下文中丛肢,維度-2將作為獨(dú)立維度
with y_axis:
    #  在該上下文中,維度-3將作為獨(dú)立維度
with x_axis, y_axis:
    #  在該上下文中剿干,維度-2和-3將作為獨(dú)立維度

讓我們舉更多例子蜂怎,來展示plate的用法。

def model1():
    a = pyro.sample('a', Normal(0, 1))
    b = pyro.sample('b', Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate('c_plate', 2):
        c = pyro.sample('c', Normal(torch.zeros(2), 1))
    with pyro.plate('d_plate', 3):
        d = pyro.sample('d', Normal(torch.zeros(3, 4, 5), 1).to_event(2))
    assert a.shape == ()                  # batch_shape == (), event_shape == ()
    assert b.shape == (2,)                # batch_shape == (), event_shape == (2,)
    assert c.shape == (2,)                # batch_shape == (2,), event_shape == ()
    assert d.shape == (3, 4, 5)           # batch_shape == (3), event_shape == (4, 5)
    ##
    x_axis = pyro.plate('x_axis', 3, dim=-2)
    y_axis = pyro.plate('y_axis', 2, dim=-3)
    with x_axis:
        x = pyro.sample('x', Normal(0, 1))
    with y_axis:
        y = pyro.sample('y', Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample('xy', Normal(0, 1))
        z = pyro.sample('z', Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)               # batch_shape == (3, 1), event_shape==()
    assert y.shape == (2, 1, 1)            # batch_shape == (2, 1, 1), event_shape==()
    assert xy.shape == (2, 3, 1)           # batch_shape == (2, 3, 1), event_shape==()
    assert z.shape == (2, 3, 1, 5)         # batch_shape == (2, 3, 1), event_shape==(5,)

test_model(model1, model1, Trace_ELBO())

可視化如下:

batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .to_event(1))
           |        with plate("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with plate("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .to_event(2))
           |
           |        x_axis = plate("x", 3, dim=-2)
           |        y_axis = plate("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1))
      2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
           |                       .to_event(1))

為了在調(diào)試代碼時(shí)方便地查看隨機(jī)變量的形狀置尔,Pyro提供了Trace.format_shapes()
方法杠步,在采樣點(diǎn)上打印分布的形狀(包含site['fn'].batch_shapesite['fn'].event_shape)、變量的形狀(site['value'].shape)、如果計(jì)算對數(shù)似然概率時(shí)log_prob的形狀(site['log_prob'].shape)幽歼。

trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  #  可選的朵锣,這句話可以打印log_prob的形狀
print(trace.format_shapes())

打印結(jié)果:

Trace Shapes:
 Param Sites:
Sample Sites:
       a dist       |
        value       |
     log_prob       |
       b dist       | 2
        value       | 2
     log_prob       |
 c_plate dist       |
        value     2 |
     log_prob       |
       c dist     2 |
        value     2 |
     log_prob     2 |
 d_plate dist       |
        value     3 |
     log_prob       |
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |
  x_axis dist       |
        value     3 |
     log_prob       |
  y_axis dist       |
        value     2 |
     log_prob       |
       x dist   3 1 |
        value   3 1 |
     log_prob   3 1 |
       y dist 2 1 1 |
        value 2 1 1 |
     log_prob 2 1 1 |
      xy dist 2 3 1 |
        value 2 3 1 |
     log_prob 2 3 1 |
       z dist 2 3 1 | 5
        value 2 3 1 | 5
     log_prob 2 3 1 |

plate句塊中采樣部分張量

plate最重要的功能之一就是部分采樣,plate句塊中的隨機(jī)變量都是條件獨(dú)立的甸私。如果樣本量為總樣本的一半诚些,那么樣本損失的值將被認(rèn)為是總損失的一半。
在實(shí)現(xiàn)部分時(shí)皇型,用戶需要通知Pyro采樣量和樣本總量的值诬烹,Pyro就會(huì)隨機(jī)產(chǎn)生一定量的數(shù)據(jù)指標(biāo)作為樣本。

data = torch.arange(100.)

def model2():
    mean = pyro.param('mean', torch.zeros(len(data)))
    with pyro.plate('data', len(data), subsample_size=10) as ind: 
        assert len(ind) == 10
        batch = data[ind]
        mean_batch = mean[ind]
        # 在batch中做一些計(jì)算
        x = pyro.sample('x', Normal(mean_batch, 1), obs=batch)
        assert x.shape == (10,)

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

廣播功能弃鸦,實(shí)現(xiàn)數(shù)據(jù)的并行枚舉

Pyro 0.2后的版本都支持離散隨機(jī)變量的并行枚舉功能绞吁。這一功能可以極大地減少計(jì)算變分推斷時(shí)梯度估計(jì)的方差,確保優(yōu)化的穩(wěn)定性唬格。
為了實(shí)現(xiàn)枚舉掀泳,Pyro需要用戶指定哪些維度是不獨(dú)立的,哪些是獨(dú)立的西轩,只有不獨(dú)立的維度才允許枚舉员舵。自然地,這一指定需要用到plate語句藕畔,我們需要聲明最大數(shù)量的枚舉范圍马僻,這一關(guān)鍵字為max_plate_nesting,它是SVI類的一個(gè)參數(shù)(而且通過TraceEnum_ELBO傳入)注服。通常來說韭邓,Pyro可以自動(dòng)地指定枚舉范圍(只要運(yùn)行一次modelguide,系統(tǒng)將了解枚舉范圍)溶弟,不過在動(dòng)態(tài)變化的模型中女淑,用戶需要人工地指定max_plate_nesting的數(shù)值。
為了弄清楚max_plate_nesting的作用機(jī)制辜御,我們重新回顧model1()鸭你,這一次我們關(guān)心三種維度的形狀:最左邊的枚舉維度,中間的批維度擒权,最右邊的不獨(dú)立維度袱巨。而max_plate_nesting規(guī)定了中間的批維度

      max_plate_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .to_event(1))
           |     |      with plate("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with plate("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .to_event(2))
           |     |
           |     |      x_axis = plate("x", 3, dim=-2)
           |     |      y_axis = plate("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1))
           |2 3 1|5         z = sample("z", Normal(0, 1).expand([5]))
           |     |                     .to_event(1))

上面的例子中碳抄,如果我們聲明(過度)充裕的max_plate_nesting=4也是可以的愉老,但不能聲明例如max_plate_nesting=2,因?yàn)?<3剖效,這時(shí)系統(tǒng)將會(huì)報(bào)錯(cuò)嫉入。
我們再舉一個(gè)例子:

@config_enumerate
#該修飾符表示枚舉類型焰盗,不能省略!咒林!
def model3():
    p = pyro.param('p', torch.arange(6) / 6.)
    locs = pyro.param('locs', torch.tensor([-1., 1.]))
    # locs in [-1, 1]
    # a in [0, 1, 2, 3, 4, 5]
    a = pyro.sample('a', Categorical(torch.ones(6) / 6.))
    # p[a] in [0, 1/6, 2/6, 3/6, 4/6, 5/6]
    b = pyro.sample('b', Bernoulli(p[a])) # 聲明b依賴于a
    # b in [0, 1]
    with pyro.plate('c_plate', 4):
        c = pyro.sample('c',  Bernoulli(0.4))
        # c in [0, 1]
        with pyro.plate('d_plate', 5):
            d = pyro.sample('d', Bernoulli(0.3))
            # d in [0, 1]
            e_loc = locs[d.long()].unsqueeze(-1)
            # e_loc in [-1, 1]
            e_scale = torch.arange(1., 8.)
            # e_scale in [1, 2, ..., 7]
            e = pyro.sample('e', Normal(e_loc, e_scale).to_event(1)) # 依賴于d
    #                            枚舉維度|批維度(獨(dú)立維度)|不獨(dú)立維度
    assert a.shape == (                6,            1,1            )  # 多類別分布的維度大小為6
    assert b.shape == (              2,1,            1,1            )  # 枚舉伯努利分布姨谷,非擴(kuò)增
    assert c.shape == (            2,1,1,            1,1            )  # 伯努利分布,非擴(kuò)增
    assert d.shape == (          2,1,1,1,            1,1            )  # 伯努利分布映九,非擴(kuò)增
    assert e.shape == (          2,1,1,1,            5,4,          7)  # e是采樣出來的梦湘,依賴于d
    #
    assert e_loc.shape ==   (    2,1,1,1,            1,1,         1,) # 最后的逗號(hào)可以省略
    assert e_scale.shape == (                                     7,) # 注意逗號(hào)不能省略!件甥!

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

我們重新來可視化一下:

     max_plate_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.plate("c_plate", 4):
       2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
            |   |         with pyro.plate("d_plate", 5):
     2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1., 8.)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .to_event(1))

我們分析一下這些維度捌议。我們?yōu)镻yro指定了枚舉的維度max_plate_nesting:Pyro給a賦予枚舉維度-3,給b賦予枚舉維度-4引有,給c賦予枚舉維度-5瓣颅,給d賦予枚舉維度-6。當(dāng)用戶不指定維度擴(kuò)展后的數(shù)值時(shí)譬正,新維度被默認(rèn)為1宫补,這方便計(jì)算。我們還可以觀察到曾我,log_prob的形狀廣播的范圍是枚舉維度和獨(dú)立維度粉怕,比如trace.nodes['d']['log_prob'].shape == (2,1,1,1,5,4)

使用Pyro的自帶工具Trace.format_shapes():

trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob() # 可選
print(trace.format_shapes())

結(jié)果:

Trace Shapes:                
 Param Sites:                
            p             6  
         locs             2  
Sample Sites:                
       a dist             |  
        value       6 1 1 |  
     log_prob       6 1 1 |  
       b dist       6 1 1 |  
        value     2 1 1 1 |  
     log_prob     2 6 1 1 |  
 c_plate dist             |  
        value           4 |  
     log_prob             |  
       c dist           4 |  
        value   2 1 1 1 1 |  
     log_prob   2 1 1 1 4 |  
 d_plate dist             |  
        value           5 |  
     log_prob             |  
       d dist         5 4 |  
        value 2 1 1 1 1 1 |  
     log_prob 2 1 1 1 5 4 |  
       e dist 2 1 1 1 5 4 | 7
        value 2 1 1 1 5 4 | 7
     log_prob 2 1 1 1 5 4 |  

編寫并行代碼

在Pyro中,我們需要掌握兩個(gè)取巧的技術(shù)抒巢,來實(shí)現(xiàn)并行采樣:廣播 贫贝、 橢圓分片。我們通過下面的例子來分別介紹枚舉情形和非枚舉情形下的用法蛉谜。

width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumeration = None # 設(shè)為True或False

def fun(observe):
    p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)
    # 在這些樣本點(diǎn)上稚晚,分布形狀取決于Pyro是否枚舉
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y))
    if enumerated:
        assert x_active.shape == (2, 1, 1) # max_plate_nesting==2
        assert y_active.shape == (2, 1, 1, 1)
    else:
        assert x_active.shape == (width, 1)
        assert y_active.shape == (height, )
    # 第一個(gè)trick:廣播,broadcast型诚。枚舉和非枚舉都可使用客燕。
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
    # 第二個(gè)trick:橢圓分片。Pyro可以在左方任意增加維度狰贯。
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)
    #
    with x_axis, y_axis:
        if observe:
            pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test: 非枚舉
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test: 枚舉也搓。注意目標(biāo)函數(shù)為TraceEnum_ELBO
enumerated = True
test_model(model4, config_enumerate(guide4, 'parallel'), TraceEnum_ELBO(max_plate_nesting=2))

在pyro.plate內(nèi)部實(shí)現(xiàn)自動(dòng)廣播

在以上所有model/plate的實(shí)現(xiàn)中,我們都使用了pyro.plate的自動(dòng)擴(kuò)增功能暮现,使變量滿足pyro.sample規(guī)定的形狀还绘。這一廣播方式等價(jià)于.expand()
我們稍許更改上面的代碼作為例子栖袋,注意幾點(diǎn)區(qū)別:

  • 我們僅考慮并行枚舉的情況,但對于串行的抚太、非枚舉的情況也適用塘幅;
  • 我們將采樣函數(shù)分離出來昔案,model代碼使用常規(guī)的形式,這樣做有利于代碼的維護(hù)电媳;
  • pyro.plate使用ELBO的num_particles參數(shù)踏揣,將上下文中最遠(yuǎn)的內(nèi)容打包。
# 規(guī)定采樣的樣本量
num_particals = 100
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x).expand([num_particals, width, 1]))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y).expand([num_particals, 1, height]))
    return x_active, y_active

def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample('y_acitve', Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample('x_active', Bernoulli(p_x).expand([width, 1]))
    with y_axis:
        y_active = pyro.sample('y_active', Bernoulli(p_y).expand([height]))
    return x_acitve, y_active

def fun(observe, sample_fn):
    p_x = pyro.param('p_x', torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param('p_y', torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)
    # 
    with pyro.plate('num_particals', 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        ## 并行枚舉指標(biāo)被擴(kuò)增在“num_particals”最左邊
        assert x_active.shape == (2, 1, 1, 1) 
        assert y_active.shape == (2, 1, 1, 1, 1)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, 1, 1, 1)
        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, 1, width, height)
        #
        with x_axis, y_axis:
            if observe:
                pyro.sample('pixels', Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn):
    def model():
        fun(observe=True, sample_fn=sample_fn)
    #
    @config_enumerate
    def guide():
        fun(observe=False, sample_fn=sample_fn)

test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting) 
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)

在第一個(gè)采樣函數(shù)中匾乓,我們像賬房先生那樣捞稿,仔細(xì)規(guī)定了Bernoulli分布的的形狀。請仔細(xì)觀察num_particles, widthheight傳入sample_pixel_locations函數(shù)的方式拼缝。這一方式有些笨拙娱局。
對于第二個(gè)采樣函數(shù),我們需要注意pyro.plate的參數(shù)必須要提供咧七,這樣系統(tǒng)才能猜出批維度的形狀衰齐。
我們可以看到,對于張量操作继阻,使用pyro.plate實(shí)現(xiàn)并行是多么容易耻涛!
pyro.plate還具有將代碼模塊化的效果。

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
  • 序言:七十年代末瘟檩,一起剝皮案震驚了整個(gè)濱河市抹缕,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌墨辛,老刑警劉巖歉嗓,帶你破解...
    沈念sama閱讀 211,265評論 6 490
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異背蟆,居然都是意外死亡鉴分,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 90,078評論 2 385
  • 文/潘曉璐 我一進(jìn)店門带膀,熙熙樓的掌柜王于貴愁眉苦臉地迎上來志珍,“玉大人,你說我怎么就攤上這事垛叨÷着矗” “怎么了?”我有些...
    開封第一講書人閱讀 156,852評論 0 347
  • 文/不壞的土叔 我叫張陵嗽元,是天一觀的道長敛纲。 經(jīng)常有香客問我,道長剂癌,這世上最難降的妖魔是什么淤翔? 我笑而不...
    開封第一講書人閱讀 56,408評論 1 283
  • 正文 為了忘掉前任,我火速辦了婚禮佩谷,結(jié)果婚禮上旁壮,老公的妹妹穿的比我還像新娘监嗜。我一直安慰自己,他們只是感情好抡谐,可當(dāng)我...
    茶點(diǎn)故事閱讀 65,445評論 5 384
  • 文/花漫 我一把揭開白布裁奇。 她就那樣靜靜地躺著,像睡著了一般麦撵。 火紅的嫁衣襯著肌膚如雪刽肠。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 49,772評論 1 290
  • 那天免胃,我揣著相機(jī)與錄音音五,去河邊找鬼。 笑死杜秸,一個(gè)胖子當(dāng)著我的面吹牛放仗,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播撬碟,決...
    沈念sama閱讀 38,921評論 3 406
  • 文/蒼蘭香墨 我猛地睜開眼诞挨,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了呢蛤?” 一聲冷哼從身側(cè)響起惶傻,我...
    開封第一講書人閱讀 37,688評論 0 266
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎其障,沒想到半個(gè)月后银室,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 44,130評論 1 303
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡励翼,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 36,467評論 2 325
  • 正文 我和宋清朗相戀三年蜈敢,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片汽抚。...
    茶點(diǎn)故事閱讀 38,617評論 1 340
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡抓狭,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出造烁,到底是詐尸還是另有隱情否过,我是刑警寧澤,帶...
    沈念sama閱讀 34,276評論 4 329
  • 正文 年R本政府宣布惭蟋,位于F島的核電站苗桂,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏告组。R本人自食惡果不足惜煤伟,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 39,882評論 3 312
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧持偏,春花似錦驼卖、人聲如沸氨肌。這莊子的主人今日做“春日...
    開封第一講書人閱讀 30,740評論 0 21
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽怎囚。三九已至卿叽,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間恳守,已是汗流浹背考婴。 一陣腳步聲響...
    開封第一講書人閱讀 31,967評論 1 265
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留催烘,地道東北人沥阱。 一個(gè)月前我還...
    沈念sama閱讀 46,315評論 2 360
  • 正文 我出身青樓,卻偏偏與公主長得像伊群,于是被迫代替她去往敵國和親考杉。 傳聞我的和親對象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 43,486評論 2 348

推薦閱讀更多精彩內(nèi)容