1. 程式人生 > >深度神經網路優化策略之——殘差學習

深度神經網路優化策略之——殘差學習

問題起源

  深度學習普遍認為發端於2006年,根據Bengio的定義,深層網路由多層自適應非線性單元組成——即多層非線性模組的級聯,所有層次上都包含可訓練的引數,在工程實際操作中,深層神經網路通常是五層及以上,包含數百萬個可學習的自由引數的龐然大物。理論上,網路模型無論深淺與否,都能通過函式逼近資料的內在關係和本質特徵,但在解決真實世界的複雜問題時,需要指數增長的計算單元,淺層網路往往出現函式表達能力不足,而深層網路則可能僅僅需要較少的計算單元。
  不過網路並不是像理論上那樣越深越好,除了顯而易見的因為層數過多而導致浪費性質的佔用視訊記憶體和“吃”計算力的問題,還會出現以下三種問題。
- 過擬合 (over fit)
- 梯度彌散 (vanishing gradient problem)
- 網路退化

(degenerate)

  其中,問題一、二並不是本文所講的殘差學習主要要解決的問題,所以就不多贅述,只講述網路退化的問題。其現象如下圖所示,是隨著網路層數的增多,整體模型的表達能力增強,但是訓練精度反而變差,並且因為訓練精度本身也下降的緣故,故而可以排除是過擬合的原因,而確定是網路退化。

When deeper networks are able to start converging, a degradation problem has been exposed: with the network depth increasing, accuracy gets saturated which might be unsurprising and then degrades rapidly. Unexpectedly, such degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error,as reported in and thoroughly verified by our experiments.
image

  但是,很可惜的是,業界依然對於網路退化的原因及其標準情況仍然沒有定論,甚至說出現了隨著網路變深而效果變差的問題的時候,也有可能無法分辨出是梯度彌散還是網路退化的問題。讀者如果有興趣,可以自行去尋找網路退化方面的研究論文,各家的觀點雖然都不盡相同,但我們還是可以發現不少有用的資訊。

殘差學習

而對於上述問題,Kaiming He大神提出了一種簡潔而不失優雅的殘差學習的方法。多的不談,我們直接甩出模型結構來講解殘差學習的思想。

image

  首先,只看圖的左半邊,也就是橘紅色的部分。左側與普通網路連線方式的區別一目瞭然——在順次直連而下的基礎上加入了每隔兩層的跨接橋(其實官方的叫法並非如此,然而這麼叫它顯得更加直觀)。不過純憑看圖的感覺畢竟流於表面,用數學說話才是嚴謹的態度。

對於一個神經網路而言,我們需要用反向傳播來更新引數,就像這樣:

image

image

此時,第二個式子所得的結果就是我們常說的梯度。

而當如下圖網路越來越深的時候:

image

……

image

image

這時候再通過算偏導求梯度,就會是這樣:

image

其實數列的每一項都很小,再依此相乘就會越來越小,最後趨近於0,舉個簡單的例子就是0.9雖然很接近於1,但當有n個0.9相乘時(n趨近於無限大),最後的結果就會無限趨近於0。

而當有了“跨接橋”之後,我們再算偏導的時候就會變成這樣:

image

說白了就是1.01的n次方依然大於1。

最後,我們可以發現對於相同的資料集來講,殘差網路比同等深度的其他網路表現出了更好的效能。

image

image

  不過,這是大神的測試結果,沒有什麼說服力,而我在自己的專案裡做了一組關於有無殘差學習的對比,下面是資料圖(專案是和影象增強有關,所以用PSNR作為評判標準):

image

image

  最後可見,Loss的下降趨勢,殘差學習的方法明顯更加平穩,而最後結果Loss和PSNR雖然差距目測不大,但最後的圖片視覺效果卻千差萬別。

下一節我們會講模型結構圖的右半邊——同樣是Kaiming He大神的Skip Connection策略。