tf.cond 與 tf.control_dependencies 的控制問題
問題引入
在搜尋tf.cond
的使用方法時,找到了這樣的一個問題:
執行下面的一段tensorflow程式碼:
pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
with tf.control_dependencies([assign_x_2]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval())
從程式碼上看,tf.cond
經過判斷pred
的值對x
進行更新。但實際上無論在pred = Ture 還是 False,輸出的結果都是2,都是pred = tf.constant(True)
的情況。
這是怎麼回事呢?
順序執行
先不進行解釋,有人在回覆中給出了一個可以正確執行的程式碼,看一下有什麼區別:
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
with tf.control_dependencies([tf.assign(x, [2])]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval(feed_dict={pred: False})) # ==> [1]
print(y.eval(feed_dict={pred: True })) # ==> [2]
區別也不大,只是把assign_x_2 = tf.assign(x, [2])
這句整體移動到了tf.control_dependencies([tf.assign(x, [2])])
的內部。
給出的解釋是:
如果要讓
tf.cond()
在其中一個分支中執行命令(如分配),你必須在你要傳遞給的函式建立執行副命令的操作。
If you want to perform a side effect (like an assignment) in one of the branches, you must create the op that performs the side effect inside the function that you pass to .
因為在TensorFlow圖中的執行是依次向前流過圖形的,所以在任一分支中引用的所有操作必須在條件進行求值之前執行。這意味著true和false分支都接受對tf.assign()
op 的控制依賴。
Because execution in a TensorFlow graph flows forward through the graph, all operations that you refer to in either branch must execute before the conditional is evaluated. This means that both the true and the false branches receive a control dependency on thetf.assign()
op.
翻譯的可能不夠準確,大意就是assign_x_2 = tf.assign(x, [2])
這句話在tf.cond
已經執行過了,因此無論執行update_x_2
(讓x=2)或lambda: tf.identity(x)
(保持x不變),得到的結果都是x=2
。
這麼來看其實是一個很簡單的問題,定義時不僅定義了模型,也隱含著定義了執行順序。
tf.control_dependencies()
這個函式加不加看起來沒有什麼區別,比如:
import tensorflow as tf
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
# x_2 = tf.assign(x, [2])
def update_x_2():
# with tf.control_dependencies([x_2]): #[tf.assign(x, [2])]):
return tf.assign(x, [2])
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.global_variables_initializer())
print(y.eval(feed_dict={pred: False})) # ==> [1]
print(y.eval(feed_dict={pred: True})) # ==> [2]
去掉之後執行結果和正確的相同。具體作用還是看一下官網咖……
直接搜tf.control_dependencies
得到的資訊並不多:
Wrapper for Graph.control_dependencies() using the default graph.
Seetf.Graph.control_dependencies
for more details.
在tf.Graph.control_dependencies
這裡確實講得很詳細,其作用簡單來說就是控制計算順序。
with g.control_dependencies([a, b, c]):
# `d` and `e` will only run after `a`, `b`, and `c` have executed.
d = ...
e = ...
有了這句話,with
中的語句就會在control_dependencies()
中的操作執行之後執行,並且也支援巢狀操作。在給出的錯誤例子中,很像開頭提出的問題:
# WRONG
def my_func(pred, tensor):
t = tf.matmul(tensor, tensor)
with tf.control_dependencies([pred]):
# The matmul op is created outside the context, so no control
# dependency will be added.
return t
# RIGHT
def my_func(pred, tensor):
with tf.control_dependencies([pred]):
# The matmul op is created in the context, so a control dependency
# will be added.
return tf.matmul(tensor, tensor)
上面t
操作在tf.control_dependencies
之前已經被執行了,因此就無法控制t
的先後順序。如果我們把my_func
看作是tf.cond
中的分支操作函式,那麼很可能在pred
更新之前就已經進行了操作,因此可能造成一些錯誤。
總結
這麼一看,好像我自己寫的沒有注意這麼多細節,但目前從結果上看好像還都沒什麼問題,或許需要重新改寫一下。