在tensorflow中儲存模型引數
阿新 • • 發佈:2019-01-22
想要儲存訓練之後得到的神經網路引數,一般有兩種辦法。
第一種,可以將tensor物件轉換為numpy陣列進行儲存。
即,
numpy.savetxt('weight.txt', weight.eval())第二種,是利用tensorflow自帶的Saver物件。
import tensorflow as tf ##################################################3 w1 = tf.Variable(tf.constant(1.0), name='w1') w2 = tf.Variable(tf.constant(2.0), name='w2'with tf.Session() as sess: sess.run(tf.global_variables_initializer()) w1 = tf.add(w1, w2) saver.save(sess, './my-model.ckpt')
上面的程式碼中,建立了容器vars。它收集了tensor變數w1和w2。之後,tensorflow將這一容器儲存。
在session中執行,就能將資料儲存到tensorflow建立的幾個檔案中。
上面的程式碼執行結束後,當前目錄下出現四個檔案:
my-model.ckpt.meta
my-model.ckpt.data-*
my-model.ckpt.index
checkpoint
利用這四個檔案就能恢復出 w1和w2這兩個變數。
with tf.Session() as sess: new_saver = tf.train.import_meta_graph('my-model.ckpt.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars'執行結果為:
[<tf.Tensor 'w1:0' shape=() dtype=float32_ref>, <tf.Tensor 'w2:0' shape=() dtype=float32_ref>] Tensor("w1:0", shape=(), dtype=float32_ref) w1:0 1.0 Tensor("w2:0", shape=(), dtype=float32_ref) w2:0 2.0