淺談TF的共享變數
先說說為什麼需要共享變數。
我們在訓練模型的時候,需要一次次的輸入訓練資料,網路的權重和偏執在一次次的迭代過程中,不斷地修正自身的值,這個迭代過程,我們通常的程式設計思路是這麼做:
conver1_weight=tf.xxx(conver1_weight,…)
我們從兩個方面考慮這麼做的後果:
1,迭代過程被封裝在自己編寫的函式內部(考慮到模組化或者程式碼易讀性需要這麼做),那麼在函式內部的這個變數就是區域性變數,無法影響函式外部的conver1_weight的值,當然我們可以將conver1_weight設定為全域性變數。比如下面的例子:
import tensorflow as tf
import numpy as np
global_var=tf.Variable(tf.constant(0.5))
def change_global_var():
global global_var
global_var=tf.add(global_var,0.4)
return global_var
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
print("global_var=",sess.run(global_var))
tmp=change_global_var( )
print("after add,global_var=",sess.run(global_var))
但是,這麼做會破會工程的封裝性,沒錯,就是這個cao蛋的理由,也是設計和使用共享變數的理由之一,雖然它看起來比什麼共享變數更簡單直觀易用。
這麼做的另一個缺點,和我們說的第二條缺點一樣。接著看:
2.神經網路很少是簡單的,主要是反映在節點的數量和訓練資料的量上。設想我們有一個3層,每層100個節點的網路,而且有10000條訓練資料。這樣的話,就有兩個100x100的方陣資料,每訓練一次,產生一個這樣的資料集(conver1_weight=tf.xxx(conver1_weight,…)會產生一個新的conver1_weight,name和原先的cover1_weight不一樣,大家可以編寫簡單程式碼測試),這時候產生的訓練變數有多少?1000x100x100,而且這還是隻有一個weight引數,加上bias呢?或者如果這是一個複雜的神經網路,有上億個神經元的時候呢?消耗的記憶體無疑是驚人的。怎麼處理這個問題呢?TF的設計者想出了共享變數這個點子,核心思想就是:如果根據name可知該變數存在,那麼使用該變數的值運算,不再建立新的tensor變數。
共享變數的宣告、建立和使用不復雜。下面說明:
第一次宣告共享變數,需要在tf.variable_scope中宣告,指明該共享變數的作用域,類似於其他語言的宣告一個靜態的類成員,該成員只能在類範圍內共享
[程式碼段1]
with tf.variable_scope("scope1"):
get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))
如果程式的其他地方需要用到這個共享變數,那麼,也要宣告這段程式和變數屬於上面宣告的作用域scope1,並且宣告引數reuse=True,這時候,才可以用tf.get_variable()來取得該變數。格式如下:
[程式碼段2]
with tf.variable_scope("scope1",reuse=True):
get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))
此時,在scope1中不能再用get_variable建立或取得[程式碼段1]沒有的額變數,否則會提示錯誤:Variable scope1/firstvar2 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
get_variable會從系統維護的變數列表中查詢name為firstvar的變數,並用get_var3指向該變數,並不會建立新name的新變數(和程式碼1中不一樣,程式碼1中,如果沒有該name的變數,則建立一個)。
當然:resuse=tf.AUTO_REUSE更方便,可以實現第一次reuse=False,第二次自動為True。
完整的簡單演示程式碼如下:
import tensorflow as tf
with tf.variable_scope("scope1"):
get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))
print("get_var1:",get_var1.name)
with tf.variable_scope("scope2"):
get_var2=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.4))
print("get_var2:",get_var2.name)
with tf.variable_scope("scope1",reuse=True):
get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))
print("get_var3:",get_var3.name)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("get_var1=",get_var1.eval())
print("get_var2=",get_var2.eval())
print("get_var3=",get_var3.eval())