【TensorFlow】tf.scatter_update()
在看tensorflow官網的API的時候,看到一個更新資料的函式。該函式的目的是為了能更新tensor的值,這個函式也解決了之前我想要更新tensor值的想法。在網上找了很多關於 tf.scatter_update() 的資料,但是找到的基本都是tensorflow官網上的API介紹和Stack Overflow上的提問,可見關於這個API的中文資料是相當少的,所以我打算寫下這篇部落格來介紹 tf.scatter_update()。
在這裡我簡短的介紹一下這個函式的使用:
tf.scatter_update
scatter_update(
ref,
indices,
updates,
use_locking=None,
name=None
)
在原始碼,函式的定義的位置在 tensorflow/python/ops/gen_state_ops.py.
引數介紹:
ref: 原來的tensor;
indices: 原來tensor中要更新的索引值,同樣也 tensor;
updates: 用於替代原來tensor的tensor值,注意,這個tensor和原來的tensor的shape要相同。
use_locking=None, name=None,一般情況下,我們使用預設的就好。
返回:依舊是一個tensor,shape和原來的tensor相同,是按照indices更新過tensor值的tensor;
介紹完了這個函式,那麼我來舉一個示例來讓大家明白怎麼去用這個函式。
程式碼如下:
輸出:import tensorflow as tf g = tf.Graph() with g.as_default(): a = tf.Variable(initial_value=[[0, 0, 0, 0],[0, 0, 0, 0]]) b = tf.scatter_update(a, [0, 1], [[1, 1, 0, 0], [1, 0, 4, 0]]) with tf.Session(graph=g) as sess: sess.run(tf.global_variables_initializer()) print(sess.run(a)) print(sess.run(b))
[[0 0 0 0]
[0 0 0 0]]
[[1 1 0 0]
[1 0 4 0]]
我們能看到原來的tensor是
[[0 0 0 0]
[0 0 0 0]]
更新tensor值後的tensor是
[[1 1 0 0]
[1 0 4 0]]
總結:1、對於tf.scatter_update()來說,ref和updates的shape一定要相同,要不然會報錯;
2、indices也是一個tensor,我們需要更新哪一維就寫哪一維;
3、這樣的方式適合更新整個tensor的值,特別適合批量化更新tensor;