1. 程式人生 > >Variable和get_variable的用法以及區別

Variable和get_variable的用法以及區別

沒有 constant src 分開 true iba 順序 () lse

在tensorflow中,可以使用tf.Variable來創建一個變量,也可以使用tf.get_variable來創建一個變量,但是在一個模型需要使用其他模型的變量時,tf.get_variable就派上大用場了。

先分別介紹兩個函數的用法:

 1 import tensorflow as tf
 2 var1 = tf.Variable(1.0,name=firstvar)
 3 print(var1:,var1.name)
 4 var1 = tf.Variable(2.0,name=firstvar)
 5 print(var1:,var1.name)
 6 var2 = tf.Variable(3.0)
7 print(var2:,var2.name) 8 var2 = tf.Variable(4.0) 9 print(var2:,var2.name) 10 get_var1 = tf.get_variable(name=firstvar,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3)) 11 print(get_var1:,get_var1.name) 12 get_var1 = tf.get_variable(name=firstvar1,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
13 print(get_var1:,get_var1.name) 14 15 with tf.Session() as sess: 16 sess.run(tf.global_variables_initializer()) 17 print(var1=,var1.eval()) 18 print(var2=,var2.eval()) 19 print(get_var1=,get_var1.eval())

結果如下:

技術分享圖片

我們來分析一下代碼,tf.Varibale是以定義的變量名稱為唯一標識的,如var1,var2,所以可以重復地創建name=‘firstvar‘的變量,但是tensorflow會給它們按順序取後綴,如firstvar_1:0,firstval_2:0,...,如果沒有制定名字,系統會自動加上一個名字Variable:0。而且由於tf.Varibale是以定義的變量名稱為唯一標識的,所以當第二次命名同一個變量名時,第一個變量就會被覆蓋,所以var1由1.0變成2.0。

對於tf.get_variable,它是以指定的name屬性為唯一標識,而不是定義的變量名稱,所以不能同時定義兩個變量name是相同的,例如下面這種就會報錯:

1 get_var1 = tf.get_variable(name=a,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3))
2 print(get_var1:,get_var1.name)
3 get_var2 = tf.get_variable(name=a,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))
4 print(get_var1:,get_var1.name)

這樣就會報錯了。如果我們想聲明兩次相同name的變量,這時variable_scope就派上用場了,可以使用variable_scope將它們分開:

1 import tensorflow as tf
2 with tf.variable_scope(test1):
3     get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4 with tf.variable_scope(test2):
5     get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var1:,get_var1.name)
7 print(get_var2:,get_var2.name)

這樣就不會報錯了,variable_scope相當於聲明了作用域,這樣在不同的作用域存在相同的變量就不會沖突了,結果如下:

技術分享圖片

當然,scope還支持嵌套:

1 import tensorflow as tf
2 with tf.variable_scope(test1,):
3     get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4     with tf.variable_scope(test2,):
5         get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var1:,get_var1.name)
7 print(get_var2:,get_var2.name)

輸出結果為:

技術分享圖片

怎麽樣?可以對照上面的結果體會一下不同。那麽如何通過get_variable來實現變量共享呢?這就要用到variable_scope裏的一個屬性:reuse,顧名思義嘛,當把reuse設置成True時就可以了,它表示使用已經定義過的變量,這是get_variable就不會再創建新的變量,而是去找與name相同的變量:

import tensorflow as tf
with tf.variable_scope(test1,):
    get_var1 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
    with tf.variable_scope(test2,):
        get_var2 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
print(get_var1:,get_var1.name)
print(get_var2:,get_var2.name)
with tf.variable_scope(test1,reuse=True):
    get_var3 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
    with tf.variable_scope(test2,):
        get_var4 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
print(get_var3:,get_var3.name)
print(get_var4:,get_var4.name)

輸出結果如下:

技術分享圖片

當然前面說過,reuse=True是使用前面已經創建過的變量,如果僅僅只有從第八行到最後的代碼,也會報錯的,如果還是想這麽做,就需要把reuse屬性設置成tf.AUTO_REUSE

1 import tensorflow as tf
2 with tf.variable_scope(test1,reuse=tf.AUTO_REUSE):
3     get_var3 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
4     with tf.variable_scope(test2,):
5         get_var4 = tf.get_variable(name=firstvar,shape=[2],dtype=tf.float32)
6 print(get_var3:,get_var3.name)
7 print(get_var4:,get_var4.name)

此時就不會報錯,tf.AUTO_REUSE可以實現第一次調用variable_scope時,傳入的reuse值為False,再次調用時,傳入reuse的值就會自動變為True。

Variable和get_variable的用法以及區別