[Tensorflow]L2正則化和collection【tf.GraphKeys】
阿新 • • 發佈:2019-01-22
L2-Regularization 實現的話,需要把所有的引數放在一個集合內,最後計算loss時,再減去加權值。
相比自己亂搞,程式碼一團糟,Tensorflow 提供了更優美的實現方法。
一、tf.GraphKeys : 多個包含Variables(Tensor)集合
(1)GLOBAL_VARIABLES:使用tf.get_variable()時,預設會將vairable放入這個集合。
我們熟悉的tf.global_variables_initializer()就是初始化這個集合內的Variables。
Tips: tf.GraphKeys.GLOBAL_VARIABLES == "variable"。即其儲存的是一個字串。import tensorflow as tf sess=tf.Session() a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer()) b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer()) #collections=None等價於 collection=[tf.GraphKeys.GLOBAL_VARIABLES] gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) #tf.get_collection(collection_name)返回某個collection的列表 for var in gv: print(var is a) print(var.get_shape())
(2)自定義集合
想個集合的名字,然後在tf.get_variable時,把集合名字傳給 collection 就好了。
import tensorflow as tf
sess=tf.Session()
a=tf.get_variable("a",shape=[10],collections=["mycollection"]) #不把GLOBAL_VARIABLES加進去,那麼就不在那個集合裡了。
keys=tf.get_collection("mycollection")
for key in keys:
print(key.name)
二、L2正則化
先看看tf.contrib.layers.l2_regularizer(weight_decay)都執行了什麼:
import tensorflow as tf sess=tf.Session() weight_decay=0.1 tmp=tf.constant([0,1,2,3],dtype=tf.float32) """ l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) """ #**上面程式碼的等價程式碼 a=tf.get_variable("I_am_a",initializer=tmp) a2=tf.reduce_sum(a*a)*weight_decay/2; a3=tf.get_variable(a.name.split(":")[0]+"/Regularizer/l2_regularizer",initializer=a2) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES,a2) #** sess.run(tf.global_variables_initializer()) keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) for key in keys: print("%s : %s" %(key.name,sess.run(key)))
我們很容易可以模擬出tf.contrib.layers.l2_regularizer都做了什麼,不過會讓程式碼變醜。 以下比較完整實現L2 正則化。
import tensorflow as tf
sess=tf.Session()
weight_decay=0.1 #(1)定義weight_decay
l2_reg=tf.contrib.layers.l2_regularizer(weight_decay) #(2)定義l2_regularizer()
tmp=tf.constant([0,1,2,3],dtype=tf.float32)
a=tf.get_variable("I_am_a",regularizer=l2_reg,initializer=tmp) #(3)建立variable,l2_regularizer複製給regularizer引數。
#目測REXXX_LOSSES集合
#regularizer定義會將a加入REGULARIZATION_LOSSES集合
print("Global Set:")
keys = tf.get_collection("variables")
for key in keys:
print(key.name)
print("Regular Set:")
keys = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
for key in keys:
print(key.name)
print("--------------------")
sess.run(tf.global_variables_initializer())
print(sess.run(a))
reg_set=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #(4)則REGULARIAZTION_LOSSES集合會包含所有被weight_decay後的引數和,將其相加
l2_loss=tf.add_n(reg_set)
print("loss=%s" %(sess.run(l2_loss)))
"""
此處輸出0.7,即:
weight_decay*sigmal(w*2)/2=0.1*(0*0+1*1+2*2+3*3)/2=0.7
其實程式碼自己寫也很方便,用API看著比較正規。
在網路模型中,直接將l2_loss加入loss就好了。(loss變大,執行train自然會decay)
"""