tf.gather
tf.gather可以實(shí)現(xiàn)根據(jù)索引號(hào)收集數(shù)據(jù)的目的覆旱≌号螅考慮班級成績冊的例子,假設(shè)共有4個(gè)班級扣唱,每個(gè)班級35個(gè)學(xué)生藕坯,8門科目,保存成績冊的張量shape為[4,35,8]
x=tf.random.uniform([4,35,8],maxval=100,dtype=tf.int32) # 成績冊張量
現(xiàn)在需要收集第1~2個(gè)班級的成績冊噪沙,可以給定需要收集班級的索引號(hào):[0,1]炼彪,并指定班級的維度axis=0,通過tf.gather函數(shù)收集數(shù)據(jù)正歼,帶碼如下:
tf.gather(x,[0,1],axis=0) # 在班級維度收集1~2班級成績冊
實(shí)際上辐马,對于上述需求,通過切片x[:2]可以更加方便的實(shí)現(xiàn)局义。但是對于不規(guī)則的索引方式喜爷,比如,需要抽查所有班級的第1萄唇、4檩帐、9、12另萤、13湃密、27號(hào)學(xué)生的成績數(shù)據(jù),則切片方式實(shí)現(xiàn)起來非常麻煩四敞,而tf.gather則是針對于此需求設(shè)計(jì)的勾缭,使用起來更加方便,實(shí)現(xiàn)如下:
tf.gather(x, [0,3,8,11,12,26],axis=1)
如果需要收集所有同學(xué)的第3和第5門科目的成績目养,則可以指定科目維度axis=2俩由,實(shí)現(xiàn)如下:
tf.gather(x,[2,4],axis=2)
可以看到,tf.gather非常適合索引號(hào)沒有規(guī)則的場合癌蚁,其中索引號(hào)可以亂序排序幻梯,此時(shí)收集的數(shù)據(jù)也是對應(yīng)順序兜畸,例如:
a = tf.range(8)
a = tf.reshape(a, [4,2])
print(a)
print(tf.gather(a, [3,1,0,2], axis=0))
我們將問題變得稍微復(fù)雜一點(diǎn)。如果希望抽查第[2,3]班級的第[3,4,6,27]號(hào)同學(xué)的科目成績碘梢,則可以通過組合多個(gè)tf.gather實(shí)現(xiàn)咬摇。首先抽出第[2,3]班級,實(shí)現(xiàn)如下:
student = tf.gather(x, [1,2], axis=0)
再從這2個(gè)班級的同學(xué)中提取對應(yīng)學(xué)生成績煞躬,代碼如下:
tf.gather(student, [2,3,5,26],axis=1)
此時(shí)得到這2個(gè)班級4個(gè)學(xué)生的成績張量肛鹏,shape為[2,4,8]
tf.gather_nd
通過tf.gather_nd函數(shù),可以通過指定每次采樣點(diǎn)的多維坐標(biāo)實(shí)現(xiàn)采樣多個(gè)點(diǎn)的目的恩沛。抽查第2個(gè)班級的第2個(gè)同學(xué)的所有科目在扰,第3個(gè)班級的第3個(gè)同學(xué)的所有科目,第4個(gè)班級的第4個(gè)同學(xué)的所有科目雷客。那么這3個(gè)采樣點(diǎn)的索引坐標(biāo)可以記為:[1,1][2,2][3,3]芒珠,我們將這個(gè)采樣方案合并一個(gè)List參數(shù),即[[1,1][2,2][3,3]]搅裙,通過tf.gather_nd函數(shù)即可皱卓,實(shí)現(xiàn)如下:
tf.gather_nd(x, [[1,1],[2,2],[3,3]])
可以看到,結(jié)果與串行采樣方式的完全一樣部逮,實(shí)現(xiàn)更加簡潔娜汁,計(jì)算效率大大提升。