來自官網(wǎng)
tf.nn.embedding_lookup_sparse(
params,
sp_ids,
sp_weights,
partition_strategy='mod',
name=None,
combiner=None,
max_norm=None
)
-
params
embedding使用的lookup table. -
sp_ids
查找lookup table的SparseTensor. -
combiner
通過什么運算把一行的數(shù)據(jù)結(jié)合起來mean
,sum
等. - 其它沒用到過
例子
首先定義embedding的矩陣
import numpy as np
import tensorflow as tf
### embedding matrix
example = np.arange(24).reshape(6, 4).astype(np.float32)
embedding = tf.Variable(example)
其實這個矩陣就是
#------------------------------------------------------#
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]], dtype=float32)
#------------------------------------------------------#
接下來使用tf.SparseTensor
來定義一個稀疏矩陣
### embedding lookup SparseTensor
idx = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 1], [1, 2], [2, 0]],
values=[0, 1, 2, 3, 0], dense_shape=[3, 3])
# 這個稀疏矩陣寫成普通形式這樣
#---------------------------------------------------------------------#
array([[0, 1, None],
[None, 2, 3],
[0, None, None]]) # 為了與0元素相區(qū)別勃刨,沒有填充的部分寫成了None
#---------------------------------------------------------------------#
使用查找表妥色,打印出結(jié)果
embed = tf.nn.embedding_lookup_sparse(embedding, idx, None, combiner='sum')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(embed))
#----------------結(jié)果----------------------#
[[ 4. 6. 8. 10.]
[20. 22. 24. 26.]
[ 0. 1. 2. 3.]]
#-------------------------------------------#
現(xiàn)在分析一下結(jié)果,結(jié)果的shape
=(idx.shape[0], embedding.shape[1])
雪猪,其中結(jié)果的第一行等于“embeding的第一行加上embedding的第二行慈俯,也就是idx的第一行非None的元素的value渤刃,對應(yīng)了embedding的行數(shù),然后這些行相加“贴膘;結(jié)果第二行為”embedding第3行和第四行相加“卖子;結(jié)果第三行也同理。