Tensorflow模型引數的Saver儲存讀取
阿新 • • 發佈:2018-12-19
一、Saver儲存
import tensorflow as tf import numpy as np #定義W和b W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight') b = tf.Variable([1,2,3],dtype = tf.float32,name = 'biases') #注:初始化變數Variable init = tf.global_variables_initializer() #建立tf.train.Saver() 來儲存, 提取變數。 #建立my_net資料夾,儲存變數 saver = tf.train.Saver() sess = tf.Session() sess.run(init) #儲存變數到路徑my_net save_path = saver.save(sess,"my_net/save_net.ckpt")#儲存格式為ckpt #輸出儲存的變數 print("save path:",save_path)
結果:
二、Saver讀取
import tensorflow as tf import numpy as np #建立W,b的空容器 W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights") b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases") #不需要初始化變數 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "my_net/save_net.ckpt") print("weights:", sess.run(W)) print("biases:", sess.run(b))