Batch normalization及其在tensorflow中的實現
阿新 • • 發佈:2018-11-30
Batch normalization(BN)
BN是對輸入的特徵圖進行標準化的操作,其公式為:
- xx - 原輸入
- x^x^ - 標準化後的輸入
- μμ - 一個batch中的均值
- σ2σ2 - 一個batch中的方差
- ϵϵ - 一個很小的數,防止除0
- ββ - 中心偏移量(center)
- γγ - 縮放(scale)係數
tensorflow中提供了三種BN方法:
tf.nn.batch_normalization
tf.layers.batch_normalization
tf.contrib.layers.batch_norm
以tf.layers.batch_normalization
為例介紹裡面所包含的主要引數:
tf.layers.batch_normalization(inputs, decay=0.999, center=True, scale=True, is_training=True, epsilon=0.001)
- 1
一般使用只要定義以下的引數即可:
-
inputs: 輸入張量[N, H, W, C]
-
decay: 滑動平均的衰減係數,一般取接近1的值,這樣能在驗證和測試集上獲得較好結果
-
center: 中心偏移量,上述的ββ ,為True,則自動新增,否則忽略
-
scale: 縮放係數,上述的γγ,為True,則自動新增,否則忽略
-
epsilon: 為防止除0而加的一個很小的數
-
is_training: 是否是訓練過程,為True則代表是訓練過程,那麼將根據decay用指數滑動平均求得moments,並累加儲存到
moving_mean
和moving_variance
中。否則是測試過程,函式直接取這兩個引數來用。如果是True,則需在訓練的session中新增將BN引數更新操作加入訓練的程式碼:
# execute update_ops to update batch_norm weights update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(decayed_learning_rate) train_op = optimizer.minimize(loss, global_step = global_step)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
Note
需要看上述函式的詳細引數,可在python終端通過以下命令獲取:
import tensorflow as tf
help(tf.layers.batch_normalization) # help中新增函式名