Pytorch——BatchNorm層和LayerNorm層的引數含義以及應用理解
在我們平常面試和工程中會用到BN和LN,但或許沒有去了解過BN和LN到底在那個維度上進行的正則化(減均值除以標準差)。下面將會採用各種例子來為大家介紹BN層和LN層各個引數以及差別。
一、BatchNorm(批標準化):
BatchNorm一共有三個函式分別是BatchNorm1d,BatchNorm2d,BatchNorm3d,她們的輸入的tensor的維度是不一樣的,以及引數的定義也是不一樣的,我們一個一個的說。
BatchNorm1d:
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
引數含義:
num_features:如果你輸出的tensor是(N,C,L)維度的,那麼這裡定義為C;如果你輸入的tensor是(N,L)維度的,則此處設定為L。這裡N表示batch_size,C是資料的channel(通道),L是特徵維度(資料長度)。
eps:對輸入資料進行歸一化時加在分母上,防止除零。
momentum :計算整個樣本全域性均值running_mean和方差running_var時是採用動量的模式進行,這個設定的是這個動量的大小,後文會提到。
affine:一個布林值,當設定為True時,此模組具有可學習的仿射引數weight和bias,一般我們做正則化是使得資料服從N(0,1),但是經過仿射變換可以到到N(bias,weight^2)的正態分佈。這兩個引數可以學習的,初始化時weigt為1,bias為0
track_running_stats:設為True時,BatchNorm層會統計全域性均值running_mean和方差running_var。
從引數的含義我們可以知道,針對不同的tensor輸出,我們提前設定的num_features時不一樣的,當做BN的tensor維度是(N,C,L)時,我們定義的num_features是C,意味著我們會根據每個通道的不同樣本的L長度特徵進行相加再除以N*L得到均值,因此會得到C個均值。再具體點的例項就是輸入為(5,3,10)的tensor,我們會取[0,0,:],[1,0,:],....[4,0,:]這些向量的數值(一共有5*10個數字)加起來除以5*10得到第一個通道的均值,並對以上數字進行正則化。當做BN的tensor維度是(N,L),我們定義的num_features是L,因此我們會計算出L個均值和方差,可以看成(N,L,1)的形式,每一個特徵長度為1,只有L個通道,具體點的例項:輸入維度為(4,5)的tensor,會取[0,0],[1,0],[2,0],[3,0]這4個數進行正則化,可以知道我們最終會得到L個均值和方差
momentum引數的應用是為了計算全域性樣本的均值和方差的,因為當訓練完樣本資料後,我們可以得到一個整個樣本的均值和方差,但是這個均值和方差的得到不是把所有樣本都計算遍歷一遍計算得到的,而是在每一個betch經過BatchNorm1d的時候,內部會儲存下該次batch的均值和方差,並通過以下等式來計算得到全域性的均值和方差。
$$x_{new} = (1-momentum) \timesx_{cur} + momentum \times x_{batch}$$
如果track_running_stats=False,則在內部不會進行計算全域性均值running_mean和方差running_var。
下面直接上例子來看看:
m = nn.BatchNorm1d(5, affine=False, momentum=0.1)
tensor = torch.FloatTensor([i for i in range(20)]).reshape(4,5)
print(tensor)
output = m(tensor)
print(output)
print(m.running_mean)
print(m.running_var)
結果如下:
### tensor的取值
tensor([[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.]])
### BN之後的結果 tensor([[-1.3416, -1.3416, -1.3416, -1.3416, -1.3416], [-0.4472, -0.4472, -0.4472, -0.4472, -0.4472], [ 0.4472, 0.4472, 0.4472, 0.4472, 0.4472], [ 1.3416, 1.3416, 1.3416, 1.3416, 1.3416]])
### 全域性均值(由於momentum=0.1,running_mean初始值為0,所以這是根據公式計算過後的結果) tensor([0.7500, 0.8500, 0.9500, 1.0500, 1.1500])
### 全域性方差 tensor([5.0667, 5.0667, 5.0667, 5.0667, 5.0667])
我們再來看看,輸入的tensor是三維的情況:
m = nn.BatchNorm1d(2, affine=False, momentum=0.1)
tensor = torch.FloatTensor([i for i in range(18)]).reshape(3,2,3)
print(tensor)
output = m(tensor)
print(output)
print(m.running_mean)
print(m.running_var)
結果如下:
### tensor的取值
tensor([[[ 0., 1., 2.], [ 3., 4., 5.]], [[ 6., 7., 8.], [ 9., 10., 11.]], [[12., 13., 14.], [15., 16., 17.]]])
### BN之後的取值 tensor([[[-1.4094e+00, -1.2081e+00, -1.0067e+00], [-1.4094e+00, -1.2081e+00, -1.0067e+00]], [[-2.0135e-01, -2.9802e-08, 2.0135e-01], [-2.0135e-01, 5.9605e-08, 2.0135e-01]], [[ 1.0067e+00, 1.2081e+00, 1.4094e+00], [ 1.0067e+00, 1.2081e+00, 1.4094e+00]]])
### 全域性均值和方差 tensor([0.7000, 1.0000]) tensor([3.6750, 3.6750])
大概檢驗一下,根據計算公式,第一個均值應該是每個樣本通道1的所有特徵求和得到的均值也就是(0+1+2+6+7+8+12+13+14)/ 9 = 7 (全域性均值是乘以了momentum=0.1的結果),方差應該是:$[(0-7)^2+.....+(14-7)^2]/9=24.66$(注意這裡是有偏樣本方差分母是N)。[0,0,0]位置這個數BN之後為:$(0-7)/ \sqrt{24.66}=-1.4094$,有些同學可能或說為什麼全域性方差不對呢?因為全域性方差中計算的是無偏的樣本方差(分母是N-1),並且初始值running_var=1.
BatchNorm2d:
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
引數:
num_features:輸入的tensor是(N,C,H,W)維度,num_features定義為C。其中N表示batch_size
為什麼會出現BatchNorm2d呢?那是因為1d只能處理(N,L)和(N,C,L)兩種輸入,但是在CV領域特徵圖常常是(N,C,H,W),這時候1d就沒法處理了,就需要2d了,其實有了1d的瞭解,2d就是把剩下兩個維度的數全部搞在一起進行計算均值和方差。具體點的例子就是輸入tensor為(5,3,100,120),那我們定義的num_featurs=3那麼我們將提取[0,0,:,:],[1,0,:,:],[2,0,:,:]...[4,0,:,:]這幾個矩陣的所有數(一共有5*100*120個數字)計算均值和方差,因此我們可以知道我們最後得到的依然是3組均值和方差。
其他引數和1d是一樣的,作用也是一樣的。我們接下來看下例子。
m = nn.BatchNorm2d(3, affine=False, momentum=0.1) tensor = torch.FloatTensor([i for i in range(36)]).reshape(3,3,2,2) print(tensor) output = m(tensor) print(output) print(m.running_mean) print(m.running_var)
結果如下:
tensor([[[[ 0., 1.], [ 2., 3.]], [[ 4., 5.], [ 6., 7.]], [[ 8., 9.], [10., 11.]]], [[[12., 13.], [14., 15.]], [[16., 17.], [18., 19.]], [[20., 21.], [22., 23.]]], [[[24., 25.], [26., 27.]], [[28., 29.], [30., 31.]], [[32., 33.], [34., 35.]]]]) tensor([[[[-1.3690, -1.2676], [-1.1661, -1.0647]], [[-1.3690, -1.2676], [-1.1661, -1.0647]], [[-1.3690, -1.2676], [-1.1661, -1.0647]]], [[[-0.1521, -0.0507], [ 0.0507, 0.1521]], [[-0.1521, -0.0507], [ 0.0507, 0.1521]], [[-0.1521, -0.0507], [ 0.0507, 0.1521]]], [[[ 1.0647, 1.1661], [ 1.2676, 1.3690]], [[ 1.0647, 1.1661], [ 1.2676, 1.3690]], [[ 1.0647, 1.1661], [ 1.2676, 1.3690]]]])
## 全域性均值和方差,計算方式看前文 tensor([1.3500, 1.7500, 2.1500]) tensor([11.5091, 11.5091, 11.5091])
我們就簡單計算下前兩個均值就好了,第一個均值:(0+1+2+3+12+13+14+15+24+25+26+27)/ 12 = 13.5,第二個均值:(4+5+6+7+16+17+18+19+28+29+30+31)/ 12 = 17.5大家可以看看我就計算了那些數字。
BatchNorm3d:
torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
有了前面的知識,其實我們也能猜到3d處理的就是輸入tensor為(N,C,D,H,W)的情況,此時num_features應該定義C。給一個例子來說就是進行把最後三個維度的數字按照樣本個數全部加起來,數字總數為N*D*H*W,最後依然會產生C個均值和方差。
二、LayerNorm(層標準化):
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)
引數看起來和BatchNorm差不多,但是LayerNorm不會記錄全域性的均值和方差。最重要的就是前三個引數。
normalized_shape:可以設定為:int,列表,或者torch.Size([3, 4])
eps:對輸入資料進行歸一化時加在分母上,防止除零。
elementwise_affine:是否進行仿射變換,如果是True則此模組具有可學習的仿射引數weight和bias,使得能夠使得資料不僅僅是服從N(0,1)正態分佈。
normalized_shape傳入整數
如果normalized_shape傳入的是整數,那麼會在輸入tensor的最後一維一定要和這個整數一樣,比如normalized_shape=4,則tensor的最後一個維度一定要為4,而進行的正則化就是最最後一位的數字進行的,更具體的例子,輸入tensor維度為(3,4),那麼會對[0,0],[0,1],[0,2],[0,3]位置的加一起求均值和方差(一共只有4個數),那麼就會出現3個均值和方差。看例子:
m = nn.LayerNorm(4, elementwise_affine=True) tensor = torch.FloatTensor([i for i in range(12)]).reshape(3,4) print(tensor) output = m(tensor) print(output) print(m.weight) print(m.bias)
結果:
tensor([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]]) tensor([[-1.3416, -0.4472, 0.4472, 1.3416], [-1.3416, -0.4472, 0.4472, 1.3416], [-1.3416, -0.4472, 0.4472, 1.3416]], grad_fn=<NativeLayerNormBackward0>) Parameter containing: tensor([1., 1., 1., 1.], requires_grad=True) Parameter containing: tensor([0., 0., 0., 0.], requires_grad=True)
大家可以自行計算下[0,1,2,3]這四個數正則化之後的結果。是不是就是圖中所示。因此normalized_shape傳入的是整數還是比較好理解的。
normalized_shape傳入列表
如果normalized_shape傳入的是列表,比如[3,4],那麼需要要求傳入的tensor需要最後兩個維度需要滿足[3, 4],會把最後兩個維度以用12個數據進行求均值和方差並正則化。具體一點的例子,傳入的tensor維度為(N,C,3,4)那麼會對【0,0,:,:】這12個數進行正則化,【0,1,:,:】這12個數進行正則化.....因此最後得到會得到N*C個均值和方差。看例子。
m = nn.LayerNorm([3,4], elementwise_affine=True) tensor = torch.FloatTensor([i for i in range(12*4)]).reshape(2,2,3,4) print(tensor) output = m(tensor) print(output) print(m.weight) print(m.bias)
結果如下:
tensor([[[[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]], [[12., 13., 14., 15.], [16., 17., 18., 19.], [20., 21., 22., 23.]]], [[[24., 25., 26., 27.], [28., 29., 30., 31.], [32., 33., 34., 35.]], [[36., 37., 38., 39.], [40., 41., 42., 43.], [44., 45., 46., 47.]]]]) tensor([[[[-1.5933, -1.3036, -1.0139, -0.7242], [-0.4345, -0.1448, 0.1448, 0.4345], [ 0.7242, 1.0139, 1.3036, 1.5933]], [[-1.5933, -1.3036, -1.0139, -0.7242], [-0.4345, -0.1448, 0.1448, 0.4345], [ 0.7242, 1.0139, 1.3036, 1.5933]]], [[[-1.5933, -1.3036, -1.0139, -0.7242], [-0.4345, -0.1448, 0.1448, 0.4345], [ 0.7242, 1.0139, 1.3036, 1.5933]], [[-1.5933, -1.3036, -1.0139, -0.7242], [-0.4345, -0.1448, 0.1448, 0.4345], [ 0.7242, 1.0139, 1.3036, 1.5933]]]], grad_fn=<NativeLayerNormBackward0>) Parameter containing: tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]], requires_grad=True) Parameter containing: tensor([[0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], requires_grad=True)
大家也可以自行計算下[0, 1, 2, 3,4,5...11]這12個數字最後正則化之後的結果。
這裡要注意的是weight和bias在訓練過程中是會更新的,並且會在一次正則化中使用多次,比如上面(2,2,3,4)的例子,(0,0,:,:)會使用weight和bias對應位置的數字,(0,1,:,:)也會使用對應位置數字。
參考網頁:
BatchNorm2d — PyTorch 1.10 documentation
BatchNorm1d — PyTorch 1.10 documentation
BatchNorm3d — PyTorch 1.10 documentation
LayerNorm — PyTorch 1.10 documentation
pytorch LayerNorm引數詳解,計算過程_拿鐵大俠的部落格-CSDN部落格_nn.layernorm使用
pytorch BatchNorm引數詳解,計算過程_拿鐵大俠的部落格-CSDN部落格_batchnorm 引數
【PyTorch】詳解pytorch中nn模組的BatchNorm2d()函式_安靜-CSDN部落格_nn.batchnorm2d
以上內容如有錯誤,懇請指正