1. 程式人生 > 實用技巧 >Pytorch Batch Normalization 中 track_running_stats問題

Pytorch Batch Normalization 中 track_running_stats問題

Batch Normalization,批規範化

BatchNormalization(簡稱為BN)[2],中文翻譯成批規範化,是在深度學習中普遍使用的一種技術,通常用於解決多層神經網路中間層的協方差偏移(Internal Covariate Shift)問題,類似於網路輸入進行零均值化和方差歸一化的操作,不過是在中間層的輸入中操作而已,具體原理不累述了,見[2-4]的描述即可。

在BN操作中,最重要的無非是這四個式子:

注意到這裡的最後一步也稱之為仿射(affine),引入這一步的目的主要是設計一個通道,使得輸出output至少能夠回到輸入input的狀態(當γ=1,β
=0
時)使得BN的引入至少不至於降低模型的表現,這是深度網路設計的一個套路。

整個過程見流程圖,BN在輸入後插入,BN的輸出作為規範後的結果輸入的後層網路中。 好了,這裡我們記住了,在BN中,一共有這四個引數我們要考慮的:
  • γ,β:分別是仿射中的weightbias,在pytorch中用weightbias表示。
  • μBσB2:和上面的引數不同,這兩個是根據輸入的batch的統計特性計算的,嚴格來說不算是“學習”到的引數,不過對於整個計算是很重要的。在pytorch中,這兩個統計引數,用running_meanrunning_var
    表示[5],這裡的running指的就是當前的統計引數不一定只是由當前輸入的batch決定,還可能和歷史輸入的batch有關,詳情見以下的討論,特別是引數momentum那部分。

Update 2020/3/16:
因為BN層的考核,在工作面試中實在是太常見了,在本文順帶補充下BN層的引數的具體shape大小。
以圖片輸入作為例子,在pytorch中即是nn.BatchNorm2d(),我們實際中的BN層一般是對於通道進行的,舉個例子而言,我們現在的輸入特徵(可以視為之前討論的batch中的其中一個樣本的shape)為xRC×W×H\mathbf{x} \in \mathbb{R}^{C \times W \times H}

xRC×W×H(其中C是通道數,W是width,H是height),那麼我們的μBRC\mu_{\mathcal{B}} \in \mathbb{R}^{C}μBRC,而方差σ2BRC\sigma^{2}_{\mathcal{B}} \in \mathbb{R}^CσB2RC。而仿射中weight,γRC\mathrm{weight}, \gamma \in \mathbb{R}^{C}weight,γRC以及bias,βRC\mathrm{bias}, \beta \in \mathbb{R}^{C}bias,βRC。我們會發現,這些引數,無論是學習引數還是統計引數都會通道數有關,其實在pytorch中,通道數的另一個稱呼是num_features,也即是特徵數量,因為不同通道的特徵資訊通常很不相同,因此需要隔離開通道進行處理。

有些朋友可能會認為這裡的weight應該是一個張量,而不應該是一個向量,其實不是的,這裡的weight其實應該看成是對輸入特徵圖的每個通道得到的歸一化後的xˆ\hat{\mathbf{x}}x^進行尺度放縮的結果,因此對於一個通道數為CCC的輸入特徵圖,那麼每個通道都需要一個尺度放縮因子,同理,bias也是對於每個通道而言的。這裡切勿認為yiγxˆi+βy_i \leftarrow \gamma \hat{x}_i+\betayiγx^i+β這一步是一個全連線層,他其實只是一個尺度放縮而已。關於這些引數的形狀,其實可以直接從pytorch原始碼看出,這裡截取了_NormBase層的部分初始程式碼,便可一見端倪。



在Pytorch中使用

Pytorch中的BatchNorm的API主要有:

一般來說pytorch中的模型都是繼承nn.Module類的,都有一個屬性trainning指定是否是訓練狀態,訓練狀態與否將會影響到某些層的引數是否是固定的,比如BN層或者Dropout層。通常用model.train()指定當前模型model為訓練狀態,model.eval()指定當前模型為測試狀態。
同時,BN的API中有幾個引數需要比較關心的,一個是affine指定是否需要仿射,還有個是track_running_stats指定是否跟蹤當前batch的統計特性。容易出現問題也正好是這三個引數:trainningaffinetrack_running_stats

  • 其中的affine指定是否需要仿射,也就是是否需要上面算式的第四個,如果affine=False,則γ=1,β=0\gamma=1,\beta=0γ=1,β=0,並且不能學習被更新。一般都會設定成affine=True[10]
  • trainningtrack_running_statstrack_running_stats=True表示跟蹤整個訓練過程中的batch的統計特性,得到方差和均值,而不只是僅僅依賴與當前輸入的batch的統計特性。相反的,如果track_running_stats=False那麼就只是計算當前輸入的batch的統計特性中的均值和方差了。當在推理階段的時候,如果track_running_stats=False,此時如果batch_size比較小,那麼其統計特性就會和全域性統計特性有著較大偏差,可能導致糟糕的效果。

一般來說,trainningtrack_running_stats有四種組合[7]

  1. trainning=True,track_running_stats=True。這個是期望中的訓練階段的設定,此時BN將會跟蹤整個訓練過程中batch的統計特性。
  2. trainning=True,track_running_stats=False。此時BN只會計算當前輸入的訓練batch的統計特性,可能沒法很好地描述全域性的資料統計特性。
  3. trainning=False,track_running_stats=True。這個是期望中的測試階段的設定,此時BN會用之前訓練好的模型中的(假設已經儲存下了)running_meanrunning_var並且不會對其進行更新。一般來說,只需要設定model.eval()其中model中含有BN層,即可實現這個功能。[6,8]
  4. trainning=False,track_running_stats=False效果同(2),只不過是位於測試狀態,這個一般不採用,這個只是用測試輸入的batch的統計特性,容易造成統計特性的偏移,導致糟糕效果。

同時,我們要注意到,BN層中的running_meanrunning_var的更新是在forward()操作中進行的,而不是optimizer.step()中進行的,因此如果處於訓練狀態,就算你不進行手動step(),BN的統計特性也會變化的。如

這個時候要將model.eval()轉到測試階段,才能固定住running_meanrunning_var。有時候如果是先預訓練模型然後載入模型,重新跑測試的時候結果不同,有一點效能上的損失,這個時候十有八九是trainningtrack_running_stats設定的不對,這裡需要多注意。 [8]

假設一個場景,如下圖所示:

此時為了收斂容易控制,先預訓練好模型model_A,並且model_A內含有若干BN層,後續需要將model_A作為一個inference推理模型和model_B聯合訓練,此時就希望model_A中的BN的統計特性值running_meanrunning_var不會亂變化,因此就必須將model_A.eval()設定到測試模式,否則在trainning模式下,就算是不去更新該模型的引數,其BN都會改變的,這個將會導致和預期不同的結果。

Update 2020/3/17:
評論區的Oshrin朋友提出問題

作者您好,寫的很好,但是是否存在問題。即使將track_running_stats設定為False,如果momentum不為None的話,還是會用滑動平均來計算running_mean和running_var的,而非是僅僅使用本batch的資料情況。而且關於凍結bn層,有一些更好的方法。

這裡的momentum的作用,按照文件,這個引數是在對統計引數進行更新過程中,進行指數平滑使用的,比如統計引數的更新策略將會變成:

其中的更新後的統計引數xˆnew\hat{x}_{\mathrm{new}}x^new,是根據當前觀察xtx_txt和歷史觀察xˆ\hat{x}x^進行加權平均得到的(差分的加權平均相當於歷史序列的指數平滑),預設的momentum=0.1。然而跟蹤歷史資訊並且更新的這個行為是基於track_running_statstrue並且training=true的情況同時成立的時候,才會進行的,當在track_running_stats=true, training=false時(在預設的model.eval()情況下,即是之前談到的四種組合的第三個,既滿足這種情況),將不涉及到統計引數的指數滑動更新了。[12,13]

這裡引用一個不錯的BN層凍結的例子,如:[14]

總結來說,在某些情況下,即便整體的模型處於model.train()的狀態,但是某些BN層也可能需要按照需求設定為model_bn.eval()的狀態。

Update 2020.6.19:
評論區有個同學問了一個問題:

K.G.lee:想問博主,為什麼模型測試時的引數為trainning=False, track_running_stats=True啊??測試不是用訓練時的滑動平均值嗎?為什麼track_running_stats=True呢?為啥要跟蹤當前batch??

我感覺這個問題問得挺好的,我們需要去翻下原始碼[15],我們發現我們所有的BatchNorm層都有個共同的父類_BatchNorm,我們最需要關注的是return F.batch_norm()這一段,我們發現,其對training的判斷邏輯是

那麼,其實其在eval階段,這裡的track_running_stats並不能設定為False,原因很簡單,這樣會使得上面談到的training=True,導致最終的期望程式錯誤。至於設定了track_running_stats=True是不是會導致在eval階段跟蹤測試集的batch的統計引數呢?我覺得是不會的,我們追蹤會發現[16],整個流程的最後一步其實是呼叫了torch.batch_norm(),其是呼叫C++的底層函式,其引數列表可和track_running_stats一點關係都沒有,只是由training控制,因此當training=False時,其不會跟蹤統計引數的,只是會呼叫訓練集訓練得到的統計引數。(當然,時間有限,我也沒有繼續追到C++層次去看原始碼了)。

Reference

[1].用pytorch踩過的坑
[2]. Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[C]// International Conference on International Conference on Machine Learning. JMLR.org, 2015:448-456.
[3].<深度學習優化策略-1>Batch Normalization(BN)
[4].詳解深度學習中的Normalization,BN/LN/WN
[5]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of-batchnorm-if-gradients-are-accumulated/18870
[7].BatchNorm2d增加的引數track_running_stats如何理解?
[8].Why track_running_stats is not set to False during eval
[9].How to train with frozen BatchNorm?
[10].Proper way of fixing batchnorm layers during training
[11].大白話《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》
[12]. https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146/2
[13]. https://zhuanlan.zhihu.com/p/65439075
[14]. https://github.com/NVIDIA/apex/issues/122
[15]. https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d
[16]. https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#batch_norm

來源:https://blog.csdn.net/LoseInVain/article/details/86476010

<wiz_tmp_tag id="wiz-table-range-border" contenteditable="false" style="display: none;">



來自為知筆記(Wiz)