tensorflow(三) 模型儲存
阿新 • • 發佈:2019-01-01
tensorflow最簡單的儲存與載入模型的方法是Saver物件(存放在tensorflow.train)。構造器給graph所有的變數,或者定義在列表中的變數,新增save和restore的操作,分別為儲存和載入。變數儲存在二進位制的檔案中,主要包含的是從變數名到tensor值的對映關係。
儲存變數
通過下面的一段程式碼穿件Saver物件來管理模型中的變數(預設情況下是所有的變數,也可以自行選擇)。
import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable (tf.random_normal([2,3]), name="v2")
init_op = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver_path = saver.save(sess, "/home/yang/data/model.ckpt")
print "Model saved in file: ", saver_path
恢復變數
用同一個Saver物件來恢復變數,注意,當你從檔案恢復變數是,不需要對它進行初始化,否則會報錯。
import tensorflow as tf
v1 = tf.Variable(tf.random_normal([1,2]), name="v1")
v2 = tf.Variable(tf.random_normal([2,3]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "/home/yang/data/model.ckpt")
print "Model restored"