『論文筆記』Swin Transformer
阿新 • • 發佈:2021-07-26
https://zhuanlan.zhihu.com/p/361366090
目前transform的兩個非常嚴峻的問題
- 受限於影象的矩陣性質,一個能表達資訊的圖片往往至少需要幾百個畫素點,而建模這種幾百個長序列的資料恰恰是Transformer的天生缺陷;
- 目前的基於Transformer框架更多的是用來進行影象分類,對例項分割這種密集預測的場景Transformer並不擅長解決。
在Swin Transformer之前的ViT和iGPT,它們都使用了小尺寸的影象作為輸入,這種直接resize的策略無疑會損失很多資訊。與它們不同的是,Swin Transformer的輸入是影象的原始尺寸另外Swin Transformer使用的是CNN中最常用的層次的網路結構,在CNN中一個特別重要的一點是隨著網路層次的加深,節點的感受野也在不斷擴大,這個特徵在Swin Transformer中也是滿足的。Swin Transformer的這種層次結構,也賦予了它可以像FPN,U-Net等結構實現可以進行分割或者檢測的任務。
圖1:Swin Transformer和ViT的對比
圖2:Swin-T的網路結構
在圖2中,輸入影象之後是一個Patch Partition,再之後是一個Linear Embedding層,這兩個加在一起其實就是一個Patch Merging層(至少上面的原始碼中是這麼實現的)。這一部分的原始碼如下:
class PatchMerging(nn.Module): def __init__(self, in_channels, out_channels, downscaling_factor): super().__init__() self.downscaling_factor = downscaling_factor self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0) self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels) def forward(self, x): b, c, h, w = x.shape new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor x = self.patch_merge(x) # (1, 48, 3136) x = x.view(b, -1, new_h, new_w).permute(0, 2, 3, 1) # (1, 56, 56, 48) x = self.linear(x) # (1, 56, 56, 96) return x
Patch Merging的作用是對影象進行降取樣,類似於CNN中Pooling層。Patch Merging是主要是通過nn.Unfold
函式實現降取樣的,nn.Unfold
的功能是對影象進行滑窗,相當於卷積操作的第一步,因此它的引數包括視窗的大小和滑窗的步長。根據原始碼中給出的超參我們知道這一步降取樣的比例是
,因此經過nn.Unfold
之後會得到
個長度為的特徵向量,其中是輸入到這個stage的Feature Map的通道數,第一個stage的輸入是RGB影象,因此通道數為3,表示為式(1)。
接著的view
和permute
是將得到的向量序列還原到的二維矩陣,linear
是將長度是的特徵向量對映到out_channels
可以看出Patch Partition/Patch Merging起到的作用像是CNN中通過帶有步長的滑窗來降低解析度,再通過卷積來調整通道數。不同的是在CNN中最常使用的降取樣的最大池化或者平均池化往往會丟棄一些資訊,例如最大池化會丟棄一個視窗內的地響應值,而Patch Merging的策略並不會丟棄其它響應,但它的缺點是帶來運算量的增加。在一些需要提升模型容量的場景中,我們其實可以考慮使用Patch Merging來替代CNN中的池化。