1. 程式人生 > >淺談TF的共享變數

淺談TF的共享變數

先說說為什麼需要共享變數。
我們在訓練模型的時候,需要一次次的輸入訓練資料,網路的權重和偏執在一次次的迭代過程中,不斷地修正自身的值,這個迭代過程,我們通常的程式設計思路是這麼做:
conver1_weight=tf.xxx(conver1_weight,…)
我們從兩個方面考慮這麼做的後果:
1,迭代過程被封裝在自己編寫的函式內部(考慮到模組化或者程式碼易讀性需要這麼做),那麼在函式內部的這個變數就是區域性變數,無法影響函式外部的conver1_weight的值,當然我們可以將conver1_weight設定為全域性變數。比如下面的例子:

import tensorflow as
tf import numpy as np global_var=tf.Variable(tf.constant(0.5)) def change_global_var(): global global_var global_var=tf.add(global_var,0.4) return global_var sess=tf.Session() init=tf.global_variables_initializer() sess.run(init) print("global_var=",sess.run(global_var)) tmp=change_global_var(
) print("after add,global_var=",sess.run(global_var))

但是,這麼做會破會工程的封裝性,沒錯,就是這個cao蛋的理由,也是設計和使用共享變數的理由之一,雖然它看起來比什麼共享變數更簡單直觀易用。
這麼做的另一個缺點,和我們說的第二條缺點一樣。接著看:
2.神經網路很少是簡單的,主要是反映在節點的數量和訓練資料的量上。設想我們有一個3層,每層100個節點的網路,而且有10000條訓練資料。這樣的話,就有兩個100x100的方陣資料,每訓練一次,產生一個這樣的資料集(conver1_weight=tf.xxx(conver1_weight,…)會產生一個新的conver1_weight,name和原先的cover1_weight不一樣,大家可以編寫簡單程式碼測試),這時候產生的訓練變數有多少?1000x100x100,而且這還是隻有一個weight引數,加上bias呢?或者如果這是一個複雜的神經網路,有上億個神經元的時候呢?消耗的記憶體無疑是驚人的。怎麼處理這個問題呢?TF的設計者想出了共享變數這個點子,核心思想就是:如果根據name可知該變數存在,那麼使用該變數的值運算,不再建立新的tensor變數。
共享變數的宣告、建立和使用不復雜。下面說明:
第一次宣告共享變數,需要在tf.variable_scope中宣告,指明該共享變數的作用域,類似於其他語言的宣告一個靜態的類成員,該成員只能在類範圍內共享

[程式碼段1]
with tf.variable_scope("scope1"):
    get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))

如果程式的其他地方需要用到這個共享變數,那麼,也要宣告這段程式和變數屬於上面宣告的作用域scope1,並且宣告引數reuse=True,這時候,才可以用tf.get_variable()來取得該變數。格式如下:

[程式碼段2]
with tf.variable_scope("scope1",reuse=True):
    get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))

此時,在scope1中不能再用get_variable建立或取得[程式碼段1]沒有的額變數,否則會提示錯誤:Variable scope1/firstvar2 does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
get_variable會從系統維護的變數列表中查詢name為firstvar的變數,並用get_var3指向該變數,並不會建立新name的新變數(和程式碼1中不一樣,程式碼1中,如果沒有該name的變數,則建立一個)。
當然:resuse=tf.AUTO_REUSE更方便,可以實現第一次reuse=False,第二次自動為True。
完整的簡單演示程式碼如下:

import tensorflow as tf
with tf.variable_scope("scope1"):
    get_var1=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.3))
    print("get_var1:",get_var1.name)
with tf.variable_scope("scope2"):
    get_var2=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.4))
    print("get_var2:",get_var2.name)
with tf.variable_scope("scope1",reuse=True):
    get_var3=tf.get_variable("firstvar",[1],initializer=tf.constant_initializer(0.5))
    print("get_var3:",get_var3.name)    
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("get_var1=",get_var1.eval())
    print("get_var2=",get_var2.eval())
    print("get_var3=",get_var3.eval())
    

點選這裡執行