1. 程式人生 > >空間變換網路--spatial transform network

空間變換網路--spatial transform network

CNN分類時,通常需要考慮輸入樣本的區域性性、平移不變性、縮小不變性,旋轉不變性等,以提高分類的準確度。這些不變性的本質就是影象處理的經典方法,即影象的裁剪、平移、縮放、旋轉,而這些方法實際上就是對影象進行空間座標變換,我們所熟悉的一種空間變換就是仿射變換,影象的仿射變換公式可以表示如下:
這裡寫圖片描述

上式中,實現影象的放大、縮小、平移、旋轉等。

那麼,神經網路是否有辦法,用一種統一的結構,自適應實現這些變換呢?本文提出了一種叫做空間變換網路(Spatial Transform Networks, STN)的網路模型,該網路不需要關鍵點的標定,能夠根據分類或者其它任務自適應地將資料進行空間變換和對齊(包括平移、縮放、旋轉以及其它幾何變換等)。在輸入資料空間差異較大的情況下,這個網路可以加在現有的卷積網路中,提高分類的準確性。

本文所提的空間變換網路的主要作用在於

1.可以將輸入變換為網路的下一層期望的形式;

2.可以在訓練的過程中自動選擇感興趣的區域特徵;

3.可以實現對各種形變的資料進行空間變換;

這裡寫圖片描述

例如對於上圖中輸入手寫字型,我們感興趣的是黃色框中的包含數字的區域,那麼在訓練的過程中,學習到的空間變換網路會自動提取黃色框中的區域性資料特徵,並對框內的資料進行空間變換,得到輸出output。

2. 空間變換網路原理詳解

所謂空間變換網路,實際上是在神經網路的某兩層之間引入一個空間變換網路,該空間變換網路包括兩個部分,網路結構如下圖所示:

這裡寫圖片描述

第一部分為為”localization net”,其用於生成仿射變換系絡結構進行設計,”localization net”網路中的引數則為空間變換網路需要訓練的引數;第二部分就是空間變換即仿射變換。通過該區域性網路產生仿射變換系數θ

θ後對U進行仿射變換的原理。

對於仿射變換,如果直接由仿射變換系數θθ大小的座標位置點矩陣,接下來就是要對該座標位置點進行仿射變換,仿射變換公式及示意圖如下:

這裡寫圖片描述

需要注意的是,這時候的仿射係數為θθ的逆矩陣,經過仿射變換後可以得到V中的位置座標點在U中對應的位置。但是,這時候求得的U中的座標點仍然可能為非整數,因此通常需要進行插值得到對應的座標點。得到U中的座標點後,則可以將其複製到V中,從而得到仿射變換結果V。具體仿射變換過程,也可以結合下圖進行理解:

這裡寫圖片描述

上圖中由Localisation Net生成仿射變換系數θθ後,仿射變換的過程是依次執行步驟1,2,3,4。

3. 空間變換網路的實際應用

以上講解的是空間變換網路的理解,那麼在實際應用中,我們該如何新增空間變換網路到我們自己的網路中呢?接下來重點講解空間變換網路的應用。

 3.1.空間變換網路作為網路的第一層

空間變換網路可以直接作為網路的第一層,即Localisation Net的輸入為input,從而直接對輸入進行仿射變換,對於Localisation Net的設計,可以根據輸入input的大小設計Localisation Net為全連線層或卷積層,例如對於手寫字型,輸入圖片大小為40x40,即input=[batch_size,1600],那麼我們可以設計Localisation Net包含兩個全連線層,第一個全連線層w1=[1600,20],b1=[20],第一個全連線層w2=[20,6],b2=[6],則第二個全連線層的輸出為[batch_size,6],即為仿射變換系數。若輸入Localisation Net的input尺寸較大,則需要在Localisation Net中添加捲積和池化層,最後再輸入到全連線層,得到仿射變換系數;關於全連線層和卷積層的設計詳見參考程式碼4.1,4.2。

 3.2.空間變換網路插入CNN的中間層

空間變換網路還可以新增在CNN的中間層,可以直接將空間變換網路插入conv或者max-pooling層的前面或者後面。此外,還可以在CNN的同一層插入多個空間變換網路,下面給出空間變換網路插入CNN的手寫字型網路結構圖:

這裡寫圖片描述

上圖中第一個空間變換網路ST1作用於輸入影象,直接對輸入影象進行空間變換,第二、三個空間變換網路ST2a,ST2b作用於conv1,用於對第一層的卷積特徵進行空間變換,而ST3用於對更深層的卷積特徵進行空間變換。

由於空間變換網路能夠自動提取區域性區域特徵,因此在網路的同一層插入父哦個空間變換網路可以提取多個區域性區域特徵,從而可以結合多個區域性區域特徵進行分類,如下如的網路是實現兩張輸入的圖片中的手寫字型相加,在網路的第一層插入兩層空間變換網路ST1,ST2,並將其直接作用語輸入影象。圖中第三列為空間變換結果,有圖可知,網路ST1,ST2分別提取了輸入手寫字型的不同區域的特徵。

這裡寫圖片描述

4. 程式碼分析

首先看一仿射變換的程式碼實現,程式碼的實現如上所述,首先由函式_meshgrid生成輸出V的座標位置點grid,在通過仿射變換系數theta對grid進行仿射變換得到U中對於位置座標點T_g,之後對T_g進行雙線性插值,並複製插值後的U中的座標點的畫素值到V中,得到輸出V。具體程式碼實現如下:


    def transform(theta, input_dim, out_size):
        with tf.variable_scope('_transform'):
            num_batch = tf.shape(input_dim)[0]
            height = tf.shape(input_dim)[1]
            width = tf.shape(input_dim)[2]
            num_channels = tf.shape(input_dim)[3]
            theta = tf.reshape(theta, (-1, 2, 3))
            theta = tf.cast(theta, 'float32')

            # grid of (x_t, y_t, 1), eq (1) in ref [1]
            height_f = tf.cast(height, 'float32')
            width_f = tf.cast(width, 'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            grid = _meshgrid(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.pack([num_batch]))
            grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))#得到輸出座標位置點

            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
            T_g = tf.batch_matmul(theta, grid)#仿射變換
            x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])#
            y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
            x_s_flat = tf.reshape(x_s, [-1])
            y_s_flat = tf.reshape(y_s, [-1])

            input_transformed = _interpolate(
                input_dim, x_s_flat, y_s_flat,
                out_size)#插值,並得到輸出

            output = tf.reshape(
                input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
            return output

接下來結合兩個具體的例項分別講解”localization net”的為全連線層和卷積層的設計。

(4.1)”localization net”的為全連線層的例項:

該示例中,空間變換網路用於對輸入影象進行變換處理,”localization net”包括兩個全連線層,具體網路設計如下:



    x = tf.placeholder(tf.float32, [None, 1600])#輸入
    y = tf.placeholder(tf.float32, [None, 10])


    x_tensor = tf.reshape(x, [-1, 40, 40, 1])


    W_fc_loc1 = weight_variable([1600, 20])#第一個全連線層
    b_fc_loc1 = bias_variable([20])

    W_fc_loc2 = weight_variable([20, 6])#第二個全連線層

    initial = np.array([[1., 0, 0], [0, 1., 0]])
    initial = initial.astype('float32')
    initial = initial.flatten()
    b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')


    h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)

    keep_prob = tf.placeholder(tf.float32)
    h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)

    h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)#仿射變換系數theta,大小為bath_size*6

    out_size = (40, 40)
    h_trans = transformer(x_tensor, h_fc_loc2, out_size)

(4.2)”localization net”的為卷積層的例項:

如以下程式碼所示,該”localization net”層包括兩個卷積層,幷包括兩個全連線層,得到仿射變換系數theta。


from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D

#localization net,得到仿射變換系數theta

locnet = Sequential()
locnet.add(MaxPooling2D(pool_size=(2,2), input_shape=input_shape))#
locnet.add(Convolution2D(20, 5, 5))
locnet.add(MaxPooling2D(pool_size=(2,2)))
locnet.add(Convolution2D(20, 5, 5))

locnet.add(Flatten())
locnet.add(Dense(50))
locnet.add(Activation('relu'))
locnet.add(Dense(6, weights=weights))#輸出仿射變換系數theta
#locnet.add(Activation('sigmoid'))

#build the model

model = Sequential()
model.add(SpatialTransformer(localization_net=locnet,
                             downsample_factor=3, input_shape=input_shape))#仿射變換

model.add(Convolution2D(32, 3, 3, border_mode='same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(256))
model.add(Activation('relu'))

model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam')