tensorflow 恢復(restore)模型的兩種方式
0.前言
首先我們要理解TensorFlow的一個規則,首先構建計算圖(graph),然後初始化graph中的data,這兩步是分開的。
1.如何恢復模型
有兩種方式(這兩種方式有比較大的不同):
1.1 重新使用程式碼構建圖
舉個例子(完整程式碼):
def build_graph():
w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
add = tf.add(w1,w2,name='add')
add1 = tf.add(add,w3,name='add1')
return w3,add1
with tf.Session() as sess:
ckpt_state = tf.train.get_checkpoint_state('./temp/')
if ckpt_state:
w3,add1=build_graph()
saver = tf.train.Saver()
saver.restore(sess, ckpt_state.model_checkpoint_path)
else:
w3,add1=build_graph()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
a = sess.run(add1,feed_dict={
w3:[1,2,3,4]
})
print(a)
saver.save(sess,'./temp/model'
上面的流程很簡單,首先build_graph(),然後如果有ckpt檔案就從該檔案中讀取資料,否則用sess.run(init_op)初始化資料。
那麼第一種restore方法就出來了:
build_graph()
saver = tf.train.Saver()
saver.restore(sess, ckpt_state.model_checkpoint_path)
首先build graph,等於是將圖重新建立了一遍,和之前圖的一樣,然後將ckpt檔案裡的資料restore到圖裡的變數裡。
當然,在build graph的過程中,你可以在原有的圖裡加一些變數,但是加的變數一定要初始化,但是要注意到一個問題,如果使用:
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
這種方式時,如果定義init_op時的graph中已經存在原有圖的變數,那麼sess.run(init_op)會將載入進來的資料清空。
為了解決這個問題,兩種方式:
新定義的變數放在init_op之前,在init_op之後restore(注意,載入好變數後才run(init_op)同樣會覆蓋)
即,init_op得到當前圖中的所有變數,sess.run(init_op)對init_op中的變數進行初始化,所以什麼時候定義init_op和什麼時候執行run(init_op)都很重要
只初始化未初始化的變數
def get_uninitialized_variables(sess):
global_vars = tf.global_variables()
# print([str(i.name) for i in global_vars])
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
print([str(i.name) for i in not_initialized_vars])
return not_initialized_vars
sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))
PS:注意saver = tf.train.Saver()要定義在圖構建完成之後
即將被restore的變數不用初始化,但是隻有在restore之後,這些變數才會被初始化,所以在restore之前執行這些值會報沒有初始化的錯。
1.2 利用儲存的.meta檔案恢復圖
上面的方式適用於斷點續訓,且自己有構建圖的完整程式碼,如果我要用別人的網路(fine tune),或者在自己原有網路上修改(即修改原有網路的某個部分),那麼將網路的圖重新構建一遍會很麻煩,那麼我們可以直接從.meta檔案中載入網路結構。
1.2.1 get_tensor_by_name
完整程式碼:
def build_graph():
w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
add = tf.add(w1,w2,name='add')
add1 = tf.add(add,w3,name='add1')
return w3,add1
with tf.Session() as sess:
ckpt_state = tf.train.get_checkpoint_state('./temp/')
if ckpt_state:
saver = tf.train.import_meta_graph('./temp/model.meta')
graph = tf.get_default_graph()
w3 = graph.get_tensor_by_name('W3:0')
add1 = graph.get_tensor_by_name('add1:0')
saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
print(sess.run(tf.get_collection('w1')[0]))
else:
w3,add1=build_graph()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
a = sess.run(add1,feed_dict={
w3:[1,2,3,4]
})
print(a)
saver.save(sess,'./temp/model')
上面使用了import_meta_graph()來載入圖,並用restore給變數賦值。
通過get_tensor_by_name來獲取儲存的圖中的op或變數,之後可以對獲取的值進行操作,如果之後save的話,也會將import_meta_graph()中圖引用的部分儲存下來。
1.2.2
def build_graph():
w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
add = tf.add(w1,w2,name='add')
add1 = tf.add(add,w3,name='add1')
tf.add_to_collection('w1','W1:0')
tf.add_to_collection('w3',w3)
tf.add_to_collection('add1',add1)
return w3,add1
with tf.Session() as sess:
ckpt_state = tf.train.get_checkpoint_state('./temp/')
if ckpt_state:
saver = tf.train.import_meta_graph('./temp/model.meta')
w3 = tf.get_collection('w3')[0]
add1 = tf.get_collection('add1')[0]
# run init_op before restore
saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
else:
w3,add1=build_graph()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
a = sess.run(add1,feed_dict={
w3:[1,2,3,4]
})
print(a)
saver.save(sess,'./temp/model')
通過import_meta_graph引進圖,通過get_collection獲得變數,其實和get_tensor_by_name差不多,但是可能會更方便一點。
2 總結
總的來說,兩種方式都是先構造好圖,然後通過restore來給圖裡的變數賦值。
一個常見的問題是,要引入新的變數,對以前的圖進行改造,那麼如何初始化新的變數且不覆蓋原來的資料?
可以先啥都不管把所有的圖相關的部分構造好後,得到init_op,然後在restore前run(init_op)
對未初始化的變數進行初始化