scatter_(input, dim, index, src)將src中數(shù)據(jù)根據(jù)index中的索引按照dim的方向填進(jìn)input中。
>>> x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
index的shape剛好與x的shape對(duì)應(yīng)案怯,也就是index中每個(gè)元素指定x中一個(gè)數(shù)據(jù)的填充位置。dim=0少办,表示按行填充扼鞋,主要理解按行填充。舉例index中的第0行第2列的值為2阴颖,表示在第2行(從0開始)進(jìn)行填充活喊,對(duì)應(yīng)到input = zeros(3, 5)中就是位置(2,2)。所以此處要求input的列數(shù)要與x列數(shù)相同量愧,而index中的最大值應(yīng)與zeros(3, 5)行數(shù)相一致钾菊。
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
同上理,可以把1.23看成[[1.23], [1.23]]偎肃。此處按列填充煞烫,index中的2對(duì)應(yīng)zeros(2, 4)的(0,2)位置。
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
0.0000 0.0000 1.2300 0.0000
0.0000 0.0000 0.0000 1.2300
[torch.FloatTensor of size 2x4]
綜上累颂,幾點(diǎn)要注意:
index的shape要與填充數(shù)據(jù)src的shape一致滞详,如果不一致,將進(jìn)行廣播
index中的索引指的是要把src中對(duì)應(yīng)位置的數(shù)據(jù)按照指定那個(gè)維度(即dim)填充到原數(shù)據(jù)input中,我們知道了要填充的數(shù)據(jù)是什么料饥,填充到input的哪行那列呢蒲犬,dim指定哪個(gè)維度,這個(gè)維度就是index索引值稀火,另一個(gè)維度就是這個(gè)索引在index中的位置暖哨。
scatter() 和 scatter_() 的作用是一樣的,只不過 scatter() 不會(huì)直接修改原來的 Tensor凰狞,而 scatter_() 會(huì)篇裁。PyTorch 中,一般函數(shù)加下劃線代表直接在原來的 Tensor 上修改
scatter() 一般可以用來對(duì)標(biāo)簽進(jìn)行 one-hot 編碼赡若,這就是一個(gè)典型的用標(biāo)量來修改張量的一個(gè)例子
class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size, 1).random_() % class_num
#tensor([[6],
# [0],
# [3],
# [2]])
torch.zeros(batch_size, class_num).scatter_(1, label, 1)
#tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
# [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
# [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
轉(zhuǎn)載于https://blog.csdn.net/qq_16234613/article/details/79827006
轉(zhuǎn)載于https://www.cnblogs.com/dogecheng/p/11938009.html