3.4 命名張量
張量的尺寸(或軸)通常索引諸如像素位置或顏色通道之類的東西惭婿。 這意味著當(dāng)我們想索引張量時(shí)赤拒,我們需要記住維度的順序并相應(yīng)地編寫索引孩革。 隨著數(shù)據(jù)通過多個(gè)張量轉(zhuǎn)換广鳍,跟蹤哪個(gè)維度包含哪些數(shù)據(jù)可能容易出錯(cuò)。
為了使事情具體苹支,假設(shè)我們?cè)?.1.4節(jié)中的img_t這樣的3D張量(為簡化起見,我們將使用虛擬數(shù)據(jù))误阻,并將其轉(zhuǎn)換為灰度债蜜。 我們查找了顏色的典型權(quán)重以得出單個(gè)亮度值:
# In[2]:
img_t = torch.randn(3, 5, 5) # shape [channels, rows, columns]
weights = torch.tensor([0.2126, 0.7152, 0.0722])
我們還經(jīng)常希望代碼泛化,例如究反,從表示為具有高度和寬度尺寸的2D張量的灰度圖像到添加第三個(gè)通道尺寸的彩色圖像(如RGB)策幼,或者從單個(gè)圖像到一批圖像。 在2.1.4節(jié)中奴紧,我們?cè)赽atch_t中引入了另一個(gè)批處理維度特姐; 在這里,我們假裝有兩個(gè)批次:
# In[3]:
batch_t = torch.randn(2, 3, 5, 5) # shape [batch, channels, rows, columns]
因此黍氮,有時(shí)RGB通道的尺寸為0唐含,有時(shí)它們的尺寸為1。但是我們可以通過從末尾開始計(jì)數(shù)來進(jìn)行歸納:它們始終在3維中沫浆,是末尾的第三個(gè)捷枯。 簡單的,沒有權(quán)重的均值可以這樣寫:
# In[4]:
img_gray_naive = img_t.mean(-3)
batch_gray_naive = batch_t.mean(-3)
img_gray_naive.shape, batch_gray_naive.shape
# Out[4]:
(torch.Size([5, 5]), torch.Size([2, 5, 5]))
但是現(xiàn)在我們也有分量了专执。 PyTorch將允許我們乘以相同形狀的東西淮捆,以及在給定維度上一個(gè)操作數(shù)的大小為1的形狀。 它還會(huì)自動(dòng)附加尺寸為1的前導(dǎo)尺寸本股。 這是稱為廣播的功能攀痊。 形狀(2,3拄显,5苟径,5)的batch_t乘以形狀(3,1,1)的未壓縮的權(quán)重躬审,得出一個(gè)形狀為(2棘街、3蟆盐、5、5)的張量遭殉,然后我們可以根據(jù)該張量求和來自末端的第三個(gè)維度(三個(gè)通道):
# In[5]:
unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1)
img_weights = (img_t * unsqueezed_weights)
batch_weights = (batch_t * unsqueezed_weights)
img_gray_weighted = img_weights.sum(-3)
batch_gray_weighted = batch_weights.sum(-3)
batch_weights.shape, batch_t.shape, unsqueezed_weights.shape
# Out[5]:
(torch.Size([2, 3, 5, 5]), torch.Size([2, 3, 5, 5]), torch.Size([3, 1, 1]))
因?yàn)檫@很快就會(huì)變得混亂(并且為了提高效率)石挂,所以PyTorch函數(shù)einsum(改編自NumPy)指定了索引迷你語言,為此類產(chǎn)品的總和提供了索引名稱险污。 就像在Python中一樣誊稚,廣播是一種總結(jié)未命名事物的形式,它使用三個(gè)點(diǎn)“ ...”來完成罗心; 但不必太擔(dān)心einsum里伯,因?yàn)槲覀儗⒃谝韵聝?nèi)容中不使用它:
# In[6]:
img_gray_weighted_fancy = torch.einsum('...chw,c->...hw', img_t, weights)
batch_gray_weighted_fancy = torch.einsum('...chw,c->...hw', batch_t, weights)
batch_gray_weighted_fancy.shape
# Out[6]:
torch.Size([2, 5, 5])
如我們所見,其中涉及大量簿記渤闷。 這很容易出錯(cuò)疾瓮,尤其是在我們的代碼中創(chuàng)建和使用張量的位置相距甚遠(yuǎn)的情況下,這引起了從業(yè)人員的注意飒箭,因此建議改用維度名稱狼电。
PyTorch 1.3添加了命名張量作為實(shí)驗(yàn)功能(請(qǐng)參閱https://pytorch.org/tutorials/intermediate/named_tensor_tutorial.html和https://pytorch.org/docs/stable/named_tensor.html)。 張量工廠函數(shù)(例如tensor和rank)帶有一個(gè)名稱參數(shù)弦蹂。 名稱應(yīng)為字符串序列:
# In[7]:
weights_named = torch.tensor([0.2126, 0.7152, 0.0722], names=['channels'])
weights_named
# Out[7]:tensor([0.2126, 0.7152, 0.0722], names=('channels',))
當(dāng)我們已經(jīng)有一個(gè)張量并且想要添加名稱(但不更改現(xiàn)有名稱)時(shí)肩碟,可以在其上調(diào)用方法fine_names。 與索引類似凸椿,省略號(hào)(...)允許您省略任意數(shù)量的尺寸削祈。 使用重命名同級(jí)方法,您還可以覆蓋或刪除(通過傳入None)現(xiàn)有名稱:
# In[8]:
img_named = img_t.refine_names(..., 'channels', 'rows', 'columns')
batch_named = batch_t.refine_names(..., 'channels', 'rows', 'columns')
print("img named:", img_named.shape, img_named.names)
print("batch named:", batch_named.shape, batch_named.names)
# Out[8]:
img named: torch.Size([3, 5, 5]) ('channels', 'rows', 'columns')
batch named: torch.Size([2, 3, 5, 5]) (None, 'channels', 'rows', 'columns')
對(duì)于具有兩個(gè)輸入的操作脑漫,除了通常的尺寸檢查(大小是否相同髓抑,或者一個(gè)是否為1并可以廣播給另一個(gè))之外,PyTorch現(xiàn)在將為我們檢查名稱优幸。 到目前為止吨拍,它不會(huì)自動(dòng)對(duì)齊尺寸,因此我們需要明確地進(jìn)行此操作网杆。 方法align_as返回一個(gè)張量羹饰,其中添加了缺失的維,而現(xiàn)有的維以正確的順序排列:
# In[9]:
weights_aligned = weights_named.align_as(img_named)
weights_aligned.shape, weights_aligned.names
# Out[9]:
(torch.Size([3, 1, 1]), ('channels', 'rows', 'columns'))
接受維度參數(shù)(例如sum)的函數(shù)也采用命名維度:
# In[10]:
gray_named = (img_named * weights_aligned).sum('channels')
gray_named.shape, gray_named.names
# Out[10]:
(torch.Size([5, 5]), ('rows', 'columns'))
如果嘗試將尺寸與不同的名稱組合在一起碳却,則會(huì)出現(xiàn)錯(cuò)誤:
gray_named = (img_named[..., :3] * weights_named).sum('channels')
RuntimeError: Error when attempting to broadcast dims ['channels', 'rows','columns'] and dims ['channels']: dim 'columns' and dim 'channels'are at the same position from the right but do not match.
如果要在對(duì)命名張量進(jìn)行運(yùn)算的函數(shù)之外使用張量队秩,則需要通過將名稱重命名為None來刪除名稱。 以下內(nèi)容使我們回到了未命名尺寸的世界:
# In[12]:
gray_plain = gray_named.rename(None)
gray_plain.shape, gray_plain.names
# Out[12]:
(torch.Size([5, 5]), (None, None))
鑒于在撰寫本文時(shí)此功能具有實(shí)驗(yàn)性質(zhì)追城,并且為了避免因索引和對(duì)齊方式而搞混刹碾,在本書的其余部分中燥撞,我們將堅(jiān)持未命名座柱。 指定的張量可能消除許多潛在的對(duì)齊的誤差源(會(huì)令人頭疼的錯(cuò)誤)迷帜,如果PyTorch論壇有任何跡象。 看到它們將被廣泛采用將會(huì)很有趣色洞。