TensorFlow筆記——(1)理解tf.control_dependencies與control_flow_ops.with_dependencies
阿新 • • 發佈:2019-01-31
引言
我們在實現神經網路的時候經常會看到tf.control_dependencies的使用,但是這個函式究竟是什麼作用,我們應該在什麼情況下使用呢?今天我們就來一探究竟。
理解
其實從字面上看,control_dependencies 是控制依賴的意思,我們可以大致推測出來,這個函式應該使用來控制就算圖節點之間的依賴的。其實正是如此,tf.control_dependencies()是用來控制計算流圖的,給圖中的某些節點指定計算的順序。
原型分析
tf.control_dependencies(self, control_inputs)
arguments:control_inputs: A list of `Operation` or `Tensor` objects
which must be executed or computed before running the operations
defined in the context. (注意這裡control_inputs是list)
return: A context manager that specifies control dependencies
for all operations constructed within the context.
通過以上的解釋,我們可以知道,該函式接受的引數control_inputs,是Operation或者Tensor構成的list。返回的是一個上下文管理器,該上下文管理器用來控制在該上下文中的操作的依賴。也就是說,上下文管理器下定義的操作是依賴control_inputs中的操作的,control_dependencies用來控制control_inputs中操作執行後,才執行上下文管理器中定義的操作。
例子1
如果我們想要確保獲取更新後的引數,name我們可以這樣組織我們的程式碼。
opt = tf.train.Optimizer().minize(loss)
with tf.control_dependencies([opt]): #先執行opt
updated_weight = tf.identity(weight) #再執行該操作
with tf.Session() as sess:
tf.global_variables_initializer().run()
sess.run(updated_weight, feed_dict={...}) # 這樣每次得到的都是更新後的weight
control_flow_ops.with_dependencies
除了常用tf.control_dependencies()我們還會看到,control_flow_ops.with_dependencies(),其實連個函式都可以實現依賴的控制,只是實現的方式不太一樣。
with_dependencies(dependencies, output_tensor, name=None)
Produces the content of `output_tensor` only after `dependencies`.
所有的依賴操作完成後,計算output_tensor並返回
In some cases, a user may want the output of an operation to be
consumed externally only after some other dependencies have run
first. This function ensures returns `output_tensor`, but only after all
operations in `dependencies` have run. Note that this means that there is
no guarantee that `output_tensor` will be evaluated after any `dependencies`
have run.
See also @{tf.tuple$tuple} and @{tf.group$group}.
Args:
dependencies: Iterable of operations to run before this op finishes.
output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
name: (Optional) A name for this operation.
Returns:
Same as `output_tensor`.
Raises:
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
例子2
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) #從一個集合中取出變數,返回的是一個列表
......
total_loss, clones_gradients = model_deploy.optimize_clones(
clones,
optimizer,
var_list=variables_to_train)
......
# tf.group()將多個tensor或者op合在一起,然後進行run,返回的是一個op
update_op = tf.group(*update_ops)
train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
name='train_op')