tensorflow學習筆記--三(Variables: 建立,初始化,儲存,和恢復)
tensorflow學習筆記--三
Variables: 建立,初始化,儲存,和恢復
TensorFlow Variables 是記憶體中的容納 tensor 的快取。這一小節介紹了用它們在模型訓練時(during training)建立、儲存和更新模型引數(model parameters) 的方法。
當訓練模型時,用變數來儲存和更新引數。變數包含張量 (Tensor)存放於記憶體的快取區。建模時它們需要被明確地初始化,模型訓練後它們必須被儲存到磁碟。這些變數的值可在之後模型訓練和分析是被載入。
本文件描述以下兩個TensorFlow類。
建立
當建立一個變數時,你將一個張量
作為初始值傳入建構函式Variable()
注意,所有這些操作符都需要你指定張量的shape。那個形狀自動成為變數的shape。變數的shape通常是固定的,但TensorFlow提供了高階的機制來重新調整其行列數。
# Create two variables. weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights") biases = tf.Variable(tf.zeros([200]), name="biases")
呼叫tf.Variable()
新增一些操作(Op, operation)到graph:
- 一個
Variable
操作存放變數的值。 - 一個初始化op將變數設定為初始值。這事實上是一個
tf.assign
操作. - 初始值的操作,例如示例中對
biases
變數的zeros
操作也被加入了graph。
tf.Variable
的返回值是Python的tf.Variable
類的一個例項。
初始化
變數的初始化必須在模型的其它操作執行之前先明確地完成。最簡單的方法就是新增一個給所有變數初始化的操作,並在使用模型之前首先執行那個操作。
你或者可以從檢查點檔案中重新獲取變數值,詳見下文。
使用tf.initialize_all_variables()
# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Later, when launching the model
with tf.Session() as sess:
# Run the init operation.
sess.run(init_op)
...
# Use the model
...
由另一個變數初始化
你有時候會需要用另一個變數的初始化值給當前變數初始化。由於tf.initialize_all_variables()
是並行地初始化所有變數,所以在有這種需求的情況下需要小心。
用其它變數的值初始化一個新的變數時,使用其它變數的initialized_value()
屬性。你可以直接把已初始化的值作為新變數的初始值,或者把它當做tensor計算得到一個值賦予新變數。
# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")
自定義初始化
tf.initialize_all_variables()
函式便捷地新增一個op來初始化模型的所有變數。你也可以給它傳入一組變數進行初始化。最簡單的儲存和恢復模型的方法是使用tf.train.Saver
物件。構造器給graph的所有變數,或是定義在列表裡的變數,新增save
和restore
ops。saver物件提供了方法來執行這些ops,定義檢查點檔案的讀寫路徑。
儲存和載入
最簡單的儲存和恢復模型的方法是使用tf.train.Saver
物件。構造器給graph的所有變數,或是定義在列表裡的變數,新增save
和restore
ops。saver物件提供了方法來執行這些ops,定義檢查點檔案的讀寫路徑。
檢查點檔案
變數儲存在二進位制檔案裡,主要包含從變數名到tensor值的對映關係。
當你建立一個Saver
物件時,你可以選擇性地為檢查點檔案中的變數挑選變數名。預設情況下,將每個變數Variable.name
屬性的值。
儲存變數
用tf.train.Saver()
建立一個Saver
來管理模型中的所有變數。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print "Model saved in file: ", save_path
恢復變數
用同一個Saver
物件來恢復變數。注意,當你從檔案中恢復變數時,不需要事先對它們做初始化。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print "Model restored."
# Do some work with the model
...
選擇儲存和恢復哪些變數
如果你不給tf.train.Saver()
傳入任何引數,那麼saver將處理graph中的所有變數。其中每一個變數都以變數建立時傳入的名稱被儲存。
有時候在檢查點檔案中明確定義變數的名稱很有用。舉個例子,你也許已經訓練得到了一個模型,其中有個變數命名為"weights"
,你想把它的值恢復到一個新的變數"params"
中。
有時候僅儲存和恢復模型的一部分變數很有用。再舉個例子,你也許訓練得到了一個5層神經網路,現在想訓練一個6層的新模型,可以將之前5層模型的引數匯入到新模型的前5層中。
你可以通過給tf.train.Saver()
建構函式傳入Python字典,很容易地定義需要保持的變數及對應名稱:鍵對應使用的名稱,值對應被管理的變數。
注意:
- 如果需要儲存和恢復模型變數的不同子集,可以建立任意多個saver物件。同一個變數可被列入多個saver物件中,只有當saver的
restore()
函式被執行時,它的值才會發生改變。 - 如果你僅在session開始時恢復模型變數的一個子集,你需要對剩下的變數執行初始化op。詳情請見
tf.initialize_variables()
。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...
原文連結: http://tensorflow.org/how_tos/variables/index.html