基於pytorch的ESRGAN(論文閱讀筆記+復現)
阿新 • • 發佈:2018-12-06
程式碼的框架——《https://github.com/xinntao/BasicSR》
ESRGAN論文《ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks》的連結——https://arxiv.org/pdf/1809.00219.pdf
程式碼在目錄/home/guanwp/BasicSR-master/codes/下,執行以下命令實現train和test
python train.py -opt options/train/train_esrgan.json
python test.py -opt options/test/test_esrgan.json
理論
程式碼
給出setting
{ "name": "ESRGAN_x4_DIV2K" // please remove "debug_" during training , "use_tb_logger": true , "model":"srgan" , "scale": 4 , "gpu_ids": [3,4,5] , "datasets": { "train": { "name": "DIV2K" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4" , "subset_file": null , "use_shuffle": true , "n_workers": 8 , "batch_size": 16 , "HR_size": 128 , "use_flip": true , "use_rot": true } , "val": { "name": "val_set5" , "mode": "LRHR" , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5" , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4" } } , "path": { "root": "/home/guanwp/BasicSR-master", "pretrain_model_G": null ,"experiments_root": "/home/guanwp/BasicSR-master/experiments/", "models": "/home/guanwp/BasicSR-master/experiments/ESRGAN_x4_DIV2K/models", "log": "/home/guanwp/BasicSR-master/experiments/ESRGAN_x4_DIV2K", "val_images": "/home/guanwp/BasicSR-master/experiments/ESRGAN_x4_DIV2K/val_images" } , "network_G": { "which_model_G": "RRDB_net" // RRDB_net | sr_resnet , "norm_type": null , "mode": "CNA" , "nf": 64 , "nb": 23// number of residual block , "in_nc": 3 , "out_nc": 3 , "gc": 32 , "group": 1 } , "network_D": { "which_model_D": "discriminator_vgg_128" , "norm_type": "batch" , "act_type": "leakyrelu" , "mode": "CNA" , "nf": 64 , "in_nc": 3 } , "train": { "lr_G": 1e-4 , "weight_decay_G": 0 , "beta1_G": 0.9 , "lr_D": 1e-4 , "weight_decay_D": 0 , "beta1_D": 0.9 , "lr_scheme": "MultiStepLR" , "lr_steps": [50000, 100000, 200000, 300000] , "lr_gamma": 0.5 , "pixel_criterion": "l1" , "pixel_weight": 0//1e-2//just for the NIQE, you should set to 0 , "feature_criterion": "l1" , "feature_weight": 1 , "gan_type": "vanilla" , "gan_weight": 5e-3 //for wgan-gp , "D_update_ratio": 1//for the D network , "D_init_iters": 0 // , "gp_weigth": 10 , "manual_seed": 0 , "niter": 5e5//6e5//5e5 , "val_freq": 2000//5e3 } , "logger": { "print_freq": 200 , "save_checkpoint_freq": 5e3 } }
執行程式碼
結果