1. 程式人生 > >tf.cond, tensorflow下的三目運算子

tf.cond, tensorflow下的三目運算子

# tf.cond(
#     pred,
#     true_fn=None,
#     false_fn=None,
#     strict=False,
#     name=None,
#     fn1=None,
#     fn2=None
# )
# tensorflow下的三目運算子

import tensorflow as tf

x = tf.constant(1.0)
y = tf.constant(2.0)
z = tf.constant(3.0)

def f1():
    return tf.Print(x, [x])

def f2():
    return tf.Print(y, [y])

op = tf.cond(x > y, true_fn=f2, false_fn=f1)
with tf.Session() as sess:
    sess.run(op)

如果pred正確,執行true_fn,否則執行false_fn。