1. 程式人生 > >tf.group()用於組合多個操作

tf.group()用於組合多個操作

tf.group()用於創造一個操作,可以將傳入引數的所有操作進行分組。API手冊如:

tf.group(
    *inputs,
    **kwargs
)

ops = tf.group(tensor1, tensor2,...) 
其中*inputs是0個或者多個用於組合tensor,一旦ops完成了,那麼傳入的tensor1,tensor2,...等等都會完成了,經常用於組合一些訓練節點,如在Cycle GAN中的多個訓練節點,例子如:

generator_train_op = tf.train.AdamOptimizer(g_loss, ...)
discriminator_train_op = tf.train.AdamOptimizer(d_loss,...)
train_ops = tf.groups(generator_train_op ,discriminator_train_op)

with tf.Session() as sess:
  sess.run(train_ops) 
  # 一旦運行了train_ops,那麼裡面的generator_train_op和discriminator_train_op都將被呼叫

注意的是,tf.group()返回的是個操作,而不是值,如果你想下面一樣用,返回的將不是值

a = tf.Variable([5])
b = tf.Variable([6])
c = a+b
d = a*b
e = a/b
ops = tf.group(c,d,e)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ee = sess.run(ops)

返回的將不是c,d,e的運算結果,而是一個None,就是因為這個是一個操作,而不是一個張量。如果需要返回結果,請參考tf.tuple()
--------------------- 
作者:FesianXu 
來源:CSDN 
原文:https://blog.csdn.net/LoseInVain/article/details/81703786 
版權宣告:本文為博主原創文章,轉載請附上博文連結!