tf.group()用於組合多個操作
tf.group()用於創造一個操作,可以將傳入引數的所有操作進行分組。API手冊如:
tf.group(
*inputs,
**kwargs
)
ops = tf.group(tensor1, tensor2,...)
其中*inputs是0個或者多個用於組合tensor,一旦ops完成了,那麼傳入的tensor1,tensor2,...等等都會完成了,經常用於組合一些訓練節點,如在Cycle GAN中的多個訓練節點,例子如:
generator_train_op = tf.train.AdamOptimizer(g_loss, ...)
discriminator_train_op = tf.train.AdamOptimizer(d_loss,...)
train_ops = tf.groups(generator_train_op ,discriminator_train_op)
with tf.Session() as sess:
sess.run(train_ops)
# 一旦運行了train_ops,那麼裡面的generator_train_op和discriminator_train_op都將被呼叫
注意的是,tf.group()返回的是個操作,而不是值,如果你想下面一樣用,返回的將不是值
a = tf.Variable([5])
b = tf.Variable([6])
c = a+b
d = a*b
e = a/b
ops = tf.group(c,d,e)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ee = sess.run(ops)
返回的將不是c,d,e的運算結果,而是一個None,就是因為這個是一個操作,而不是一個張量。如果需要返回結果,請參考tf.tuple()
---------------------
作者:FesianXu
來源:CSDN
原文:https://blog.csdn.net/LoseInVain/article/details/81703786
版權宣告:本文為博主原創文章,轉載請附上博文連結!