[Deep Learning] TensorFlow模型、引數的儲存與讀取
阿新 • • 發佈:2019-02-01
大部分情況,我們會把訓練的網路儲存下來,用於後面的使用。或者,在當前網路下對引數進行一定程度的微調。
儲存變數
import tensorflow as tf
# The file path to save the data
save_file = './model.ckpt'
# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))
# Class used to save and/or restore Tensor Variables
saver = tf.train.Saver()
with tf.Session() as sess:
# Initialize all the Variables
sess.run(tf.global_variables_initializer())
# Show the values of weights and bias
print('Weights:')
print(sess.run(weights))
print('Bias:')
print(sess.run(bias))
# Save the model
saver.save(sess, save_file)
Weights:
[[-0.97990924 1.03016174 0.74119264]
[-0.82581609 -0.07361362 -0.86653847]]
Bias:
[ 1.62978125 -0.37812829 0.64723819]
讀取變數# Remove the previous weights and bias
tf.reset_default_graph()
# Two Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3 ]))
# Class used to save and/or restore Tensor Variables
saver = tf.train.Saver()
with tf.Session() as sess:
# Load the weights and bias
saver.restore(sess, save_file)
# Show the values of weights and bias
print('Weight:')
print(sess.run(weights))
print('Bias:')
print(sess.run(bias))