Tensorflow---Saver和restore的用法
阿新 • • 發佈:2017-05-19
restore val 打印 多個 point == 一次 path 例如
Saver的作用是將我們訓練好的模型的參數保存下來,以便下一次繼續用於訓練或測試;Restore的用法是將訓練好的參數提取出來。
1.Saver類訓練完後,是以checkpoints文件形式保存。提取的時候也是從checkpoints文件中恢復變量。Checkpoints文件是一個二進制文件,它把變量名映射到對應的tensor值 。
2.通過for循環,Saver類可以自動的生成checkpoint文件。這樣我們就可以保存多個訓練結果。例如,我們可以保存每一步訓練的結果。但是為了避免填滿整個磁盤,Saver可以自動的管理Checkpoints文件。例如,我們可以指定保存最近的N個Checkpoints文件。
應用實例:
#保存變量
import tensorflow as tf # 創建兩個變量 v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v_1") v2= tf.Variable(tf.zeros([200]), name="v_2") # 添加用於初始化變量的節點 init_op = tf.global_variables_initializer() # Create a saver. saver = tf.train.Saver(tf.global_variables()) # 運行,保存變量 with tf.Session() as sess: tf.global_variables_initializer().run()for step in range(5000): sess.run(init_op) if step % 1000 == 0:
saver.save(sess,basicpath+‘my-model‘, global_step=step) print("v1 = ", v1.eval()) print("v2 = ", v2.eval()) print_tensors_in_checkpoint_file(basicpath+"my-model-0", None, True) #通過這個方法,我們可以打印出保存了什麽變量和值。
恢復變量:
saver = tf.train.Saver() with tf.Session() as sess: #tf.global_variables_initializer().run() module_file = tf.train.latest_checkpoint(‘C:/Users/defadiannao/Desktop/saver/‘) saver.restore(sess, module_file) #print("w1:", sess.run(v3)) #print("b1:", sess.run(v4)) print("w1:", sess.run(v1)) print("b1:", sess.run(v2))
Tensorflow---Saver和restore的用法