1. 程式人生 > >Spatial Transformer Network

Spatial Transformer Network

求導 ram ret als 每一個 mage 部分 設置 row

https://blog.csdn.net/yaoqi_isee/article/details/72784881

Abstract:

作者說明了CNN對於輸入的數據缺乏空間變換不變形(lack of spatially invariant ability to input data),因此作者引入了一個spatial transformer module,不需要額外的監督,能夠以data-driven的方式學習得到輸入圖像的空間變換參數,賦予網絡spatial invariant能力。

Introduction:

普通的神經網絡通過max-pooling實現了一定程度上的translation invariance,但是這種不變形是通過網絡的max-pooling層的堆疊實現的,對於網絡內部的feature map來說,輸入的圖像如果進行了大範圍的(平移)變換,feature map還是無法做到invariance(因為每一個max-pooling就只是2x2大小的模塊,只能保證在2x2大小範圍內的微小的變換,輸出是不變的,通過堆疊這些2x2的池化單元,才能實現對大尺度平移變換的不變性)。

這篇文章中,作者提出了一個spatial transformer module(記為ST模塊),這個模塊對於任意輸入的圖像或者feature map,產生一個對應的spatial transform的參數,然後根據這個參數將原來的圖像或者feature map做一個全局(而非局部)的空間變換,得到最終的canonical pose(也就是正正方方的圖,比如原來物體是斜的,通過ST模塊之後變成正的了)。

Spatial Transformer:
ST模塊可以分成三個部分:localization network根據輸入的feature map回歸spatial transform的參數 θθ,然後用這個參數去生成一個采樣的grid,最後根據這個grid以及輸入的feature map得到輸出的經過空間變換的feature map,如下圖所示
技術分享圖片

Localization network
localization的網絡輸入一張feature map URH×W×CU∈RH×W×C,輸出 θ=floc(U)θ=floc(U), θθ 的size取決於我們預先定義的空間變換的類型,比如仿射變換的話,大小就是6維。

Parameterized Sampling Grid
有了空間變換的參數之後,我們就可以知道輸出的feature map上的每一個點在輸入的feature map上的位置了。比如說對於二維的仿射變換,我們可以建立輸出feature map上的坐標和輸入feature map上坐標之間的映射關系:
技術分享圖片
其中 (xti,yti)

(xit,yit) 表示輸出的feature map上的坐標,(xsi,ysi)(xis,yis) 表示輸出feature map上坐標對應在輸入feature map上的采樣點坐標。
當我們把上面的參數特殊化之後,其實就可以model其他的變換,比如attention,crop,translation,以attention為例,參數為
技術分享圖片

Differentiable Image Sampling
知道了輸出feature map在輸入feature map上的采樣點坐標之後,接下來就是要根據采樣點的值確定輸出目標點的值了。這裏一般會用到kernel,以采樣點為中心的kernel範圍內的點對輸出目標點的值都有貢獻。
技術分享圖片
上式中,V表示輸出特征圖,i表示特征圖的下標,c表示第c個channel,所有的空間變換對於各個channel都是一樣的。H,WH′,W′ 表示輸出的特征圖的長寬。U表示輸入的特征圖,k表示預定義的kernel,xsi,ysixis,yis 表示采樣點坐標,ΦΦ 表示kernel的參數。

一般常用的kernel為雙線性插值kernel(根據輸出feature map上規律的坐標值計算輸入的feature map上采樣點的坐標通常得到的坐標值是小數,所以可以用雙線性插值計算采樣點處的feature值),這個時候,上式就退化成
技術分享圖片
求導的話也很方便
技術分享圖片
根據采樣點的坐標值可以繼續對參數 θθ 求導,從而更新localization網絡的參數。

Experiment

1.Distorted MNIST
作者首先在Distorted MNIST數據集上進行實驗,主要存在以下幾種空間變換:R(旋轉)、RTS(旋轉平移尺度變換)、P(投影變換)、E(彈性變換)
作者設置了兩個baseline:fully-connected NN(FCN)以及convolutional NN(CNN)。
實驗組加入了ST模塊,分別是Aff(仿射變換)、Proj(投影變換)以及TPS(plate spline transformation)
實驗結果如下所示:
技術分享圖片
表格中的數字表示不同的模型在不同的distorted mnist數據集上的錯誤率。可以看到加了ST模塊的模型相比沒有加ST的對照模型,錯誤率降低了。
右圖中第一欄表示輸入的distorted的圖像,(b)欄表示根據 θθ 得到的sampling grid,(c)欄表示spatial transformer的輸出。

3.Fine-Grained Classification
通過在一個網絡裏面加入多個ST模塊,可以提高網絡model各種空間變換的能力。作者在CUB數據集(the birds appear at a range of scales and orientations, are not tightly cropped)上進行了實驗。
作者以state-of-art作為baseline
作者自己的網絡采用了2個或者4個ST模塊,模型如下圖所示,ST模塊用的是attention機制的:
技術分享圖片
通過locoliation網絡預測兩個 θθ,然後根據這兩個 θθ 得到兩個sampling的結果,分別取提取特征做分類。
技術分享圖片
上圖表示CUB上的結果。可以看到2ST-CNN中一個集中在頭部一個集中在身體。

--------------------- 本文來自 yj_isee 的CSDN 博客 ,全文地址請點擊:https://blog.csdn.net/yaoqi_isee/article/details/72784881?utm_source=copy

Spatial Transformer Network