1. 程式人生 > 實用技巧 >pytorch BatchNorm引數詳解,計算過程

pytorch BatchNorm引數詳解,計算過程

https://blog.csdn.net/weixin_39228381/article/details/107896863

目錄

說明

BatchNorm1d引數

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d訓練時前向傳播

BatchNorm1d評估時前向傳播

總結


說明

網路訓練時和網路評估時,BatchNorm模組的計算方式不同。如果一個網路裡包含了BatchNorm,則在訓練時需要先呼叫train(),使網路裡的BatchNorm模組的training=True(預設是True),在網路評估時,需要先呼叫eval()使網路的training=False。

BatchNorm1d引數

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

輸入維度是(N, C, L)時,num_features應該取C;這裡N是batch size,C是資料的channel,L是資料長度。

輸入維度是(N, L)時,num_features應該取L;這裡N是batch size,L是資料長度,這時可以認為每條資料只有一個channel,省略了C

eps

對輸入資料進行歸一化時加在分母上,防止除零,詳情見下文。

momentum

更新全域性均值running_mean和方差running_var時使用該值進行平滑,詳情見下文。

affine

設為True時,BatchNorm層才會學習引數和,否則不包含這兩個變數,變數名是weight和bias,詳情見下文。

track_running_stats

設為True時,BatchNorm層會統計全域性均值running_mean和方差running_var,詳情見下文。

BatchNorm1d訓練時前向傳播

  1. 首先對輸入batch求和,並用這兩個結果把batch歸一化,使其均值為0,方差為1。歸一化公式用到了eps(),即。如下輸入內容,shape是(3, 4),即batch_size=3,此時num_features需要傳入4。
    1. tensor = torch.FloatTensor([[1, 2, 4, 1],
    2. [6, 3, 2, 4],
    3. [2, 4, 6, 1]])
    此時,(無偏樣本方差)和(有偏樣本方差),有偏和無偏的區別在於無偏的分母是N-1,有偏的分母是N。注意在BatchNorm中,用於更新running_var時,使用無偏樣本方差即,但是在對batch進行歸一化時,使用有偏樣本方差,因此如果batch_size=1,會報錯。歸一化後的內容如下。
    1. [[-0.9258, -1.2247, 0.0000, -0.7071],
    2. [ 1.3887, 0.0000, -1.2247, 1.4142],
    3. [-0.4629, 1.2247, 1.2247, -0.7071]]
  2. 如果track_running_stats==True,則使用momentum更新模組內部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是,其中代表更新後的running_mean和running_var,表示更新前的running_mean和running_var,表示當前batch的均值和無偏樣本方差。
  3. 如果track_running_stats==False,則BatchNorm中不含有running_mean和running_var兩個變數。
  4. 如果affine==True,則對歸一化後的batch進行仿射變換,即乘以模組內部的weight(初值是[1., 1., 1., 1.])然後加上模組內部的bias(初值是[0., 0., 0., 0.]),這兩個變數會在反向傳播時得到更新。
  5. 如果affine==False,則BatchNorm中不含有weight和bias兩個變數,什麼都都不做。

BatchNorm1d評估時前向傳播

  1. 如果track_running_stats==True,則對batch進行歸一化,公式為,注意這裡的均值和方差是running_mean和running_var,在網路訓練時統計出來的全域性均值和無偏樣本方差。
  2. 如果track_running_stats==False,則對batch進行歸一化,公式為,注意這裡的均值和方差是batch自己的mean和var,此時BatchNorm裡不含有running_mean和running_var。注意此時使用的是無偏樣本方差(和訓練時不同),因此如果batch_size=1,會使分母為0,就報錯了。
  3. 如果affine==True,則對歸一化後的batch進行放射變換,即乘以模組內部的weight然後加上模組內部的bias,這兩個變數都是網路訓練時學習到的。
  4. 如果affine==False,則BatchNorm中不含有weight和bias兩個變數,什麼都不做。

總結

在使用batchNorm時,通常只需要指定num_features就可以了。網路訓練前呼叫train(),訓練時BatchNorm模組會統計全域性running_mean和running_var,學習weight和bias,即文獻中的和。網路評估前呼叫eval(),評估時,對傳入的batch,使用統計的全域性running_mean和running_var對batch進行歸一化,然後使用學習到的weight和bias進行仿射變換。