tensorflow 儲存和載入模型 -2
阿新 • • 發佈:2019-01-23
1、
我們經常在訓練完一個模型之後希望儲存訓練的結果,這些結果指的是模型的引數,以便下次迭代的訓練或者用作測試。Tensorflow針對這一需求提供了Saver類。- Saver類提供了向checkpoints檔案儲存和從checkpoints檔案中恢復變數的相關方法。Checkpoints檔案是一個二進位制檔案,它把變數名對映到對應的tensor值。
- 只要提供一個計數器,當計數器觸發時,Saver類可以自動的生成checkpoint檔案。這讓我們可以在訓練過程中儲存多箇中間結果。例如,我們可以儲存每一步訓練的結果。
- 為了避免填滿整個磁碟,Saver可以自動的管理Checkpoints檔案。例如,我們可以指定儲存最近的N個Checkpoints檔案
import tensorflow as tf import numpy as np isTrain = True train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '/home/jdlu/jdluTensor/test/tmp/' x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b))
說明:
訓練的過程:
1、先設定isTrain=True,然後會儲存模型,設定isTrain=False會將訓練好的模型載入進來進行測試
2、train_steps:表示訓練的次數,例子中使用100
3、checkpoint_steps:表示訓練多少次儲存一下checkpoints,例子中使用50
4、checkpoint_dir:表示checkpoints檔案的儲存路徑,例子中使用當前路徑
說明:每訓練checkpoint_steps就儲存一次模型,在訓練的過程中,就可以多次儲存模型。if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt',global_step = i+1)
測試的過程:
1、測試的過程就是載入訓練模型好的模型
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
說明:
checkpoint的檔案內容:
儲存model的路徑下的檔案內容:
saver.save(sess, checkpoint_dir + 'model.ckpt',global_step = i+1)
每次儲存一次都會相應生成三個檔案,分別是.data-00000-of-00001,.index,.meta
==================================================================================================================