1. 程式人生 > 程式設計 >TensorFlow Saver:儲存和讀取模型引數.ckpt例項

TensorFlow Saver:儲存和讀取模型引數.ckpt例項

在使用TensorFlow的過程中,儲存模型引數變數是很重要的一個環節,既可以保證訓練過程資訊不丟失,也可以幫助我們在需要快速恢復或使用一個模型的時候,利用之前儲存好的引數之間匯入,可以節省大量的訓練時間。本文通過最簡單的例程教大家如何儲存和讀取.ckpt檔案。

一、儲存到檔案

首先是匯入必要的東西:

import tensorflow as tf
import numpy as np

隨便寫幾個變數:

# Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name='weights')
b = tf.Variable([[1,3]],name='biases')
 
init= tf.initialize_all_variables()

定義一個saver,來儲存我們的各種變數:

saver = tf.train.Saver()

儲存的檔案用.ckpt字尾:

with tf.Session() as sess:
  sess.run(init)
  save_path = saver.save(sess,"my_net/save_net.ckpt")
  print("Save to path: ",save_path)

上面我們就完成了儲存操作。

接下來我們要把之前儲存過的變數取出來。

二、取出之前儲存的變數

這裡要注意,取出時要先開闢一個容器來裝,shape和type要和我們之前儲存的.ckpt一樣。

# restore variables
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2,3)),name="weights")
b = tf.Variable(np.arange(3).reshape((1,name="biases")

restore時,不需要進行init= tf.initialize_all_variables()操作。

利用saver提取檔案:

saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess,"my_net/save_net.ckpt")
  print("weights:",sess.run(W))
  print("biases:",sess.run(b))

結果:

TensorFlow Saver:儲存和讀取模型引數.ckpt例項

以上這篇TensorFlow Saver:儲存和讀取模型引數.ckpt例項就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。