1. 程式人生 > >Batch Normalization文章學習筆記

Batch Normalization文章學習筆記

BN學習筆記

Batch Normalization的提出

BN是谷歌提出的一種深度學習,網路優化的結構,能夠加速網路的訓練. 文章在提出方法之前,對之前的一些成果進行了回顧包括

  1. 深度學習網路在訓練過程中,訓練變慢的原因:隨著網路深度的加深,由於訓練過程的資訊前向傳遞的過程中,一旦前一層的Layer的資訊的distribution發生改變.對後面的分佈也隨之改變(文章中稱為Internal Covariate Shift),隨著深度的逐漸深入資訊的distrution也會發生改變,而這些又需要重新學習,這造成了深度學習訓練困難的難題(包括模型梯度爆炸或者消逝).原文也回顧了一些經驗上的解決方法: ReLu結構,權重初值的小心的初始化,網路模型分步訓練等等…

  2. 文章將問題歸結到Layer輸出資訊的Internal Covariate Shift以後,說了一些前人與之相關的類似的工作,例如對資料集進行白化,或者簡單的中心化都能夠加速網路的訓練使之達到比較好的訓練效果,但是如果要計算資料集的協方差矩陣,存在很大的計算量,包括求逆矩陣等等,這會帶來很大的計算複雜度,所以原文提出了一種簡化的方法對每層的的輸出進行尺度變換

    μ=1mi=1mxi\mu=\frac{1}{m}\sum_{i=1}^m{x_i} σ2=1mi=1m(xiμ)2\sigma^2=\frac{1}{m}\sum_{i=1}^m{(x_i-\mu)^2}

    x^=xμσ2+ϵ\hat{x}=\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}} x^\hat{x}是變換後的一個尺度無關的變數,經過這一層以後變數被固定在了一個尺度無關的區域,然後對x^\hat{x}進行線性變換 y=γx^+βy=\gamma*\hat{x}+\beta 其中y,βy , \beta都是自動學習的變數,一旦訓練完成後,就屬於固定不變的數值了 在驗證集中,由於沒有batch的概念,所以網路之前的引數都是E(x)Var[x]E(x) \quad Var[x]
    使用之前的穩定的數值,這點要特別注意. E(x)=Eb(μ)E(x)=E_b(\mu) Var(σ2)=mm1E(σ2)Var(\sigma^2)=\frac{m}{m-1}E(\sigma^2)

  3. 最後對BN對梯度的改善做了一些說明BN不改變輸出相對於輸入的梯度,當梯度變大時會減小輸出相對於權值的梯度,增加網路的穩定性,防止模型崩潰 BN(Wu)u=BN(αWu)u\frac{\partial{BN(Wu)}}{\partial{u}}=\frac{\partial{BN(\alpha Wu)}}{\partial{u}} BN(Wu)u=1αBN(αWu)W\frac{\partial{BN(Wu)}}{\partial{u}}=\frac{1}{\alpha}\frac{\partial{BN(\alpha Wu)}}{\partial{W}}

總結:在知乎看到的帖子,對於BN的原理的揭示重點是放在了梯度上,BN最重要的是對梯度進行了優化,大概率防止了原來模型中因為BP鏈式求導帶來的累積乘法項帶來的梯度消逝和爆炸.個人感覺也傾向於這點. BN雖然有效,但是理論依據個人感覺也並沒有那麼充分,有些推理過程中的論斷更多的依賴於實踐經驗,而不是數學…