TensorFlow:儲存/恢復和混合多個模型
目錄
在學習這篇部落格之前,我希望你已經掌握了Tensorflow基本的操作。如果沒有,你可以閱讀這篇入門文章。
為什麼要學習模型的儲存和恢復呢?因為這對於避免資料的混亂無序是至關重要的,特別是在你程式碼中的不同圖。
如何儲存和載入模型
saver類
在不同的會話中,當需要將資料在硬碟上面進行儲存時,那麼我們就可以使用Saver這個類。這個Saver構造類允許你去控制3個目標:
- 目標(The target):這個引數設定目標。在分散式架構的情況下,我們可以指定要計算哪個TF伺服器或者“目標”。
-
圖(The graph):
- 配置(The config):這個引數設定配置。你可以使用 ConfigProto 引數來進行配置Tensorflow。點選這裡,檢視更多資訊。
Saver類可以處理你的圖中元資料和變數資料的儲存和恢復。而我們唯一需要做的是,告訴Saver類我們需要儲存哪個圖和哪些變數。
在預設情況下,Saver類能處理預設圖中包含的所有變數。但是,你也可以去建立很多的Saver類,去儲存你想要的任何子圖。
import tensorflow as tf
# First, you design your mathematical operations
# We are the default graph scope
# Let's design a variable
v1 = tf.Variable(1. , name="v1")
v2 = tf.Variable(2. , name="v2")
# Let's design an operation
a = tf.add(v1, v2)
# Let's create a Saver object
# By default, the Saver handles every Variables related to the default graph
all_saver = tf.train.Saver()
# But you can precise which vars you want to save under which name
v2_saver = tf.train.Saver({"v2": v2})
# By default the Session handles the default graph and all its included variables
with tf.Session() as sess:
# Init v and v2
sess.run(tf.global_variables_initializer())
# Now v1 holds the value 1.0 and v2 holds the value 2.0
# We can now save all those values
all_saver.save(sess, 'data.chkp')
# or saves only v2
v2_saver.save(sess, 'data-v2.chkp')
當你運行了上面的程式之後,如果你去看資料夾,那麼你會發現資料夾中存在了七個檔案(如下)。在接下來的部落格中,我會詳細解釋這些檔案的意義。目前你只需要知道,模型的權重是儲存在 .chkp
檔案中,模型的圖是儲存在 .chkp.meta
檔案中。
├── checkpoint
├── data-v2.chkp.data-00000-of-00001
├── data-v2.chkp.index
├── data-v2.chkp.meta
├── data.chkp.data-00000-of-00001
├── data.chkp.index
├── data.chkp.meta
恢復操作和其它元資料
我想分享的最後一個資訊是,Saver將儲存與圖有關聯的任何元資料。這就意味著,當我們恢復一個模型的時候,我們還同時恢復了所有與圖相關的變數、操作和集合。
當我們恢復一個元模型(restore a meta checkpoint)時,實際上我們執行的操作是將恢復的圖載入到當前的預設圖中。所有當你完成模型恢復之後,你可以在預設圖中訪問載入的任何內容,比如一個張量,一個操作或者集合。
import tensorflow as tf
# Let's laod a previous meta graph in the current graph in use: usually the default graph
# This actions returns a Saver
saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')
# We can now access the default graph where all our metadata has been loaded
graph = tf.get_default_graph()
# Finally we can retrieve tensors, operations, etc.
global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')
train_op = graph.get_operation_by_name('loss/train_op')
hyperparameters = tf.get_collection('hyperparameters')
恢復權重
請記住,在實際的環境中,真實的權重只能存在於一個會話中。也就是說,restore
這個操作必須在一個會話中啟動,然後將資料權重匯入到圖中。理解恢復操作的最好方法是將它簡單的看做是一種資料初始化操作。
with tf.Session() as sess:
# To initialize values with saved data
saver.restore(sess, 'results/model.ckpt-1000-00000-of-00001')
print(sess.run(global_step_tensor)) # returns 1000
在新圖中匯入預訓練模型
至此,你應該已經明白瞭如何去儲存和恢復一個模型。然而,我們還可以使用一些技巧去幫助你更快的儲存和恢復一個模型。比如:
- 一個圖的輸出能成為另一個圖的輸入嗎?
答案是確定的。但是目前我的做法是先將第一個圖進行儲存,然後在另一個圖中進行恢復。但是這種方案感覺很笨重,我不知道是否有更好的方法。
但是這種方法確實能工作,除非你想要去重新訓練第一個圖。在這種情況下,你需要將輸入的梯度重新輸入到第一張圖中的特定的訓練步驟中。我想你已經被這種複雜的方案給逼瘋了把。:-)
- 我可以在一個圖中混合不同的圖嗎?
答案當然是肯定的,但是你必須非常小心名稱空間。這種方法有一點好處是,簡化了一切。比如,你可以預載入一個VGG-19模型。然後訪問圖中的任何節點,並執行你自己的後續操作,從而訓練一整個完整的模型。
如果你只想微調你自己的節點,那麼你可以在你想要的地方中斷梯度。
import tensorflow as tf
# Load the VGG-16 model in the default graph
vgg_saver = tf.train.import_meta_graph(dir + '/vgg/results/vgg-16.meta')
# Access the graph
vgg_graph = tf.get_default_graph()
# Retrieve VGG inputs
self.x_plh = vgg_graph.get_tensor_by_name('input:0')
# Choose which node you want to connect your own graph
output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')
# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')
# output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')
# output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')
# Stop the gradient for fine-tuning
output_conv_sg = tf.stop_gradient(output_conv) # It's an identity function
# Build further operations
output_conv_shape = output_conv_sg.get_shape().as_list()
W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))
b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))
z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1
a = tf.nn.relu(z1)
References: