1. 程式人生 > 實用技巧 >FFDNet: Toward a Fast and Flexible Solution for CNN-Based Image Denoising

FFDNet: Toward a Fast and Flexible Solution for CNN-Based Image Denoising

論文來源:FFDNet: Toward a Fast and Flexible Solution for CNN-Based Image Denoising

筆記參考:論文閱讀:FFDNet

程式碼參考:FFDNet_pytorch

DnCNN利用Batch Normalization和residual learning可以有效地去除均勻高斯噪聲,且對一定噪聲水平範圍的噪聲都有抑制作用。然而真實的噪聲並不是均勻的高斯噪聲,其是訊號依賴的,各顏色通道相關的,而且是不均勻的,可能隨空間位置變化的。在這種情況下,FFDNet使用噪聲估計圖作為輸入,權衡對均布噪聲的抑制和細節的保持,從而應對更加複雜的真實場景。而CBDNet進一步發揮了這種優勢,其將噪聲水平估計過程也用一個子網路實現,從而使得整個網路可以實現盲去噪。

文章貢獻:

  1. 針對影象去噪問題,提出了一種快速靈活的去噪網路FFDNet。通過將一個可調噪聲級別圖作為輸入,一個單一的FFDNet能夠處理不同級別的噪聲,以及空間變化的噪聲。
  2. 我們強調了確保噪音水平圖在控制降噪和細節保留之間的平衡方面的重要性。
  3. FFDNet在被AWGN破壞的合成噪聲影象和真實噪聲影象上都展示了具有感知吸引力的結果,展示了它在實際影象去噪方面的潛力。

在DnCNN的基礎上添加了下采樣和上取樣:

引入可逆下采樣運算元將W×H×C的輸入影象重塑為4個下采樣(W/2)×(H/2)× 4C的子影象。這裡C為通道數,灰度影象C = 1,彩色影象C = 3。為了使噪聲級圖能夠在不引入視覺偽影的情況下,有效地控制噪聲降低和細節保留之間的平衡,對卷積濾波器採用了正交初始化方法。

程式碼(pytorch):

 1 import torch
 2 import torch.nn as nn
 3 import torch.nn.functional as F
 4 import torch.optim as optim
 5 from torch.autograd import Variable
 6 
 7 import utils
 8 
 9 class FFDNet(nn.Module):
10 
11     def __init__(self, is_gray):
12         super(FFDNet, self).__init__()
13 
14
if is_gray: 15 self.num_conv_layers = 15 # all layers number 16 self.downsampled_channels = 5 # Conv_Relu in 17 self.num_feature_maps = 64 # Conv_Bn_Relu in 18 self.output_features = 4 # Conv out 19 else: 20 self.num_conv_layers = 12 21 self.downsampled_channels = 15 22 self.num_feature_maps = 96 23 self.output_features = 12 24 25 self.kernel_size = 3 26 self.padding = 1 27 28 layers = [] 29 # Conv + Relu 30 layers.append(nn.Conv2d(in_channels=self.downsampled_channels, out_channels=self.num_feature_maps, \ 31 kernel_size=self.kernel_size, padding=self.padding, bias=False)) 32 layers.append(nn.ReLU(inplace=True)) 33 34 # Conv + BN + Relu 35 for _ in range(self.num_conv_layers - 2): 36 layers.append(nn.Conv2d(in_channels=self.num_feature_maps, out_channels=self.num_feature_maps, \ 37 kernel_size=self.kernel_size, padding=self.padding, bias=False)) 38 layers.append(nn.BatchNorm2d(self.num_feature_maps)) 39 layers.append(nn.ReLU(inplace=True)) 40 41 # Conv 42 layers.append(nn.Conv2d(in_channels=self.num_feature_maps, out_channels=self.output_features, \ 43 kernel_size=self.kernel_size, padding=self.padding, bias=False)) 44 45 self.intermediate_dncnn = nn.Sequential(*layers) 46 47 def forward(self, x, noise_sigma): 48 noise_map = noise_sigma.view(x.shape[0], 1, 1, 1).repeat(1, x.shape[1], x.shape[2] // 2, x.shape[3] // 2) 49 50 x_up = utils.downsample(x.data) # 4 * C * H/2 * W/2 51 x_cat = torch.cat((noise_map.data, x_up), 1) # 4 * (C + 1) * H/2 * W/2 52 x_cat = Variable(x_cat) 53 54 h_dncnn = self.intermediate_dncnn(x_cat) 55 y_pred = utils.upsample(h_dncnn) 56 return y_pred