tf.cond()的用法
由於tensorflow使用的是graph的計算概念,在沒有涉及控制資料流向的時候程式設計和普通程式語言的程式設計差別不大,但是涉及到控制資料流向的操作時,就要特別小心,不然很容易出錯。這也是TensorFlow比較反直覺的地方。
在TensorFlow中,tf.cond()類似於c語言中的if...else...,用來控制資料流向,但是僅僅類似而已,其中差別還是挺大的。關於tf.cond()函式的具體操作,我參考了tf的說明文件。
format:tf.cond(pred, fn1, fn2, name=None)
Return :either fn1() or fn2() based on the boolean predicate `pred`.(注意這裡,也就是說'fnq'和‘fn2’是兩個函式)
arguments:`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have the same non-zero number and type of outputs('fnq'和‘fn2’返回的是非零的且型別相同的輸出)
官方例子:
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
上面例子執行這樣的操作,如果x<y則result這個操作是tf.add(x,z),反之則是tf.square(y)。這一點上,確實很像邏輯控制中的if...else...,但是官方說明裡也提到
Since z is needed for at least one branch of the cond,branch of the cond, the tf.mul operation is always executed, unconditionally.
因為z在cond函式中的至少一個分支被用到,所以
z = tf.multiply(a, b)
總是被無條件執行,這個確實很反直覺,跟我想象中的不太一樣,按一般的邏輯不應該是不用到就不執行麼?,然後查閱官方文件,我感受到了來之官方文件深深的鄙視0.0
Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics.
翻譯過來應該是:儘管這樣的操作與TensorFlow的資料流模型一致,但是偶爾還是會令那些期望慵懶語法的使用者吃驚。(應該是這麼翻譯的吧,淦,我就那個懶人0.0)
好吧,我就大概記錄一下我自己的理解(如果錯了,歡迎拍磚)。因為TensorFlow是基於圖的計算,資料以流的形式存在,所以只要構建好了圖,有資料來源,那麼應該都會 資料流過,所以在執行tf.cond之前,兩個資料流一個是tf.add()中的x,z,一個是tf.square(y)中的y,而tf.cond()就決定了是資料流x,z從tf.add()流過,還是資料流y從tf.square()流過。這裡這個tf.cond也就像個控制水流的閥門,水流管道x,z,y在這個閥門交匯,而tf.cond決定了誰將流向後面的管道,但是不管哪一個水流流向下一個管道,在閥門作用之前,水流應該都是要到達閥門的。(囉囉嗦嗦了一大堆,還是不太理解)
栗子:
import tensorflow as tf
a=tf.constant(2)
b=tf.constant(3)
x=tf.constant(4)
y=tf.constant(5)
z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
with tf.Session() as session:
print(result.eval())