1. 程式人生 > >Tensorflow訓練和預測中的BN層的坑

Tensorflow訓練和預測中的BN層的坑

  以前使用Caffe的時候沒注意這個,現在使用預訓練模型來動手做時遇到了。在slim中的自帶模型中inception, resnet, mobilenet等都自帶BN層,這個坑在《實戰Google深度學習框架》第二版這本書P166裡只是提了一句,沒有做出解答。

  書中說訓練時和測試時使用的引數is_training都為True,然後給出了一個連結供參考。本人剛開始使用時也是按照書中的做法沒有改動,後來從儲存後的checkpoint中載入模型做預測時出了問題:當改變需要預測資料的batchsize時預測的label也跟著變,這意味著checkpoint裡面沒有儲存訓練中BN層的引數,使用的BN層引數還是從需要預測的資料中計算而來的。這顯然會出問題,當預測的batchsize越大,假如你的預測資料集和訓練資料集的分佈一致,結果就越接近於訓練結果,但如果batchsize=1,那BN層就發揮不了作用,結果很難看。

  那如果在預測時is_traning=false呢,但BN層的引數沒有從訓練中儲存,那使用的就是隨機初始化的引數,結果不堪想象。

  所以需要在訓練時把BN層的引數儲存下來,然後在預測時載入,參考幾位大佬的部落格,有了以下訓練時新增的程式碼:

 1 update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
 2 with tf.control_dependencies(update_ops):
 3         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)
4 5 # 設定儲存模型 6 var_list = tf.trainable_variables() 7 g_list = tf.global_variables() 8 bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name] 9 bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name] 10 var_list += bn_moving_vars 11 saver = tf.train.Saver(var_list=var_list, max_to_keep=5)

這樣就可以在預測時從checkpoint檔案載入BN層的引數並設定is_training=False。

最後要說的是,雖然這麼做可以解決這個問題,但也可以利用預測資料來計算BN層的引數,不是說一定要儲存訓練時的引數,兩種方案可以作為超引數來調節使用,看哪種方法的結果更好。

感謝幾位大佬的部落格解惑:

  https://blog.csdn.net/dongjbstrong/article/details/80447110?utm_source=blogxgwz0

  http://www.cnblogs.com/hrlnw/p/7227447.html