Tensorflow: 動態的給變數tf.Variable賦值【tf.assign】
阿新 • • 發佈:2018-12-25
Motivation
錯誤:
tensorflow不能直接給Variable賦值,比如:
embedding_var = tf.Variable(1)
test_var = 10
embedding_var = test_var
或者:
embedding_var = tf.Variable(1)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)
x.assign(1)
解決方法
正確:
如果只需要給Variable賦值一次,可以通過assign這樣進行賦值:
import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(x)
print sess.run(y)
print sess.run(x)
但是通常賦一次值的意義不大,因為有時我們想將網路中的一些輸出通過saver()儲存下來,或者通過tensorboard檢視網路的embedding投影,那麼就需要將網路中產生的輸出以變數的形式儲存,這樣就可以在saver.save()的時候將這些輸出給儲存到本地,又因為tensorflow不能在圖外面直接對變數進行操作,所以我通過用一個佔位符來傳輸網路的輸出結果,然後再session裡面取出網路的輸出值,feed給該佔位符,然後將佔位符的值賦給一個臨時變數作為儲存,如下,親測有效:
flat_value = np.zeros([200,4*4*32]) mid_vari = tf.placeholder(tf.float32, [200,4*4*32],name="mid_vari") embedding_var = tf.Variable(tf.zeros([200,4*4*32]), name=NAME_TO_VISUALISE_VARIABLE) mid_vari_2 = tf.assign(embedding_var,mid_vari) with tf.Session() as sess: saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) for i in range(200): flat_value,_=sess.run([flat,mid_vari_2],feed_dict={x:one_x,y:labels,mid_vari:flat_value})
比較周折,不過也是試了很多辦法才找到的解決方案T_T。