【CVPR2022】Restormer: Efficient Transformer for High-Resolution Image Restoration
a
1、研究動機
論文的 motivation 非常簡單,就是認為CNN感受野有限,因此無法對長距離畫素相關性進行建模。因此,想使用 Transformer 的思路來進行影象修復。
2、主要方法
論文整體框架如下圖所示,還是類似UNet的結構,按著1/2,1/4, 1/8 下采樣,在中間新增skip connection。如圖中畫紅圈的部分展示,每個 Transformer block 由兩個部分串聯組成:MDTA 和 GDFN。
對於特徵上下采樣,作者分別採用 PyTorch 裡的 pixel-unshuffle 和 pixel-shuffle 實現,非常類似 swin transformer 裡的 patch merging (不清楚實現是不是一樣的,還沒時間比較,汗 ~~~)。
MDTA (Multi-Dconv Head Transposed Attention)
Transformer中計算量主要來自於注意力計算部分,為了降低計算量,作者構建了MDTA,不在畫素維度計算 attention,而是在通道維度計算。過程很簡單,先用 point-wise conv 和 dconv 預處理,在通道維計算 atteniton,如下圖所示。
直接看程式碼:
## Multi-DConv Head Transposed Self-Attention (MDTA) class Attention(nn.Module): def __init__(self, dim, num_heads, bias): super(Attention, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) def forward(self, x): b,c,h,w = x.shape # 升維,卷積,分塊得到qkv qkv = self.qkv_dwconv(self.qkv(x)) q,k,v = qkv.chunk(3, dim=1) # 維度變化 [B, C, H, W] ==> [B, head, C/head, HW] q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) # [B, head, C/head, HW] * [B, head, HW, C/head] * [head, 1, 1] ==> [B, head, C/head, C/head] attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) # [B, head, C/head, C/head] * [B, head, C/head, HW] ==> [B, head, C/head, HW] out = (attn @ v) # [B, head, C/head, HW] ==> [B, head, C/head, H, W] out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out
GDFN (Gated-Dconv Feed-Forward Network)
VIT中使用全連線網路FFN處理,在本文中作者有兩個改進:1)引入 gating mechanism, 下面分支使用GELU啟用。2)使用 dconv 學習影象區域性結構資訊。
直接看程式碼:
## Gated-Dconv Feed-Forward Network (GDFN) class FeedForward(nn.Module): def __init__(self, dim, ffn_expansion_factor, bias): super(FeedForward, self).__init__() hidden_features = int(dim*ffn_expansion_factor) self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) def forward(self, x): x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) x = F.gelu(x1) * x2 x = self.project_out(x) return x
其它細節與實驗分析
網路在下圖中畫紅圈的部分還有一個細節,這個位置沒有像之前的兩個 block 使用 1X1 的卷積來降維,而是又使用了幾個 Transformer block 來處理,叫做 Refinement stage。作者有一個實驗專門驗證這個 Refinement 階段的有效性。
從 Level-1 到 Level-4 ,Transformer block的數量是 [4,6,6,8],MDTA中的 head 數量為[1,2,4,8],通道數為[48,96,192,384]。Refinement階段有4個block。同時,作者還採用了 progressive training 的策略,輸入影象尺寸從 128 到 384 漸增。
作者在影象去雨、單影象運動去模糊、散焦去模糊(在單影象和雙畫素資料上)、影象去噪(在合成和真實資料上)四個任務做了大量實驗以證明方法的有效性。具體可以參照作者論文,這裡不過多介紹了。