1. 程式人生 > 其它 >『論文筆記』Swin Transformer

『論文筆記』Swin Transformer

https://zhuanlan.zhihu.com/p/361366090

目前transform的兩個非常嚴峻的問題

  1. 受限於影象的矩陣性質,一個能表達資訊的圖片往往至少需要幾百個畫素點,而建模這種幾百個長序列的資料恰恰是Transformer的天生缺陷;
  2. 目前的基於Transformer框架更多的是用來進行影象分類,對例項分割這種密集預測的場景Transformer並不擅長解決。

在Swin Transformer之前的ViT和iGPT,它們都使用了小尺寸的影象作為輸入,這種直接resize的策略無疑會損失很多資訊。與它們不同的是,Swin Transformer的輸入是影象的原始尺寸另外Swin Transformer使用的是CNN中最常用的層次的網路結構,在CNN中一個特別重要的一點是隨著網路層次的加深,節點的感受野也在不斷擴大,這個特徵在Swin Transformer中也是滿足的。Swin Transformer的這種層次結構,也賦予了它可以像FPN,U-Net等結構實現可以進行分割或者檢測的任務。


圖1:Swin Transformer和ViT的對比

圖2:Swin-T的網路結構

在圖2中,輸入影象之後是一個Patch Partition,再之後是一個Linear Embedding層,這兩個加在一起其實就是一個Patch Merging層(至少上面的原始碼中是這麼實現的)。這一部分的原始碼如下:

class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x) # (1, 48, 3136)
        x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48)
        x = self.linear(x) # (1, 56, 56, 96)
        return x


Patch Merging的作用是對影象進行降取樣,類似於CNN中Pooling層。Patch Merging是主要是通過nn.Unfold函式實現降取樣的,nn.Unfold的功能是對影象進行滑窗,相當於卷積操作的第一步,因此它的引數包括視窗的大小和滑窗的步長。根據原始碼中給出的超參我們知道這一步降取樣的比例是
,因此經過nn.Unfold之後會得到
個長度為
的特徵向量,其中是輸入到這個stage的Feature Map的通道數,第一個stage的輸入是RGB影象,因此通道數為3,表示為式(1)。

接著的viewpermute是將得到的向量序列還原到的二維矩陣,linear是將長度是的特徵向量對映到out_channels

的長度,因此stage-1的Patch Merging的輸出向量維度是,對比原始碼的註釋,這裡省略了第一個batch為的維度。

可以看出Patch Partition/Patch Merging起到的作用像是CNN中通過帶有步長的滑窗來降低解析度,再通過卷積來調整通道數。不同的是在CNN中最常使用的降取樣的最大池化或者平均池化往往會丟棄一些資訊,例如最大池化會丟棄一個視窗內的地響應值,而Patch Merging的策略並不會丟棄其它響應,但它的缺點是帶來運算量的增加。在一些需要提升模型容量的場景中,我們其實可以考慮使用Patch Merging來替代CNN中的池化。