1. 程式人生 > >tensorflow中Batch Normalization的實現

tensorflow中Batch Normalization的實現

tensorflow版本1.4

tensorflow目前還沒實現完全封裝好的Batch Normalization的實現,這裡主要試著實現一下。
關於理論可參見《 解讀Batch Normalization》

對於TensorFlow下的BN的實現,首先我們列舉一下需要注意的事項:

  • (1)需要自動適應卷積層(batch_size*height*width*channel)和全連線層(batch_size*channel);
  • (2)需要能夠分別處理Training和Testing的情況,Training時需要更新均值和方差,Testing時使用歷史滑動平均得到的均值與方差,即需要提供is_training的標誌位引數;
  • (3)最好提供滑動平均係數可調;
  • (4)BN的計算量較大,儘量提高儲存與運算效率;
  • (5)需要注意alpha和beta引數可以被BP更新,而均值和方差通過計算得到;
  • (6)load模型時,歷史均值、方差以及alpha和beta引數需要被正常載入;

最終的實現如下:

#coding=utf-8
# util.py 用於實現一些功能函式

import tensorflow as tf

# 實現Batch Normalization
def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
    # 獲取輸入維度並判斷是否匹配卷積層(4)或者全連線層(2)
shape = x.shape assert len(shape) in [2,4] param_shape = shape[-1] with tf.variable_scope(name): # 宣告BN中唯一需要學習的兩個引數,y=gamma*x+beta gamma = tf.get_variable('gamma',param_shape,initializer=tf.constant_initializer(1)) beta = tf.get_variable('beat', param_shape,initializer=tf.constant_initializer(0
)) # 計算當前整個batch的均值與方差 axes = list(range(len(shape)-1)) batch_mean, batch_var = tf.nn.moments(x,axes,name='moments') # 採用滑動平均更新均值與方差 ema = tf.train.ExponentialMovingAverage(moving_decay) def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean,batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) # 訓練時,更新均值與方差,測試時使用之前最後一次儲存的均值與方差 mean, var = tf.cond(tf.equal(is_training,True),mean_var_with_update, lambda:(ema.average(batch_mean),ema.average(batch_var))) # 最後執行batch normalization return tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)

測試函式如下:

import util
import tensorflow as tf


# 注意bn_layer中滑動平均的操作導致該層只支援半精度、float32和float64型別變數
x = tf.constant([[1,2,3],[2,4,8],[3,9,27]],dtype=tf.float32)
y = util.bn_layer(x,True)

# 注意bn_layer中的一些操作必須被提前初始化
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('x = ',x.eval())
    print('y = ',y.eval())

結果輸出如下:(證明我們的初步計算是正確的)

x =  [[  1.   2.   3.]
 [  2.   4.   8.]
 [  3.   9.  27.]]
y =  [[-1.22473562 -1.01904869 -0.93499756]
 [ 0.         -0.33968294 -0.45137817]
 [ 1.2247355   1.35873151  1.38637543]]

下面介紹一下,實現過程中遇到的一些函式:

tf.nn.moments

# 用於在指定維度計算均值與方差
tf.nn.moments(
    x,
    axes,
    shift=None,
    name=None,
    keep_dims=False
)
  • x: 輸入Tensor
  • axes: int型Array,用於指定在哪些維度計算均值與方差。如果x是1-D向量且axes=[0] 那麼該函式就是計算整個向量的均值與方差
  • shift: 暫時無用

tf.train.ExponentialMovingAverage

# 類,用於計算滑動平均
tf.train.ExponentialMovingAverage

__init__(
    decay,
    num_updates=None,
    zero_debias=False,
    name='ExponentialMovingAverage'
)

具體的滑動公式如下,等價於一種指數衰減:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

tf.control_dependencies

# tf.control_dependencies(control_inputs)返回一個控制依賴的上下文管理器,
# 使用with關鍵字可以讓在這個上下文環境中的操作都在control_inputs之後執行
# 比如:
with tf.control_dependencies([a, b]):
  # 只有在a和b執行完後,c和d才會被執行
  c = ...
  d = ...

tf.cond

# 用於有條件的執行函式,當pred為True時,執行true_fn函式,否則執行false_fn函式
tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)

尤其需要注意的是,pred引數是tf.bool型變數,直接寫“True”或者“False”是python型bool,會報錯的。因此在我的BN實現中使用了tf.equal(is_training,True)的操作。

tf.nn.batch_normalization

# 用於最中執行batch normalization的函式
tf.nn.batch_normalization(
    x,
    mean,
    variance,
    offset,
    scale,
    variance_epsilon,
    name=None
)

計算公式為: y = scale*(x-mean)/var + offset