tensorflow儲存和恢復模型saver.restore
阿新 • • 發佈:2021-09-28
1.本文只對一些細節點做補充,大體的步驟就不詳述了
2.儲存模型
① 首先我使用的是tensorflow-gpu 1.4.0
② 這個版本生成的ckpt檔案是這樣的:
其中.meta存放的是網路模型和所有的變數;
.index 和.data一起存放變數資料
-0 -500表示checkpoint點
③ 儲存的配置(一定細看程式碼註釋!!!)
import tensorflow as tf w1 = tf.Variable(變數的初始化, name='w1') w2 = tf.Variable(變數的初始化, name='w2') saver = tf.train.Saver([w1,w2],max_to_keep=5, keep_checkpoint_every_n_hours=2) # 這裡是細節部分,可以指定儲存的變數,每兩小時儲存最近的5個模型 sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, './checkpoint_dir/MyModel',global_step=step,write_meta_graph=False)) # 因為模型沒必要多次儲存,所以寫為False
3.恢復模型(一定細看程式碼註釋!!!)
程式碼:
import tensorflow as tf with tf.Session() as sess: saver = tf.train.import_meta_graph(模型路徑) # 模型路徑中必須指定到具體的模型下如:xx.ckpt-500.meta,且一般來講,所有模型都是一樣的,如果沒有改變模型的條件下。 # 下面的restore就是在當前的sess下恢復了所有的變數 saver.restore(sess,資料路徑) # 資料路徑也必須指定到具體某個模型的資料,但建立這個路徑的方法很多,比如呼叫最後一個儲存的模型tf.train.latest_checkpoint('./checkpoint_dir'),也可以是xx.ckpt-500.data,並且這兩個是等效的,如果是xx.ckpt-0.data,就是第一個模型的資料 print(sess.run('w1:0')) # 這裡的w1必須加上:0
————————————————
版權宣告:本文為CSDN博主「做一隻AI小能手」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處連結及本宣告。
原文連結:https://blog.csdn.net/qq_37285386/article/details/88957558