函數(shù)原型,nd
的意思是可以收集n dimension
的tensor
tf.gather_nd(
params,
indices,
name=None
)
- 意思是要收集
[params[0][0],params[1][1]]
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']
- 意思是要收集
[params[1],params[0]]
indices = [[1], [0]]
params = [['a', 'b'], ['c', 'd']]
output = [['c', 'd'], ['a', 'b']]
- 意思是要收集
[params[1]]
indices = [[1]]
params = [[['a0', 'b0'], ['c0', 'd0']],
[['a1', 'b1'], ['c1', 'd1']]]
output = [[['a1', 'b1'], ['c1', 'd1']]]
- 我們使用這個(gè)函數(shù)的一般是想完成這樣一個(gè)功能:T是一個(gè)二維
tensor
莽囤,我們想要根據(jù)另外一個(gè)二維tensor
value的最后一維最大元素的下標(biāo)選出tensor
T 中最后一維最大的元素,組成一個(gè)新的一維的tensor
,那么就可以首先選出最后一維度的下標(biāo)[1,2,3]
坠韩,然后將將其擴(kuò)展成[[0,1],[1,2],[2,3]]
挟秤,然后使用這個(gè)函數(shù)選擇即可蒂教。
max_indicies = tf.argmax(T, 1)
import tensorflow as tf
sess = tf.InteractiveSession()
values = tf.constant([[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0]])
T = tf.constant([[0, 1, 2 , 3],
[4, 5, 6 , 7],
[8, 9, 10, 11]])
max_indices = tf.argmax(values, axis=1)
# If T.get_shape()[0] is None, you can replace it with tf.shape(T)[0].
result = tf.gather_nd(T, tf.stack((tf.range(T.get_shape()[0],
dtype=max_indices.dtype),
max_indices),
axis=1))
print(result.eval())