1. 程式人生 > 其它 >Vision MLP 之 S2-MLP V1&V2 : Spatial-Shift MLP Architecture for Vision

Vision MLP 之 S2-MLP V1&V2 : Spatial-Shift MLP Architecture for Vision

Vision MLP 之 S2-MLP V1&V2 : Spatial-Shift MLP Architecture for Vision

原始文件:https://www.yuque.com/lart/papers/dgdu2b

這裡將會總結關於 S2-MLP 的兩篇文章。這兩篇文章核心思路是一樣的,即基於空間偏移操作替換空間 MLP。

從摘要理解文章

V1

Recently, visual Transformer (ViT) and its following works abandon the convolution and exploit the self-attention operation

, attaining a comparable or even higher accuracy than CNNs. More recently, MLP-Mixer abandons both the convolution and the self-attention operation, proposing an architecture containing only MLP layers.
To achieve cross-patch communications, it devises an additional token-mixing MLP besides the channel-mixing MLP. It achieves promising results when training on an extremely large-scale dataset. But it cannot achieve as outstanding performance as its CNN and ViT counterparts when training on medium-scale datasets such as ImageNet1K and ImageNet21K
. The performance drop of MLP-Mixer motivates us to rethink the token-mixing MLP.

這裡引出了本文的主要內容,即改進空間 MLP。

We discover that the token-mixing MLP is a variant of the depthwise convolution with a global reception field and spatial-specific configuration. But the global reception field and the spatial-specific property make token-mixing MLP prone to over-fitting

.

指出了空間 MLP 的問題,由於其全域性感受野和空間特定的屬性使得模型容易過擬合

In this paper, we propose a novel pure MLP architecture, spatial-shift MLP (S2-MLP). Different from MLP-Mixer, our S2-MLP only contains channel-mixing MLP.

這裡提到僅有通道 MLP,說明想到了新的辦法來擴張通道 MLP 的感受野還可以保留點運算。

We utilize a spatial-shift operation for communications between patches. It has a local reception field and is spatial-agnostic. It is parameter-free and efficient for computation.

引出本文的核心內容,也就是標題中提到的空間偏移操作。看上去這一操作不帶引數,僅僅是用來調整特徵的一個處理手段。
Spatial-Shift 操作可以參考這裡的幾篇文章:https://www.yuque.com/lart/architecture/conv#i8nnp

The proposed S2-MLP attains higher recognition accuracy than MLP-Mixer when training on ImageNet-1K dataset. Meanwhile, S2-MLP accomplishes as excellent performance as ViT on ImageNet-1K dataset with considerably simpler architecture and fewer FLOPs and parameters.

V2

Recently, MLP-based vision backbones emerge. MLP-based vision architectures with less inductive bias achieve competitive performance in image recognition compared with CNNs and vision Transformers. Among them, spatial-shift MLP (S2-MLP), adopting the straightforward spatial-shift operation, achieves better performance than the pioneering works including MLP-mixer and ResMLP. More recently, using smaller patches with a pyramid structure, Vision Permutator (ViP) and Global Filter Network (GFNet) achieve better performance than S2-MLP.

這裡引出了金字塔結構,看來 V2 版本要使用類似的構造。

In this paper, we improve the S2-MLP vision backbone. We expand the feature map along the channel dimension and split the expanded feature map into several parts. We conduct different spatial-shift operations on split parts.

依然延續了空間偏移的策略,但是不知道相較於 V1 版本改動如何

Meanwhile, we exploit the split-attention operation to fuse these split parts.

這裡還引入了 split-attention(ResNeSt)來融合分組。難道這裡是要使用並行分支?

Moreover, like the counterparts, we adopt smaller-scale patches and use a pyramid structure for boosting the image recognition accuracy.
We term the improved spatial-shift MLP vision backbone as S2-MLPv2. Using 55M parameters, our medium-scale model, S2-MLPv2-Medium achieves an 83.6% top-1 accuracy on the ImageNet-1K benchmark using 224×224 images without self-attention and external training data.

在我看來,V2 相較於 V1,主要是借鑑了 CycleFC 的一些想法,並進行了適應性的調整。整體改動有兩方面:

  1. 引入多分支處理的思想,並應用 Split-Attention 來融合不同分支。
  2. 受現有工作的啟發,使用更小的 patch 和分層金字塔結構。

主要內容

核心結構比較

V1 中,整體流程延續的是 MLP-Mixer 的思路,仍然保持直筒狀結構。

MLP-Mixer 的結構圖:

從圖中可以看到,不同於 MLP-Mixer 中的 Pre-Norm 結構,S2MLP 使用的是 Post-Norm 結構。
另外,S2MLP 的改動主要集中在空間 MLP 的位置,由原來的Spatial-MLP(Linear->GeLU->Linear)轉變為Spatial-Shifted Channel-MLP(Linear->GeLU->Spatial-Shift->Lienar)
關於空間偏移的核心虛擬碼如下:

可以看到,這裡就是將輸入劃分成四個不同的分組,各自沿著不同的軸向(H 和 W 軸)偏移,由於實現的原因,在邊界部分會有重複值出現。分組數依賴於方向的數量,這裡預設使用 4,即向四個方向偏移。
雖然從單個空間偏移模組上來看,僅僅關聯了相鄰的 patch,但是從整體堆疊後的結構來看,可以實現一個近似的長距離互動過程。

而在 V2 版本相較於 V1 版本引入了多分支處理的策略,並且在結構上開始使用 Pre-Norm 形式。

關於多分支結構的構造思路與 CycleFC 非常類似。不同支路使用不同的處理策略,同時在多分支整合時,使用了 Split-Attention 的方式進行融合。

Split-Attention: Vision Permutator (Hou et al., 2021) adopts split attention proposed in ResNeSt (Zhang et al., 2020) for enhancing multiple feature maps from different operations. 本文借鑑使用來融合多分支。
主要操作過程:

  1. 輸入 \(K\) 個特徵圖(可以來自不同分支)\(\mathbf{X} = \{X_k \in \mathbb{R}^{N \times C}\}^{K}_{k=1}, \, N=HW\)
  2. 將所有特診圖的列求和後的結果累加:\(a \in \mathbb{R}^{C} = \sum_{k=1}^{K}\sum_{n=1}^{N}\mathbf{X}_{k}[n, :]\)
  3. 通過堆疊的全連線層進行變換,得到針對不同特徵圖的通道注意力 logits:\(\hat{a} \in \mathbb{R}^{KC} = \sigma(a W_1) W_2, \, W_1 \in \mathbb{R}^{C \times \bar{C}}, \, W_2 \in \mathbb{R}^{\bar{C} \times KC}\)
  4. 使用 reshape 來調整注意力向量的形狀:\(\hat{a} \in \mathbb{R}^{KC} \rightarrow \hat{A} \in \mathbb{R}^{K \times C}\)
  5. 使用 softmax 沿著索引 \(k\) 計算,來獲得針對不同樣本的歸一化注意力權重:\(\bar{A}[:, c] \in \mathbb{R}^{K} = \text{softmax}(\hat{A}[:, c])\)
  6. 對輸入的 \(K\) 個特徵圖加權求和得到結果 \(Y\),其一行的結果可以表示為:\(Y[n, :] \in \mathbb{R}^{C} = \sum_{k=1}^{K} X_{k}[n, :] \odot \bar{A}[k, :]\)

不過需要注意的是,這裡第三個分支是一個恆等分支,直接將輸入的部分通道取了過來,這一點延續了 GhostNet 的想法,而不同於 CycleFC,使用的是一個獨立的通道 MLP。

GhostNet的核心結構:

關於該多分支結構的核心虛擬碼如下:

其他細節

Spatial-Shift 與 Depthwise Convolution 的關係

實際上,四個方向的偏移都是可以通過特定的卷積核構造來實現的:

所以分組空間偏移操作可以通過為 Depthwise Convolution 的不同分組指定對應上面的卷積核來實現。

實際上實現偏移的方法非常多,除了文中提到的切片索引和構造核的 depthwise convolution 的方式,還可以通過分組torch.roll和自定義 offset 的deform_conv2d來實現。

import torch
import torch.nn.functional as F
from torchvision.ops import deform_conv2d

xs = torch.meshgrid(torch.arange(5), torch.arange(5))
x = torch.stack(xs, dim=0)
x = x.unsqueeze(0).repeat(1, 4, 1, 1).float()

direct_shift = torch.clone(x)
direct_shift[:, 0:2, :, 1:] = torch.clone(direct_shift[:, 0:2, :, :4])
direct_shift[:, 2:4, :, :4] = torch.clone(direct_shift[:, 2:4, :, 1:])
direct_shift[:, 4:6, 1:, :] = torch.clone(direct_shift[:, 4:6, :4, :])
direct_shift[:, 6:8, :4, :] = torch.clone(direct_shift[:, 6:8, 1:, :])
print(direct_shift)

pad_x = F.pad(x, pad=[1, 1, 1, 1], mode="replicate")  # 這裡需要藉助padding來保留邊界的資料

roll_shift = torch.cat(
    [
        torch.roll(pad_x[:, c * 2 : (c + 1) * 2, ...], shifts=(shift_h, shift_w), dims=(2, 3))
        for c, (shift_h, shift_w) in enumerate([(0, 1), (0, -1), (1, 0), (-1, 0)])
    ],
    dim=1,
)
roll_shift = roll_shift[..., 1:6, 1:6]
print(roll_shift)

k1 = torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k2 = torch.FloatTensor([[0, 0, 0], [0, 0, 1], [0, 0, 0]]).reshape(1, 1, 3, 3)
k3 = torch.FloatTensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k4 = torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 1, 0]]).reshape(1, 1, 3, 3)
weight = torch.cat([k1, k1, k2, k2, k3, k3, k4, k4], dim=0)  # 每個輸出通道對應一個輸入通道
conv_shift = F.conv2d(pad_x, weight=weight, groups=8)
print(conv_shift)

offset = torch.empty(1, 2 * 8 * 1 * 1, 1, 1)
for c, (rel_offset_h, rel_offset_w) in enumerate([(0, -1), (0, -1), (0, 1), (0, 1), (-1, 0), (-1, 0), (1, 0), (1, 0)]):
    offset[0, c * 2 + 0, 0, 0] = rel_offset_h
    offset[0, c * 2 + 1, 0, 0] = rel_offset_w
offset = offset.repeat(1, 1, 7, 7).float()
weight = torch.eye(8).reshape(8, 8, 1, 1).float()
deconv_shift = deform_conv2d(pad_x, offset=offset, weight=weight)
deconv_shift = deconv_shift[..., 1:6, 1:6]
print(deconv_shift)

"""
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.],
          [0., 0., 1., 2., 3.]],

         [[0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.]],

         [[1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.],
          [1., 2., 3., 4., 4.]],

         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]],

         [[1., 1., 1., 1., 1.],
          [2., 2., 2., 2., 2.],
          [3., 3., 3., 3., 3.],
          [4., 4., 4., 4., 4.],
          [4., 4., 4., 4., 4.]],

         [[0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.],
          [0., 1., 2., 3., 4.]]]])
"""

偏移方向的影響

實驗是在 ImageNet 的子集上跑的。

V1 中針對不同的偏移方向進行了消融實驗,這裡的模型中都是按照方向個數對通道分組。從結果中可以看到:

  • 偏移確實可以帶來效能增益。
  • a 和 b:四個方向和八個方向相比,差異並不大。
  • e 和 f:水平偏移效果更好。
  • c 和 e/f:兩個軸的偏移要好於單個軸的偏移。

輸入尺寸以及 patchsize 的影響

實驗是在 ImageNet 的子集上跑的。

V1 中在固定 patchsize 後,不同的輸入尺寸 WxH 的表現也不同。過大的 patchsize 效果也不好,會丟失更多的細節資訊,但是卻可以有效提升推理速度。

金字塔結構的有效性

V2 中,構造了兩個不同的結構,一個有著更小的 patch,並且使用金字塔結構,另一個更大的 patch,不使用金字塔結構。可以看到,同時受益於小 patchsize 帶來的細節資訊的效能增強和金字塔結構帶來的更優的計算效率,前者獲得了更好的表現。

Split-Attention 的效果

V2 將 split-attention 與特徵直接相加取平均對比。可以看到,前者更優。不過這裡引數量也不一樣了,其實更合理的比較應該最起碼是加幾層帶引數的結構來融合三分支的特徵。

三分支結構的有效性

這裡的實驗說明有些模糊,作者說道“In this section, we evaluate the influence of removing one of them.”但是卻沒有說明去掉特定分支後其他結構的調整方式。

實驗結果

實驗結果直接看 V2 論文的表格即可:

連結