1. 程式人生 > 其它 >【CVPR2022】Restormer: Efficient Transformer for High-Resolution Image Restoration

【CVPR2022】Restormer: Efficient Transformer for High-Resolution Image Restoration

a

論文連結:https://arxiv.org/abs/2111.09881

程式碼連結:https://github.com/swz30/Restormer

1、研究動機

論文的 motivation 非常簡單,就是認為CNN感受野有限,因此無法對長距離畫素相關性進行建模。因此,想使用 Transformer 的思路來進行影象修復。

2、主要方法

論文整體框架如下圖所示,還是類似UNet的結構,按著1/2,1/4, 1/8 下采樣,在中間新增skip connection。如圖中畫紅圈的部分展示,每個 Transformer block 由兩個部分串聯組成:MDTA 和 GDFN。

對於特徵上下采樣,作者分別採用 PyTorch 裡的 pixel-unshuffle 和 pixel-shuffle 實現,非常類似 swin transformer 裡的 patch merging (不清楚實現是不是一樣的,還沒時間比較,汗 ~~~)。

MDTA (Multi-Dconv Head Transposed Attention)

Transformer中計算量主要來自於注意力計算部分,為了降低計算量,作者構建了MDTA,不在畫素維度計算 attention,而是在通道維度計算。過程很簡單,先用 point-wise conv 和 dconv 預處理,在通道維計算 atteniton,如下圖所示。

直接看程式碼:

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        
    def forward(self, x):
        b,c,h,w = x.shape

        # 升維,卷積,分塊得到qkv
        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        # 維度變化 [B, C, H, W] ==> [B, head, C/head, HW] 
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        # [B, head, C/head, HW] * [B, head, HW, C/head] * [head, 1, 1] ==> [B, head, C/head, C/head]
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        # [B, head, C/head, C/head] * [B, head, C/head, HW] ==> [B, head, C/head, HW]
        out = (attn @ v)
        
        # [B, head, C/head, HW] ==> [B, head, C/head, H, W]
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

GDFN (Gated-Dconv Feed-Forward Network)

VIT中使用全連線網路FFN處理,在本文中作者有兩個改進:1)引入 gating mechanism, 下面分支使用GELU啟用。2)使用 dconv 學習影象區域性結構資訊。

直接看程式碼:

## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

其它細節與實驗分析

網路在下圖中畫紅圈的部分還有一個細節,這個位置沒有像之前的兩個 block 使用 1X1 的卷積來降維,而是又使用了幾個 Transformer block 來處理,叫做 Refinement stage。作者有一個實驗專門驗證這個 Refinement 階段的有效性。

從 Level-1 到 Level-4 ,Transformer block的數量是 [4,6,6,8],MDTA中的 head 數量為[1,2,4,8],通道數為[48,96,192,384]。Refinement階段有4個block。同時,作者還採用了 progressive training 的策略,輸入影象尺寸從 128 到 384 漸增。

作者在影象去雨、單影象運動去模糊、散焦去模糊(在單影象和雙畫素資料上)、影象去噪(在合成和真實資料上)四個任務做了大量實驗以證明方法的有效性。具體可以參照作者論文,這裡不過多介紹了。