函數(shù)torch.gather(input, dim, index, out=None) → Tensor
沿給定軸 dim ,將輸入索引張量 index 指定位置的值進行聚合.
對一個 3 維張量,輸出可以定義為:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters:
- input (Tensor) – 源張量
- dim (int) – 索引的軸
- index (LongTensor) – 聚合元素的下標(index需要是torch.longTensor類型)
- out (Tensor, optional) – 目標張量
使用說明舉例:
- dim = 1
import torch
a = torch.randint(0, 30, (2, 3, 5))
print(a)
'''
tensor([[[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 28., 22., 27., 0.]],
[[ 26., 10., 20., 29., 18.],
[ 5., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]]])
'''
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
print(a.size()==index.size())
b = torch.gather(a, 1,index)
print(b)
'''
True
tensor([[[ 18., 26., 22., 1., 0.],
[ 18., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.]],
[[ 5., 29., 10., 0., 22.],
[ 26., 10., 20., 29., 18.],
[ 10., 29., 10., 0., 22.]]])
可以看到沿著dim=1诵闭,也就是列的時候采蚀。輸出tensor第一頁內(nèi)容,
第一行分別是 按照index指定的来屠,
input tensor的第一頁
第一列的下標為0的元素 第二列的下標為1元素 第三列的下標為2的元素虑椎,第四列下標為0元素,第五列下標為2元素
index-->0,1,2,0,2 output--> 18., 26., 22., 1., 0.
'''
- dim =2
c = torch.gather(a, 2,index)
print(c)
'''
tensor([[[ 18., 5., 7., 18., 7.],
[ 3., 3., 3., 3., 3.],
[ 28., 28., 28., 28., 28.]],
[[ 10., 20., 20., 20., 20.],
[ 5., 5., 5., 5., 5.],
[ 10., 10., 10., 10., 10.]]])
dim = 2的時候就安裝 行 聚合了俱笛。參照上面的舉一反三捆姜。
'''
- dim = 0
index2 = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
d = torch.gather(a, 0,index2)
print(d)
'''
tensor([[[ 18., 10., 20., 1., 18.],
[ 3., 24., 26., 21., 3.],
[ 10., 29., 10., 0., 22.]],
[[ 26., 5., 7., 1., 1.],
[ 3., 26., 9., 7., 9.],
[ 10., 29., 22., 27., 0.]]])
這個有點特殊,dim = 0的時候(三維情況下)迎膜,是從不同的頁收集元素的泥技。
這里舉的例子只有兩頁。所有index在0,1兩個之間選擇磕仅。
輸出的矩陣元素也是按照index的指定珊豹。分別在第一頁和第二頁之間跳著選的。
index [0,1,1,0,1]的意思就是宽涌。
在第一頁選這個位置的元素平夜,在第二頁選這個位置的元素,在第二頁選卸亮,第一頁選,第二頁選玩裙。
'''