Variable和get_variable的用法以及區別
在tensorflow中,可以使用tf.Variable來創建一個變量,也可以使用tf.get_variable來創建一個變量,但是在一個模型需要使用其他模型的變量時,tf.get_variable就派上大用場了。
先分別介紹兩個函數的用法:
1 import tensorflow as tf 2 var1 = tf.Variable(1.0,name=‘firstvar‘) 3 print(‘var1:‘,var1.name) 4 var1 = tf.Variable(2.0,name=‘firstvar‘) 5 print(‘var1:‘,var1.name) 6 var2 = tf.Variable(3.0)7 print(‘var2:‘,var2.name) 8 var2 = tf.Variable(4.0) 9 print(‘var2:‘,var2.name) 10 get_var1 = tf.get_variable(name=‘firstvar‘,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3)) 11 print(‘get_var1:‘,get_var1.name) 12 get_var1 = tf.get_variable(name=‘firstvar1‘,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4))13 print(‘get_var1:‘,get_var1.name) 14 15 with tf.Session() as sess: 16 sess.run(tf.global_variables_initializer()) 17 print(‘var1=‘,var1.eval()) 18 print(‘var2=‘,var2.eval()) 19 print(‘get_var1=‘,get_var1.eval())
結果如下:
我們來分析一下代碼,tf.Varibale是以定義的變量名稱為唯一標識的,如var1,var2,所以可以重復地創建name=‘firstvar‘的變量,但是tensorflow會給它們按順序取後綴,如firstvar_1:0,firstval_2:0,...,如果沒有制定名字,系統會自動加上一個名字Variable:0。而且由於tf.Varibale是以定義的變量名稱為唯一標識的,所以當第二次命名同一個變量名時,第一個變量就會被覆蓋,所以var1由1.0變成2.0。
對於tf.get_variable,它是以指定的name屬性為唯一標識,而不是定義的變量名稱,所以不能同時定義兩個變量name是相同的,例如下面這種就會報錯:
1 get_var1 = tf.get_variable(name=‘a‘,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.3)) 2 print(‘get_var1:‘,get_var1.name) 3 get_var2 = tf.get_variable(name=‘a‘,shape=[1],dtype=tf.float32,initializer=tf.constant_initializer(0.4)) 4 print(‘get_var1:‘,get_var1.name)
這樣就會報錯了。如果我們想聲明兩次相同name的變量,這時variable_scope就派上用場了,可以使用variable_scope將它們分開:
1 import tensorflow as tf 2 with tf.variable_scope(‘test1‘): 3 get_var1 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 4 with tf.variable_scope(‘test2‘): 5 get_var2 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 6 print(‘get_var1:‘,get_var1.name) 7 print(‘get_var2:‘,get_var2.name)
這樣就不會報錯了,variable_scope相當於聲明了作用域,這樣在不同的作用域存在相同的變量就不會沖突了,結果如下:
當然,scope還支持嵌套:
1 import tensorflow as tf 2 with tf.variable_scope(‘test1‘,): 3 get_var1 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 4 with tf.variable_scope(‘test2‘,): 5 get_var2 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 6 print(‘get_var1:‘,get_var1.name) 7 print(‘get_var2:‘,get_var2.name)
輸出結果為:
怎麽樣?可以對照上面的結果體會一下不同。那麽如何通過get_variable來實現變量共享呢?這就要用到variable_scope裏的一個屬性:reuse,顧名思義嘛,當把reuse設置成True時就可以了,它表示使用已經定義過的變量,這是get_variable就不會再創建新的變量,而是去找與name相同的變量:
import tensorflow as tf with tf.variable_scope(‘test1‘,): get_var1 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) with tf.variable_scope(‘test2‘,): get_var2 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) print(‘get_var1:‘,get_var1.name) print(‘get_var2:‘,get_var2.name) with tf.variable_scope(‘test1‘,reuse=True): get_var3 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) with tf.variable_scope(‘test2‘,): get_var4 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) print(‘get_var3:‘,get_var3.name) print(‘get_var4:‘,get_var4.name)
輸出結果如下:
當然前面說過,reuse=True是使用前面已經創建過的變量,如果僅僅只有從第八行到最後的代碼,也會報錯的,如果還是想這麽做,就需要把reuse屬性設置成tf.AUTO_REUSE
1 import tensorflow as tf 2 with tf.variable_scope(‘test1‘,reuse=tf.AUTO_REUSE): 3 get_var3 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 4 with tf.variable_scope(‘test2‘,): 5 get_var4 = tf.get_variable(name=‘firstvar‘,shape=[2],dtype=tf.float32) 6 print(‘get_var3:‘,get_var3.name) 7 print(‘get_var4:‘,get_var4.name)
此時就不會報錯,tf.AUTO_REUSE可以實現第一次調用variable_scope時,傳入的reuse值為False,再次調用時,傳入reuse的值就會自動變為True。
Variable和get_variable的用法以及區別