1. 程式人生 > >tensorflow下實現ResNet網路對資料集cifar-10的影象分類

tensorflow下實現ResNet網路對資料集cifar-10的影象分類

DenseNet傳送門:DenseNet

先來簡單講講ResNet的網路結構。ResNet的出現是為了解決深度網路中由於層數太多,導致的degradation problem(退化問題),作者在原論文中對比了較為“耿直”的深度卷積網路(例如以VGG為原型,不斷加深層數)在不同層數的訓練精度:


從圖中可見,18層的卷積結構反而要比34層的準確率要高,這就是所謂的degradation problem。而ResNet提出了一種比較新穎的解決方法,即identity mapping:在常用的卷積結構中加入一個shortcut connection(捷徑)。如下圖所示:


以上圖為例,對於一個input,對其做兩層卷積操作(包含BN,RELU),在輸出端加上卷積操作之前的數值。這裡得到的輸出作為後續卷積層的輸入。這樣一個操作便構成了Residual module。

其效果也很顯著,相比未加入residual結構的卷積網路,其準確率是隨著層數加深而增加的,如下圖所示:

其中作者在論文中還介紹了一些ResNet的結構變化,比如卷積核的改變:


以及相應的ResNet A(在shortcut connection的部分為了使維度一致,僅使用zero-padding,不加任何訓練引數), ResNet B(僅在需要使維度一致的情況下使用權值引數W計算Wx得到與輸出一致的維度), ResNet C(在每一個shortcut connection均使用權值引數W對shortcut connection的輸入進行訓練)。不過C版的引數量相較於B版較大,且並沒有較大的提升效果,並不推薦使用。

而在進行cifar10的資料集訓練的時候,我這裡使用的是A版,即不帶引數訓練的shortcut connection。網路的輸入是32x32的圖片,並且經過預處理對所有點減去平均值。第一層為一個3x3的卷積層。之後一共有6n個3x3的卷積層,其中對於為{32,16,8}的features maps分別為2n個卷積層。殘差網路最後接一個全域性的平均池化,以及一個10類的全連線。如果想修改成其他的網路結構(如resnet_56),只需要修改residual_blocks中的n引數即可。

其模型結構部分程式碼如下所示:


其程式碼下載地址:Architecture        其中包含一個之前寫好的inceptionv3 和一個正在寫的DenseNet.