1. 程式人生 > >tensorflow凍結變量方法(tensorflow freeze variable)

tensorflow凍結變量方法(tensorflow freeze variable)

常用 原來 定義變量 針對 spa art 名詞 splay res

最近由於項目需要,要對tensorflow構造的模型中部分變量凍結,然後繼續訓練,因此研究了一下tf中凍結變量的方法,目前找到三種,各有優缺點,記錄如下:

1.名詞解釋

凍結變量,指的是在訓練模型時,對某些可訓練變量不更新,即僅參與前向loss計算,不參與後向傳播,一般用於模型的finetuning等場景。例如:我們在其他數據上訓練了一個resnet152模型,然後希望在目前數據上做finetuning,一般來講,網絡的前幾層卷積是用來提取底層圖像特征的,因此可以對前3個卷積層進行凍結,不改變其weight和bias的數值。

2.方法介紹

目前我找到了三種tf凍結變量的方法,各有優缺點,具體如下:

2.1 trainable=False

一切tf.Variable或tf.Variable的子類,在創建時,都有一個trainable參數,在tf官方文檔(https://www.tensorflow.org/api_docs/python/tf/Variable)中有對這個參數的定義,

技術分享圖片

意思是,如果trainable設置為True,就會把變量添加到GraphKeys.TRAINABLE_VARIABLES集合中,如果是False,則不添加。而在計算梯度進行後向傳播時,我們一般會使用一個optimizer,然後調用該optimizer的compute_gradients方法。在compute_gradients中,第二個參數var_list如果不傳入,則默認為GraphKeys.TRAINABLE_VARIABLES。

技術分享圖片

總結下,trainable=False凍結變量的邏輯:trainable=False → 該變量不會放入GraphKeys.TRAINABLE_VARIABLES → 調用optimizer.compute_gradients方法時默認變量列表為GraphKeys.TRAINABLE_VARIABLES,該變量不在其中,因此不參與後向傳播,值不進行更新,達到凍結變量效果。

優點:操作簡單,只要在你創建變量時設置trainable=False即可

缺點:不知道大家發現沒有,我上面的總結中,optimizer.compute_gradients方法默認變量列表是GraphKeys.TRAINABLE_VARIABLES,這句話還意味著,如果我不想用默認變量列表,而使用自定義變量列表,那麽即使設置了trainable=False,只要把該變量加入到自定義變量列表

中,變量還是會參與後向傳播的,值也會更新。另外,tf.layerstf.contrib.rnn等一些高度封裝的API是不支持這個參數的,沒法用該方法凍結變量。最後,如果我們在使用Saver保存ckpt時,一般調動tf.trainable_variables()方法只存可訓練參數,這時返回的變量列表,也有上面的問題,即設置了trainable=False的變量不會在裏面。

2.2 tf.stop_gradient()

我們還可以通過在某個變量外面包裹一層tf.stop_gradient()函數來達到凍結變量的目的。例如我們想凍結w1,可以寫成這樣:

w1 = tf.stop_gradient(w1)

在後向傳播時,w1的值就不會更新。下面說下優缺點。

優點:操作簡單,針對想凍結的變量,添加上面這一行即可,而且相比於上一個方法,設置了tf.stop_gradient()的變量,不會從GraphKeys.TRAINABLE_VARIABLES集合中去除,因此不會影響梯度計算和保存模型

缺點:和上一個方法類似,tf.stop_gradient()的輸入是Tensor,tf.layers、tf.contrib.rnn等一些高度封裝的API的返回值沒法作為參數傳入,即不能用該方法凍結

2.3 optimizer.compute_gradients(loss,var_list=no_freeze_vars)

optimizer.compute_gradients在2.1中提到過,其實我們只需要在計算梯度時,指定變量列表,把希望凍結的變量去除,即可完成凍結變量。但這麽做有一個前提,我們必須知道所有可訓練變量的名字,並根據一些規則去除變量。獲取所有可訓練變量名字調用tf.trainable_variables()方法即可,但去除變量則需要我們在構建網絡的時候,合理利用tf.variable_scope,對不同變量做區分。例如,我們如果想把可訓練變量中所有卷積層變量凍結,可以這麽寫:

trainable_vars = tf.trainable_variables()
freeze_conv_var_list = [t for t in trainable_vars if not t.name.startswith(uconv)]
grads = opt.compute_gradients(loss, var_list=freeze_conv_var_list)

下面總結下優缺點,

優點:沒有2.1和2.2的缺點,是一種適用範圍更加廣泛的方法

缺點:相對2.1,2.2使用起來比較復雜,需要自己去除凍結變量,並且variable_scope不能隨意改動,因為可能使去除變量的過濾操作無效化。例如:如果把原來‘cnn‘ scope改為‘vgg‘,那麽上面的代碼就無效了

3.總結

tf對於一些常用操作,往往會提供多種方法,但每種方法一般都是有區別的,並且操作原理和後面的邏輯也會有不同,要謹慎使用

tensorflow凍結變量方法(tensorflow freeze variable)