tensorflow之模型集引數的儲存
最近在研究移動端的時候涉及到了模型的冰凍,即將訓練得到的模型生成對應的pb檔案。因此研究了一下tensorflow中的幾種儲存模型的方式,具體如下:
一、save儲存。
save儲存一定要在session中進行,並且save儲存時會儲存所有的引數資訊,而這些資訊是我們不一定需要的。並且save儲存一般儲存的是引數,而不儲存網路結構。對於一個訓練好的網路模型來說,我們儲存引數和儲存結構是同樣重要的。所以save儲存一般並不常用
save的儲存
tf.global_variables_initializer().run() # 初始化所有變數 saver=tf.train.Saver() # 引數為空,預設儲存所有變數 saver=tf.train.Saver([w,b]) # 儲存部分變數 saver.save(sess,logdir+'model.ckpt')
save的使用:
tf.global_variables_initializer().run() # 初始化所有變數 # 驗證之前是否已經儲存了檢查點檔案 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: try: saver = tf.train.Saver() # 引數為空,預設儲存所有變數,這裡只有變數w1、b1 saver.restore(sess, ckpt.model_checkpoint_path) saver=None except: saver = tf.train.Saver([w1,b1]) # 引數為空,預設儲存所有變數,這裡只有變數w1、b1 saver.restore(sess, ckpt.model_checkpoint_path) saver = None
注意在使用前要進行初始化。主要是對沒有儲存的變數賦值。另外此方法可以對前幾層的變數不變,最後一層變數賦初始值從新訓練。主要在借鑑前人的模型的時候可以使用。
二、write_grape:該方法主要是將圖儲存起來,儲存結果只含圖不含其它任何資料。以後遇到再補充。
三、convert_variabe_to_constants:該方法將圖和資料一起儲存。在儲存時會將圖中的變數取值以常量的形式儲存。在儲存模型時只儲存了GraphDef部分,GraphDef儲存了從輸入層到輸出層的計算過程。在儲存時通過convert_variable_to_constants函式來指定儲存的節點名稱。具體使用方法為:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,outpit_node_name=['name']])
with tf.gfile.FastGFile('path/name.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
sess.close()
呼叫方法為:
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(filename, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# 模型執行
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
匯入之後:首先通過下面的語句來進行張量的名稱。然後將張量的名稱傳給具體的節點。注意後面必須是0
input_x = sess.graph.get_tensor_by_name("input:0")
keep_prob = sess.graph.get_tensor_by_name("keep_prob:0")
最後連結幾個關於儲存的介紹
https://blog.csdn.net/c2a2o2/article/details/72778628
https://blog.csdn.net/sinat_29957455/article/details/78511119
https://blog.csdn.net/wc781708249/article/details/78039029