TensorFlow儲存和載入訓練模型
阿新 • • 發佈:2018-12-06
儲存:使用saver.save()方法儲存
載入:使用saver.restore()方法載入
下面是個完整例子:
儲存:
import tensorflow as tf W = tf.Variable([[1, 1, 1], [2, 2, 2]], dtype=tf.float32, name='w') b = tf.Variable([[0, 1, 2]], dtype=tf.float32, name='b') saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) save_path = saver.save(sess, r"D:\test\wb") # 將W、b儲存到指定位置
載入:
import tensorflow as tf W = tf.Variable(tf.truncated_normal(shape=(2, 3)), dtype=tf.float32, name='w') b = tf.Variable(tf.truncated_normal(shape=(1, 3)), dtype=tf.float32, name='b') saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, r"D:\test\wb") # 從指定位置載入模型 print(sess.run(W)) print(sess.run(b)) """ 輸出: [[1. 1. 1.] [2. 2. 2.]] [[0. 1. 2.]] """
就算W和b定義了不同於模型的值,但是仍會輸出載入模型的值,如:
import tensorflow as tf W = tf.Variable([[0,0,0],[0,0,0]],dtype = tf.float32,name='w') b = tf.Variable([[0,0,0]],dtype = tf.float32,name='b') saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, r"D:\test\wb") print(sess.run(W)) print(sess.run(b)) """ 輸出: [[1. 1. 1.] [2. 2. 2.]] [[0. 1. 2.]] """
這種方法不方便的在於,在使用模型的時候,必須把模型的結構重新定義一遍,然後載入對應名字的變數的值。