1. 程式人生 > 其它 >【ARXIV2202】Visual Attention Network

【ARXIV2202】Visual Attention Network

【ARXIV2202】Visual Attention Network

論文地址:https://arxiv.org/abs/2202.09741

程式碼地址:https://github.com/Visual-Attention-Network

一些想法

  • 這個方法看起來非常簡單,有些像在MobileNet 中間加了一個 帶空洞的 depth-wise conv
  • 論文題目說是提出了一個 attention 模組,但網路本質還是四階段的 transformer
  • 沒有任何一個模組是新提出的,但組合起來是在 計算量 和 準確率 間取得了平衡
  • 效能的提升,可能關鍵點還是在於四階段 transformer 網路的獨特結構和訓練策略

研究動機

作者指出 self-attention 存在三個不足:(1)將影象處理為一維序列,忽略了其二維結構。(2)很難處理高解析度影象。(2)它只捕捉了空間適應性,而忽略了通道適應性。因此,作者提出了一種新的大核注意力(LKA)模組,並進一步介紹了一種基於LKA的新的神經網路——視覺注意網路(VAN)。

方法介紹

1、Large Kernel Attention (LKA)

LKA 與 MobileNet 很相似,MobileNet將標準卷積解耦為兩部分,即 Depth-wise conv 和 Point-wise conv。作者將卷積分解為三個部分:大核 Depth-wise conv、大核帶空洞的 Depth-wise conv 和 Point-wise conv。這樣就有效地分解大的卷積核。

上圖中,彩色網格表示卷積核的位置,黃色網格表示中心點。從圖中可以看出,13×13卷積分解為5×5 DConv,5×5 dilated DConv,空洞步長為3 ,1×1點卷積。

如下所示,實際程式碼實現中為 5×5 的 DConv,7×7 的 dilated DConv,dilation=3,最後是1×1點卷積。

class AttentionModule(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)#深度卷積
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)#深度空洞卷積
        self.conv1 = nn.Conv2d(dim, dim, 1)#逐點卷積


    def forward(self, x):
        u = x.clone()        
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return u * attn   #注意力操作
     
class SpatialAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = AttentionModule(d_model)  #注意力操作
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)  #注意力操作
        x = self.proj_2(x)
        x = x + shorcut   #殘差連線
        return x

class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU):
        super().__init__()
        self.norm1 = nn.BatchNorm2d(dim)
        self.attn = SpatialAttention(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = nn.BatchNorm2d(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        layer_scale_init_value = 1e-2            
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))#drop_path分支中,每個batch有概率使樣本在self.attn或者mlp不會”執行“,會以0直接傳遞。
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
        return 

網路結構上,作者還是使用了 swin 的結構分為四個階段:H/4×W/4、H/8×W/8、H/16×W/16和H/32×W/32。隨著解析度的降低,輸出通道的數量也在不斷增加。網路細節如下表所示,其中e.r. 表示FFN中的 expansion ratio。

每個階段的圖示如下:

實驗分析

1、影象分類

在影象分類任務上,VAN優於其他引數計算成本相似的CNN,ViTs 和 MLPs。作者在每個類別中選擇了一個具有代表性的網路進行討論。ConvNeXt[53]是一種特殊的CNN,它吸收了VIT的一些優點,如大的感受野(7×7卷積)和先進的訓練策略(300個epoch、資料增強等)。VAN和ConvNeXt[53]相比,VAN-base比CoNvNeXt-t多出0.7%,因為VAN具有更大的感受野和自適應能力。Swin-Transformer是一種著名的ViT變體,採用區域性注意力和移動視窗的方式。由於VAN對二維結構資訊非常友好,具有較大的感受野,並在通道維度上實現了自適應性,VAN-Base比Swin-T提高了1.5%。從結果中可以看出,在小模型上面VAN的表現更加出色。

2、消融實驗

DW-D-Conv提供了深度空洞卷積,這在捕獲LKA中的長程依賴性中發揮了作用。DW-Conv可以利用影象的區域性上下文資訊。注意力機制的引入可以看作是使網路實現了自適應特性。受益於此,VAN-Tiny實現了約1.1%的提升。1×1 Conv捕獲了通道維度中的關係。結合注意機制,引入了通道維度的自適應性,提高了0.8%,證明了通道維度自適應性的必要性。

3、視覺化

從視覺化的比較,可以看出VAN方法能夠更好的聚焦目標區域。尤其是當目標占影象比例較大時,效果更好,也說明VAN可以捕捉長距離依賴關係。