【ARXIV2202】Visual Attention Network
【ARXIV2202】Visual Attention Network
一些想法
- 這個方法看起來非常簡單,有些像在MobileNet 中間加了一個 帶空洞的 depth-wise conv
- 論文題目說是提出了一個 attention 模組,但網路本質還是四階段的 transformer
- 沒有任何一個模組是新提出的,但組合起來是在 計算量 和 準確率 間取得了平衡
- 效能的提升,可能關鍵點還是在於四階段 transformer 網路的獨特結構和訓練策略
研究動機
作者指出 self-attention 存在三個不足:(1)將影象處理為一維序列,忽略了其二維結構。(2)很難處理高解析度影象。(2)它只捕捉了空間適應性,而忽略了通道適應性。因此,作者提出了一種新的大核注意力(LKA)模組,並進一步介紹了一種基於LKA的新的神經網路——視覺注意網路(VAN)。
方法介紹
1、Large Kernel Attention (LKA)
LKA 與 MobileNet 很相似,MobileNet將標準卷積解耦為兩部分,即 Depth-wise conv 和 Point-wise conv。作者將卷積分解為三個部分:大核 Depth-wise conv、大核帶空洞的 Depth-wise conv 和 Point-wise conv。這樣就有效地分解大的卷積核。
上圖中,彩色網格表示卷積核的位置,黃色網格表示中心點。從圖中可以看出,13×13卷積分解為5×5 DConv,5×5 dilated DConv,空洞步長為3 ,1×1點卷積。
如下所示,實際程式碼實現中為 5×5 的 DConv,7×7 的 dilated DConv,dilation=3,最後是1×1點卷積。
class AttentionModule(nn.Module): def __init__(self, dim): super().__init__() self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)#深度卷積 self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)#深度空洞卷積 self.conv1 = nn.Conv2d(dim, dim, 1)#逐點卷積 def forward(self, x): u = x.clone() attn = self.conv0(x) attn = self.conv_spatial(attn) attn = self.conv1(attn) return u * attn #注意力操作 class SpatialAttention(nn.Module): def __init__(self, d_model): super().__init__() self.proj_1 = nn.Conv2d(d_model, d_model, 1) self.activation = nn.GELU() self.spatial_gating_unit = AttentionModule(d_model) #注意力操作 self.proj_2 = nn.Conv2d(d_model, d_model, 1) def forward(self, x): shorcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) #注意力操作 x = self.proj_2(x) x = x + shorcut #殘差連線 return x class Block(nn.Module): def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU): super().__init__() self.norm1 = nn.BatchNorm2d(dim) self.attn = SpatialAttention(dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.BatchNorm2d(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) layer_scale_init_value = 1e-2 self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))#drop_path分支中,每個batch有概率使樣本在self.attn或者mlp不會”執行“,會以0直接傳遞。 x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) return
網路結構上,作者還是使用了 swin 的結構分為四個階段:H/4×W/4、H/8×W/8、H/16×W/16和H/32×W/32。隨著解析度的降低,輸出通道的數量也在不斷增加。網路細節如下表所示,其中e.r. 表示FFN中的 expansion ratio。
每個階段的圖示如下:
實驗分析
1、影象分類
在影象分類任務上,VAN優於其他引數計算成本相似的CNN,ViTs 和 MLPs。作者在每個類別中選擇了一個具有代表性的網路進行討論。ConvNeXt[53]是一種特殊的CNN,它吸收了VIT的一些優點,如大的感受野(7×7卷積)和先進的訓練策略(300個epoch、資料增強等)。VAN和ConvNeXt[53]相比,VAN-base比CoNvNeXt-t多出0.7%,因為VAN具有更大的感受野和自適應能力。Swin-Transformer是一種著名的ViT變體,採用區域性注意力和移動視窗的方式。由於VAN對二維結構資訊非常友好,具有較大的感受野,並在通道維度上實現了自適應性,VAN-Base比Swin-T提高了1.5%。從結果中可以看出,在小模型上面VAN的表現更加出色。
2、消融實驗
DW-D-Conv提供了深度空洞卷積,這在捕獲LKA中的長程依賴性中發揮了作用。DW-Conv可以利用影象的區域性上下文資訊。注意力機制的引入可以看作是使網路實現了自適應特性。受益於此,VAN-Tiny實現了約1.1%的提升。1×1 Conv捕獲了通道維度中的關係。結合注意機制,引入了通道維度的自適應性,提高了0.8%,證明了通道維度自適應性的必要性。
3、視覺化
從視覺化的比較,可以看出VAN方法能夠更好的聚焦目標區域。尤其是當目標占影象比例較大時,效果更好,也說明VAN可以捕捉長距離依賴關係。