TensorFLow collection的用法
阿新 • • 發佈:2018-11-07
tf.add_to_collection: 把變數放入一個集合當中
tf.get_collection: 把集合中的全部變數取出,是一個列表
tf.add_n: 把列表中的變數加起來
例子:
import tensorflow as tf
tf.reset_default_graph()
w1 = tf.get_variable('w1', shape=[4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1))
w2 = tf.get_variable('w2', shape= [4], dtype=tf.float32, initializer=tf.constant_initializer(0.1))
tf.add_to_collection('w', w1)
tf.add_to_collection('w', w2)
get_w = tf.get_collection('w')
add_w = tf.add_n(get_w)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print("生成資料的樣結果:")
print (sess.run(w1))
print(sess.run(w2))
print("放入集合w資料的結果:")
print(sess.run(get_w))
print("集合w中的資料相加的記過")
print(sess.run(add_w))
執行結果如下:
生成資料的樣結果:
[ 0.00524002 0.0679507 -0.0160088 -0.04484038]
[0.1 0.1 0.1 0.1]
放入集合w資料的結果:
[array([ 0.00524002, 0.0679507 , -0.0160088 , -0.04484038], dtype=float32), array([0.1 , 0.1, 0.1, 0.1], dtype=float32)]
集合w中的資料相加的結果:
[0.10524002 0.16795069 0.0839912 0.05515962]