這篇文章簡(jiǎn)單的講一講如何在TensorFlow里指定修改Variable類(lèi)型張量指定坐標(biāo)位置的值过牙。
不得不吐槽TensorFlow的張量設(shè)計(jì)得蛋疼但壮,明明支持下標(biāo)和切片操作,卻只支持到一半木人,只能讀不能改麸锉。比如matrix是個(gè)二維的Variable钠绍,用matrix[x][y]
下標(biāo),或者matrix[x1:x2][y1:y2]
這樣的切片能讀取出指定位置或者范圍的值花沉,但是要是想局部更新一個(gè)張量可就沒(méi)那么容易了柳爽。想寫(xiě)matrix[x][y] = 0
?試試您就知道了0_0主穗。(說(shuō)它蛋疼是那是因?yàn)橛袑?duì)比泻拦,隔壁老李家MXNet的ndarray這么寫(xiě)就沒(méi)得問(wèn)題,溜溜的)
那莫忽媒,就只好曲線(xiàn)救國(guó)啦。首先搜StackOverflow腋粥,找到一篇回答How to update a subset of 2D tensor in Tensorflow?晦雨,大致的思路是,TensorFlow不讓你直接單獨(dú)改指定位置的值隘冲,但是留了個(gè)歪門(mén)兒闹瞧,就是tf.scatter_update
這個(gè)方法,它可以批量替換張量某一維上的所有數(shù)據(jù)展辞。
照著這個(gè)思路改改奥邮,寫(xiě)出了第一版的解決方法。提取個(gè)函數(shù)的話(huà)罗珍,長(zhǎng)成下面這個(gè)樣子:
def set_value(matrix, x, y, val):
# 提取出要更新的行
row = tf.gather(matrix, x)
# 構(gòu)造這行的新數(shù)據(jù)
new_row = tf.concat([row[:y], [val], row[y+1:]], axis=0)
# 使用 tf.scatter_update 方法進(jìn)正行替換
matrix.assign(tf.scatter_update(matrix, x, new_row))
其中matrix
是要更新的張量洽腺,x和y是目標(biāo)坐標(biāo),val是要寫(xiě)入的值覆旱。其余的代碼注釋得很清楚了蘸朋,不贅述。
問(wèn)題解決扣唱,但是這么做有沒(méi)什么缺點(diǎn)呢藕坯?有团南,那就是慢,特別是矩陣很大的時(shí)候炼彪,那是真心的慢吐根。
繼續(xù)想辦法,TensorFlow是對(duì)張量運(yùn)算(其實(shí)二維的就是矩陣運(yùn)算)有速度優(yōu)化的辐马,能不能將張量修改的操作變成一個(gè)普通的張量運(yùn)算呢佑惠?能,再構(gòu)建一個(gè)差值張量然后做個(gè)加法齐疙,哎膜楷,又是一條旁門(mén)邪道。把剛剛的函數(shù)改改贞奋,參數(shù)不變赌厅,計(jì)算過(guò)程變成這樣:
def set_value(matrix, x, y, val):
# 得到張量的寬和高,即第一維和第二維的Size
w = int(matrix.get_shape()[0])
h = int(matrix.get_shape()[1])
# 構(gòu)造一個(gè)只有目標(biāo)位置有值的稀疏矩陣轿塔,其值為目標(biāo)值于原始值的差
val_diff = val - matrix[x][y]
diff_matrix = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=[x, y], values=[val_diff], dense_shape=[w, h]))
# 用 Variable.assign_add 將兩個(gè)矩陣相加
matrix.assign_add(diff_matrix)
注意在這個(gè)方法里面我用了一個(gè)tf.SparseTensor
類(lèi)型特愿,這是一個(gè)TensorFlow里的稀疏張量(或者叫稀疏矩陣),構(gòu)造它的時(shí)候只需要指定有值位置的內(nèi)容勾缭,其余位置默認(rèn)為0揍障。這樣一方面方便了差值張量的構(gòu)造,另一方面大大的減少了內(nèi)存的消耗(別忘了我們是要修改一個(gè)很大的矩陣)俩由。
實(shí)測(cè)在我的場(chǎng)景下毒嫡,后一種方法的效率大概提升了4倍。我的場(chǎng)景是什么呢幻梯?其實(shí)是cs20si課程作業(yè)1的第3題兜畸,具體的代碼和上下文可以看Github倉(cāng)庫(kù)的這個(gè)文件。
最后碘梢,祝各位TF Boy/Girl們咬摇,Happy Hacking。