Tensorflow儲存和讀取模型
阿新 • • 發佈:2019-01-24
1.概述
將深度學習應用到工業領域實時處理資料時,我們需要訓練好的模型實時計算,那就需要儲存和讀取模型,tensorflow目前提供了這方面的初步工作。因為tensorflow只能儲存變數而不是儲存整個網路,所以在提取模型時,我們還需要重新第一網路結構。
2.程式碼演示
(1)儲存
import tensorflow as tf
import numpy as np
#儲存時dtype型別要一致,一般使用float32,另外要定義變數名
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1 ,2,3]], dtype=tf.float32, name='biases')
# 初始化所有變數
init = tf.initialize_all_variables()
# 構建儲存模型
saver = tf.train.Saver()
#啟動
with tf.Session() as sess:
sess.run(init)
#定義儲存路徑
save_path = saver.save(sess, "/Users/chunsoft/Desktop/savemodel/save_test.ckpt")
print("Save to path: ", save_path)
(2)讀取
import tensorflow as tf
import numpy as np
# 重新定義相同的變數的dtype和shape
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# 不需要初始化
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "/Users/chunsoft/Desktop/savemodel/save_test.ckpt" )
print("weights:", sess.run(W))
print("biases:", sess.run(b))
讀取和儲存還是很方便的,期待tensorflow版本更新後,能夠儲存整個網路。