BatchNorm在TensorFlow中的應用
x = tf.nn.batch_normalization(x, mean, variance, beta, gama, BN_EPSILON)
x為輸入資料,mean為批量資料x的均值,variance為批量資料x的方差(注意均值,方差為每一個維度求均值,方差),beta和gama分別為可學習的平移引數和縮放參數,BN_EPSILON防止方差為0(通常設為0.001)。
完整的bn函式如下
def bn(x, use_bn, is_training): x_shape = x.get_shape() params_shape = x_shape[-1:] if not use_bn: bias = _get_variable('bias', params_shape, initializer=tf.zeros_initializer) retrun x + bias axis = list(range(len(x_shape) - 1)) beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer) gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer) moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer, trainable=False) moving_variance = _get_variable(moving_variance, params_shape, initializer=tf.ones_initializer, trainable=False) # these ops will only be performed when training. mean, variance = tf.nn.moments(x, axis) update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean) tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance) mean, variance = control_flow_ops.cond(is_training, lambda:(mean, variance), lambda:(moving_mean, moving_variance)) x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON) return x
1.下面簡單介紹下里面所用的函式:
tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)計算x的均值和方差(當x的維度為[batch, height, width, depth]時,對於global normalization,axes = [0, 1, 2], 此時mean和variance的維度為[depth],對於simple batch normalization, axes = [0], 此時mean和variance的維度為[height, width, depth])
moving_averages.assign_moving_average(variable, value, decay, zero_debias=True, name=None)計算變數的滑動平均值,更新後變數的值為variable * decay + value * (1 - decay), 在本例中decay = BN_DECAY = 0.9997
tf.add_to_collection(name, value)把變數放入一個集合,集合的關鍵字為name
control_flow_ops.cond(pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None)返回true_fn如果pred為True否則返回false_fn。本例中用以控制訓練和測試使用的均值和方差。注意pred不能是python bool
2.接下來簡單介紹怎麼更新bn的引數:
batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION) batchnorm_updates_op = tf.group(*batchnorm_updates) train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
tf.get_collection(key, scope=None):從關鍵字為key的集合中取出全部的變數,如果scope不為None,則取出該集合中包含scope變數名的變數。
tf.group(*inputs, **kwargs):將一些operation或者變數group起來。