TensorFlow的Saver儲存類
一、Saver的介紹
有時可能只需要儲存或者載入部分變數。
比如,可能有一個之前訓練好的5層神經網路模型,但現在想寫一個6層的神經網路,那麼可以將之前5層神經網路中的引數直接載入到新的模型,而僅僅將最後一層神經網路重新訓練。
為了儲存或者載入部分變數,在宣告tf.train.Saver類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用saver = tf.train.Saver([v1])命令來構建tf.train.Saver類,那麼只有變數v1會被載入進來。
1、tf.train.Saver.save()方法儲存模型
模型格式:
tf.train.Saver.save(sess, save_path, global_step=None , latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
'''
引數說明:
sess: 用於儲存變數操作的會話。
save_path: String型別,用於指定訓練結果的儲存路徑。
global_step: 如果提供的話,這個數字會新增到save_path後面,用於構建checkpoint檔案。這個引數有助於我們區分不同訓練階段的結果。
'''
2、tf.train.Saver.restore方法提取模型
tf.train.Saver.restore(sess, save_path)
'''
引數說明:
sess: 用於載入變數操作的會話。
save_path: 同儲存模型是用到的的save_path引數。
'''
3、Tensorflow模型是什麼?
Tensorflow模型主要包含網路的設計或者圖(graph),和我們已經訓練好的網路引數的值。因此Tensorflow模型有兩個主要的檔案:
A) Meta graph:
這是一個儲存完整Tensorflow graph的protocol buffer,比如說,所有的 variables, operations, collections等等。這個檔案的字尾是 .meta 。
B) Checkpoint file:
這是一個包含所有權重(weights),偏置(biases),梯度(gradients)和所有其他儲存的變數(variables)的二進位制檔案。它包含兩個檔案:
mymodel.data-00000-of-00001
mymodel.index
其中,.data檔案包含了我們的訓練變數。
另外,除了這兩個檔案,Tensorflow有一個叫做checkpoint的檔案,記錄著已經最新的儲存的模型檔案。