1. 程式人生 > 其它 >tensorflow儲存和恢復模型saver.restore

tensorflow儲存和恢復模型saver.restore

1.本文只對一些細節點做補充,大體的步驟就不詳述了
2.儲存模型
① 首先我使用的是tensorflow-gpu 1.4.0
② 這個版本生成的ckpt檔案是這樣的:

其中.meta存放的是網路模型和所有的變數;
.index 和.data一起存放變數資料
-0 -500表示checkpoint點
③ 儲存的配置(一定細看程式碼註釋!!!)

import tensorflow as tf
w1 = tf.Variable(變數的初始化, name='w1')
w2 = tf.Variable(變數的初始化, name='w2')
saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)   # 這裡是細節部分,可以指定儲存的變數,每兩小時儲存最近的5個模型
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False))   # 因為模型沒必要多次儲存,所以寫為False

3.恢復模型(一定細看程式碼註釋!!!)
程式碼:

import tensorflow as tf
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph(模型路徑)  # 模型路徑中必須指定到具體的模型下如:xx.ckpt-500.meta,且一般來講,所有模型都是一樣的,如果沒有改變模型的條件下。
    # 下面的restore就是在當前的sess下恢復了所有的變數
    saver.restore(sess,資料路徑)  # 資料路徑也必須指定到具體某個模型的資料,但建立這個路徑的方法很多,比如呼叫最後一個儲存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,並且這兩個是等效的,如果是xx.ckpt-0.data,就是第一個模型的資料
    print(sess.run('w1:0'))  # 這裡的w1必須加上:0

————————————————
版權宣告:本文為CSDN博主「做一隻AI小能手」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處連結及本宣告。
原文連結:https://blog.csdn.net/qq_37285386/article/details/88957558