有很多時(shí)候卖陵,我們需要對深度學(xué)習(xí)過程中的tensor進(jìn)行一些非整齊遭顶、離散化的賦值操作,例如我們讓網(wǎng)絡(luò)的一支輸出可能的索引值泪蔫,而另外一支可能需要去取對應(yīng)索引值的內(nèi)容棒旗。PyTorch提供了幾種方法實(shí)現(xiàn)上述操作,但是其實(shí)際效果之間存在差異撩荣,在這里整理一下嗦哆。
-
scatter_(dim, index, src)
按照index
谤祖,將src
的數(shù)據(jù)散放到self
的'dim'維度中。例如老速,對于三維Tensor粥喜,效果如下:self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
-
dim (int)
- 要散布拷貝的維度 -
index (LongTensor)
- 散布拷貝的索引 -
src (Tensor or float)
- 要散布拷貝的源,可以是單個(gè)浮點(diǎn)值或是tensor
-
-
index_fill_(dim, index, val)
按照index
橘券,將val
的值填充self
的dim
維度额湘。效果如下:>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float) >>> index = torch.tensor([0, 2]) >>> x.index_fill_(1, index, -1) tensor([[-1., 2., -1.], [-1., 5., -1.], [-1., 8., -1.]])
-
dim (int)
- 要填充的維度 -
index (LongTensor)
- 要填充的索引 -
val (float)
- 要填充的值
-
-
index_put_(indices, value)
按照indices
,將val
的值填充到self
的對應(yīng)位置旁舰。效果如下:>>> a = torch.zeros([5,5]) >>> index = (torch.LongTensor([0,1]),torch.LongTensor([1,2]) >>> a.index_put_(index), torch.Tensor([1,1])) tensor([[ 0., 1., 0., 0., 0.], [ 0., 0., 1., 0., 0.], [ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.], [ 0., 0., 0., 0., 0.]])
-
indices (tuple of LongTensor)
- 要填充的索引 -
value (Tensor)
- 要填充的值組成的tensor
-
這三者的參數(shù)名相像锋华,但實(shí)際上對各參數(shù)的定義有差別,要仔細(xì)跟據(jù)參數(shù)類型和例子好好分析箭窜。