pytorch BatchNorm引數詳解,計算過程
https://blog.csdn.net/weixin_39228381/article/details/107896863
目錄
說明
網路訓練時和網路評估時,BatchNorm模組的計算方式不同。如果一個網路裡包含了BatchNorm,則在訓練時需要先呼叫train(),使網路裡的BatchNorm模組的training=True(預設是True),在網路評估時,需要先呼叫eval()使網路的training=False。
BatchNorm1d引數
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
num_features
輸入維度是(N, C, L)時,num_features應該取C;這裡N是batch size,C是資料的channel,L是資料長度。
輸入維度是(N, L)時,num_features應該取L;這裡N是batch size,L是資料長度,這時可以認為每條資料只有一個channel,省略了C
eps
對輸入資料進行歸一化時加在分母上,防止除零,詳情見下文。
momentum
更新全域性均值running_mean和方差running_var時使用該值進行平滑,詳情見下文。
affine
設為True時,BatchNorm層才會學習引數和,否則不包含這兩個變數,變數名是weight和bias,詳情見下文。
track_running_stats
設為True時,BatchNorm層會統計全域性均值running_mean和方差running_var,詳情見下文。
BatchNorm1d訓練時前向傳播
- 首先對輸入batch求和,並用這兩個結果把batch歸一化,使其均值為0,方差為1。歸一化公式用到了eps(),即。如下輸入內容,shape是(3, 4),即batch_size=3,此時num_features需要傳入4。
-
tensor = torch.FloatTensor([[1, 2, 4, 1],
- [6, 3, 2, 4],
- [2, 4, 6, 1]])
- [[-0.9258, -1.2247, 0.0000, -0.7071],
- [ 1.3887, 0.0000, -1.2247, 1.4142],
- [-0.4629, 1.2247, 1.2247, -0.7071]]
-
tensor = torch.FloatTensor([[1, 2, 4, 1],
- 如果track_running_stats==True,則使用momentum更新模組內部的running_mean(初值是[0., 0., 0., 0.])和running_var(初值是[1., 1., 1., 1.]),更新公式是,其中代表更新後的running_mean和running_var,表示更新前的running_mean和running_var,表示當前batch的均值和無偏樣本方差。
- 如果track_running_stats==False,則BatchNorm中不含有running_mean和running_var兩個變數。
- 如果affine==True,則對歸一化後的batch進行仿射變換,即乘以模組內部的weight(初值是[1., 1., 1., 1.])然後加上模組內部的bias(初值是[0., 0., 0., 0.]),這兩個變數會在反向傳播時得到更新。
- 如果affine==False,則BatchNorm中不含有weight和bias兩個變數,什麼都都不做。
BatchNorm1d評估時前向傳播
- 如果track_running_stats==True,則對batch進行歸一化,公式為,注意這裡的均值和方差是running_mean和running_var,在網路訓練時統計出來的全域性均值和無偏樣本方差。
- 如果track_running_stats==False,則對batch進行歸一化,公式為,注意這裡的均值和方差是batch自己的mean和var,此時BatchNorm裡不含有running_mean和running_var。注意此時使用的是無偏樣本方差(和訓練時不同),因此如果batch_size=1,會使分母為0,就報錯了。
- 如果affine==True,則對歸一化後的batch進行放射變換,即乘以模組內部的weight然後加上模組內部的bias,這兩個變數都是網路訓練時學習到的。
- 如果affine==False,則BatchNorm中不含有weight和bias兩個變數,什麼都不做。
總結
在使用batchNorm時,通常只需要指定num_features就可以了。網路訓練前呼叫train(),訓練時BatchNorm模組會統計全域性running_mean和running_var,學習weight和bias,即文獻中的和。網路評估前呼叫eval(),評估時,對傳入的batch,使用統計的全域性running_mean和running_var對batch進行歸一化,然後使用學習到的weight和bias進行仿射變換。