TensorFlow的batch_normalization
批量標準化(batch normalization簡稱BN)主要是為了克服當神經網路層數加深而導致難以訓練而誕生的。當深度神經網路隨著網路深度加深,訓練起來會越來越困難,收斂速度會很慢,還會產生梯度消失問題(vanishing gradient problem)。
在統計機器學習領域中有一個ICS(Internal Covariate Shift)理論:源域(source domain)和目標域(target domain)的資料分佈是一致的。也就是訓練資料和測試資料滿足相同的分佈,這是通過訓練資料獲得的模型在測試資料上有一個好的效果的保證。
Covariate Shift是指訓練資料的樣本和測試資料的樣本分佈不一致時,訓練獲取的模型無法很好的泛化
解決Covariate Shift問題可以通過對訓練樣本和測試樣本的比例對訓練樣本做一個矯正,通過批量標準化來標準化某些層或所有層的輸入,從而固定每層輸入訊號的均值與方差。
一、批量標準化的實現
批量標準化是在啟用函式之前,對z=wx+b做標準化,使得輸出結果滿足標準的正態分佈,即均值為0,方差為1。讓每一層的輸入有一個穩定的分佈便於網路的訓練。
二、批量標準化的優點
1、加大探索的步長,加快模型收斂的速度
2、更容易跳出區域性最小值
3、破壞原來的資料分佈,在一定程度上可以緩解過擬合。
當遇到神經網路收斂速度很慢或梯度爆炸等無法訓練的情況時,可以嘗試使用批量標準化來解決問題。
三、TensorFlow的批量標準化例項
1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)
函式介紹:計算x的均值和方差
引數介紹:
- x:需要計算均值和方差的tensor
- axes:指定求解x某個維度上的均值和方差,如果x是一維tensor,則axes=[0]
- name:用於計算均值和方差操作的名稱
- keep_dims:是否產生與輸入相同相同維度的結果
z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
#計算z的均值和方差
#計算列的均值和方差
z_mean_col,z_var_col = tf.nn.moments(z,axes=[0])
#[1.5 1.5 1.5 1.5 1.5] [0.25 0.25 0.25 0.25 0.25]
#計算行的均值和方差
z_mean_row,z_var_row = tf.nn.moments(z,axes=[1])
#等價於axes=[-1],-1表示最後一維
#[1. 2.] [0. 0.]
#計算整個陣列的均值和方差
z_mean,z_var = tf.nn.moments(z,axes=[0,1])
#1.5 0.25
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(z))
print(sess.run(z_mean_col),sess.run(z_var_col))
print(sess.run(z_mean_row),sess.run(z_var_row))
print(sess.run(z_mean),sess.run(z_var))
2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)
函式介紹:計算batch normalization
引數介紹:
- x:輸入的tensor,具有任意的維度
- mean:輸入tensor的均值
- variance:輸入tensor的方差
- offset:偏置tensor,初始化為1
- scale:比例tensor,初始化為0
- variance_epsilon:一個接近於0的數,避免除以0
z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
#計算z的均值和方差
z_mean,z_var = tf.nn.moments(z,axes=[0,1])
scale = tf.Variable(tf.ones([2,5]))
shift = tf.Variable(tf.zeros([2,5]))
#計算batch normalization
z_bath_norm = tf.nn.batch_normalization(z,z_mean,z_var,shift,scale,variance_epsilon=0.001)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(z))
print(sess.run(z_bath_norm))