1. 程式人生 > >BatchNorm在TensorFlow中的應用

BatchNorm在TensorFlow中的應用

x = tf.nn.batch_normalization(x, mean, variance, beta, gama, BN_EPSILON)

x為輸入資料,mean為批量資料x的均值,variance為批量資料x的方差(注意均值,方差為每一個維度求均值,方差),beta和gama分別為可學習的平移引數和縮放參數,BN_EPSILON防止方差為0(通常設為0.001)。

完整的bn函式如下

def bn(x, use_bn, is_training):
  x_shape = x.get_shape()
  params_shape = x_shape[-1:]
  if not use_bn:
    bias = _get_variable('bias', params_shape, initializer=tf.zeros_initializer)
    retrun x + bias
  axis = list(range(len(x_shape) - 1))
  beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer)
  gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer)
  moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer, trainable=False)
  moving_variance = _get_variable(moving_variance, params_shape, initializer=tf.ones_initializer, trainable=False)

  # these ops will only be performed when training.
  mean, variance = tf.nn.moments(x, axis)
  update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
  update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
  mean, variance = control_flow_ops.cond(is_training, lambda:(mean, variance), lambda:(moving_mean, moving_variance))
  x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
  return x

1.下面簡單介紹下里面所用的函式:

tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)計算x的均值和方差(當x的維度為[batch, height, width, depth]時,對於global normalization,axes = [0, 1, 2], 此時mean和variance的維度為[depth],對於simple batch normalization, axes = [0], 此時mean和variance的維度為[height, width, depth])

moving_averages.assign_moving_average(variable, value, decay, zero_debias=True, name=None)計算變數的滑動平均值,更新後變數的值為variable * decay + value * (1 - decay), 在本例中decay = BN_DECAY = 0.9997

tf.add_to_collection(name, value)把變數放入一個集合,集合的關鍵字為name

control_flow_ops.cond(pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None)返回true_fn如果pred為True否則返回false_fn。本例中用以控制訓練和測試使用的均值和方差。注意pred不能是python bool

2.接下來簡單介紹怎麼更新bn的引數:

batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

tf.get_collection(key, scope=None):從關鍵字為key的集合中取出全部的變數,如果scope不為None,則取出該集合中包含scope變數名的變數。

tf.group(*inputs, **kwargs):將一些operation或者變數group起來。