談談Tensorflow的Batch Normalization的使用
阿新 • • 發佈:2019-02-17
tensorflow 在實現Batch Normalization (各個網路層輸出的結果歸一化,以防止過擬合)時,主要用到一下兩個API。分別是
1)tf.nn.moments(x, axes, name=None, keep_dims=False) ⇒ mean, variance:
其中計算的得到的為統計矩,mean 是一階矩,variance 是二階中心矩 各引數的另一為
- x 可以理解為我們輸出的資料,形如 [batchsize, height, width, kernels]
- axes 表示在哪個維度上求解,是個list,例如 [0, 1, 2]
- name 就是個名字,
- keep_dims 是否保持維度
img = tf.Variable(tf.random_normal([2, 3]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
輸出的結果分別為:
這個例子挺容易理解的,該函式就是在[0] 維度上求了一個均值和方差。 2)tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None) tf.nn.batch_norm_with_global_normalization(t, m, v, beta, gamma, variance_epsilon, scale_after_normalization, name=None) 由函式介面可知,tf.nn.moments 計算返回的 mean 和 variance 作為 tf.nn.batch_normalization 引數進一步呼叫;img = [[ 0.69495416 2.08983064 -1.08764684] [ 0.31431156 -0.98923939 -0.34656194]] mean = [ 0.50463283 0.55029559 -0.71710438] variance = [ 0.0362222 2.37016821 0.13730171]
在這一堆引數裡面,其中x,mean和variance這三個,已經知道了,就是通過moments計算得到的,另外菱格引數,offset和scale一般需要訓練,其中offset一般初始化為0,scale初始化為1,另外這兩個引數的offset,scale的維度和mean相同。
def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
""" Assume 2d [batch, values] tensor"""
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
offset = tf.get_variable('offset', [size])
pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
batch_mean, batch_var = tf.nn.moments(x, [0])
train_mean_op = tf.assign(pop_mean, pop_mean*decay+batch_mean*(1-decay))
train_var_op = tf.assign(pop_var, pop_var*decay + batch_var*(1-decay))
def batch_statistics():
with tf.control_dependencies([train_mean_op, train_var_op]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(training, batch_statistics, population_statistics)
參考文章:
[1] https://www.jianshu.com/p/0312e04e4e83[2] http://blog.csdn.net/lanchunhui/article/details/70792458
歡迎關注: 自然語言處理技術