Browse Source

improve codes comments

Xintao 3 years ago
parent
commit
35ee6f781e

+ 4 - 5
.github/workflows/no-response.yml

@@ -1,12 +1,11 @@
 name: No Response
 
+# TODO: it seems not to work
 # Modified from: https://raw.githubusercontent.com/github/docs/main/.github/workflows/no-response.yaml
 
-# **What it does**: Closes issues that don't have enough information to be
-#                   actionable.
-# **Why we have it**: To remove the need for maintainers to remember to check
-#                     back on issues periodically to see if contributors have
-#                     responded.
+# **What it does**: Closes issues that don't have enough information to be actionable.
+# **Why we have it**: To remove the need for maintainers to remember to check back on issues periodically
+#                     to see if contributors have responded.
 # **Who does it impact**: Everyone that works on docs or docs-internal.
 
 on:

+ 4 - 1
inference_realesrgan.py

@@ -8,6 +8,8 @@ from realesrgan import RealESRGANer
 
 
 def main():
+    """Inference demo for Real-ESRGAN.
+    """
     parser = argparse.ArgumentParser()
     parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
     parser.add_argument(
@@ -53,7 +55,7 @@ def main():
         pre_pad=args.pre_pad,
         half=args.half)
 
-    if args.face_enhance:
+    if args.face_enhance:  # Use GFPGAN for face enhancement
         from gfpgan import GFPGANer
         face_enhancer = GFPGANer(
             model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
@@ -78,6 +80,7 @@ def main():
         else:
             img_mode = None
 
+        # give warnings for too large/small images
         h, w = img.shape[0:2]
         if max(h, w) > 1000 and args.netscale == 4:
             import warnings

+ 1 - 2
options/finetune_realesrgan_x4plus.yml

@@ -90,7 +90,6 @@ network_g:
   num_block: 23
   num_grow_ch: 32
 
-
 network_d:
   type: UNetDiscriminatorSN
   num_in_ch: 3
@@ -169,7 +168,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 2
options/finetune_realesrgan_x4plus_pairdata.yml

@@ -52,7 +52,6 @@ network_g:
   num_block: 23
   num_grow_ch: 32
 
-
 network_d:
   type: UNetDiscriminatorSN
   num_in_ch: 3
@@ -131,7 +130,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 2
options/train_realesrgan_x2plus.yml

@@ -91,7 +91,6 @@ network_g:
   num_grow_ch: 32
   scale: 2
 
-
 network_d:
   type: UNetDiscriminatorSN
   num_in_ch: 3
@@ -167,7 +166,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 2
options/train_realesrgan_x4plus.yml

@@ -90,7 +90,6 @@ network_g:
   num_block: 23
   num_grow_ch: 32
 
-
 network_d:
   type: UNetDiscriminatorSN
   num_in_ch: 3
@@ -166,7 +165,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 1
options/train_realesrnet_x2plus.yml

@@ -125,7 +125,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 1
options/train_realesrnet_x4plus.yml

@@ -124,7 +124,7 @@ train:
 #   save_img: True
 
 #   metrics:
-#     psnr: # metric name, can be arbitrary
+#     psnr: # metric name
 #       type: calculate_psnr
 #       crop_border: 4
 #       test_y_channel: false

+ 1 - 1
realesrgan/__init__.py

@@ -3,4 +3,4 @@ from .archs import *
 from .data import *
 from .models import *
 from .utils import *
-from .version import __gitsha__, __version__
+from .version import __version__

+ 14 - 7
realesrgan/archs/discriminator_arch.py

@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
 
 @ARCH_REGISTRY.register()
 class UNetDiscriminatorSN(nn.Module):
-    """Defines a U-Net discriminator with spectral normalization (SN)"""
+    """Defines a U-Net discriminator with spectral normalization (SN)
+
+    It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    Arg:
+        num_in_ch (int): Channel number of inputs. Default: 3.
+        num_feat (int): Channel number of base intermediate features. Default: 64.
+        skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
+    """
 
     def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
         super(UNetDiscriminatorSN, self).__init__()
         self.skip_connection = skip_connection
         norm = spectral_norm
-
+        # the first convolution
         self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
-
+        # downsample
         self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
         self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
         self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
         self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
         self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
         self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
-
-        # extra
+        # extra convolutions
         self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
         self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
-
         self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
 
     def forward(self, x):
+        # downsample
         x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
         x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
         x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
         if self.skip_connection:
             x6 = x6 + x0
 
-        # extra
+        # extra convolutions
         out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
         out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
         out = self.conv9(out)

+ 28 - 11
realesrgan/data/realesrgan_dataset.py

@@ -15,18 +15,31 @@ from torch.utils import data as data
 
 @DATASET_REGISTRY.register()
 class RealESRGANDataset(data.Dataset):
-    """
-    Dataset used for Real-ESRGAN model.
+    """Dataset used for Real-ESRGAN model:
+    Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It loads gt (Ground-Truth) images, and augments them.
+    It also generates blur kernels and sinc kernels for generating low-quality images.
+    Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+    Args:
+        opt (dict): Config for train datasets. It contains the following keys:
+            dataroot_gt (str): Data root path for gt.
+            meta_info (str): Path for meta information file.
+            io_backend (dict): IO backend type and other kwarg.
+            use_hflip (bool): Use horizontal flips.
+            use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+            Please see more options in the codes.
     """
 
     def __init__(self, opt):
         super(RealESRGANDataset, self).__init__()
         self.opt = opt
-        # file client (io backend)
         self.file_client = None
         self.io_backend_opt = opt['io_backend']
         self.gt_folder = opt['dataroot_gt']
 
+        # file client (lmdb io backend)
         if self.io_backend_opt['type'] == 'lmdb':
             self.io_backend_opt['db_paths'] = [self.gt_folder]
             self.io_backend_opt['client_keys'] = ['gt']
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
             with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
                 self.paths = [line.split('.')[0] for line in fin]
         else:
+            # disk backend with meta_info
+            # Each line in the meta_info describes the relative path to an image
             with open(self.opt['meta_info']) as fin:
-                paths = [line.strip() for line in fin]
+                paths = [line.strip().split(' ')[0] for line in fin]
                 self.paths = [os.path.join(self.gt_folder, v) for v in paths]
 
         # blur settings for the first degradation
         self.blur_kernel_size = opt['blur_kernel_size']
         self.kernel_list = opt['kernel_list']
-        self.kernel_prob = opt['kernel_prob']
+        self.kernel_prob = opt['kernel_prob']  # a list for each kernel probability
         self.blur_sigma = opt['blur_sigma']
-        self.betag_range = opt['betag_range']
-        self.betap_range = opt['betap_range']
-        self.sinc_prob = opt['sinc_prob']
+        self.betag_range = opt['betag_range']  # betag used in generalized Gaussian blur kernels
+        self.betap_range = opt['betap_range']  # betap used in plateau blur kernels
+        self.sinc_prob = opt['sinc_prob']  # the probability for sinc filters
 
         # blur settings for the second degradation
         self.blur_kernel_size2 = opt['blur_kernel_size2']
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
         self.final_sinc_prob = opt['final_sinc_prob']
 
         self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21
+        # TODO: kernel range is now hard-coded, should be in the configure file
         self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect
         self.pulse_tensor[10, 10] = 1
 
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
                 retry -= 1
         img_gt = imfrombytes(img_bytes, float32=True)
 
-        # -------------------- augmentation for training: flip, rotation -------------------- #
+        # -------------------- Do augmentation for training: flip, rotation -------------------- #
         img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
 
-        # crop or pad to 400: 400 is hard-coded. You may change it accordingly
+        # crop or pad to 400
+        # TODO: 400 is hard-coded. You may change it accordingly
         h, w = img_gt.shape[0:2]
         crop_pad_size = 400
         # pad
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
         pad_size = (21 - kernel_size) // 2
         kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
 
-        # ------------------------------------- sinc kernel ------------------------------------- #
+        # ------------------------------------- the final sinc kernel ------------------------------------- #
         if np.random.uniform() < self.opt['final_sinc_prob']:
             kernel_size = random.choice(self.kernel_range)
             omega_c = np.random.uniform(np.pi / 3, np.pi)

+ 12 - 10
realesrgan/data/realesrgan_paired_dataset.py

@@ -11,8 +11,7 @@ from torchvision.transforms.functional import normalize
 class RealESRGANPairedDataset(data.Dataset):
     """Paired image dataset for image restoration.
 
-    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
-    GT image pairs.
+    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
 
     There are three modes:
     1. 'lmdb': Use lmdb files.
@@ -28,8 +27,8 @@ class RealESRGANPairedDataset(data.Dataset):
             dataroot_lq (str): Data root path for lq.
             meta_info (str): Path for meta information file.
             io_backend (dict): IO backend type and other kwarg.
-            filename_tmpl (str): Template for each filename. Note that the
-                template excludes the file extension. Default: '{}'.
+            filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+                Default: '{}'.
             gt_size (int): Cropped patched size for gt patches.
             use_hflip (bool): Use horizontal flips.
             use_rot (bool): Use rotation (use vertical flip and transposing h
@@ -42,25 +41,25 @@ class RealESRGANPairedDataset(data.Dataset):
     def __init__(self, opt):
         super(RealESRGANPairedDataset, self).__init__()
         self.opt = opt
-        # file client (io backend)
         self.file_client = None
         self.io_backend_opt = opt['io_backend']
+        # mean and std for normalizing the input images
         self.mean = opt['mean'] if 'mean' in opt else None
         self.std = opt['std'] if 'std' in opt else None
 
         self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
-        if 'filename_tmpl' in opt:
-            self.filename_tmpl = opt['filename_tmpl']
-        else:
-            self.filename_tmpl = '{}'
+        self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
 
+        # file client (lmdb io backend)
         if self.io_backend_opt['type'] == 'lmdb':
             self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
             self.io_backend_opt['client_keys'] = ['lq', 'gt']
             self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
         elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
+            # disk backend with meta_info
+            # Each line in the meta_info describes the relative path to an image
             with open(self.opt['meta_info']) as fin:
-                paths = [line.strip() for line in fin]
+                paths = [line.strip().split(' ')[0] for line in fin]
             self.paths = []
             for path in paths:
                 gt_path, lq_path = path.split(', ')
@@ -68,6 +67,9 @@ class RealESRGANPairedDataset(data.Dataset):
                 lq_path = os.path.join(self.lq_folder, lq_path)
                 self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
         else:
+            # disk backend
+            # it will scan the whole folder to get meta info
+            # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
             self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
 
     def __getitem__(self, index):

+ 26 - 11
realesrgan/models/realesrgan_model.py

@@ -13,35 +13,45 @@ from torch.nn import functional as F
 
 @MODEL_REGISTRY.register()
 class RealESRGANModel(SRGANModel):
-    """RealESRGAN Model"""
+    """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It mainly performs:
+    1. randomly synthesize LQ images in GPU tensors
+    2. optimize the networks with GAN training.
+    """
 
     def __init__(self, opt):
         super(RealESRGANModel, self).__init__(opt)
-        self.jpeger = DiffJPEG(differentiable=False).cuda()
-        self.usm_sharpener = USMSharp().cuda()
+        self.jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
+        self.usm_sharpener = USMSharp().cuda()  # do usm sharpening
         self.queue_size = opt.get('queue_size', 180)
 
     @torch.no_grad()
     def _dequeue_and_enqueue(self):
-        # training pair pool
+        """It is the training pair pool for increasing the diversity in a batch.
+
+        Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+        batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+        to increase the degradation diversity in a batch.
+        """
         # initialize
         b, c, h, w = self.lq.size()
         if not hasattr(self, 'queue_lr'):
-            assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
+            assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
             self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
             _, c, h, w = self.gt.size()
             self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
             self.queue_ptr = 0
-        if self.queue_ptr == self.queue_size:  # full
+        if self.queue_ptr == self.queue_size:  # the pool is full
             # do dequeue and enqueue
             # shuffle
             idx = torch.randperm(self.queue_size)
             self.queue_lr = self.queue_lr[idx]
             self.queue_gt = self.queue_gt[idx]
-            # get
+            # get first b samples
             lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
             gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
-            # update
+            # update the queue
             self.queue_lr[0:b, :, :, :] = self.lq.clone()
             self.queue_gt[0:b, :, :, :] = self.gt.clone()
 
@@ -55,6 +65,8 @@ class RealESRGANModel(SRGANModel):
 
     @torch.no_grad()
     def feed_data(self, data):
+        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+        """
         if self.is_train and self.opt.get('high_order_degradation', True):
             # training data synthesis
             self.gt = data['gt'].to(self.device)
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
                 scale = 1
             mode = random.choice(['area', 'bilinear', 'bicubic'])
             out = F.interpolate(out, scale_factor=scale, mode=mode)
-            # noise
+            # add noise
             gray_noise_prob = self.opt['gray_noise_prob']
             if np.random.uniform() < self.opt['gaussian_noise_prob']:
                 out = random_add_gaussian_noise_pt(
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
                     rounds=False)
             # JPEG compression
             jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
-            out = torch.clamp(out, 0, 1)
+            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
             out = self.jpeger(out, quality=jpeg_p)
 
             # ----------------------- The second degradation process ----------------------- #
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
             mode = random.choice(['area', 'bilinear', 'bicubic'])
             out = F.interpolate(
                 out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
-            # noise
+            # add noise
             gray_noise_prob = self.opt['gray_noise_prob2']
             if np.random.uniform() < self.opt['gaussian_noise_prob2']:
                 out = random_add_gaussian_noise_pt(
@@ -162,7 +174,9 @@ class RealESRGANModel(SRGANModel):
             self._dequeue_and_enqueue()
             # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
             self.gt_usm = self.usm_sharpener(self.gt)
+            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
         else:
+            # for paired training or validation
             self.lq = data['lq'].to(self.device)
             if 'gt' in data:
                 self.gt = data['gt'].to(self.device)
@@ -175,6 +189,7 @@ class RealESRGANModel(SRGANModel):
         self.is_train = True
 
     def optimize_parameters(self, current_iter):
+        # usm sharpening
         l1_gt = self.gt_usm
         percep_gt = self.gt_usm
         gan_gt = self.gt_usm

+ 27 - 12
realesrgan/models/realesrnet_model.py

@@ -12,35 +12,46 @@ from torch.nn import functional as F
 
 @MODEL_REGISTRY.register()
 class RealESRNetModel(SRModel):
-    """RealESRNet Model"""
+    """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+    It is trained without GAN losses.
+    It mainly performs:
+    1. randomly synthesize LQ images in GPU tensors
+    2. optimize the networks with GAN training.
+    """
 
     def __init__(self, opt):
         super(RealESRNetModel, self).__init__(opt)
-        self.jpeger = DiffJPEG(differentiable=False).cuda()
-        self.usm_sharpener = USMSharp().cuda()
+        self.jpeger = DiffJPEG(differentiable=False).cuda()  # simulate JPEG compression artifacts
+        self.usm_sharpener = USMSharp().cuda()  # do usm sharpening
         self.queue_size = opt.get('queue_size', 180)
 
     @torch.no_grad()
     def _dequeue_and_enqueue(self):
-        # training pair pool
+        """It is the training pair pool for increasing the diversity in a batch.
+
+        Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+        batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+        to increase the degradation diversity in a batch.
+        """
         # initialize
         b, c, h, w = self.lq.size()
         if not hasattr(self, 'queue_lr'):
-            assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
+            assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
             self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
             _, c, h, w = self.gt.size()
             self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
             self.queue_ptr = 0
-        if self.queue_ptr == self.queue_size:  # full
+        if self.queue_ptr == self.queue_size:  # the pool is full
             # do dequeue and enqueue
             # shuffle
             idx = torch.randperm(self.queue_size)
             self.queue_lr = self.queue_lr[idx]
             self.queue_gt = self.queue_gt[idx]
-            # get
+            # get first b samples
             lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
             gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
-            # update
+            # update the queue
             self.queue_lr[0:b, :, :, :] = self.lq.clone()
             self.queue_gt[0:b, :, :, :] = self.gt.clone()
 
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
 
     @torch.no_grad()
     def feed_data(self, data):
+        """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+        """
         if self.is_train and self.opt.get('high_order_degradation', True):
             # training data synthesis
             self.gt = data['gt'].to(self.device)
-            # USM the GT images
+            # USM sharpen the GT images
             if self.opt['gt_usm'] is True:
                 self.gt = self.usm_sharpener(self.gt)
 
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
                 scale = 1
             mode = random.choice(['area', 'bilinear', 'bicubic'])
             out = F.interpolate(out, scale_factor=scale, mode=mode)
-            # noise
+            # add noise
             gray_noise_prob = self.opt['gray_noise_prob']
             if np.random.uniform() < self.opt['gaussian_noise_prob']:
                 out = random_add_gaussian_noise_pt(
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
                     rounds=False)
             # JPEG compression
             jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
-            out = torch.clamp(out, 0, 1)
+            out = torch.clamp(out, 0, 1)  # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
             out = self.jpeger(out, quality=jpeg_p)
 
             # ----------------------- The second degradation process ----------------------- #
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
             mode = random.choice(['area', 'bilinear', 'bicubic'])
             out = F.interpolate(
                 out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
-            # noise
+            # add noise
             gray_noise_prob = self.opt['gray_noise_prob2']
             if np.random.uniform() < self.opt['gaussian_noise_prob2']:
                 out = random_add_gaussian_noise_pt(
@@ -160,7 +173,9 @@ class RealESRNetModel(SRModel):
 
             # training pair pool
             self._dequeue_and_enqueue()
+            self.lq = self.lq.contiguous()  # for the warning: grad and param do not obey the gradient layout contract
         else:
+            # for paired training or validation
             self.lq = data['lq'].to(self.device)
             if 'gt' in data:
                 self.gt = data['gt'].to(self.device)

+ 27 - 4
realesrgan/utils.py

@@ -12,6 +12,19 @@ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
 
 class RealESRGANer():
+    """A helper class for upsampling images with RealESRGAN.
+
+    Args:
+        scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+        model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+        model (nn.Module): The defined network. If None, the model will be constructed here. Default: None.
+        tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+            input images into tiles, and then process each of them. Finally, they will be merged into one image.
+            0 denotes for do not use tile. Default: 0.
+        tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+        pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+        half (float): Whether to use half precision during inference. Default: False.
+    """
 
     def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
         self.scale = scale
@@ -26,10 +39,12 @@ class RealESRGANer():
         if model is None:
             model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
 
+        # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
         if model_path.startswith('https://'):
             model_path = load_file_from_url(
                 url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
         loadnet = torch.load(model_path)
+        # prefer to use params_ema
         if 'params_ema' in loadnet:
             keyname = 'params_ema'
         else:
@@ -41,6 +56,8 @@ class RealESRGANer():
             self.model = self.model.half()
 
     def pre_process(self, img):
+        """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+        """
         img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
         self.img = img.unsqueeze(0).to(self.device)
         if self.half:
@@ -49,7 +66,7 @@ class RealESRGANer():
         # pre_pad
         if self.pre_pad != 0:
             self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
-        # mod pad
+        # mod pad for divisible borders
         if self.scale == 2:
             self.mod_scale = 2
         elif self.scale == 1:
@@ -64,10 +81,14 @@ class RealESRGANer():
             self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
 
     def process(self):
+        # model inference
         self.output = self.model(self.img)
 
     def tile_process(self):
-        """Modified from: https://github.com/ata4/esrgan-launcher
+        """It will first crop input images to tiles, and then process each tile.
+        Finally, all the processed tiles are merged into one images.
+
+        Modified from: https://github.com/ata4/esrgan-launcher
         """
         batch, channel, height, width = self.img.shape
         output_height = height * self.scale
@@ -188,7 +209,7 @@ class RealESRGANer():
                 output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
                 output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
                 output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
-            else:
+            else:  # use the cv2 resize for alpha channel
                 h, w = alpha.shape[0:2]
                 output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
 
@@ -213,7 +234,9 @@ class RealESRGANer():
 
 
 def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
-    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+    """Load file form http url, will download models if necessary.
+
+    Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
     """
     if model_dir is None:
         hub_dir = get_dir()

+ 7 - 17
scripts/extract_subimages.py

@@ -14,34 +14,24 @@ def main(args):
 
     opt (dict): Configuration dict. It contains:
         n_thread (int): Thread number.
-        compression_level (int):  CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
-            A higher value means a smaller size and longer compression time.
-            Use 0 for faster CPU decompression. Default: 3, same in cv2.
-
+        compression_level (int):  CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
+            and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
         input_folder (str): Path to the input folder.
         save_folder (str): Path to save folder.
         crop_size (int): Crop size.
         step (int): Step for overlapped sliding window.
-        thresh_size (int): Threshold size. Patches whose size is lower
-            than thresh_size will be dropped.
+        thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
 
     Usage:
         For each folder, run this script.
-        Typically, there are four folders to be processed for DIV2K dataset.
-            DIV2K_train_HR
-            DIV2K_train_LR_bicubic/X2
-            DIV2K_train_LR_bicubic/X3
-            DIV2K_train_LR_bicubic/X4
-        After process, each sub_folder should have the same number of
-        subimages.
+        Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
+        After process, each sub_folder should have the same number of subimages.
         Remember to modify opt configurations according to your settings.
     """
 
     opt = {}
     opt['n_thread'] = args.n_thread
     opt['compression_level'] = args.compression_level
-
-    # HR images
     opt['input_folder'] = args.input
     opt['save_folder'] = args.output
     opt['crop_size'] = args.crop_size
@@ -68,6 +58,7 @@ def extract_subimages(opt):
         print(f'Folder {save_folder} already exists. Exit.')
         sys.exit(1)
 
+    # scan all images
     img_list = list(scandir(input_folder, full_path=True))
 
     pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
@@ -88,8 +79,7 @@ def worker(path, opt):
         opt (dict): Configuration dict. It contains:
             crop_size (int): Crop size.
             step (int): Step for overlapped sliding window.
-            thresh_size (int): Threshold size. Patches whose size is lower
-                than thresh_size will be dropped.
+            thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
             save_folder (str): Path to save folder.
             compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
 

+ 2 - 0
scripts/generate_meta_info.py

@@ -11,6 +11,7 @@ def main(args):
         for img_path in img_paths:
             status = True
             if args.check:
+                # read the image once for check, as some images may have errors
                 try:
                     img = cv2.imread(img_path)
                 except Exception as error:
@@ -20,6 +21,7 @@ def main(args):
                     status = False
                     print(f'Img is None: {img_path}')
             if status:
+                # get the relative path
                 img_name = os.path.relpath(img_path, root)
                 print(img_name)
                 txt_file.write(f'{img_name}\n')

+ 3 - 1
scripts/generate_meta_info_pairdata.py

@@ -5,6 +5,7 @@ import os
 
 def main(args):
     txt_file = open(args.meta_info, 'w')
+    # sca images
     img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
     img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
 
@@ -12,6 +13,7 @@ def main(args):
                                                     f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
 
     for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
+        # get the relative paths
         img_name_gt = os.path.relpath(img_path_gt, args.root[0])
         img_name_lq = os.path.relpath(img_path_lq, args.root[1])
         print(f'{img_name_gt}, {img_name_lq}')
@@ -19,7 +21,7 @@ def main(args):
 
 
 if __name__ == '__main__':
-    """Generate meta info (txt file) for paired images.
+    """This script is used to generate meta info (txt file) for paired images.
     """
     parser = argparse.ArgumentParser()
     parser.add_argument(

+ 3 - 1
scripts/generate_multiscale_DF2K.py

@@ -5,7 +5,6 @@ from PIL import Image
 
 
 def main(args):
-
     # For DF2K, we consider the following three scales,
     # and the smallest image whose shortest edge is 400
     scale_list = [0.75, 0.5, 1 / 3]
@@ -37,6 +36,9 @@ def main(args):
 
 
 if __name__ == '__main__':
+    """Generate multi-scale versions for GT images with LANCZOS resampling.
+    It is now used for DF2K dataset (DIV2K + Flickr 2K)
+    """
     parser = argparse.ArgumentParser()
     parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
     parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')

+ 30 - 11
scripts/pytorch2onnx.py

@@ -1,17 +1,36 @@
+import argparse
 import torch
 import torch.onnx
 from basicsr.archs.rrdbnet_arch import RRDBNet
 
-# An instance of your model
-model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
-model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
-# set the train mode to false since we will only run the forward pass.
-model.train(False)
-model.cpu().eval()
 
-# An example input you would normally provide to your model's forward() method
-x = torch.rand(1, 3, 64, 64)
+def main(args):
+    # An instance of the model
+    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+    if args.params:
+        keyname = 'params'
+    else:
+        keyname = 'params_ema'
+    model.load_state_dict(torch.load(args.input)[keyname])
+    # set the train mode to false since we will only run the forward pass.
+    model.train(False)
+    model.cpu().eval()
 
-# Export the model
-with torch.no_grad():
-    torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
+    # An example input
+    x = torch.rand(1, 3, 64, 64)
+    # Export the model
+    with torch.no_grad():
+        torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
+    print(torch_out.shape)
+
+
+if __name__ == '__main__':
+    """Convert pytorch model to onnx models"""
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
+    parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
+    parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
+    args = parser.parse_args()
+
+    main(args)