1. 程式人生 > >TensorFLow collection的用法

TensorFLow collection的用法

tf.add_to_collection: 把變數放入一個集合當中
tf.get_collection: 把集合中的全部變數取出,是一個列表
tf.add_n: 把列表中的變數加起來

例子:
import tensorflow as tf
tf.reset_default_graph()

w1 = tf.get_variable('w1', shape=[4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.1))
w2 = tf.get_variable('w2', shape=
[4], dtype=tf.float32, initializer=tf.constant_initializer(0.1)) tf.add_to_collection('w', w1) tf.add_to_collection('w', w2) get_w = tf.get_collection('w') add_w = tf.add_n(get_w) init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) print("生成資料的樣結果:") print
(sess.run(w1)) print(sess.run(w2)) print("放入集合w資料的結果:") print(sess.run(get_w)) print("集合w中的資料相加的記過") print(sess.run(add_w))
執行結果如下:
生成資料的樣結果:
[ 0.00524002  0.0679507  -0.0160088  -0.04484038]
[0.1 0.1 0.1 0.1]
放入集合w資料的結果:
[array([ 0.00524002,  0.0679507 , -0.0160088 , -0.04484038], dtype=float32), array([0.1
, 0.1, 0.1, 0.1], dtype=float32)] 集合w中的資料相加的結果: [0.10524002 0.16795069 0.0839912 0.05515962]