BoTNet:Bottleneck Transformers for Visual Recognition
【GiantPandaCV導語】基於Transformer的骨幹網路,同時使用卷積與自注意力機制來保持全域性性和區域性性。模型在ResNet最後三個BottleNeck中使用了MHSA替換3x3卷積。屬於早期的結合CNN+Transformer的工作。簡單來講Non-Local+Self Attention+BottleNeck = BoTNet
引言
本文的發展脈絡如下圖所示:
實際上沿著Transformer Block改進的方向進行的,與CNN架構也是相容的。具體結構如下圖所示:
兩者都遵循了BottleNeck的設計原則,可以有效降低計算量。
設計Transformer中self attention存在幾個挑戰:
- 圖片尺寸比較大,比如目標檢測中解析度在1024x1024
- 記憶體和計算量的佔用高,導致訓練開銷比較大。
本文設計如下:
- 使用卷積識別底層特徵的抽象資訊。
- 使用self attention處理通過卷積層得到的高層資訊。
這樣可以有效處理大解析度影象。
方法
BoTNet中MHSA模組如下圖所示:
上邊的這個MHSA Block是核心創新點,其與Transformer中的MHSA有所不同:
- 由於處理物件不是一維的,而是類似CNN模型,所以有非常多特性與此相關。
- 歸一化這裡並沒有使用Layer Norm而是採用的Batch Norm,與CNN一致。
- 非線性啟用,BoTNet使用了三個非線性啟用
- 左側content-position模組引入了二維的位置編碼,這是與Transformer中最大區別。
由於該模組是處理BxCHW的形式,所以難免讓人想起來Non Local 操作,這裡列出筆者以前繪製的一幅圖:
可以看出主要區別就是在於Content-postion模組引入的位置資訊。
BoTNet細節設計:
整體的設計和ResNet50幾乎一樣,唯一不同在於最後一個階段中三個BottleNeck使用了MHSA模組。具體這樣做的原因是Self attention需要消耗巨大的計算量,在模型最後加入時候feature map的size比較小,相對而言計算量比較小。
實驗
在目標檢測和分割領域效能對比
解析度改變對BoTNet幫助更大
消融實驗-相對位置編碼
BoTNet對ResNet系列模型的提升
BoTNet與更大的圖片適配
BoTNet與Non-Local Net的比較
與ImageNet上結果比較
模型放縮的影響
顯示卡香氣飄來,又是谷歌的騷操作,將EfficientNet方法放在BoTNet上:
可以看出與期望相符合,Transformer架構帶來的效能上限要高於CNN,雖然模型大小比較小的時候效能比較弱,但是模型量變大以後其效能就有了保證。
程式碼
核心模組:MHSA (由第三方進行實現)
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=4):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k)
content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
content_position = torch.matmul(content_position, q)
energy = content_content + content_position
attention = self.softmax(energy)
out = torch.matmul(v, attention.permute(0, 1, 3, 2))
out = out.view(n_batch, C, width, height)
return out
參考
https://arxiv.org/abs/2101.11605
https://zhuanlan.zhihu.com/p/347849929
https://github.com/leaderj1001/BottleneckTransformers/blob/main/model.py
跑不動ImageNet,想試試Vision Transformer的同學可以看看這個倉庫,
https://github.com/pprp/pytorch-cifar-model-zoo
在CIFAR10上測試:
python train.py --model 'botnet' --name "fast_training" --sched 'cosine' --epochs 100 --cutout True --lr 0.1 --bs 128 --nw 4
目前可以在100個epoch內達到驗證集91.1%的準確率。
程式碼改變世界