1. 程式人生 > >TensorFlow的batch_normalization

TensorFlow的batch_normalization

批量標準化(batch normalization簡稱BN)主要是為了克服當神經網路層數加深而導致難以訓練而誕生的。當深度神經網路隨著網路深度加深,訓練起來會越來越困難收斂速度會很慢,還會產生梯度消失問題(vanishing gradient problem)。

在統計機器學習領域中有一個ICS(Internal Covariate Shift)理論:源域(source domain)和目標域(target domain)的資料分佈是一致的。也就是訓練資料和測試資料滿足相同的分佈,這是通過訓練資料獲得的模型在測試資料上有一個好的效果的保證。

Covariate Shift是指訓練資料的樣本和測試資料的樣本分佈不一致時,訓練獲取的模型無法很好的泛化

。它是分佈不一致假設之下的一個分支問題,也就是指源域和目標域的條件概率是一致的,但是其邊緣概率不同。對於神經網路而言,神經網路的各層輸出,在經過了層內操作後,各層輸出分佈會隨著輸入分佈的變化而變化,而且差異會隨著網路的深度增加而加大,但是每一層隨指向的樣本標記是不會改變的。

解決Covariate Shift問題可以通過對訓練樣本和測試樣本的比例對訓練樣本做一個矯正,通過批量標準化來標準化某些層或所有層的輸入,從而固定每層輸入訊號的均值與方差。

一、批量標準化的實現

批量標準化是在啟用函式之前,對z=wx+b做標準化,使得輸出結果滿足標準的正態分佈,即均值為0,方差為1。讓每一層的輸入有一個穩定的分佈便於網路的訓練。

二、批量標準化的優點

1、加大探索的步長,加快模型收斂的速度

2、更容易跳出區域性最小值

3、破壞原來的資料分佈,在一定程度上可以緩解過擬合。

當遇到神經網路收斂速度很慢或梯度爆炸等無法訓練的情況時,可以嘗試使用批量標準化來解決問題。

三、TensorFlow的批量標準化例項

1、tf.nn.moments(x,axes,shift=None,name=None,keep_dims=False)

函式介紹:計算x的均值和方差

引數介紹:

  • x:需要計算均值和方差的tensor
  • axes:指定求解x某個維度上的均值和方差,如果x是一維tensor,則axes=[0]
  • name:用於計算均值和方差操作的名稱
  • keep_dims:是否產生與輸入相同相同維度的結果
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
    #計算z的均值和方差
    #計算列的均值和方差
    z_mean_col,z_var_col = tf.nn.moments(z,axes=[0])
    #[1.5 1.5 1.5 1.5 1.5] [0.25 0.25 0.25 0.25 0.25]
    #計算行的均值和方差
    z_mean_row,z_var_row = tf.nn.moments(z,axes=[1])
    #等價於axes=[-1],-1表示最後一維
    #[1. 2.] [0. 0.]
    #計算整個陣列的均值和方差
    z_mean,z_var = tf.nn.moments(z,axes=[0,1])
    #1.5 0.25
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(z))
    print(sess.run(z_mean_col),sess.run(z_var_col))
    print(sess.run(z_mean_row),sess.run(z_var_row))
    print(sess.run(z_mean),sess.run(z_var))

2、tf.nn.batch_normalization(x,mean,variance,offset,scale,variance_epsilon,name=None)

函式介紹:計算batch normalization

引數介紹:

  • x:輸入的tensor,具有任意的維度
  • mean:輸入tensor的均值
  • variance:輸入tensor的方差
  • offset:偏置tensor,初始化為1
  • scale:比例tensor,初始化為0
  • variance_epsilon:一個接近於0的數,避免除以0
    z = tf.constant([[1,1,1,1,1],[2,2,2,2,2]],dtype=tf.float32)
    #計算z的均值和方差
    z_mean,z_var = tf.nn.moments(z,axes=[0,1])
    scale = tf.Variable(tf.ones([2,5]))
    shift = tf.Variable(tf.zeros([2,5]))
    #計算batch normalization
    z_bath_norm = tf.nn.batch_normalization(z,z_mean,z_var,shift,scale,variance_epsilon=0.001)
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    print(sess.run(z))
    print(sess.run(z_bath_norm))