1. 程式人生 > >batch normalization 理解

batch normalization 理解

對batch normalization 一直屬於一知半解狀態,二面被問的一臉懵逼,所以決定好好理一理這個問題。

1、What is batch normalization?

batch normalization 其實就是對資料做一個批量的規範化操作,使得在深度神經網路訓練過程中使得每一層神經網路的輸入保持相同分佈的。

具體實現過程如下:

A、對於一個mini-batch,求資料的均值,方差。這裡其實是

                                            

B、得到歸一化之後的資料x,使得結果(輸出訊號各個維度)的均值為0,方差為1.

C、“scale and shift”操作則是為了讓因訓練所需而“刻意”加入的BN能夠有可能還原最初的輸入。

                                    

                                    

2、為什麼要做“scale and shift”?

一句話“模型的表達能力不下降”

我是這樣理解的,你的每次資料肯定是不同分佈的,你才能從學到東西嘛,如果你全都歸一化了,就相當於每次資料都變成一樣的了,那人家網路還怎麼學習。

第一步的規範化會將幾乎所有資料對映到啟用函式的非飽和區(線性區),僅利用到了線性變化能力,從而降低了神經網路的表達能力。

於是加了一個scale和shift,這兩個引數可以經過學習得到,意思是通過scale和shift把這個值從標準正態分佈左移或者由移一點並長胖一點或者變瘦一點,每個例項挪動的程度不一樣,這樣等價於非線性函式的值從正中心周圍的線性區往非線性區動了動。

核心思想應該是想找到一個線性和非線性的較好平衡點,既能享受非線性的較強表達能力的好處,又避免太靠非線性區兩頭使得網路收斂速度太慢。

3、Why batch normalization?

是用來解決“Internal Covariate Shift”(隱層中資料分佈不同)。

A、上層引數需要不斷適應新的輸入資料分佈,降低學習速度。

B、資料的分佈一直在發生變換,可能後逐漸像非線性啟用函式的飽和區域移動,如sigmod函式,可能導致後向傳播的時候淺層神經網路的梯度消失,收斂越來越慢。

C、每層的更新都會影響到其它層,因此每層的引數更新策略需要儘可能的謹慎。所以對引數的選擇特別重要。

4、為什麼有效?

A、Normalization 的資料伸縮不變性。

做了batch normalization之後,可以向成強行將資料拉回均值為0,方差為1的標準分佈,使資料回到梯度變化比較大的敏感區域,就可以避免梯度消失,加快訓練速度。

B、Normalization 的權重伸縮不變性。

大概的意思就是,加入BN之後,不管權重怎麼變換,對於梯度的反向傳播都是沒有影響的。所以可以有效的解決梯度消失和梯度爆炸。

C、同時,權重越大的更新時梯度越小。引數的變化就越穩定,相當於實現了引數正則化的效果,避免參數的大幅震盪,提高網路的泛化效能。

5、到底解決了什麼問題?

A、提升了訓練速度,收斂過程大大加快,還能增加分類效果。

B、類似於Dropout的一種防止過擬合的正則化表達方式,所以不用Dropout也能達到相當的效果。

C、另外調參過程也簡單多了,對於初始化要求沒那麼高,而且可以使用大的學習率等。

D、解決梯度消失和梯度爆炸。

6、Where to use BN?

在每一隱層的啟用函式之前,相當於先歸一化,再拿去啟用。

7、什麼時候不要用?

用計算一階二階統計量,均值和方差。不適用於動態網路結構和RNN。

8、什麼時候效果比較好?

每個mini-batch比較接近,和整體資料之間應該近似同分布。