學習筆記之——基於pytorch的SFTGAN(xintao程式碼學習,及資料處理部分的學習)
阿新 • • 發佈:2018-12-06
程式碼的框架仍然是——《https://github.com/xinntao/BasicSR》
給出SFTGAN的論文《Recovering Realistic Texture in Image Super-resolution by Deep Spatial Feature Transform》連結https://arxiv.org/pdf/1804.02815.pdf
之前已經寫過SFTGAN論文的閱讀筆記( 閱讀筆記之——《Recovering Realistic Texture in Image Super-resolution by Deep Spatial Feature Transform》
SFTGAN的網路結構如下圖所示
SFT——特徵空間的轉換( Spatial Feature Transform )
首先檢視__init__.py
def create_model(opt): model = opt['model']##this para is came from the .json file #the model in jason, decided which modl import #so if you add a new model, this .py must be modified if model == 'sr':###this is the SR model from .SR_model import SRModel as M#take sr as an example elif model == 'srgan':###this is the SRGAN from .SRGAN_model import SRGANModel as M elif model == 'srragan': from .SRRaGAN_model import SRRaGANModel as M elif model == 'sftgan':###this is the SFTGAN from .SFTGAN_ACD_model import SFTGAN_ACD_Model as M else: raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) m = M(opt) print('Model [{:s}] is created.'.format(m.__class__.__name__)) return m#return the model
在SFTGAN_ACD_model.py裡面為SFTGAN的網路
由於本博文主要關於SFTNET部分,就先不看跟GAN相關的程式碼部分了
class SFTLayer(nn.Module): def __init__(self): super(SFTLayer, self).__init__() self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1) self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1) self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1) self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1) def forward(self, x): # x[0]: fea; x[1]: cond scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True)) shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True)) return x[0] * (scale + 1) + shift class ResBlock_SFT(nn.Module): def __init__(self): super(ResBlock_SFT, self).__init__() self.sft0 = SFTLayer() self.conv0 = nn.Conv2d(64, 64, 3, 1, 1) self.sft1 = SFTLayer() self.conv1 = nn.Conv2d(64, 64, 3, 1, 1) def forward(self, x): # x[0]: fea; x[1]: cond fea = self.sft0(x) fea = F.relu(self.conv0(fea), inplace=True) fea = self.sft1((fea, x[1])) fea = self.conv1(fea) return (x[0] + fea, x[1]) # return a tuple containing features and conditions class SFT_Net(nn.Module):##############the main network def __init__(self): super(SFT_Net, self).__init__() self.conv0 = nn.Conv2d(3, 64, 3, 1, 1) sft_branch = [] for i in range(16): sft_branch.append(ResBlock_SFT())###residual block+SFT layer sft_branch.append(SFTLayer()) sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1)) self.sft_branch = nn.Sequential(*sft_branch) self.HR_branch = nn.Sequential( nn.Conv2d(64, 256, 3, 1, 1), nn.PixelShuffle(2), nn.ReLU(True), nn.Conv2d(64, 256, 3, 1, 1), nn.PixelShuffle(2), nn.ReLU(True), nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(True), nn.Conv2d(64, 3, 3, 1, 1) ) self.CondNet = nn.Sequential( nn.Conv2d(8, 128, 4, 4), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 32, 1) ) def forward(self, x): # x[0]: img; x[1]: seg cond = self.CondNet(x[1]) fea = self.conv0(x[0]) res = self.sft_branch((fea, cond))###there are two input of the SFT layer, the one is the seg,while the other is the output pf the convolution fea = fea + res out = self.HR_branch(fea) return out
其實網路結構比較好理解,關鍵就是怎麼把資料輸入網路了
在網路喂資料部分是:
def feed_data(self, data, need_HR=True):###the x[0]is the LR, the X[1]is the segment
# LR
self.var_L = data['LR'].to(self.device)
# seg
self.var_seg = data['seg'].to(self.device)
# category
self.var_cat = data['category'].long().to(self.device)
if need_HR: # train or val
self.var_H = data['HR'].to(self.device)
而之前的程式碼,比如SR網路中是:
def feed_data(self, data, need_HR=True):#feed the data,
self.var_L = data['LR'].to(self.device) # LR
if need_HR:
self.real_H = data['HR'].to(self.device) # HR
估計關鍵點應該在data中,在data檔案中的__init__.py。之前都沒有好好解讀過資料處理的檔案,正好借次機會,把xintao前輩的程式碼框架的資料處理部分解讀好
import torch.utils.data
def create_dataloader(dataset, dataset_opt):#while the dataloader is the data(from the create_dataset) put into the GPU
phase = dataset_opt['phase']
if phase == 'train':
batch_size = dataset_opt['batch_size']
shuffle = dataset_opt['use_shuffle']#
num_workers = dataset_opt['n_workers']#
else:
batch_size = 1
shuffle = False
num_workers = 1
return torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)###read my CSDN
def create_dataset(dataset_opt):###data set is the data
mode = dataset_opt['mode']
if mode == 'LR':
from data.LR_dataset import LRDataset as D
elif mode == 'LRHR':### this should be deeply read
from data.LRHR_dataset import LRHRDataset as D
elif mode == 'LRHRseg_bg':
from data.LRHR_seg_bg_dataset import LRHRSeg_BG_Dataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
print('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
dataset_opt['name']))
return dataset
選擇LRHR_seg_bg_dataset.py來解讀
import os.path
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import data.util as util
class LRHRSeg_BG_Dataset(data.Dataset):
'''
Read HR image, segmentation probability map; generate LR image, category for SFTGAN
also sample general scenes for background
need to generate LR images on-the-fly
'''
def __init__(self, opt):
super(LRHRSeg_BG_Dataset, self).__init__()
self.opt = opt
self.paths_LR = None
self.paths_HR = None
self.paths_HR_bg = None # HR images for background scenes
self.LR_env = None # environment for lmdb
self.HR_env = None
self.HR_env_bg = None
# read image list from lmdb or image files
#LR、HR,and the seg
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'], \
opt['dataroot_HR_bg'])
assert self.paths_HR, 'Error: HR path is empty.'
if self.paths_LR and self.paths_HR:
assert len(self.paths_LR) == len(self.paths_HR), \
'HR and LR datasets have different number of images - {}, {}.'.format(\
len(self.paths_LR), len(self.paths_HR))
self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
self.ratio = 10 # 10 OST data samples and 1 DIV2K general data samples(background)
def __getitem__(self, index):
HR_path, LR_path = None, None
scale = self.opt['scale']#the upscale
HR_size = self.opt['HR_size']#the HR patch size
# get HR image
if self.opt['phase'] == 'train' and \
random.choice(list(range(self.ratio))) == 0: # read background images
bg_index = random.randint(0, len(self.paths_HR_bg) - 1)
HR_path = self.paths_HR_bg[bg_index]
img_HR = util.read_img(self.HR_env_bg, HR_path)
seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0)
seg[0, :, :] = 1 # background
else:
HR_path = self.paths_HR[index]
img_HR = util.read_img(self.HR_env, HR_path)
seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
# read segmentatin files, you should change it to your settings.
# modcrop in the validation / test phase
if self.opt['phase'] != 'train':
img_HR = util.modcrop(img_HR, 8)
seg = np.transpose(seg.numpy(), (1, 2, 0))
# get LR image
if self.paths_LR:
LR_path = self.paths_LR[index]
img_LR = util.read_img(self.LR_env, LR_path)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = seg.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, HR_size)
W_s = _mod(W_s, random_scale, scale, HR_size)
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
seg = cv2.resize(np.copy(seg), (W_s, H_s), interpolation=cv2.INTER_NEAREST)
H, W, _ = img_HR.shape
# using matlab imresize
img_LR = util.imresize_np(img_HR, 1 / scale, True)
if img_LR.ndim == 2:
img_LR = np.expand_dims(img_LR, axis=2)
H, W, C = img_LR.shape
if self.opt['phase'] == 'train':
LR_size = HR_size // scale
# randomly crop
rnd_h = random.randint(0, max(0, H - LR_size))
rnd_w = random.randint(0, max(0, W - LR_size))
img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
seg = seg[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
# augmentation - flip, rotate
img_LR, img_HR, seg = util.augment([img_LR, img_HR, seg], self.opt['use_flip'],
self.opt['use_rot'])
# category
#this part is based on the seg
if 'building' in HR_path:
category = 1
elif 'plant' in HR_path:
category = 2
elif 'mountain' in HR_path:
category = 3
elif 'water' in HR_path:
category = 4
elif 'sky' in HR_path:
category = 5
elif 'grass' in HR_path:
category = 6
elif 'animal' in HR_path:
category = 7
else:
category = 0 # background
else:
category = -1 # during val, useless
# BGR to RGB, HWC to CHW, numpy to tensor
if img_HR.shape[2] == 3:
img_HR = img_HR[:, :, [2, 1, 0]]
img_LR = img_LR[:, :, [2, 1, 0]]
img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
seg = torch.from_numpy(np.ascontiguousarray(np.transpose(seg, (2, 0, 1)))).float()
if LR_path is None:
LR_path = HR_path
return {
'LR': img_LR,
'HR': img_HR,
'seg': seg,
'category': category,
'LR_path': LR_path,
'HR_path': HR_path
}
def __len__(self):
return len(self.paths_HR)
與之前的SR進行對比
LRHR_dataset.py
import os.path
import random
import numpy as np
import cv2
import torch
import torch.utils.data as data
import data.util as util
class LRHRDataset(data.Dataset):
'''
Read LR and HR image pairs.
If only HR image is provided, generate LR image on-the-fly.
The pair is ensured by 'sorted' function, so please check the name convention.
'''
def __init__(self, opt):
super(LRHRDataset, self).__init__()
self.opt = opt
self.paths_LR = None
self.paths_HR = None
self.LR_env = None # environment for lmdb
self.HR_env = None
# read image list from subset list txt
if opt['subset_file'] is not None and opt['phase'] == 'train':
with open(opt['subset_file']) as f:
self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
for line in f])
if opt['dataroot_LR'] is not None:
raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
else: # read image list from lmdb or image files
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
assert self.paths_HR, 'Error: HR path is empty.'
if self.paths_LR and self.paths_HR:
assert len(self.paths_LR) == len(self.paths_HR), \
'HR and LR datasets have different number of images - {}, {}.'.format(\
len(self.paths_LR), len(self.paths_HR))
self.random_scale_list = [1]
def __getitem__(self, index):###Objects can be iterated
HR_path, LR_path = None, None
scale = self.opt['scale']#the upscale
HR_size = self.opt['HR_size']#the HR patch size
# get HR image
HR_path = self.paths_HR[index]
img_HR = util.read_img(self.HR_env, HR_path)###feed the image
# modcrop in the validation / test phase
if self.opt['phase'] != 'train':
img_HR = util.modcrop(img_HR, scale)
# change color space if necessary
if self.opt['color']:
img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0]
#you just should know that this is the process of training
# get LR image
if self.paths_LR:
LR_path = self.paths_LR[index]
img_LR = util.read_img(self.LR_env, LR_path)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_HR.shape
def _mod(n, random_scale, scale, thres):
rlt = int(n * random_scale)
rlt = (rlt // scale) * scale
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, HR_size)
W_s = _mod(W_s, random_scale, scale, HR_size)
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
# force to 3 channels
if img_HR.ndim == 2:
img_HR = cv2.cvtColor(img_HR, cv2.COLOR_GRAY2BGR)
H, W, _ = img_HR.shape
# using matlab imresize
img_LR = util.imresize_np(img_HR, 1 / scale, True)
if img_LR.ndim == 2:
img_LR = np.expand_dims(img_LR, axis=2)
if self.opt['phase'] == 'train':
# if the image size is too small
H, W, _ = img_HR.shape
if H < HR_size or W < HR_size:
img_HR = cv2.resize(
np.copy(img_HR), (HR_size, HR_size), interpolation=cv2.INTER_LINEAR)
# using matlab imresize
img_LR = util.imresize_np(img_HR, 1 / scale, True)
if img_LR.ndim == 2:
img_LR = np.expand_dims(img_LR, axis=2)
H, W, C = img_LR.shape
LR_size = HR_size // scale###this step make sure the size of the LR is match the size of HR
############################this is the augmentation#####################################
# randomly crop
rnd_h = random.randint(0, max(0, H - LR_size))
rnd_w = random.randint(0, max(0, W - LR_size))
img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
# augmentation - flip, rotate
img_LR, img_HR = util.augment([img_LR, img_HR], self.opt['use_flip'], \
self.opt['use_rot'])
#########################################################################################
# change color space if necessary
if self.opt['color']:
img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0]
# BGR to RGB, HWC to CHW, numpy to tensor
if img_HR.shape[2] == 3:
img_HR = img_HR[:, :, [2, 1, 0]]
img_LR = img_LR[:, :, [2, 1, 0]]
img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
if LR_path is None:
LR_path = HR_path
return {'LR': img_LR, 'HR': img_HR, 'LR_path': LR_path, 'HR_path': HR_path}
def __len__(self):
return len(self.paths_HR)
補充
torch.utils.data.dataloader()用法
關於def __getitem__(self, index):
https://blog.csdn.net/qq_24805141/article/details/81411775