tf常用集合及其獲取方式
集合
tensorflow用集合colletion
組織不同類別的物件。tf.GraphKeys
中包含了所有預設集合的名稱。
collection
提供了一種“零存整取”的思路:在任意位置,任意層次都可以創造物件,存入相應collection
中;創造完成後,統一從一個collection
中取出一類變數,施加相應操作。
例如,
tf.Optimizer
只優化tf.GraphKeys.TRAINABLE_VARIABLES
中的變數。
本文介紹幾個常用集合
- Variable
集合:模型引數
- Summary
集合:監測
- 自定義集合
Variable
被收集在名為tf.GraphKeys.VARIABLES
colletion
中
定義
Tensorflow使用Variable
類表達、更新、儲存模型引數。
Variable
是在可變更的,具有保持性的記憶體控制代碼,儲存著Tensor
。必須使用Tensor
進行初始化。
k = tf.Variable(tf.random_normal([]), name='k')
建立的Variable
被新增到預設的collection
中。
初始化
在整個session
執行之前,圖中的全部Variable
必須被初始化。
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
在執行完初始化之後,Variable
特別強調:Variable
的值在sess.run(init)之後就確定了;Tensor
的值要在sess.run(x)之後才確定。
獲取
和Tensor
, Operation
一樣,Variable
也是全域性的。
可以通過tf.all_variables()檢視所有tf.GraphKeys.VARIABLES
中的物件:
# example for y = k*x
x = tf.constant(1.0, shape=[]) # 0D tensor
k = tf.Variable(tf.constant(0.5, shape=[]) )
y = tf.mul(x, k)
v = tf.all _variables()
也可以用通用方法直接訪問collection
:
v = tf.get_collection(tf.GraphKeys.VARIABLES)
各類Variable
另外,tensorflow還維護另外幾個collection
:
函式 | 集合名 | 意義 |
---|---|---|
tf.all_variables() | VARIABLES | 儲存和讀取checkpoints時,使用其中所有變數 |
tf.trainable_variables() | TRAINABLE_VARIABLES | 訓練時,更新其中所有變數 |
tf.moving_average_variables() | MOVING_AVERAGE_VARIABLES | ExponentialMovingAverage 物件會生成此類變數 |
tf.local_variables() | LOCAL_VARIABLES | 在all_variables() 之外,需要用tf.init_local_variables()初始化 |
tf.model_variables() | MODEL_VARIABLES |
Summary
被收集在名為tf.GraphKeys.SUMMARIES
的colletion
中
定義
Summary
是對網路中Tensor
取值進行監測的一種Operation
。這些操作在圖中是“外圍”操作,不影響資料流本身。
用例
我們模仿常見的訓練過程,建立一個最簡單的用例。
# 迭代的計數器
global_step = tf.Variable(0, trainable=False)
# 迭代的+1操作
increment_op = tf.assign_add(global_step, tf.constant(1))
# 例項應用中,+1操作往往在`tf.train.Optimizer.apply_gradients`內部完成。
# 建立一個根據計數器衰減的Tensor
lr = tf.train.exponential_decay(0.1, global_step, decay_steps=1, decay_rate=0.9, staircase=False)
# 把Tensor新增到觀測中
tf.scalar_summary('learning_rate', lr)
# 並獲取所有監測的操作`sum_opts`
sum_ops = tf.merge_all_summaries()
# 初始化sess
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init) # 在這裡global_step被賦初值
# 指定監測結果輸出目錄
summary_writer = tf.train.SummaryWriter('/tmp/log/', sess.graph)
# 啟動迭代
for step in range(0, 10):
s_val = sess.run(sum_ops) # 獲取serialized監測結果:bytes型別的字串
summary_writer.add_summary(s_val, global_step=step) # 寫入檔案
sess.run(increment_op) # 計數器+1
呼叫tf.scalar_summary系列函式時,就會向預設的collection
中新增一個Operation
。
再次回顧“零存整取”原則:建立網路的各個層次都可以新增監測;在新增完所有監測,初始化sess之前,統一用tf.merge_all_summaries獲取。
檢視
SummaryWriter檔案中儲存的是序列化的結果,需要藉助TensorBoard才能檢視。
在命令列中執行tensorboard,傳入儲存SummaryWriter檔案的目錄:
tensorboard --logdir /tmp/log
完成後會提示:
You can navigate to http://127.0.1.1:6006
可以直接使用伺服器本地瀏覽器訪問這個地址(本機6006埠),或者使用遠端瀏覽器訪問伺服器ip地址的6006埠。
自定義
除了預設的集合,我們也可以自己創造collection
組織物件。網路損失就是一類適宜物件。
tensorflow中的Loss提供了許多建立損失Tensor
的方式。
x1 = tf.constant(1.0)
l1 = tf.nn.l2_loss(x1)
x2 = tf.constant([2.5, -0.3])
l2 = tf.nn.l2_loss(x2)
建立損失不會自動新增到集合中,需要手工指定一個collection
:
tf.add_to_collection("losses", l1)
tf.add_to_collection("losses", l2)
建立完成後,可以統一獲取所有損失,losses
是個Tensor
型別的list:
losses = tf.get_collection('losses')
另一種常見操作把所有損失累加起來得到一個Tensor
:
loss_total = tf.add_n(losses)
執行操作可以得到損失取值:
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
losses_val = sess.run(losses)
loss_total_val = sess.run(loss_total)
實際上,如果使用TF-Slim包的losses系列函式建立損失,會自動新增到名為”losses”的collection
中。