實驗——基於pytorch的blind restoration聯合網路訓練
阿新 • • 發佈:2019-01-10
加噪和超分的處理可以參考本人github程式碼https://github.com/gwpscut/degradation-model-for-image-restoration
去噪和超分網路都可以參考本人之前的博文哈
subnetwork為DnCnn,主網路為SRResnet。subnetwork輸出為noise level map。注意,博文《基於pytorch的超分和去噪網路聯合fine tuning》裡面採用的subnetwork輸出為clean image。
setting
{ "name": "finetune_all_subnetc16s06_basic_resnet_DIV2K", "tb_logger_dir": "sr_c16s06", "use_tb_logger": true, "model": "sr_sub", "scale": 4, "crop_scale": 0, "gpu_ids": [ 3, 5 ], "datasets": { "train": { "name": "DIV2K", "mode": "LRMRHR", "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub", "dataroot_MR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_residualALL", "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noiseALL", "subset_file": null, "use_shuffle": true, "n_workers": 8, "batch_size": 24, "HR_size": 128, "use_flip": true, "use_rot": true, "phase": "train", "scale": 4, "data_type": "img" }, "val": { "name": "val_set5_x4_c03s08_mod4", "mode": "LRMRHR", "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5", "dataroot_MR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_residualALL", "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noiseALL", "phase": "val", "scale": 4, "data_type": "img" } }, "path": { "root": "/home/guanwp/jingwen/sr_c16s06", "pretrain_model_sub": "/home/guanwp/jingwen/sr/experiments/LR_x4_subnet_residual_DIV2K_guan/models/51000_G.pth", "experiments_root": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K", "models": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K/models", "log": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K", "val_images": "/home/guanwp/jingwen/sr_c16s06/experiments/finetune_all_subnetc16s06_basic_resnet_DIV2K/val_images" }, "network_G": { "which_model_G": "sr_resnet", "norm_type": null, "mode": "CNA", "nf": 64, "nb": 16, "in_nc": 6, "out_nc": 3, "group": 1, "scale": 4 }, "network_sub": { "which_model_sub": "noise_subnet", "norm_type": "batch", "mode": "CNA", "nf": 64, "in_nc": 3, "out_nc": 3, "group": 1 }, "train": { "lr_G": 0.0001, "lr_scheme": "MultiStepLR", "lr_steps": [ 500000 ], "lr_gamma": 0.1, "pixel_criterion_basic": "l2", "pixel_criterion_noise": "l2", "pixel_weight_basic": 1.0, "pixel_weight_noise": 1.0, "val_freq": 2000.0, "manual_seed": 0, "niter": 1000000.0 }, "logger": { "print_freq": 200, "save_checkpoint_freq": 2000.0 }, "timestamp": "190129-133631", "is_train": true, "adabn": null }
model
import os from collections import OrderedDict import torch import torch.nn as nn from torch.optim import lr_scheduler import models.networks as networks from .base_model import BaseModel class SRModel(BaseModel): def __init__(self, opt): super(SRModel, self).__init__(opt) train_opt = opt['train'] finetune_type = opt['finetune_type'] # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) self.subnet = networks.define_sub(opt).to(self.device) self.load() if self.is_train: self.netG.train() self.subnet.train() # self.subnet.eval() # loss loss_type_noise = train_opt['pixel_criterion_noise'] if loss_type_noise == 'l1': self.cri_pix_noise = nn.L1Loss().to(self.device) elif loss_type_noise == 'l2': self.cri_pix_noise = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_noise)) self.l_pix_noise_w = train_opt['pixel_weight_noise'] loss_type_basic = train_opt['pixel_criterion_basic'] if loss_type_basic == 'l1': self.cri_pix_basic = nn.L1Loss().to(self.device) elif loss_type_basic == 'l2': self.cri_pix_basic = nn.MSELoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic)) self.l_pix_basic_w = train_opt['pixel_weight_basic'] # optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 self.optim_params = self.__define_grad_params(finetune_type) self.optimizer_G = torch.optim.Adam( self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G) self.optimizers.append(self.optimizer_G) # schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ train_opt['lr_steps'], train_opt['lr_gamma'])) else: raise NotImplementedError('MultiStepLR learning rate scheme is enough.') self.log_dict = OrderedDict() print('---------- Model initialized ------------------') self.print_network() print('-----------------------------------------------') def feed_data(self, data, need_HR=True): self.var_L = data['LR'].to(self.device) # LR self.real_H = data['HR'].to(self.device) # HR self.mid_L = data['MR'].to(self.device) # MR # self.real_noise = (data['LR']-data['HR']).to(self.device) def __define_grad_params(self, finetune_type=None): optim_params = [] if finetune_type == 'sft': for k, v in self.netG.named_parameters(): v.requires_grad = False if k.find('Gate') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) elif finetune_type == 'sub_sft': for k, v in self.netG.named_parameters(): v.requires_grad = False if k.find('Gate') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) for k, v in self.subnet.named_parameters(): # can optimize for a part of the model v.requires_grad = False if k.find('degration') >= 0: v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) elif finetune_type == 'basic' or finetune_type == 'sft_basic': for k, v in self.netG.named_parameters(): v.requires_grad = True optim_params.append(v) print('we only optimize params: {}'.format(k)) for k, v in self.subnet.named_parameters(): v.requires_grad = False else: for k, v in self.netG.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) for k, v in self.subnet.named_parameters(): # can optimize for a part of the model if v.requires_grad: optim_params.append(v) print('params [{:s}] will optimize.'.format(k)) else: print('WARNING: params [{:s}] will not optimize.'.format(k)) return optim_params def optimize_parameters(self, step): self.optimizer_G.zero_grad() self.fake_noise = self.subnet(self.var_L) l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L) self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1)) # self.fake_H = self.netG((self.var_L, self.fake_noise)) l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H) l_pix = l_pix_noise + l_pix_basic l_pix.backward() # self.fake_noise = self.subnet(self.var_L) # # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1)) # self.fake_H = self.netG((self.var_L, self.fake_noise)) # l_pix = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H) # l_pix.backward() self.optimizer_G.step() self.log_dict['l_pix'] = l_pix.item() def test(self): self.netG.eval() self.subnet.eval() if self.is_train: for v in self.optim_params: v.requires_grad = False else: for k, v in self.netG.named_parameters(): v.requires_grad = False for k, v in self.subnet.named_parameters(): v.requires_grad = False self.fake_noise = self.subnet(self.var_L) self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1)) # self.fake_H = self.netG((self.var_L, self.fake_noise)) if self.is_train: for v in self.optim_params: v.requires_grad = True else: for k, v in self.netG.named_parameters(): v.requires_grad = True for k, v in self.subnet.named_parameters(): v.requires_grad = True self.netG.train() self.subnet.train() # self.subnet.eval() # def test(self): # self.netG.eval() # for k, v in self.netG.named_parameters(): # v.requires_grad = False # self.fake_H = self.netG(self.var_L) # for k, v in self.netG.named_parameters(): # v.requires_grad = True # self.netG.train() def get_current_log(self): return self.log_dict def get_current_visuals(self, need_HR=True): out_dict = OrderedDict() out_dict['LR'] = self.var_L.detach()[0].float().cpu() out_dict['MR'] = self.fake_noise.detach()[0].float().cpu() out_dict['SR'] = self.fake_H.detach()[0].float().cpu() if need_HR: out_dict['HR'] = self.real_H.detach()[0].float().cpu() return out_dict def print_network(self): # G s, n = self.get_network_description(self.netG) print('Number of parameters in G: {:,d}'.format(n)) if self.is_train: message = '-------------- Generator --------------\n' + s + '\n' network_path = os.path.join(self.save_dir, '../', 'network.txt') with open(network_path, 'w') as f: f.write(message) # subnet s, n = self.get_network_description(self.subnet) print('Number of parameters in subnet: {:,d}'.format(n)) message = '\n\n\n-------------- subnet --------------\n' + s + '\n' with open(network_path, 'a') as f: f.write(message) def load(self): load_path_G = self.opt['path']['pretrain_model_G'] load_path_sub = self.opt['path']['pretrain_model_sub'] if load_path_G is not None: print('loading model for G [{:s}] ...'.format(load_path_G)) self.load_network(load_path_G, self.netG) if load_path_sub is not None: print('loading model for subnet [{:s}] ...'.format(load_path_sub)) self.load_network(load_path_sub, self.subnet) def save(self, iter_label): self.save_network(self.save_dir, self.netG, 'G', iter_label) self.save_network(self.save_dir, self.subnet, 'sub', iter_label)
network
import functools import torch import torch.nn as nn from torch.nn import init import models.modules.architecture as arch import models.modules.sft_arch as sft_arch #################### # initialize #################### def weights_init_normal(m, std=0.02): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, std) # BN also uses norm init.constant_(m.bias.data, 0.0) def weights_init_kaiming(m, scale=1): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) # elif classname.find('AdaptiveConvResNorm') != -1: # init.constant_(m.weight.data, 0.0) # if m.bias is not None: # m.bias.data.zero_() def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) def init_weights(net, init_type='kaiming', scale=1, std=0.02): # scale for 'kaiming', std for 'normal'. print('initialization method [{:s}]'.format(init_type)) if init_type == 'normal': weights_init_normal_ = functools.partial(weights_init_normal, std=std) net.apply(weights_init_normal_) elif init_type == 'kaiming': weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) net.apply(weights_init_kaiming_) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type)) #################### # define network #################### # Generator def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') elif which_model == 'modulate_sr_resnet': netG = arch.ModulateSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'arcnn': netG = arch.ARCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'srcnn': netG = arch.SRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], ada_ksize=opt_net['ada_ksize']) elif which_model == 'denoise_resnet': netG = arch.DenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], fea_norm=opt_net['fea_norm'], upsample_norm=opt_net['upsample_norm']) elif which_model == 'modulate_denoise_resnet': netG = arch.ModulateDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], mode=opt_net['mode'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], gate_conv_bias=opt_net['gate_conv_bias']) elif which_model == 'noise_subnet': netG = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) elif which_model == 'cond_denoise_resnet': netG = arch.CondDenoiseResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], upsample_mode='pixelshuffle', ada_ksize=opt_net['ada_ksize'], down_scale=opt_net['down_scale'], num_classes=opt_net['num_classes'], norm_type=opt_net['norm_type']) elif which_model == 'adabn_denoise_resnet': netG = arch.AdaptiveDenoiseResNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'], down_scale=opt_net['down_scale']) elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) if opt['init_type'] is not None: init_weights(netG, init_type=opt['init_type'], scale=0.1) if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG def define_sub(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_sub'] which_model = opt_net['which_model_sub'] if which_model == 'noise_subnet': subnet = arch.NoiseSubNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], norm_type=opt_net['norm_type'], mode=opt_net['mode']) else: raise NotImplementedError('subnet model [{:s}] not recognized'.format(which_model)) if gpu_ids: assert torch.cuda.is_available() subnet = nn.DataParallel(subnet) return subnet # Discriminator def define_D(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_D'] which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'dis_acd': # sft-gan, Auxiliary Classifier Discriminator netD = sft_arch.ACD_VGG_BN_96() elif which_model == 'discriminator_vgg_96': netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_192': netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_128_SN': netD = arch.Discriminator_VGG_128_SN() else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) init_weights(netD, init_type='kaiming', scale=1) if gpu_ids: netD = nn.DataParallel(netD) return netD def define_F(opt, use_bn=False): gpu_ids = opt['gpu_ids'] device = torch.device('cuda' if gpu_ids else 'cpu') # pytorch pretrained VGG19-54, before ReLU. if use_bn: feature_layer = 49 else: feature_layer = 34 netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \ use_input_norm=True, device=device) # netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device) if gpu_ids: netF = nn.DataParallel(netF) netF.eval() # No need to train return netF
architecture.py
import math
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from . import block as B
from . import spectral_norm as SN
from . import adaptive_norm as AN
####################
# Generator
####################
class SRCNN(nn.Module):
def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):
super(SRCNN, self).__init__()
fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode
, ada_ksize=ada_ksize)
mapping_conv = B.conv_block(nf, nf // 2, kernel_size=1, norm_type=norm_type, act_type=act_type,
mode=mode, ada_ksize=ada_ksize)
HR_conv = B.conv_block(nf // 2, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,
mode=mode, ada_ksize=ada_ksize)
self.model = B.sequential(fea_conv, mapping_conv, HR_conv)
def forward(self, x):
x = self.model(x)
return x
class ARCNN(nn.Module):
def __init__(self, in_nc, out_nc, nf, norm_type='batch', act_type='relu', mode='CNA', ada_ksize=None):
super(ARCNN, self).__init__()
fea_conv = B.conv_block(in_nc, nf, kernel_size=9, norm_type=norm_type, act_type=act_type, mode=mode
, ada_ksize=ada_ksize)
conv1 = B.conv_block(nf, nf // 2, kernel_size=7, norm_type=norm_type, act_type=act_type,
mode=mode, ada_ksize=ada_ksize)
conv2 = B.conv_block(nf // 2, nf // 4, kernel_size=1, norm_type=norm_type, act_type=act_type,
mode=mode, ada_ksize=ada_ksize)
HR_conv = B.conv_block(nf // 4, out_nc, kernel_size=5, norm_type=norm_type, act_type=None,
mode=mode, ada_ksize=ada_ksize)
self.model = B.sequential(fea_conv, conv1, conv2, HR_conv)
def forward(self, x):
x = self.model(x)
return x
class SRResNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):
super(SRResNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
mode=mode, res_scale=res_scale) for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x
class ModulateSRResNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='sft', act_type='relu',
mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):
super(ModulateSRResNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=1)
resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]
self.LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
if norm_type == 'sft':
self.LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
elif norm_type == 'sft_conv':
self.LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.norm_branch = B.sequential(*resnet_blocks)
self.HR_branch = B.sequential(*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
fea = self.fea_conv(x[0])
fea_res_block, _ = self.norm_branch((fea, x[1]))
fea_LR = self.LR_conv(fea_res_block)
res = self.LR_norm((fea_LR, x[1]))
out = self.HR_branch(fea+res)
return out
class DenoiseResNet(nn.Module):
"""
jingwen's addition
denoise Resnet
"""
def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='batch', act_type='relu',
mode='NAC', res_scale=1, upsample_mode='upconv', ada_ksize=None, down_scale=2,
fea_norm=None, upsample_norm=None):
super(DenoiseResNet, self).__init__()
n_upscale = int(math.log(down_scale, 2))
if down_scale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=fea_norm, act_type=None, stride=down_scale,
ada_ksize=ada_ksize)
resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,
mode=mode, res_scale=res_scale, ada_ksize=ada_ksize) for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode
, ada_ksize=ada_ksize)
# LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode
# , ada_ksize=ada_ksize)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
if down_scale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type, norm_type=upsample_norm, ada_ksize=ada_ksize) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=upsample_norm, act_type=act_type, ada_ksize=ada_ksize)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=upsample_norm, act_type=None, ada_ksize=ada_ksize)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x
# class ModulateDenoiseResNet(nn.Module):
# def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu',
# mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=False, ada_ksize=None):
# super(ModulateDenoiseResNet, self).__init__()
#
# self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2)
# resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
# mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
# ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]
# degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type='batch', act_type='relu')]
# degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type='batch', act_type='relu')
# for _ in range(15)])
# degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None))
#
# LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
# if norm_type == 'sft':
# LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
# elif norm_type == 'sft_conv':
# LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)
#
# if upsample_mode == 'upconv':
# upsample_block = B.upconv_blcok
# elif upsample_mode == 'pixelshuffle':
# upsample_block = B.pixelshuffle_block
# else:
# raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
# upsampler = upsample_block(nf, nf, act_type=act_type)
# HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
# HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
#
# self.norm_branch = B.sequential(*resnet_blocks)
# self.LR_conv = LR_conv
# self.LR_norm = LR_norm
# self.degration_block = B.sequential(*degration_block)
# self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1)
#
# def forward(self, x):
# fea = self.fea_conv(x)
# # noise estimation part
# # deg_estimate = self.degration_block(x) + x
# deg_estimate = self.degration_block(x)
# fea_res_block, _ = self.norm_branch((fea, deg_estimate))
# fea_LR = self.LR_conv(fea_res_block)
# res = self.LR_norm((fea_LR, deg_estimate))
# out = self.HR_branch(fea+res)
# return out
class ModulateDenoiseResNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=1, norm_type='sft', act_type='relu',
mode='CNA', res_scale=1, upsample_mode='upconv', gate_conv_bias=True, ada_ksize=None):
super(ModulateDenoiseResNet, self).__init__()
self.fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, stride=2)
resnet_blocks = [B.TwoStreamSRResNet(nf, nf, nf, norm_type=norm_type, act_type=act_type,
mode=mode, res_scale=res_scale, gate_conv_bias=gate_conv_bias,
ada_ksize=ada_ksize, input_dim=in_nc) for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=None, mode=mode)
if norm_type == 'sft':
LR_norm = AN.GateNonLinearLayer(in_nc, conv_bias=gate_conv_bias)
elif norm_type == 'sft_conv':
LR_norm = AN.MetaLayer(in_nc, conv_bias=gate_conv_bias, kernel_size=ada_ksize)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
upsampler = upsample_block(nf, nf, act_type=act_type)
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.norm_branch = B.sequential(*resnet_blocks)
self.LR_conv = LR_conv
self.LR_norm = LR_norm
self.HR_branch = B.sequential(upsampler, HR_conv0, HR_conv1)
def forward(self, x):
fea = self.fea_conv(x[0])
# noise estimation part
# deg_estimate = self.degration_block(x) + x
fea_res_block, _ = self.norm_branch((fea, x[1]))
fea_LR = self.LR_conv(fea_res_block)
res = self.LR_norm((fea_LR, x[1]))
out = self.HR_branch(fea+res)
return out
class NoiseSubNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, norm_type='batch', act_type='relu', mode='CNA'):
super(NoiseSubNet, self).__init__()
degration_block = [B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)]
degration_block.extend([B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode)
for _ in range(15)])
degration_block.append(B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, mode=mode))
self.degration_block = B.sequential(*degration_block)
def forward(self, x):
deg_estimate = self.degration_block(x)
return deg_estimate
class CondDenoiseResNet(nn.Module):
"""
jingwen's addition
denoise Resnet
"""
def __init__(self, in_nc, out_nc, nf, nb, upscale=1, res_scale=1, down_scale=2, num_classes=1, ada_ksize=None
,upsample_mode='upconv', act_type='relu', norm_type='cond_adaptive_conv_res'):
super(CondDenoiseResNet, self).__init__()
n_upscale = int(math.log(down_scale, 2))
if down_scale == 3:
n_upscale = 1
self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)
resnet_blocks = [B.CondResNetBlock(nf, nf, nf, num_classes=num_classes, ada_ksize=ada_ksize,
norm_type=norm_type, act_type=act_type) for _ in range(nb)]
self.resnet_blocks = B.sequential(*resnet_blocks)
self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
if norm_type == 'cond_adaptive_conv_res':
self.cond_adaptive = AN.CondAdaptiveConvResNorm(nf, num_classes=num_classes)
elif norm_type == "interp_adaptive_conv_res":
self.cond_adaptive = AN.InterpAdaptiveResNorm(nf, ada_ksize)
elif norm_type == "cond_instance":
self.cond_adaptive = AN.CondInstanceNorm2d(nf, num_classes=num_classes)
elif norm_type == "cond_transform_res":
self.cond_adaptive = AN.CondResTransformer(nf, ada_ksize, num_classes=num_classes)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
if down_scale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.upsample = B.sequential(*upsampler, HR_conv0, HR_conv1)
def forward(self, x, y):
# the first feature extraction
fea = self.fea_conv(x)
fea1, _ = self.resnet_blocks((fea, y))
fea2 = self.LR_conv(fea1)
fea3 = self.cond_adaptive(fea2, y)
# res
out = self.upsample(fea3 + fea)
return out
class AdaptiveDenoiseResNet(nn.Module):
"""
jingwen's addition
adabn
"""
def __init__(self, in_nc, nf, nb, upscale=1, res_scale=1, down_scale=2):
super(AdaptiveDenoiseResNet, self).__init__()
self.fea_conv = nn.Conv2d(in_nc, nf, kernel_size=3, stride=down_scale, padding=1)
resnet_blocks = [B.AdaptiveResNetBlock(nf, nf, nf, res_scale=res_scale) for _ in range(nb)]
self.resnet_blocks = B.sequential(*resnet_blocks)
self.LR_conv = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
self.batch_norm = nn.BatchNorm2d(nf, affine=True, track_running_stats=True, momentum=0)
def forward(self, x):
fea_list = [self.fea_conv(data.unsqueeze_(0)) for data in x]
fea_resblock_list = self.resnet_blocks(fea_list)
fea_LR_list = [self.LR_conv(fea) for fea in fea_resblock_list]
fea_mean, fea_var = B.computing_mean_variance(fea_LR_list)
batch_norm_dict = self.batch_norm.state_dict()
batch_norm_dict['running_mean'] = fea_mean
batch_norm_dict['running_var'] = fea_var
self.batch_norm.load_state_dict(batch_norm_dict)
return None
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(RRDBNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x
####################
# Discriminator
####################
# VGG style Discriminator with input size 128*128
class Discriminator_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminator_VGG_128_SN(nn.Module):
def __init__(self):
super(Discriminator_VGG_128_SN, self).__init__()
# features
# hxw, c
# 128, 64
self.lrelu = nn.LeakyReLU(0.2, True)
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
# 64, 64
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
# 32, 128
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
# 16, 256
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 8, 512
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
# 4, 512
# classifier
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
def forward(self, x):
x = self.lrelu(self.conv0(x))
x = self.lrelu(self.conv1(x))
x = self.lrelu(self.conv2(x))
x = self.lrelu(self.conv3(x))
x = self.lrelu(self.conv4(x))
x = self.lrelu(self.conv5(x))
x = self.lrelu(self.conv6(x))
x = self.lrelu(self.conv7(x))
x = self.lrelu(self.conv8(x))
x = self.lrelu(self.conv9(x))
x = x.view(x.size(0), -1)
x = self.lrelu(self.linear0(x))
x = self.linear1(x)
return x
class Discriminator_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_96, self).__init__()
# features
# hxw, c
# 96, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
class Discriminator_VGG_192(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminator_VGG_192, self).__init__()
# features
# hxw, c
# 192, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 96, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 48, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 24, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 12, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 6, 512
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 3, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9, conv10, conv11)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
####################
# Perceptual Network
####################
# Assume input range is [0, 1]
class VGGFeatureExtractor(nn.Module):
def __init__(self,
feature_layer=34,
use_bn=False,
use_input_norm=True,
device=torch.device('cpu')):
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
# Assume input range is [0, 1]
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
model = torchvision.models.resnet101(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.children())[:8])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
class MINCNet(nn.Module):
def __init__(self):
super(MINCNet, self).__init__()
self.ReLU = nn.ReLU(True)
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
def forward(self, x):
out = self.ReLU(self.conv11(x))
out = self.ReLU(self.conv12(out))
out = self.maxpool1(out)
out = self.ReLU(self.conv21(out))
out = self.ReLU(self.conv22(out))
out = self.maxpool2(out)
out = self.ReLU(self.conv31(out))
out = self.ReLU(self.conv32(out))
out = self.ReLU(self.conv33(out))
out = self.maxpool3(out)
out = self.ReLU(self.conv41(out))
out = self.ReLU(self.conv42(out))
out = self.ReLU(self.conv43(out))
out = self.maxpool4(out)
out = self.ReLU(self.conv51(out))
out = self.ReLU(self.conv52(out))
out = self.conv53(out)
return out
# Assume input range is [0, 1]
class MINCFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
device=torch.device('cpu')):
super(MINCFeatureExtractor, self).__init__()
self.features = MINCNet()
self.features.load_state_dict(
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
self.features.eval()
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
output = self.features(x)
return output
實驗結果:
兩者結果對比如上圖所示。直觀上,以noise level 和 LR contact 到一起的視覺效果好點,
先denoise後SR的級聯網路效果如下圖
而先noise estimation 後SR的級聯網路的效果如下圖
後者PSNR高0.3dB左右