如何檢視Tensoflow模型中已儲存的引數
1.儲存和讀取
1.1 儲存
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(aa))
# Step 1 儲存
saver.save(sess,'./ttt')
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ]
[-0.20688622 0.60574555 -0.26031223 -0.441991 ]
[-0.22254886 1.4805079 -1.7360271 1.1423918 ]]
這兒我們定義了一個name=var
的變數(隨便說一句aa
這類名稱是我們寫程式時用以區分各個變數之間的依據,換句話說是給我們自己看的;而var
這個名字是tensorflow計算圖上用來區分各個變數和操作的依據),並且將其進行了儲存。
1.2 讀取
說到讀取,就有兩個方面了:第一,知道引數的名字(上面的var)時之間讀取該變數;第二,不知道引數的名稱時可以先打出所有變數,然後找你所要變數對應的名稱再按名讀取就行。
#----------------------直接按名讀取---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
print(sess.run(tf.get_default_graph().get_tensor_by_name('var:0')))
>>
[[ 0.8604646 0.45935377 -0.24135743 -2.2841513 ]
[-0.20688622 0.60574555 -0.26031223 -0.441991 ]
[-0.22254886 1.4805079 -1.7360271 1.1423918 ]]
#----------------------檢視所有變數名---------------------------
import tensorflow as tf
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='var')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
print(var_list)
print(sess.run(var_list))
['var:0']
[array([[ 0.8604646 , 0.45935377, -0.24135743, -2.2841513 ],
[-0.20688622, 0.60574555, -0.26031223, -0.441991 ],
[-0.22254886, 1.4805079 , -1.7360271 , 1.1423918 ]],
dtype=float32)]
可以看到讀取變數後的輸出值和儲存時的一樣。
2.哪些變數能夠儲存
其實saver.save()
在儲存引數的時候是有選擇的(我說的選擇不是通過save()引數裡面控制的引數),看例子:
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10
saver = tf.train.Saver()
with tf.Session() as sess:
# Step 1 儲存
sess.run(tf.global_variables_initializer())
saver.save(sess,'./ttt')
這兒我們一共定義了6個引數,其中有三個tensor變數(aa,dd,ee)和兩個tensor常量(bb,cc),和一個普通變數,我們來看一下哪些引數儲存成功:
aa = tf.Variable(tf.random_normal(shape=[3, 4], dtype=tf.float32), name='aa')
bb = tf.constant(0,dtype=tf.float32,name='bb')
cc = tf.zeros(shape=[5],dtype=tf.float32,name='cc')
dd = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='cc')
ee = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='ee',trainable=False)
i = 10
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
# print(sess.run(list_before_train))
print(var_list)
>>
['aa:0', 'cc_1:0', 'ee:0']
我們可以看到,這兒只有3個變數被儲存成功,aa,ee,cc_1
。明顯,aa指得就是第1行程式碼定義得變數,ee指得就是第5行程式碼,那麼cc_1指得是第3行還是第4行呢? 指得是第4行,這也印證tensorflow內部是通過name='var'
這個引數來區分的。
由此我們可以得出:saver.save()
只儲存tensor變數,也就是tf.Variable()
定義的變數,其它量包括tensor常量都是不被儲存的。
3.網路模型的引數也能這樣來儲存麼?
答案是:能!
這裡以一個rnn cell按時間維度展開為例:
#--------------------------------------------儲存--------------------------------------
import tensorflow as tf
import numpy as np
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x1
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x2
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]]) # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Step 1 儲存
saver.save(sess,'./ttt')
#-------------------------------------------讀取--------------------------------------
import tensorflow as tf
import numpy as np
output_size = 5
batch_size = 4
time_step = 3
dim = 3
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=output_size)
inputs = tf.placeholder(dtype=tf.float32, shape=[time_step, batch_size, dim])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
X = np.array([[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x1
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]], # x2
[[1, 2, 1], [2, 0, 0], [2, 1, 0], [1, 1, 0]]]) # x3
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, time_major=True)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './ttt')
var_list = [v.name for v in tf.global_variables()]
print(var_list)
>>
['rnn/basic_rnn_cell/kernel:0', 'rnn/basic_rnn_cell/bias:0']
既然都儲存了,那為什麼這兒只有兩個變數呢?那是因為tensorflow內部在計算時為了方便或是更快,把所有的weight和bias都疊在一起了,具體參見此處!
另外說明一下:
在網上看到很多人提問LSTM訓練好的模型“儲存不了”。為什麼會覺得儲存不了呢? 因為在當訓練到某個時候loss已經很低了,當stop後再次載入最新幾個模型時都發現loss急劇升高,因此就會決定是因為模型的引數沒有儲存成功而導致的,因為本人在這兩天也出現了這個問題。於是網上各種搜查LSTM模型儲存的方法,試了一大堆依舊無效,後來終於發現是由於同一個函式在不同平臺(windwo,linux)上的處理結果居然不一樣,導致預處理後的訓練集一直在變而導致的!
另外,你還可以通過在每次儲存LSTM模型時,打印出其中某個引數的具體值,然後手動stop;當你再次載入模型時,立馬輸出同一個變數,對比一下是否相同,如果相同則說明儲存成功。依照我自己的實驗來看,兩者是相同的。
print('------儲存時的值----->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5, :4])
print('載入時的值----------->')
print(sess.run(tf.get_default_graph().get_tensor_by_name(
'rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel:0'))[:5,:4])
# 同一個變數,相同部分的值